diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/c_backend.ml | 161 | ||||
| -rw-r--r-- | src/parser.mly | 2 | ||||
| -rw-r--r-- | src/specialize.ml | 11 |
3 files changed, 155 insertions, 19 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml index 934bfdff..3aa287f8 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -89,7 +89,7 @@ let rec string_of_fragment ?zencode:(zencode=true) = function | F_current_exception -> "(*current_exception)" and string_of_fragment' ?zencode:(zencode=true) f = match f with - | F_op _ -> "(" ^ string_of_fragment ~zencode:zencode f ^ ")" + | F_op _ | F_unary _ -> "(" ^ string_of_fragment ~zencode:zencode f ^ ")" | _ -> string_of_fragment ~zencode:zencode f (**************************************************************************) @@ -144,6 +144,8 @@ and apat = | AP_id of id | AP_global of id * typ | AP_app of id * apat + | AP_cons of apat * apat + | AP_nil | AP_wild and aval = @@ -331,6 +333,8 @@ let rec anf_pat ?global:(global=false) (P_aux (p_aux, (l, _)) as pat) = | P_app (id, pats) -> AP_app (id, AP_tup (List.map (fun pat -> anf_pat ~global:global pat) pats)) | P_typ (_, pat) -> anf_pat ~global:global pat | P_var (pat, _) -> anf_pat ~global:global pat + | P_cons (hd_pat, tl_pat) -> AP_cons (anf_pat ~global:global hd_pat, anf_pat ~global:global tl_pat) + | P_list pats -> List.fold_right (fun pat apat -> AP_cons (anf_pat ~global:global pat, apat)) pats AP_nil | _ -> c_error ~loc:l ("Could not convert pattern to ANF: " ^ string_of_pat pat) let rec apat_globals = function @@ -368,7 +372,8 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = let alast = anf last in AE_block (aexps, alast, typ_of exp) - | E_assign (LEXP_aux (LEXP_id id, _), exp) -> + | E_assign (LEXP_aux (LEXP_id id, _), exp) + | E_assign (LEXP_aux (LEXP_cast (_, id), _), exp) -> let aexp = anf exp in AE_assign (id, lvar_typ (Env.lookup_id id (env_of exp)), aexp) @@ -571,9 +576,12 @@ let rec ctyp_equal ctyp1 ctyp2 = | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0 | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0 | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0 - | CT_tup ctyps1, CT_tup ctyps2 -> List.for_all2 ctyp_equal ctyps1 ctyps2 + | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> + List.for_all2 ctyp_equal ctyps1 ctyps2 | CT_string, CT_string -> true | CT_real, CT_real -> true + | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2 + | CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2 | _, _ -> false (* String representation of ctyps here is only for debugging and @@ -592,6 +600,9 @@ let rec string_of_ctyp = function | CT_tup ctyps -> "(" ^ Util.string_of_list ", " string_of_ctyp ctyps ^ ")" | CT_struct (id, _) | CT_enum (id, _) | CT_variant (id, _) -> string_of_id id | CT_string -> "string" + | CT_vector (true, ctyp) -> "vector<dec, " ^ string_of_ctyp ctyp ^ ">" + | CT_vector (false, ctyp) -> "vector<inc, " ^ string_of_ctyp ctyp ^ ">" + | CT_list ctyp -> "list<" ^ string_of_ctyp ctyp ^ ">" (** Convert a sail type into a C-type **) let rec ctyp_of_typ ctx typ = @@ -612,6 +623,9 @@ let rec ctyp_of_typ ctx typ = | _ -> CT_mpz end + | Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "list" -> + CT_list (ctyp_of_typ ctx typ) + | Typ_app (id, [Typ_arg_aux (Typ_arg_nexp n, _); Typ_arg_aux (Typ_arg_order ord, _); Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id vtyp_id, _)), _)]) @@ -645,7 +659,7 @@ let rec ctyp_of_typ ctx typ = let rec is_stack_ctyp ctyp = match ctyp with | CT_uint64 _ | CT_int64 | CT_bit | CT_unit | CT_bool | CT_enum _ -> true - | CT_bv _ | CT_mpz | CT_real | CT_string -> false + | CT_bv _ | CT_mpz | CT_real | CT_string | CT_list _ | CT_vector _ -> false | CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields | CT_variant (_, ctors) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps @@ -894,9 +908,27 @@ let rec instr_ctyps (I_aux (instr, aux)) = | I_throw cval | I_jump (cval, _) | I_return cval -> [cval_ctyp cval] | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure -> [] -let cdef_ctyps = function +let cdef_ctyps ctx = function | CDEF_reg_dec (_, ctyp) -> [ctyp] - | CDEF_fundef (_, _, _, instrs) -> List.concat (List.map instr_ctyps instrs) + | CDEF_fundef (id, _, _, instrs) -> + (* TODO: Move this code to DEF_fundef -> CDEF_fundef translation, and modify bytecode.ott *) + let _, Typ_aux (fn_typ, _) = + try Env.get_val_spec id ctx.tc_env with + | Type_error _ -> + (* If we can't find the function type, then it must be a nullary union constructor. *) + begin match Env.lookup_id id ctx.tc_env with + | Union (typq, typ) -> typq, function_typ unit_typ typ no_effect + | _ -> failwith ("Got function identifier " ^ string_of_id id ^ " which is neither a function nor a constructor.") + end + in + let arg_typs, ret_typ = match fn_typ with + | Typ_fn (Typ_aux (Typ_tup arg_typs, _), ret_typ, _) -> arg_typs, ret_typ + | Typ_fn (arg_typ, ret_typ, _) -> [arg_typ], ret_typ + | _ -> assert false + in + let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in + ret_ctyp :: arg_ctyps @ List.concat (List.map instr_ctyps instrs) + | CDEF_type tdef -> ctype_def_ctyps tdef | CDEF_let (_, bindings, instrs, cleanup) -> List.map snd bindings @@ -1016,6 +1048,10 @@ let is_ct_tup = function | CT_tup _ -> true | _ -> false +let is_ct_list = function + | CT_list _ -> true + | _ -> false + let rec is_bitvector = function | [] -> true | AV_lit (L_aux (L_zero, _), _) :: avals -> is_bitvector avals @@ -1113,6 +1149,25 @@ let rec compile_aval ctx = function | AV_vector _ -> c_error "Have AV_vector" + | AV_list (avals, Typ_aux (typ, _)) -> + let ctyp = match typ with + | Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "list" -> ctyp_of_typ ctx typ + | _ -> c_error "Invalid list type" + in + let gs = gensym () in + let mk_cons aval = + let setup, cval, cleanup = compile_aval ctx aval in + setup @ [ifuncall (CL_id gs) (mk_id ("cons#" ^ string_of_ctyp ctyp)) [cval; (F_id gs, CT_list ctyp)] (CT_list ctyp)] @ cleanup + in + [idecl (CT_list ctyp) gs; + ialloc (CT_list ctyp) gs] + @ List.concat (List.map mk_cons (List.rev avals)), + (F_id gs, CT_list ctyp), + [iclear (CT_list ctyp) gs] + + | AV_ref _ -> + c_error "Have AV_ref" + let compile_funcall ctx id args typ = let setup = ref [] in let cleanup = ref [] in @@ -1207,6 +1262,14 @@ let rec compile_match ctx apat cval case_label = end | AP_wild, _ -> [], [] + | AP_cons (hd_apat, tl_apat), (frag, CT_list ctyp) -> + let hd_setup, hd_cleanup = compile_match ctx hd_apat (F_field (F_unary ("*", frag), "hd"), ctyp) case_label in + let tl_setup, tl_cleanup = compile_match ctx tl_apat (F_field (F_unary ("*", frag), "tl"), CT_list ctyp) case_label in + [ijump (F_op (frag, "==", F_lit "NULL"), CT_bool) case_label] @ hd_setup @ tl_setup, tl_cleanup @ hd_cleanup + | AP_cons _, (_, _) -> c_error "Tried to pattern match cons on non list type" + + | AP_nil, (frag, _) -> [ijump (F_op (frag, "!=", F_lit "NULL"), CT_bool) case_label], [] + let unit_fragment = (F_lit "UNIT", CT_unit) (** GLOBAL: label_counter is used to make sure all labels have unique @@ -1269,13 +1332,14 @@ let rec compile_aexp ctx = function [iblock case_instrs; ilabel case_label] in [icomment "begin match"] - @ aval_setup @ [idecl ctyp case_return_id] + @ aval_setup @ [idecl ctyp case_return_id] @ (if is_stack_ctyp ctyp then [] else [ialloc ctyp case_return_id]) @ List.concat (List.map compile_case cases) @ [imatch_failure ()] @ [ilabel finish_match_label], ctyp, (fun clexp -> icopy clexp (F_id case_return_id, ctyp)), - aval_cleanup + (if is_stack_ctyp ctyp then [] else [iclear ctyp case_return_id]) + @ aval_cleanup @ [icomment "end match"] (* Compile try statement *) @@ -1904,6 +1968,7 @@ let sgen_ctyp = function | CT_struct (id, _) -> "struct " ^ sgen_id id | CT_enum (id, _) -> "enum " ^ sgen_id id | CT_variant (id, _) -> "struct " ^ sgen_id id + | CT_list _ as l -> Util.zencode_string (string_of_ctyp l) | CT_string -> "sail_string" let sgen_ctyp_name = function @@ -1918,6 +1983,7 @@ let sgen_ctyp_name = function | CT_struct (id, _) -> sgen_id id | CT_enum (id, _) -> sgen_id id | CT_variant (id, _) -> sgen_id id + | CT_list _ as l -> Util.zencode_string (string_of_ctyp l) | CT_string -> "sail_string" let sgen_cval_param (frag, ctyp) = @@ -2205,19 +2271,73 @@ let codegen_type_def ctx = function This variable should be reset to empty only when the entire AST has been translated to C. **) -let generated_tuples = ref IdSet.empty +let generated = ref IdSet.empty let codegen_tup ctx ctyps = let id = mk_id ("tuple_" ^ string_of_ctyp (CT_tup ctyps)) in - if IdSet.mem id !generated_tuples then + if IdSet.mem id !generated then empty else - let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields) - (0, Bindings.empty) - ctyps - in - generated_tuples := IdSet.add id !generated_tuples; - codegen_type_def ctx (CTD_struct (id, Bindings.bindings fields)) ^^ twice hardline + begin + let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields) + (0, Bindings.empty) + ctyps + in + generated := IdSet.add id !generated; + codegen_type_def ctx (CTD_struct (id, Bindings.bindings fields)) ^^ twice hardline + end + +let codegen_node id ctyp = + string (Printf.sprintf "struct node_%s {\n %s hd;\n struct node_%s *tl;\n};\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + ^^ string (Printf.sprintf "typedef struct node_%s *%s;" (sgen_id id) (sgen_id id)) + +let codegen_list_init id = + string (Printf.sprintf "void init_%s(%s *rop) { *rop = NULL; }" (sgen_id id) (sgen_id id)) + +let codegen_list_clear id ctyp = + string (Printf.sprintf "void clear_%s(%s *rop) {\n" (sgen_id id) (sgen_id id)) + ^^ string (Printf.sprintf " if (*rop == NULL) return;") + ^^ string (Printf.sprintf " clear_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) + ^^ string (Printf.sprintf " clear_%s(&(*rop)->tl);\n" (sgen_id id)) + ^^ string " free(*rop);" + ^^ string "}" + +let codegen_list_set id ctyp = + string (Printf.sprintf "void internal_set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) + ^^ string " if (op == NULL) { *rop = NULL; return; };\n" + ^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id)) + ^^ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) + ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, op->hd);\n" (sgen_ctyp_name ctyp)) + ^^ string (Printf.sprintf " internal_set_%s(&(*rop)->tl, op->tl);\n" (sgen_id id)) + ^^ string "}" + ^^ twice hardline + ^^ string (Printf.sprintf "void set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) + ^^ string (Printf.sprintf " clear_%s(rop);\n" (sgen_id id)) + ^^ string (Printf.sprintf " internal_set_%s(rop, op);\n" (sgen_id id)) + ^^ string "}" + +let codegen_cons id ctyp = + let cons_id = mk_id ("cons#" ^ string_of_ctyp ctyp) in + string (Printf.sprintf "void %s(%s *rop, const %s x, const %s xs) {\n" (sgen_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + ^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id)) + ^^ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) + ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, x);\n" (sgen_ctyp_name ctyp)) + ^^ string " (*rop)->tl = xs;\n" + ^^ string "}" + +let codegen_list ctx ctyp = + let id = mk_id (string_of_ctyp (CT_list ctyp)) in + if IdSet.mem id !generated then + empty + else + begin + generated := IdSet.add id !generated; + codegen_node id ctyp ^^ twice hardline + ^^ codegen_list_init id ^^ twice hardline + ^^ codegen_list_clear id ctyp ^^ twice hardline + ^^ codegen_list_set id ctyp ^^ twice hardline + ^^ codegen_cons id ctyp ^^ twice hardline + end let codegen_def' ctx = function | CDEF_reg_dec (id, ctyp) -> @@ -2271,10 +2391,17 @@ let codegen_def ctx def = | CT_tup ctyps -> ctyps | _ -> assert false in - let tups = List.filter is_ct_tup (cdef_ctyps def) in + let unlist = function + | CT_list ctyp -> ctyp + | _ -> assert false + in + let tups = List.filter is_ct_tup (cdef_ctyps ctx def) in let tups = List.map (fun ctyp -> codegen_tup ctx (untup ctyp)) tups in + let lists = List.filter is_ct_list (cdef_ctyps ctx def) in + let lists = List.map (fun ctyp -> codegen_list ctx (unlist ctyp)) lists in prerr_endline (Pretty_print_sail.to_string (pp_cdef def)); concat tups + ^^ concat lists ^^ codegen_def' ctx def let compile_ast ctx (Defs defs) = diff --git a/src/parser.mly b/src/parser.mly index b781ea1f..c8cc49a3 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -658,6 +658,8 @@ pat: { $1 } | pat1 As typ { mk_pat (P_var ($1, $3)) $startpos $endpos } + | pat1 Match typ + { mk_pat (P_var ($1, $3)) $startpos $endpos } pat_list: | pat diff --git a/src/specialize.ml b/src/specialize.ml index 9344e661..efa8783e 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -176,9 +176,16 @@ let specialize_id_fundef instantiations id ast = match split_defs (is_fundef id) ast with | None -> ast | Some (pre_ast, DEF_fundef fundef, post_ast) -> - let fundefs = - List.map (fun i -> DEF_fundef (rename_fundef (id_of_instantiation id i) fundef)) instantiations + let spec_ids = ref IdSet.empty in + let specialize_fundef instantiation = + let spec_id = id_of_instantiation id instantiation in + if IdSet.mem spec_id !spec_ids then [] else + begin + spec_ids := IdSet.add spec_id !spec_ids; + [DEF_fundef (rename_fundef spec_id fundef)] + end in + let fundefs = List.map specialize_fundef instantiations |> List.concat in append_ast pre_ast (append_ast (Defs fundefs) post_ast) | Some _ -> assert false (* unreachable *) |
