summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml29
-rw-r--r--src/ast_util.mli11
-rw-r--r--src/isail.ml18
-rw-r--r--src/pretty_print_sail.ml3
-rw-r--r--src/sail_lib.ml4
-rw-r--r--src/specialize.ml222
-rw-r--r--src/type_check.ml5
-rw-r--r--src/type_check.mli4
-rw-r--r--src/util.ml1
9 files changed, 284 insertions, 13 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index a70db3e0..f1add52b 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -87,6 +87,8 @@ let mk_qi_id bk kid =
in
QI_aux (QI_id kopt, Parse_ast.Unknown)
+let mk_qi_kopt kopt =QI_aux (QI_id kopt, Parse_ast.Unknown)
+
let mk_fundef funcls =
let tannot_opt = Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown) in
let effect_opt = Effect_opt_aux (Effect_opt_pure, Parse_ast.Unknown) in
@@ -896,3 +898,30 @@ let construct_pexp (pat,guard,exp,ann) =
match guard with
| None -> Pat_aux (Pat_exp (pat,exp),ann)
| Some guard -> Pat_aux (Pat_when (pat,guard,exp),ann)
+
+let is_valspec id = function
+ | DEF_spec (VS_aux (VS_val_spec (_, id', _, _), _)) when Id.compare id id' = 0 -> true
+ | _ -> false
+
+let is_fundef id = function
+ | DEF_fundef (FD_aux (FD_function (_, _, _, FCL_aux (FCL_Funcl (id', _), _) :: _), _)) when Id.compare id' id = 0 -> true
+ | _ -> false
+
+let rename_funcl id (FCL_aux (FCL_Funcl (_, pexp), annot)) = FCL_aux (FCL_Funcl (id, pexp), annot)
+
+let rename_fundef id (FD_aux (FD_function (ropt, topt, eopt, funcls), annot)) =
+ FD_aux (FD_function (ropt, topt, eopt, List.map (rename_funcl id) funcls), annot)
+
+let rec split_defs' f defs acc =
+ match defs with
+ | [] -> None
+ | def :: defs when f def -> Some (acc, def, defs)
+ | def :: defs -> split_defs' f defs (def :: acc)
+
+let split_defs f (Defs defs) =
+ match split_defs' f defs [] with
+ | None -> None
+ | Some (pre_defs, def, post_defs) ->
+ Some (Defs (List.rev pre_defs), def, Defs post_defs)
+
+let append_ast (Defs ast1) (Defs ast2) = Defs (ast1 @ ast2)
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 69d80ea7..349faac2 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -71,6 +71,7 @@ val mk_typschm : typquant -> typ -> typschm
val mk_typquant : quant_item list -> typquant
val mk_qi_id : base_kind_aux -> kid -> quant_item
val mk_qi_nc : n_constraint -> quant_item
+val mk_qi_kopt : kinded_id -> quant_item
val mk_fexp : id -> unit exp -> unit fexp
val mk_fexps : (unit fexp) list -> unit fexps
val mk_letbind : unit pat -> unit exp -> unit letbind
@@ -280,3 +281,13 @@ val undefined_of_typ : bool -> Ast.l -> (typ -> 'annot) -> typ -> 'annot exp
val destruct_pexp : 'a pexp -> 'a pat * ('a exp) option * 'a exp * (Ast.l * 'a)
val construct_pexp : 'a pat * ('a exp) option * 'a exp * (Ast.l * 'a) -> 'a pexp
+
+val is_valspec : id -> 'a def -> bool
+
+val is_fundef : id -> 'a def -> bool
+
+val rename_fundef : id -> 'a fundef -> 'a fundef
+
+val split_defs : ('a def -> bool) -> 'a defs -> ('a defs * 'a def * 'a defs) option
+
+val append_ast : 'a defs -> 'a defs -> 'a defs
diff --git a/src/isail.ml b/src/isail.ml
index 629bf35a..ce17561d 100644
--- a/src/isail.ml
+++ b/src/isail.ml
@@ -186,7 +186,6 @@ let help = function
| cmd ->
"Either invalid command passed to help, or no documentation for " ^ cmd ^ ". Try :help :help."
-let append_ast (Defs ast1) (Defs ast2) = Defs (ast1 @ ast2)
type input = Command of string * string | Expression of string | Empty
@@ -259,13 +258,16 @@ let handle_input' input =
in
let ids = Specialize.polymorphic_functions is_kopt !interactive_ast in
List.iter (fun id -> print_endline (string_of_id id)) (IdSet.elements ids)
- | ":ins" ->
- let id = mk_id arg in
- let instantiations = Specialize.instantiations_of id !interactive_ast in
- let print_instantiation i =
- print_endline (Util.string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar) (KBindings.bindings i))
- in
- List.iter print_instantiation instantiations
+ | ":spec" ->
+ let ast, env = Specialize.specialize !interactive_ast !interactive_env in
+ interactive_ast := ast;
+ interactive_env := env;
+ interactive_state := initial_state !interactive_ast
+
+ | ":ast" ->
+ let chan = open_out arg in
+ Pretty_print_sail.pp_defs chan !interactive_ast;
+ close_out chan
| ":output" ->
let chan = open_out arg in
Value.output_redirect chan
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 43be7a00..b55448fa 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -477,11 +477,12 @@ let doc_typdef (TD_aux(td,_)) = match td with
surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_union unions) rbrace]
| _ -> string "TYPEDEF"
+
let doc_spec (VS_aux(v,_)) =
let doc_extern ext =
let doc_backend b = Util.option_map (fun id -> string (b ^ ":") ^^ space ^^
utf8string ("\"" ^ String.escaped id ^ "\"")) (ext b) in
- let docs = Util.option_these (List.map doc_backend ["ocaml"; "lem"]) in
+ let docs = Util.option_these (List.map doc_backend ["ocaml"; "lem"; "smt"; "interpreter"]) in
if docs = [] then empty else equals ^^ space ^^ braces (separate (comma ^^ space) docs)
in
match v with
diff --git a/src/sail_lib.ml b/src/sail_lib.ml
index 86c12aae..c78b81a9 100644
--- a/src/sail_lib.ml
+++ b/src/sail_lib.ml
@@ -298,6 +298,10 @@ let string_of_bit = function
| B0 -> "0"
| B1 -> "1"
+let char_of_bit = function
+ | B0 -> '0'
+ | B1 -> '1'
+
let string_of_hex = function
| [B0; B0; B0; B0] -> "0"
| [B0; B0; B0; B1] -> "1"
diff --git a/src/specialize.ml b/src/specialize.ml
index 881881f4..2d32a90c 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -52,10 +52,40 @@ open Ast
open Ast_util
open Rewriter
+let zchar c =
+ let zc c = "z" ^ String.make 1 c in
+ if Char.code c <= 41 then zc (Char.chr (Char.code c + 16))
+ else if Char.code c <= 47 then zc (Char.chr (Char.code c + 23))
+ else if Char.code c <= 57 then String.make 1 c
+ else if Char.code c <= 64 then zc (Char.chr (Char.code c + 13))
+ else if Char.code c <= 90 then String.make 1 c
+ else if Char.code c <= 94 then zc (Char.chr (Char.code c - 13))
+ else if Char.code c <= 95 then "_"
+ else if Char.code c <= 96 then zc (Char.chr (Char.code c - 13))
+ else if Char.code c <= 121 then String.make 1 c
+ else if Char.code c <= 122 then "zz"
+ else if Char.code c <= 126 then zc (Char.chr (Char.code c - 39))
+ else raise (Invalid_argument "zchar")
+
+let zencode_string str = "z" ^ List.fold_left (fun s1 s2 -> s1 ^ s2) "" (List.map zchar (Util.string_to_list str))
+
+let zencode_upper_string str = "Z" ^ List.fold_left (fun s1 s2 -> s1 ^ s2) "" (List.map zchar (Util.string_to_list str))
+
+let is_typ_uvar = function
+ | Type_check.U_typ _ -> true
+ | _ -> false
+
+(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of.
+ This part of the typechecker API is a bit ugly. *)
+let fix_instantiation instantiation =
+ let instantiation = KBindings.bindings (KBindings.filter (fun _ uvar -> is_typ_uvar uvar) instantiation) in
+ let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, uvar) instantiation in
+ List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
+
(* Returns an IdSet with the function ids that have X-kinded
parameters, e.g. val f : forall ('a : X). 'a -> 'a. The first
argument specifies what X should be - it should be one of:
- is_nat_kopt, is_order_kopt, or is_type_kopt from Ast_util.
+ is_nat_kopt, is_order_kopt, or is_typ_kopt from Ast_util.
*)
let rec polymorphic_functions is_kopt (Defs defs) =
match defs with
@@ -68,6 +98,11 @@ let rec polymorphic_functions is_kopt (Defs defs) =
| _ :: defs -> polymorphic_functions is_kopt (Defs defs)
| [] -> IdSet.empty
+let id_of_instantiation id instantiation =
+ let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in
+ let str = zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) ^ "#" in
+ prepend_id str id
+
(* Returns a list of all the instantiations of a function id in an
ast. *)
let rec instantiations_of id ast =
@@ -75,7 +110,8 @@ let rec instantiations_of id ast =
let inspect_exp = function
| E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 ->
- instantiations := Type_check.instantiation_of exp :: !instantiations;
+ let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ instantiations := instantiation :: !instantiations;
exp
| exp -> exp
in
@@ -84,3 +120,185 @@ let rec instantiations_of id ast =
let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in
!instantiations
+
+let rec rewrite_polymorphic_calls id ast =
+ print_endline ("Rewriting: " ^ string_of_id id);
+ let vs_ids = Initial_check.val_spec_ids ast in
+
+ let rewrite_e_aux = function
+ | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
+ let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let spec_id = id_of_instantiation id instantiation in
+ (* Make sure we only generate specialized calls when we've
+ specialized the valspec. The valspec may not be generated if
+ a polymorphic function calls another polymorphic function.
+ In this case a specialization of the first may require that
+ the second needs to be specialized again, but this may not
+ have happened yet. *)
+ if IdSet.mem spec_id vs_ids then
+ E_aux (E_app (spec_id, args), annot)
+ else
+ exp
+ | exp -> exp
+ in
+
+ let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in
+ rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast
+
+let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
+ match typ_aux with
+ | Typ_id v -> KidSet.empty
+ | Typ_var kid when KidSet.mem kid exs -> KidSet.empty
+ | Typ_var kid -> KidSet.singleton kid
+ | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs:exs) typs)
+ | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs:exs) args)
+ | Typ_exist (kids, nc, typ) -> typ_frees ~exs:(KidSet.of_list kids) typ
+ | Typ_fn (typ1, typ2, _) -> KidSet.union (typ_frees ~exs:exs typ1) (typ_frees ~exs:exs typ2)
+and typ_arg_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) =
+ match typ_arg_aux with
+ | Typ_arg_nexp n -> KidSet.empty
+ | Typ_arg_typ typ -> typ_frees ~exs:exs typ
+ | Typ_arg_order ord -> KidSet.empty
+
+let specialize_id_valspec instantiations id ast =
+ match split_defs (is_valspec id) ast with
+ | None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!")
+ | Some (pre_ast, vs, post_ast) ->
+ let typschm, externs, is_cast, annot = match vs with
+ | DEF_spec (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)) -> typschm, externs, is_cast, annot
+ | _ -> assert false (* unreachable *)
+ in
+ let TypSchm_aux (TypSchm_ts (typq, typ), _) = typschm in
+
+ (* Keep track of the specialized ids to avoid generating things twice. *)
+ let spec_ids = ref IdSet.empty in
+
+ let specialize_instance instantiation =
+ (* Replace the polymorphic type variables in the type with their concrete instantiation. *)
+ let typ = Type_check.subst_unifiers instantiation typ in
+ let frees = KidSet.elements (typ_frees typ) in
+
+ (* Remove type variables from the type quantifier. *)
+ let kopts, constraints = quant_split typq in
+ let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt)) kopts in
+ let typq = mk_typquant (List.map (mk_qi_id BK_type) frees @ List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in
+ let typschm = mk_typschm typq typ in
+
+ let spec_id = id_of_instantiation id instantiation in
+ if IdSet.mem spec_id !spec_ids then [] else
+ begin
+ spec_ids := IdSet.add spec_id !spec_ids;
+ print_endline (string_of_id spec_id ^ " : " ^ string_of_typschm typschm);
+ [DEF_spec (VS_aux (VS_val_spec (typschm, spec_id, externs, is_cast), annot))]
+ end
+ in
+
+ let specializations = List.map specialize_instance instantiations |> List.concat in
+
+ append_ast pre_ast (append_ast (Defs (vs :: specializations)) post_ast)
+
+let specialize_id_fundef instantiations id ast =
+ match split_defs (is_fundef id) ast with
+ | None -> ast
+ | Some (pre_ast, DEF_fundef fundef, post_ast) ->
+ let fundefs =
+ List.map (fun i -> DEF_fundef (rename_fundef (id_of_instantiation id i) fundef)) instantiations
+ in
+ append_ast pre_ast (append_ast (Defs fundefs) post_ast)
+ | Some _ -> assert false (* unreachable *)
+
+let specialize_id_overloads instantiations id (Defs defs) =
+ let ids = IdSet.of_list (List.map (id_of_instantiation id) instantiations) in
+
+ let rec rewrite_overloads defs =
+ match defs with
+ | DEF_overload (overload_id, overloads) :: defs ->
+ let overloads = List.concat (List.map (fun id' -> if Id.compare id' id = 0 then IdSet.elements ids else [id']) overloads) in
+ DEF_overload (overload_id, overloads) :: rewrite_overloads defs
+ | def :: defs -> def :: rewrite_overloads defs
+ | [] -> []
+ in
+
+ Defs (rewrite_overloads defs)
+
+(* Once we've specialized a definition, it's original valspec should
+ be unused, unless another polymorphic function called it. We
+ therefore remove all unused valspecs. Remaining polymorphic
+ valspecs are then re-specialized. This process is iterated until
+ the whole spec is specialized. *)
+let remove_unused_valspecs ast =
+ let calls = ref (IdSet.singleton (mk_id "main")) in
+ let vs_ids = Initial_check.val_spec_ids ast in
+
+ let inspect_exp = function
+ | E_aux (E_app (call, _), _) as exp ->
+ calls := IdSet.add call !calls;
+ exp
+ | exp -> exp
+ in
+
+ let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in
+ let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in
+
+ let unused = IdSet.filter (fun vs_id -> not (IdSet.mem vs_id !calls)) vs_ids in
+
+ List.iter (fun id -> print_endline (string_of_id id)) (IdSet.elements unused);
+
+ let rec remove_unused (Defs defs) id =
+ match defs with
+ | def :: defs when is_fundef id def -> remove_unused (Defs defs) id
+ | def :: defs when is_valspec id def -> remove_unused (Defs defs) id
+ | DEF_overload (overload_id, overloads) :: defs ->
+ begin
+ match List.filter (fun id' -> Id.compare id id' <> 0) overloads with
+ | [] -> remove_unused (Defs defs) id
+ | overloads -> DEF_overload (overload_id, overloads) :: remove_unused (Defs defs) id
+ end
+ | def :: defs -> def :: remove_unused (Defs defs) id
+ | [] -> []
+ in
+
+ List.fold_left (fun ast id -> Defs (remove_unused ast id)) ast (IdSet.elements unused)
+
+let specialize_id id ast =
+ print_endline ("Specializing: " ^ string_of_id id);
+ let instantiations = instantiations_of id ast in
+
+ let ast = specialize_id_valspec instantiations id ast in
+ let ast = specialize_id_fundef instantiations id ast in
+ specialize_id_overloads instantiations id ast
+
+(* When we generate specialized versions of functions, we need to
+ ensure that the types they are specialized to appear before the
+ function definitions in the AST. Therefore we pull all the type
+ definitions (and default definitions) to the start of the AST. *)
+let reorder_typedefs (Defs defs) =
+ let tdefs = ref [] in
+
+ let rec filter_typedefs = function
+ | (DEF_default _ | DEF_type _) as tdef :: defs ->
+ tdefs := tdef :: !tdefs;
+ filter_typedefs defs
+ | def :: defs -> def :: filter_typedefs defs
+ | [] -> []
+ in
+
+ let others = filter_typedefs defs in
+ Defs (List.rev !tdefs @ others)
+
+let specialize_ids ids ast =
+ let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in
+ let ast = reorder_typedefs ast in
+ let ast, _ = Type_check.check Type_check.initial_env ast in
+ let ast = List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids) in
+ let ast, env = Type_check.check Type_check.initial_env ast in
+ let ast = remove_unused_valspecs ast in
+ ast, env
+
+let rec specialize ast env =
+ let ids = polymorphic_functions is_typ_kopt ast in
+ if IdSet.is_empty ids then
+ ast, env
+ else
+ let ast, env = specialize_ids ids ast in
+ specialize ast env
diff --git a/src/type_check.ml b/src/type_check.ml
index 27f6d8a2..10be3710 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -567,7 +567,7 @@ end = struct
else ()
| Typ_id id -> typ_error l ("Undefined type " ^ string_of_id id)
| Typ_var kid when KBindings.mem kid env.typ_vars -> ()
- | Typ_var kid -> typ_error l ("Unbound kind identifier " ^ string_of_kid kid)
+ | Typ_var kid -> typ_error l ("Unbound kind identifier " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ)
| Typ_fn (typ_arg, typ_ret, effs) -> wf_typ ~exs:exs env typ_arg; wf_typ ~exs:exs env typ_ret
| Typ_tup typs -> List.iter (wf_typ ~exs:exs env) typs
| Typ_app (id, args) when bound_typ_id env id ->
@@ -3483,7 +3483,8 @@ let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls)
the difference is irrelevant for the typechecker. *)
let check_val_spec env (VS_aux (vs, (l, _))) =
let (id, quants, typ, env) = match vs with
- | VS_val_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id, ext_opt, is_cast) ->
+ | VS_val_spec (TypSchm_aux (TypSchm_ts (quants, typ), _) as typschm, id, ext_opt, is_cast) ->
+ typ_debug ("VS typschm: " ^ string_of_id id ^ ", " ^ string_of_typschm typschm);
let env = match ext_opt "smt" with Some op -> Env.add_smt_op id op env | None -> env in
Env.wf_typ (add_typquant quants env) typ;
typ_debug "CHECKED WELL-FORMED VAL SPEC";
diff --git a/src/type_check.mli b/src/type_check.mli
index 3f43492f..68e3ad09 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -164,6 +164,8 @@ end
(* Push all the type variables and constraints from a typquant into an environment *)
val add_typquant : typquant -> Env.t -> Env.t
+val typ_frees : ?exs:KidSet.t -> typ -> KidSet.t
+
(* When the typechecker creates new type variables it gives them fresh
names of the form 'fvXXX#name, where XXX is a number (not
necessarily three digits), and name is the original name when the
@@ -239,6 +241,8 @@ type uvar =
val string_of_uvar : uvar -> string
+val subst_unifiers : uvar KBindings.t -> typ -> typ
+
val unify : l -> Env.t -> typ -> typ -> uvar KBindings.t * kid list * n_constraint option
val alpha_equivalent : Env.t -> typ -> typ -> bool
diff --git a/src/util.ml b/src/util.ml
index 51ed8926..e902b2dd 100644
--- a/src/util.ml
+++ b/src/util.ml
@@ -394,3 +394,4 @@ let red str = termcode 91 ^ str
let cyan str = termcode 96 ^ str
let blue str = termcode 94 ^ str
let clear str = str ^ termcode 0
+