summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlasdair Armstrong2018-03-08 18:14:44 +0000
committerAlasdair Armstrong2018-03-09 17:44:13 +0000
commit9570bf932e3ba0269cbed06a49fc8000b45b32a3 (patch)
treeb3f56057b891392c9de3b57e6e6c305a80635b33 /src
parentfccf018feecf914c937dc4cc253a882f482943f2 (diff)
Specialise constructors for polymorphic unions
Also work on making C backend compile RISC-V
Diffstat (limited to 'src')
-rw-r--r--src/c_backend.ml34
-rw-r--r--src/pretty_print_sail.ml6
-rw-r--r--src/sail.ml1
-rw-r--r--src/specialize.ml94
-rw-r--r--src/specialize.mli6
5 files changed, 125 insertions, 16 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml
index 2542dd42..20cdb4ac 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -519,6 +519,10 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) =
let alast = anf last in
AE_block (aexps, alast, typ_of exp)
+ | E_assign (LEXP_aux (LEXP_deref dexp, _), exp) ->
+ let gs = gensym () in
+ AE_let (gs, typ_of dexp, anf dexp, AE_assign (gs, typ_of dexp, anf exp), unit_typ)
+
| E_assign (LEXP_aux (LEXP_id id, _), exp)
| E_assign (LEXP_aux (LEXP_cast (_, id), _), exp) ->
let aexp = anf exp in
@@ -748,6 +752,7 @@ let rec ctyp_equal ctyp1 ctyp2 =
| CT_real, CT_real -> true
| CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2
| CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2
+ | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_equal ctyp1 ctyp2
| _, _ -> false
(* String representation of ctyps here is only for debugging and
@@ -769,6 +774,7 @@ let rec string_of_ctyp = function
| CT_vector (true, ctyp) -> "vector(dec, " ^ string_of_ctyp ctyp ^ ")"
| CT_vector (false, ctyp) -> "vector(inc, " ^ string_of_ctyp ctyp ^ ")"
| CT_list ctyp -> "list(" ^ string_of_ctyp ctyp ^ ")"
+ | CT_ref ctyp -> "ref(" ^ string_of_ctyp ctyp ^ ")"
(** Convert a sail type into a C-type **)
let rec ctyp_of_typ ctx typ =
@@ -814,6 +820,9 @@ let rec ctyp_of_typ ctx typ =
| Typ_id id when string_of_id id = "string" -> CT_string
| Typ_id id when string_of_id id = "real" -> CT_real
+ | Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "register" || string_of_id id = "ref" ->
+ CT_ref (ctyp_of_typ ctx typ)
+
| Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> Bindings.bindings)
| Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> Bindings.bindings)
| Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements)
@@ -830,6 +839,7 @@ let rec is_stack_ctyp ctyp = match ctyp with
| CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields
| CT_variant (_, ctors) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors
| CT_tup ctyps -> List.for_all is_stack_ctyp ctyps
+ | CT_ref ctyp -> is_stack_ctyp ctyp
let is_stack_typ ctx typ = is_stack_ctyp (ctyp_of_typ ctx typ)
@@ -1286,6 +1296,9 @@ let rec compile_aval ctx = function
| AV_id (id, typ) ->
[], (F_id id, ctyp_of_typ ctx (lvar_typ typ)), []
+ | AV_ref (id, typ) ->
+ [], (F_id id, CT_ref (ctyp_of_typ ctx (lvar_typ typ))), []
+
| AV_lit (L_aux (L_string str, _), typ) ->
[], (F_lit (V_string (String.escaped str)), ctyp_of_typ ctx typ), []
@@ -1435,9 +1448,6 @@ let rec compile_aval ctx = function
(F_id gs, CT_list ctyp),
[iclear (CT_list ctyp) gs]
- | AV_ref _ ->
- c_error "Have AV_ref"
-
let compile_funcall ctx id args typ =
let setup = ref [] in
let cleanup = ref [] in
@@ -1466,7 +1476,7 @@ let compile_funcall ctx id args typ =
(F_id gs, ctyp)
else
c_error ~loc:(id_loc id)
- (Printf.sprintf "Failure when setting up function arguments: %s and %s." (string_of_ctyp have_ctyp) (string_of_ctyp ctyp))
+ (Printf.sprintf "Failure when setting up function %s arguments: %s and %s." (string_of_id id) (string_of_ctyp have_ctyp) (string_of_ctyp ctyp))
in
let sargs = List.map2 setup_arg arg_ctyps args in
@@ -1725,11 +1735,18 @@ let rec compile_aexp ctx = function
(* assign_ctyp is the type of the C variable we are assigning to,
ctyp is the type of the C expression being assigned. These may
be different. *)
+ let pointer_assign ctyp1 ctyp2 =
+ match ctyp1 with
+ | CT_ref ctyp1 -> ctyp_equal ctyp1 ctyp2
+ | _ -> false
+ in
let assign_ctyp = ctyp_of_typ ctx assign_typ in
let setup, ctyp, call, cleanup = compile_aexp ctx aexp in
let comment = "assign " ^ string_of_ctyp assign_ctyp ^ " := " ^ string_of_ctyp ctyp in
if ctyp_equal assign_ctyp ctyp then
setup @ [call (CL_id id)], CT_unit, (fun clexp -> icopy clexp unit_fragment), cleanup
+ else if pointer_assign assign_ctyp ctyp then
+ setup @ [call (CL_addr id)], CT_unit, (fun clexp -> icopy clexp unit_fragment), cleanup
else if not (is_stack_ctyp assign_ctyp) && is_stack_ctyp ctyp then
let gs = gensym () in
setup @ [ icomment comment;
@@ -1850,7 +1867,8 @@ let rec pat_ids (Typ_aux (arg_typ_aux, _) as arg_typ) (P_aux (p_aux, (l, _)) as
| P_wild, _ -> let gs = gensym () in [gs]
| P_var (pat, _), _ -> pat_ids arg_typ pat
| P_typ (_, pat), _ -> pat_ids arg_typ pat
- | _, _ -> c_error ~loc:l ("Cannot compile pattern " ^ string_of_pat pat ^ " to C")
+ | P_app _, _ -> let gs = gensym () in [gs]
+ | _, _ -> c_error ~loc:l ("Cannot compile pattern " ^ string_of_pat pat ^ " : " ^ string_of_typ arg_typ ^ " to C")
(** Compile a sail type definition into a IR one. Most of the
actual work of translating the typedefs into C is done by the code
@@ -2513,7 +2531,7 @@ let codegen_id id = string (sgen_id id)
let upper_sgen_id id = Util.zencode_string (string_of_id id)
let upper_codegen_id id = string (upper_sgen_id id)
-let sgen_ctyp = function
+let rec sgen_ctyp = function
| CT_unit -> "unit"
| CT_bit -> "int"
| CT_bool -> "bool"
@@ -2529,8 +2547,9 @@ let sgen_ctyp = function
| CT_vector _ as v -> Util.zencode_string (string_of_ctyp v)
| CT_string -> "sail_string"
| CT_real -> "real"
+ | CT_ref ctyp -> sgen_ctyp ctyp ^ "*"
-let sgen_ctyp_name = function
+let rec sgen_ctyp_name = function
| CT_unit -> "unit"
| CT_bit -> "int"
| CT_bool -> "bool"
@@ -2546,6 +2565,7 @@ let sgen_ctyp_name = function
| CT_vector _ as v -> Util.zencode_string (string_of_ctyp v)
| CT_string -> "sail_string"
| CT_real -> "real"
+ | CT_ref ctyp -> "ref_" ^ sgen_ctyp_name ctyp
let sgen_cval_param (frag, ctyp) =
match ctyp with
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 29284262..c2acda75 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -75,7 +75,7 @@ let rec doc_typ_pat (TP_aux (tpat_aux, _)) =
| TP_wild -> string "_"
| TP_var kid -> doc_kid kid
| TP_app (f, tpats) -> doc_id f ^^ parens (separate_map (comma ^^ space) doc_typ_pat tpats)
-
+
let rec doc_nexp =
let rec atomic_nexp (Nexp_aux (n_aux, _) as nexp) =
match n_aux with
@@ -228,7 +228,7 @@ let rec doc_pat (P_aux (p_aux, _) as pat) =
| P_vector pats -> brackets (separate_map (comma ^^ space) doc_pat pats)
| P_vector_concat pats -> separate_map (space ^^ string "@" ^^ space) doc_pat pats
| P_wild -> string "_"
- | P_as (pat, id) -> separate space [doc_pat pat; string "as"; doc_id id]
+ | P_as (pat, id) -> parens (separate space [doc_pat pat; string "as"; doc_id id])
| P_app (id, pats) -> doc_id id ^^ parens (separate_map (comma ^^ space) doc_pat pats)
| P_list pats -> string "[|" ^^ separate_map (comma ^^ space) doc_pat pats ^^ string "|]"
| _ -> string (string_of_pat pat)
@@ -486,7 +486,7 @@ let doc_spec (VS_aux(v,_)) =
let doc_extern ext =
let doc_backend b = Util.option_map (fun id -> string (b ^ ":") ^^ space ^^
utf8string ("\"" ^ String.escaped id ^ "\"")) (ext b) in
- let docs = Util.option_these (List.map doc_backend ["ocaml"; "lem"; "smt"; "interpreter"]) in
+ let docs = Util.option_these (List.map doc_backend ["ocaml"; "lem"; "smt"; "interpreter"; "c"]) in
if docs = [] then empty else equals ^^ space ^^ braces (separate (comma ^^ space) docs)
in
match v with
diff --git a/src/sail.ml b/src/sail.ml
index 9a5acf1c..2e13bc01 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -271,6 +271,7 @@ let main() =
then
let ast_c = rewrite_ast_c ast in
let ast_c, type_envs = Specialize.specialize ast_c type_envs in
+ let ast_c = Spec_analysis.top_sort_defs ast_c in
C_backend.compile_ast (C_backend.initial_ctx type_envs) ast_c
else ());
(if !(opt_print_lem)
diff --git a/src/specialize.ml b/src/specialize.ml
index 4f8a7e7e..0a4e0fbb 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -309,8 +309,8 @@ let specialize_id_overloads instantiations id (Defs defs) =
therefore remove all unused valspecs. Remaining polymorphic
valspecs are then re-specialized. This process is iterated until
the whole spec is specialized. *)
-let remove_unused_valspecs ast =
- let calls = ref (IdSet.of_list [mk_id "main"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"]) in
+let remove_unused_valspecs env ast =
+ let calls = ref (IdSet.of_list [mk_id "main"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"; mk_id "append_64"]) in
let vs_ids = Initial_check.val_spec_ids ast in
let inspect_exp = function
@@ -328,7 +328,9 @@ let remove_unused_valspecs ast =
let rec remove_unused (Defs defs) id =
match defs with
| def :: defs when is_fundef id def -> remove_unused (Defs defs) id
- | def :: defs when is_valspec id def -> remove_unused (Defs defs) id
+ | def :: defs when is_valspec id def ->
+ prerr_endline ("Removing: " ^ string_of_id id);
+ remove_unused (Defs defs) id
| DEF_overload (overload_id, overloads) :: defs ->
begin
match List.filter (fun id' -> Id.compare id id' <> 0) overloads with
@@ -372,13 +374,93 @@ let specialize_ids ids ast =
let ast, _ = Type_check.check Type_check.initial_env ast in
let ast = List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids) in
let ast, env = Type_check.check Type_check.initial_env ast in
- let ast = remove_unused_valspecs ast in
+ let ast = remove_unused_valspecs env ast in
ast, env
+(***** Specialising polymorphic variant types, e.g. option *****)
+
+let rec variant_generic_typ id (Defs defs) =
+ match defs with
+ | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ ->
+ mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
+ | _ :: defs -> variant_generic_typ id (Defs defs)
+ | [] -> failwith ("No variant with id " ^ string_of_id id)
+
+let rewrite_polymorphic_constructors id ast =
+ let rewrite_e_aux = function
+ | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
+ let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let spec_id = id_of_instantiation id instantiation in
+ E_aux (E_app (spec_id, args), annot)
+ | exp -> exp
+ in
+ let rewrite_p_aux = function
+ | P_aux (P_app (id', args), annot) as pat when Id.compare id id' = 0 ->
+ begin match Type_check.typ_of_annot annot with
+ | Typ_aux (Typ_app (variant_id, _), _) as typ ->
+ let open Type_check in
+ let instantiation, _, _ = unify (fst annot) (env_of_annot annot)
+ (variant_generic_typ variant_id ast)
+ (typ_of_annot annot)
+ in
+ (* FIXME: What if instantiation only involves U_nexps? *)
+ let instantiation = fix_instantiation instantiation in
+ P_aux (P_app (id_of_instantiation id' instantiation, args), annot)
+ | Typ_aux (Typ_id variant_id, _) -> pat
+ | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type")
+ end
+ | pat -> pat
+ in
+
+ let rewrite_pat = { id_pat_alg with p_aux = (fun (pat, annot) -> rewrite_p_aux (P_aux (pat, annot))) } in
+ let rewrite_exp = { id_exp_alg with pat_alg = rewrite_pat;
+ e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in
+ rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp);
+ rewrite_pat = (fun _ -> fold_pat rewrite_pat)} ast
+
+let specialize_variants ((Defs defs) as ast) env =
+ let ctors = ref [] in
+
+ let specialize_variant (TD_aux (tdef_aux, annot)) ast env =
+ match tdef_aux with
+ | TD_variant (id, name_scheme, typq, tus, flag) as variant ->
+ let kopts = List.filter (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) (quant_kopts typq) in
+ if kopts = [] then
+ (* If non-polymorphic, then do nothing. *)
+ TD_aux (variant, annot)
+ else
+ let specialize_tu (Tu_aux (Tu_ty_id (typ, id), annot)) =
+ ctors := id :: !ctors;
+ let is = instantiations_of id ast in
+ let is = List.sort_uniq (fun i1 i2 -> String.compare (string_of_instantiation i1) (string_of_instantiation i2)) is in
+ List.map (fun i -> Tu_aux (Tu_ty_id (Type_check.subst_unifiers i typ, id_of_instantiation id i), annot)) is
+ in
+ (*
+ let kopts, constraints = quant_split typq in
+ let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in
+ let typq = mk_typquant (List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in
+ *)
+ TD_aux (TD_variant (id, name_scheme, typq, List.concat (List.map specialize_tu tus), flag), annot)
+ | _ -> assert false
+ in
+
+ let rec specialize_variants' = function
+ | DEF_type (TD_aux (TD_variant _, _) as tdef) :: defs ->
+ DEF_type (specialize_variant tdef ast env) :: specialize_variants' defs
+ | def :: defs ->
+ def :: specialize_variants' defs
+ | [] -> []
+ in
+
+ let ast = Defs (specialize_variants' defs) in
+ let ast = List.fold_left (fun ast id -> rewrite_polymorphic_constructors id ast) ast !ctors in
+ Type_check.check Type_check.initial_env ast
+
let rec specialize ast env =
let ids = polymorphic_functions (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) ast in
if IdSet.is_empty ids then
- ast, env
+ specialize_variants ast env
else
+ (prerr_endline (Util.string_of_list ", " string_of_id (IdSet.elements ids));
let ast, env = specialize_ids ids ast in
- specialize ast env
+ specialize ast env)
diff --git a/src/specialize.mli b/src/specialize.mli
index 0c4a2495..474d3c9d 100644
--- a/src/specialize.mli
+++ b/src/specialize.mli
@@ -67,3 +67,9 @@ val polymorphic_functions : (kinded_id -> bool) -> 'a defs -> IdSet.t
environment to return if there is no polymorphism to remove, in
which case specialize returns the AST unmodified. *)
val specialize : tannot defs -> Env.t -> tannot defs * Env.t
+
+val specialize_variants : tannot defs -> Env.t -> tannot defs * Env.t
+
+val instantiations_of : id -> tannot defs -> uvar KBindings.t list
+
+val string_of_instantiation : uvar KBindings.t -> string