From 9570bf932e3ba0269cbed06a49fc8000b45b32a3 Mon Sep 17 00:00:00 2001 From: Alasdair Armstrong Date: Thu, 8 Mar 2018 18:14:44 +0000 Subject: Specialise constructors for polymorphic unions Also work on making C backend compile RISC-V --- src/c_backend.ml | 34 ++++++++++++++---- src/pretty_print_sail.ml | 6 ++-- src/sail.ml | 1 + src/specialize.ml | 94 ++++++++++++++++++++++++++++++++++++++++++++---- src/specialize.mli | 6 ++++ 5 files changed, 125 insertions(+), 16 deletions(-) (limited to 'src') diff --git a/src/c_backend.ml b/src/c_backend.ml index 2542dd42..20cdb4ac 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -519,6 +519,10 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = let alast = anf last in AE_block (aexps, alast, typ_of exp) + | E_assign (LEXP_aux (LEXP_deref dexp, _), exp) -> + let gs = gensym () in + AE_let (gs, typ_of dexp, anf dexp, AE_assign (gs, typ_of dexp, anf exp), unit_typ) + | E_assign (LEXP_aux (LEXP_id id, _), exp) | E_assign (LEXP_aux (LEXP_cast (_, id), _), exp) -> let aexp = anf exp in @@ -748,6 +752,7 @@ let rec ctyp_equal ctyp1 ctyp2 = | CT_real, CT_real -> true | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2 | CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2 + | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_equal ctyp1 ctyp2 | _, _ -> false (* String representation of ctyps here is only for debugging and @@ -769,6 +774,7 @@ let rec string_of_ctyp = function | CT_vector (true, ctyp) -> "vector(dec, " ^ string_of_ctyp ctyp ^ ")" | CT_vector (false, ctyp) -> "vector(inc, " ^ string_of_ctyp ctyp ^ ")" | CT_list ctyp -> "list(" ^ string_of_ctyp ctyp ^ ")" + | CT_ref ctyp -> "ref(" ^ string_of_ctyp ctyp ^ ")" (** Convert a sail type into a C-type **) let rec ctyp_of_typ ctx typ = @@ -814,6 +820,9 @@ let rec ctyp_of_typ ctx typ = | Typ_id id when string_of_id id = "string" -> CT_string | Typ_id id when string_of_id id = "real" -> CT_real + | Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "register" || string_of_id id = "ref" -> + CT_ref (ctyp_of_typ ctx typ) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> Bindings.bindings) | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> Bindings.bindings) | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) @@ -830,6 +839,7 @@ let rec is_stack_ctyp ctyp = match ctyp with | CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields | CT_variant (_, ctors) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps + | CT_ref ctyp -> is_stack_ctyp ctyp let is_stack_typ ctx typ = is_stack_ctyp (ctyp_of_typ ctx typ) @@ -1286,6 +1296,9 @@ let rec compile_aval ctx = function | AV_id (id, typ) -> [], (F_id id, ctyp_of_typ ctx (lvar_typ typ)), [] + | AV_ref (id, typ) -> + [], (F_id id, CT_ref (ctyp_of_typ ctx (lvar_typ typ))), [] + | AV_lit (L_aux (L_string str, _), typ) -> [], (F_lit (V_string (String.escaped str)), ctyp_of_typ ctx typ), [] @@ -1435,9 +1448,6 @@ let rec compile_aval ctx = function (F_id gs, CT_list ctyp), [iclear (CT_list ctyp) gs] - | AV_ref _ -> - c_error "Have AV_ref" - let compile_funcall ctx id args typ = let setup = ref [] in let cleanup = ref [] in @@ -1466,7 +1476,7 @@ let compile_funcall ctx id args typ = (F_id gs, ctyp) else c_error ~loc:(id_loc id) - (Printf.sprintf "Failure when setting up function arguments: %s and %s." (string_of_ctyp have_ctyp) (string_of_ctyp ctyp)) + (Printf.sprintf "Failure when setting up function %s arguments: %s and %s." (string_of_id id) (string_of_ctyp have_ctyp) (string_of_ctyp ctyp)) in let sargs = List.map2 setup_arg arg_ctyps args in @@ -1725,11 +1735,18 @@ let rec compile_aexp ctx = function (* assign_ctyp is the type of the C variable we are assigning to, ctyp is the type of the C expression being assigned. These may be different. *) + let pointer_assign ctyp1 ctyp2 = + match ctyp1 with + | CT_ref ctyp1 -> ctyp_equal ctyp1 ctyp2 + | _ -> false + in let assign_ctyp = ctyp_of_typ ctx assign_typ in let setup, ctyp, call, cleanup = compile_aexp ctx aexp in let comment = "assign " ^ string_of_ctyp assign_ctyp ^ " := " ^ string_of_ctyp ctyp in if ctyp_equal assign_ctyp ctyp then setup @ [call (CL_id id)], CT_unit, (fun clexp -> icopy clexp unit_fragment), cleanup + else if pointer_assign assign_ctyp ctyp then + setup @ [call (CL_addr id)], CT_unit, (fun clexp -> icopy clexp unit_fragment), cleanup else if not (is_stack_ctyp assign_ctyp) && is_stack_ctyp ctyp then let gs = gensym () in setup @ [ icomment comment; @@ -1850,7 +1867,8 @@ let rec pat_ids (Typ_aux (arg_typ_aux, _) as arg_typ) (P_aux (p_aux, (l, _)) as | P_wild, _ -> let gs = gensym () in [gs] | P_var (pat, _), _ -> pat_ids arg_typ pat | P_typ (_, pat), _ -> pat_ids arg_typ pat - | _, _ -> c_error ~loc:l ("Cannot compile pattern " ^ string_of_pat pat ^ " to C") + | P_app _, _ -> let gs = gensym () in [gs] + | _, _ -> c_error ~loc:l ("Cannot compile pattern " ^ string_of_pat pat ^ " : " ^ string_of_typ arg_typ ^ " to C") (** Compile a sail type definition into a IR one. Most of the actual work of translating the typedefs into C is done by the code @@ -2513,7 +2531,7 @@ let codegen_id id = string (sgen_id id) let upper_sgen_id id = Util.zencode_string (string_of_id id) let upper_codegen_id id = string (upper_sgen_id id) -let sgen_ctyp = function +let rec sgen_ctyp = function | CT_unit -> "unit" | CT_bit -> "int" | CT_bool -> "bool" @@ -2529,8 +2547,9 @@ let sgen_ctyp = function | CT_vector _ as v -> Util.zencode_string (string_of_ctyp v) | CT_string -> "sail_string" | CT_real -> "real" + | CT_ref ctyp -> sgen_ctyp ctyp ^ "*" -let sgen_ctyp_name = function +let rec sgen_ctyp_name = function | CT_unit -> "unit" | CT_bit -> "int" | CT_bool -> "bool" @@ -2546,6 +2565,7 @@ let sgen_ctyp_name = function | CT_vector _ as v -> Util.zencode_string (string_of_ctyp v) | CT_string -> "sail_string" | CT_real -> "real" + | CT_ref ctyp -> "ref_" ^ sgen_ctyp_name ctyp let sgen_cval_param (frag, ctyp) = match ctyp with diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 29284262..c2acda75 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -75,7 +75,7 @@ let rec doc_typ_pat (TP_aux (tpat_aux, _)) = | TP_wild -> string "_" | TP_var kid -> doc_kid kid | TP_app (f, tpats) -> doc_id f ^^ parens (separate_map (comma ^^ space) doc_typ_pat tpats) - + let rec doc_nexp = let rec atomic_nexp (Nexp_aux (n_aux, _) as nexp) = match n_aux with @@ -228,7 +228,7 @@ let rec doc_pat (P_aux (p_aux, _) as pat) = | P_vector pats -> brackets (separate_map (comma ^^ space) doc_pat pats) | P_vector_concat pats -> separate_map (space ^^ string "@" ^^ space) doc_pat pats | P_wild -> string "_" - | P_as (pat, id) -> separate space [doc_pat pat; string "as"; doc_id id] + | P_as (pat, id) -> parens (separate space [doc_pat pat; string "as"; doc_id id]) | P_app (id, pats) -> doc_id id ^^ parens (separate_map (comma ^^ space) doc_pat pats) | P_list pats -> string "[|" ^^ separate_map (comma ^^ space) doc_pat pats ^^ string "|]" | _ -> string (string_of_pat pat) @@ -486,7 +486,7 @@ let doc_spec (VS_aux(v,_)) = let doc_extern ext = let doc_backend b = Util.option_map (fun id -> string (b ^ ":") ^^ space ^^ utf8string ("\"" ^ String.escaped id ^ "\"")) (ext b) in - let docs = Util.option_these (List.map doc_backend ["ocaml"; "lem"; "smt"; "interpreter"]) in + let docs = Util.option_these (List.map doc_backend ["ocaml"; "lem"; "smt"; "interpreter"; "c"]) in if docs = [] then empty else equals ^^ space ^^ braces (separate (comma ^^ space) docs) in match v with diff --git a/src/sail.ml b/src/sail.ml index 9a5acf1c..2e13bc01 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -271,6 +271,7 @@ let main() = then let ast_c = rewrite_ast_c ast in let ast_c, type_envs = Specialize.specialize ast_c type_envs in + let ast_c = Spec_analysis.top_sort_defs ast_c in C_backend.compile_ast (C_backend.initial_ctx type_envs) ast_c else ()); (if !(opt_print_lem) diff --git a/src/specialize.ml b/src/specialize.ml index 4f8a7e7e..0a4e0fbb 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -309,8 +309,8 @@ let specialize_id_overloads instantiations id (Defs defs) = therefore remove all unused valspecs. Remaining polymorphic valspecs are then re-specialized. This process is iterated until the whole spec is specialized. *) -let remove_unused_valspecs ast = - let calls = ref (IdSet.of_list [mk_id "main"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"]) in +let remove_unused_valspecs env ast = + let calls = ref (IdSet.of_list [mk_id "main"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"; mk_id "append_64"]) in let vs_ids = Initial_check.val_spec_ids ast in let inspect_exp = function @@ -328,7 +328,9 @@ let remove_unused_valspecs ast = let rec remove_unused (Defs defs) id = match defs with | def :: defs when is_fundef id def -> remove_unused (Defs defs) id - | def :: defs when is_valspec id def -> remove_unused (Defs defs) id + | def :: defs when is_valspec id def -> + prerr_endline ("Removing: " ^ string_of_id id); + remove_unused (Defs defs) id | DEF_overload (overload_id, overloads) :: defs -> begin match List.filter (fun id' -> Id.compare id id' <> 0) overloads with @@ -372,13 +374,93 @@ let specialize_ids ids ast = let ast, _ = Type_check.check Type_check.initial_env ast in let ast = List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids) in let ast, env = Type_check.check Type_check.initial_env ast in - let ast = remove_unused_valspecs ast in + let ast = remove_unused_valspecs env ast in ast, env +(***** Specialising polymorphic variant types, e.g. option *****) + +let rec variant_generic_typ id (Defs defs) = + match defs with + | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ -> + mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq))) + | _ :: defs -> variant_generic_typ id (Defs defs) + | [] -> failwith ("No variant with id " ^ string_of_id id) + +let rewrite_polymorphic_constructors id ast = + let rewrite_e_aux = function + | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 -> + let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + let spec_id = id_of_instantiation id instantiation in + E_aux (E_app (spec_id, args), annot) + | exp -> exp + in + let rewrite_p_aux = function + | P_aux (P_app (id', args), annot) as pat when Id.compare id id' = 0 -> + begin match Type_check.typ_of_annot annot with + | Typ_aux (Typ_app (variant_id, _), _) as typ -> + let open Type_check in + let instantiation, _, _ = unify (fst annot) (env_of_annot annot) + (variant_generic_typ variant_id ast) + (typ_of_annot annot) + in + (* FIXME: What if instantiation only involves U_nexps? *) + let instantiation = fix_instantiation instantiation in + P_aux (P_app (id_of_instantiation id' instantiation, args), annot) + | Typ_aux (Typ_id variant_id, _) -> pat + | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type") + end + | pat -> pat + in + + let rewrite_pat = { id_pat_alg with p_aux = (fun (pat, annot) -> rewrite_p_aux (P_aux (pat, annot))) } in + let rewrite_exp = { id_exp_alg with pat_alg = rewrite_pat; + e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in + rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp); + rewrite_pat = (fun _ -> fold_pat rewrite_pat)} ast + +let specialize_variants ((Defs defs) as ast) env = + let ctors = ref [] in + + let specialize_variant (TD_aux (tdef_aux, annot)) ast env = + match tdef_aux with + | TD_variant (id, name_scheme, typq, tus, flag) as variant -> + let kopts = List.filter (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) (quant_kopts typq) in + if kopts = [] then + (* If non-polymorphic, then do nothing. *) + TD_aux (variant, annot) + else + let specialize_tu (Tu_aux (Tu_ty_id (typ, id), annot)) = + ctors := id :: !ctors; + let is = instantiations_of id ast in + let is = List.sort_uniq (fun i1 i2 -> String.compare (string_of_instantiation i1) (string_of_instantiation i2)) is in + List.map (fun i -> Tu_aux (Tu_ty_id (Type_check.subst_unifiers i typ, id_of_instantiation id i), annot)) is + in + (* + let kopts, constraints = quant_split typq in + let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in + let typq = mk_typquant (List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in + *) + TD_aux (TD_variant (id, name_scheme, typq, List.concat (List.map specialize_tu tus), flag), annot) + | _ -> assert false + in + + let rec specialize_variants' = function + | DEF_type (TD_aux (TD_variant _, _) as tdef) :: defs -> + DEF_type (specialize_variant tdef ast env) :: specialize_variants' defs + | def :: defs -> + def :: specialize_variants' defs + | [] -> [] + in + + let ast = Defs (specialize_variants' defs) in + let ast = List.fold_left (fun ast id -> rewrite_polymorphic_constructors id ast) ast !ctors in + Type_check.check Type_check.initial_env ast + let rec specialize ast env = let ids = polymorphic_functions (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) ast in if IdSet.is_empty ids then - ast, env + specialize_variants ast env else + (prerr_endline (Util.string_of_list ", " string_of_id (IdSet.elements ids)); let ast, env = specialize_ids ids ast in - specialize ast env + specialize ast env) diff --git a/src/specialize.mli b/src/specialize.mli index 0c4a2495..474d3c9d 100644 --- a/src/specialize.mli +++ b/src/specialize.mli @@ -67,3 +67,9 @@ val polymorphic_functions : (kinded_id -> bool) -> 'a defs -> IdSet.t environment to return if there is no polymorphism to remove, in which case specialize returns the AST unmodified. *) val specialize : tannot defs -> Env.t -> tannot defs * Env.t + +val specialize_variants : tannot defs -> Env.t -> tannot defs * Env.t + +val instantiations_of : id -> tannot defs -> uvar KBindings.t list + +val string_of_instantiation : uvar KBindings.t -> string -- cgit v1.2.3