diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 29 | ||||
| -rw-r--r-- | src/ast_util.mli | 11 | ||||
| -rw-r--r-- | src/isail.ml | 18 | ||||
| -rw-r--r-- | src/pretty_print_sail.ml | 3 | ||||
| -rw-r--r-- | src/sail_lib.ml | 4 | ||||
| -rw-r--r-- | src/specialize.ml | 222 | ||||
| -rw-r--r-- | src/type_check.ml | 5 | ||||
| -rw-r--r-- | src/type_check.mli | 4 | ||||
| -rw-r--r-- | src/util.ml | 1 |
9 files changed, 284 insertions, 13 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index a70db3e0..f1add52b 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -87,6 +87,8 @@ let mk_qi_id bk kid = in QI_aux (QI_id kopt, Parse_ast.Unknown) +let mk_qi_kopt kopt =QI_aux (QI_id kopt, Parse_ast.Unknown) + let mk_fundef funcls = let tannot_opt = Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown) in let effect_opt = Effect_opt_aux (Effect_opt_pure, Parse_ast.Unknown) in @@ -896,3 +898,30 @@ let construct_pexp (pat,guard,exp,ann) = match guard with | None -> Pat_aux (Pat_exp (pat,exp),ann) | Some guard -> Pat_aux (Pat_when (pat,guard,exp),ann) + +let is_valspec id = function + | DEF_spec (VS_aux (VS_val_spec (_, id', _, _), _)) when Id.compare id id' = 0 -> true + | _ -> false + +let is_fundef id = function + | DEF_fundef (FD_aux (FD_function (_, _, _, FCL_aux (FCL_Funcl (id', _), _) :: _), _)) when Id.compare id' id = 0 -> true + | _ -> false + +let rename_funcl id (FCL_aux (FCL_Funcl (_, pexp), annot)) = FCL_aux (FCL_Funcl (id, pexp), annot) + +let rename_fundef id (FD_aux (FD_function (ropt, topt, eopt, funcls), annot)) = + FD_aux (FD_function (ropt, topt, eopt, List.map (rename_funcl id) funcls), annot) + +let rec split_defs' f defs acc = + match defs with + | [] -> None + | def :: defs when f def -> Some (acc, def, defs) + | def :: defs -> split_defs' f defs (def :: acc) + +let split_defs f (Defs defs) = + match split_defs' f defs [] with + | None -> None + | Some (pre_defs, def, post_defs) -> + Some (Defs (List.rev pre_defs), def, Defs post_defs) + +let append_ast (Defs ast1) (Defs ast2) = Defs (ast1 @ ast2) diff --git a/src/ast_util.mli b/src/ast_util.mli index 69d80ea7..349faac2 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -71,6 +71,7 @@ val mk_typschm : typquant -> typ -> typschm val mk_typquant : quant_item list -> typquant val mk_qi_id : base_kind_aux -> kid -> quant_item val mk_qi_nc : n_constraint -> quant_item +val mk_qi_kopt : kinded_id -> quant_item val mk_fexp : id -> unit exp -> unit fexp val mk_fexps : (unit fexp) list -> unit fexps val mk_letbind : unit pat -> unit exp -> unit letbind @@ -280,3 +281,13 @@ val undefined_of_typ : bool -> Ast.l -> (typ -> 'annot) -> typ -> 'annot exp val destruct_pexp : 'a pexp -> 'a pat * ('a exp) option * 'a exp * (Ast.l * 'a) val construct_pexp : 'a pat * ('a exp) option * 'a exp * (Ast.l * 'a) -> 'a pexp + +val is_valspec : id -> 'a def -> bool + +val is_fundef : id -> 'a def -> bool + +val rename_fundef : id -> 'a fundef -> 'a fundef + +val split_defs : ('a def -> bool) -> 'a defs -> ('a defs * 'a def * 'a defs) option + +val append_ast : 'a defs -> 'a defs -> 'a defs diff --git a/src/isail.ml b/src/isail.ml index 629bf35a..ce17561d 100644 --- a/src/isail.ml +++ b/src/isail.ml @@ -186,7 +186,6 @@ let help = function | cmd -> "Either invalid command passed to help, or no documentation for " ^ cmd ^ ". Try :help :help." -let append_ast (Defs ast1) (Defs ast2) = Defs (ast1 @ ast2) type input = Command of string * string | Expression of string | Empty @@ -259,13 +258,16 @@ let handle_input' input = in let ids = Specialize.polymorphic_functions is_kopt !interactive_ast in List.iter (fun id -> print_endline (string_of_id id)) (IdSet.elements ids) - | ":ins" -> - let id = mk_id arg in - let instantiations = Specialize.instantiations_of id !interactive_ast in - let print_instantiation i = - print_endline (Util.string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar) (KBindings.bindings i)) - in - List.iter print_instantiation instantiations + | ":spec" -> + let ast, env = Specialize.specialize !interactive_ast !interactive_env in + interactive_ast := ast; + interactive_env := env; + interactive_state := initial_state !interactive_ast + + | ":ast" -> + let chan = open_out arg in + Pretty_print_sail.pp_defs chan !interactive_ast; + close_out chan | ":output" -> let chan = open_out arg in Value.output_redirect chan diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 43be7a00..b55448fa 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -477,11 +477,12 @@ let doc_typdef (TD_aux(td,_)) = match td with surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_union unions) rbrace] | _ -> string "TYPEDEF" + let doc_spec (VS_aux(v,_)) = let doc_extern ext = let doc_backend b = Util.option_map (fun id -> string (b ^ ":") ^^ space ^^ utf8string ("\"" ^ String.escaped id ^ "\"")) (ext b) in - let docs = Util.option_these (List.map doc_backend ["ocaml"; "lem"]) in + let docs = Util.option_these (List.map doc_backend ["ocaml"; "lem"; "smt"; "interpreter"]) in if docs = [] then empty else equals ^^ space ^^ braces (separate (comma ^^ space) docs) in match v with diff --git a/src/sail_lib.ml b/src/sail_lib.ml index 86c12aae..c78b81a9 100644 --- a/src/sail_lib.ml +++ b/src/sail_lib.ml @@ -298,6 +298,10 @@ let string_of_bit = function | B0 -> "0" | B1 -> "1" +let char_of_bit = function + | B0 -> '0' + | B1 -> '1' + let string_of_hex = function | [B0; B0; B0; B0] -> "0" | [B0; B0; B0; B1] -> "1" diff --git a/src/specialize.ml b/src/specialize.ml index 881881f4..2d32a90c 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -52,10 +52,40 @@ open Ast open Ast_util open Rewriter +let zchar c = + let zc c = "z" ^ String.make 1 c in + if Char.code c <= 41 then zc (Char.chr (Char.code c + 16)) + else if Char.code c <= 47 then zc (Char.chr (Char.code c + 23)) + else if Char.code c <= 57 then String.make 1 c + else if Char.code c <= 64 then zc (Char.chr (Char.code c + 13)) + else if Char.code c <= 90 then String.make 1 c + else if Char.code c <= 94 then zc (Char.chr (Char.code c - 13)) + else if Char.code c <= 95 then "_" + else if Char.code c <= 96 then zc (Char.chr (Char.code c - 13)) + else if Char.code c <= 121 then String.make 1 c + else if Char.code c <= 122 then "zz" + else if Char.code c <= 126 then zc (Char.chr (Char.code c - 39)) + else raise (Invalid_argument "zchar") + +let zencode_string str = "z" ^ List.fold_left (fun s1 s2 -> s1 ^ s2) "" (List.map zchar (Util.string_to_list str)) + +let zencode_upper_string str = "Z" ^ List.fold_left (fun s1 s2 -> s1 ^ s2) "" (List.map zchar (Util.string_to_list str)) + +let is_typ_uvar = function + | Type_check.U_typ _ -> true + | _ -> false + +(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of. + This part of the typechecker API is a bit ugly. *) +let fix_instantiation instantiation = + let instantiation = KBindings.bindings (KBindings.filter (fun _ uvar -> is_typ_uvar uvar) instantiation) in + let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, uvar) instantiation in + List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation + (* Returns an IdSet with the function ids that have X-kinded parameters, e.g. val f : forall ('a : X). 'a -> 'a. The first argument specifies what X should be - it should be one of: - is_nat_kopt, is_order_kopt, or is_type_kopt from Ast_util. + is_nat_kopt, is_order_kopt, or is_typ_kopt from Ast_util. *) let rec polymorphic_functions is_kopt (Defs defs) = match defs with @@ -68,6 +98,11 @@ let rec polymorphic_functions is_kopt (Defs defs) = | _ :: defs -> polymorphic_functions is_kopt (Defs defs) | [] -> IdSet.empty +let id_of_instantiation id instantiation = + let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in + let str = zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) ^ "#" in + prepend_id str id + (* Returns a list of all the instantiations of a function id in an ast. *) let rec instantiations_of id ast = @@ -75,7 +110,8 @@ let rec instantiations_of id ast = let inspect_exp = function | E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 -> - instantiations := Type_check.instantiation_of exp :: !instantiations; + let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + instantiations := instantiation :: !instantiations; exp | exp -> exp in @@ -84,3 +120,185 @@ let rec instantiations_of id ast = let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in !instantiations + +let rec rewrite_polymorphic_calls id ast = + print_endline ("Rewriting: " ^ string_of_id id); + let vs_ids = Initial_check.val_spec_ids ast in + + let rewrite_e_aux = function + | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 -> + let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + let spec_id = id_of_instantiation id instantiation in + (* Make sure we only generate specialized calls when we've + specialized the valspec. The valspec may not be generated if + a polymorphic function calls another polymorphic function. + In this case a specialization of the first may require that + the second needs to be specialized again, but this may not + have happened yet. *) + if IdSet.mem spec_id vs_ids then + E_aux (E_app (spec_id, args), annot) + else + exp + | exp -> exp + in + + let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in + rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast + +let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = + match typ_aux with + | Typ_id v -> KidSet.empty + | Typ_var kid when KidSet.mem kid exs -> KidSet.empty + | Typ_var kid -> KidSet.singleton kid + | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs:exs) typs) + | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs:exs) args) + | Typ_exist (kids, nc, typ) -> typ_frees ~exs:(KidSet.of_list kids) typ + | Typ_fn (typ1, typ2, _) -> KidSet.union (typ_frees ~exs:exs typ1) (typ_frees ~exs:exs typ2) +and typ_arg_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) = + match typ_arg_aux with + | Typ_arg_nexp n -> KidSet.empty + | Typ_arg_typ typ -> typ_frees ~exs:exs typ + | Typ_arg_order ord -> KidSet.empty + +let specialize_id_valspec instantiations id ast = + match split_defs (is_valspec id) ast with + | None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!") + | Some (pre_ast, vs, post_ast) -> + let typschm, externs, is_cast, annot = match vs with + | DEF_spec (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)) -> typschm, externs, is_cast, annot + | _ -> assert false (* unreachable *) + in + let TypSchm_aux (TypSchm_ts (typq, typ), _) = typschm in + + (* Keep track of the specialized ids to avoid generating things twice. *) + let spec_ids = ref IdSet.empty in + + let specialize_instance instantiation = + (* Replace the polymorphic type variables in the type with their concrete instantiation. *) + let typ = Type_check.subst_unifiers instantiation typ in + let frees = KidSet.elements (typ_frees typ) in + + (* Remove type variables from the type quantifier. *) + let kopts, constraints = quant_split typq in + let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt)) kopts in + let typq = mk_typquant (List.map (mk_qi_id BK_type) frees @ List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in + let typschm = mk_typschm typq typ in + + let spec_id = id_of_instantiation id instantiation in + if IdSet.mem spec_id !spec_ids then [] else + begin + spec_ids := IdSet.add spec_id !spec_ids; + print_endline (string_of_id spec_id ^ " : " ^ string_of_typschm typschm); + [DEF_spec (VS_aux (VS_val_spec (typschm, spec_id, externs, is_cast), annot))] + end + in + + let specializations = List.map specialize_instance instantiations |> List.concat in + + append_ast pre_ast (append_ast (Defs (vs :: specializations)) post_ast) + +let specialize_id_fundef instantiations id ast = + match split_defs (is_fundef id) ast with + | None -> ast + | Some (pre_ast, DEF_fundef fundef, post_ast) -> + let fundefs = + List.map (fun i -> DEF_fundef (rename_fundef (id_of_instantiation id i) fundef)) instantiations + in + append_ast pre_ast (append_ast (Defs fundefs) post_ast) + | Some _ -> assert false (* unreachable *) + +let specialize_id_overloads instantiations id (Defs defs) = + let ids = IdSet.of_list (List.map (id_of_instantiation id) instantiations) in + + let rec rewrite_overloads defs = + match defs with + | DEF_overload (overload_id, overloads) :: defs -> + let overloads = List.concat (List.map (fun id' -> if Id.compare id' id = 0 then IdSet.elements ids else [id']) overloads) in + DEF_overload (overload_id, overloads) :: rewrite_overloads defs + | def :: defs -> def :: rewrite_overloads defs + | [] -> [] + in + + Defs (rewrite_overloads defs) + +(* Once we've specialized a definition, it's original valspec should + be unused, unless another polymorphic function called it. We + therefore remove all unused valspecs. Remaining polymorphic + valspecs are then re-specialized. This process is iterated until + the whole spec is specialized. *) +let remove_unused_valspecs ast = + let calls = ref (IdSet.singleton (mk_id "main")) in + let vs_ids = Initial_check.val_spec_ids ast in + + let inspect_exp = function + | E_aux (E_app (call, _), _) as exp -> + calls := IdSet.add call !calls; + exp + | exp -> exp + in + + let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in + let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in + + let unused = IdSet.filter (fun vs_id -> not (IdSet.mem vs_id !calls)) vs_ids in + + List.iter (fun id -> print_endline (string_of_id id)) (IdSet.elements unused); + + let rec remove_unused (Defs defs) id = + match defs with + | def :: defs when is_fundef id def -> remove_unused (Defs defs) id + | def :: defs when is_valspec id def -> remove_unused (Defs defs) id + | DEF_overload (overload_id, overloads) :: defs -> + begin + match List.filter (fun id' -> Id.compare id id' <> 0) overloads with + | [] -> remove_unused (Defs defs) id + | overloads -> DEF_overload (overload_id, overloads) :: remove_unused (Defs defs) id + end + | def :: defs -> def :: remove_unused (Defs defs) id + | [] -> [] + in + + List.fold_left (fun ast id -> Defs (remove_unused ast id)) ast (IdSet.elements unused) + +let specialize_id id ast = + print_endline ("Specializing: " ^ string_of_id id); + let instantiations = instantiations_of id ast in + + let ast = specialize_id_valspec instantiations id ast in + let ast = specialize_id_fundef instantiations id ast in + specialize_id_overloads instantiations id ast + +(* When we generate specialized versions of functions, we need to + ensure that the types they are specialized to appear before the + function definitions in the AST. Therefore we pull all the type + definitions (and default definitions) to the start of the AST. *) +let reorder_typedefs (Defs defs) = + let tdefs = ref [] in + + let rec filter_typedefs = function + | (DEF_default _ | DEF_type _) as tdef :: defs -> + tdefs := tdef :: !tdefs; + filter_typedefs defs + | def :: defs -> def :: filter_typedefs defs + | [] -> [] + in + + let others = filter_typedefs defs in + Defs (List.rev !tdefs @ others) + +let specialize_ids ids ast = + let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in + let ast = reorder_typedefs ast in + let ast, _ = Type_check.check Type_check.initial_env ast in + let ast = List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids) in + let ast, env = Type_check.check Type_check.initial_env ast in + let ast = remove_unused_valspecs ast in + ast, env + +let rec specialize ast env = + let ids = polymorphic_functions is_typ_kopt ast in + if IdSet.is_empty ids then + ast, env + else + let ast, env = specialize_ids ids ast in + specialize ast env diff --git a/src/type_check.ml b/src/type_check.ml index 27f6d8a2..10be3710 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -567,7 +567,7 @@ end = struct else () | Typ_id id -> typ_error l ("Undefined type " ^ string_of_id id) | Typ_var kid when KBindings.mem kid env.typ_vars -> () - | Typ_var kid -> typ_error l ("Unbound kind identifier " ^ string_of_kid kid) + | Typ_var kid -> typ_error l ("Unbound kind identifier " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ) | Typ_fn (typ_arg, typ_ret, effs) -> wf_typ ~exs:exs env typ_arg; wf_typ ~exs:exs env typ_ret | Typ_tup typs -> List.iter (wf_typ ~exs:exs env) typs | Typ_app (id, args) when bound_typ_id env id -> @@ -3483,7 +3483,8 @@ let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls) the difference is irrelevant for the typechecker. *) let check_val_spec env (VS_aux (vs, (l, _))) = let (id, quants, typ, env) = match vs with - | VS_val_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id, ext_opt, is_cast) -> + | VS_val_spec (TypSchm_aux (TypSchm_ts (quants, typ), _) as typschm, id, ext_opt, is_cast) -> + typ_debug ("VS typschm: " ^ string_of_id id ^ ", " ^ string_of_typschm typschm); let env = match ext_opt "smt" with Some op -> Env.add_smt_op id op env | None -> env in Env.wf_typ (add_typquant quants env) typ; typ_debug "CHECKED WELL-FORMED VAL SPEC"; diff --git a/src/type_check.mli b/src/type_check.mli index 3f43492f..68e3ad09 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -164,6 +164,8 @@ end (* Push all the type variables and constraints from a typquant into an environment *) val add_typquant : typquant -> Env.t -> Env.t +val typ_frees : ?exs:KidSet.t -> typ -> KidSet.t + (* When the typechecker creates new type variables it gives them fresh names of the form 'fvXXX#name, where XXX is a number (not necessarily three digits), and name is the original name when the @@ -239,6 +241,8 @@ type uvar = val string_of_uvar : uvar -> string +val subst_unifiers : uvar KBindings.t -> typ -> typ + val unify : l -> Env.t -> typ -> typ -> uvar KBindings.t * kid list * n_constraint option val alpha_equivalent : Env.t -> typ -> typ -> bool diff --git a/src/util.ml b/src/util.ml index 51ed8926..e902b2dd 100644 --- a/src/util.ml +++ b/src/util.ml @@ -394,3 +394,4 @@ let red str = termcode 91 ^ str let cyan str = termcode 96 ^ str let blue str = termcode 94 ^ str let clear str = str ^ termcode 0 + |
