summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/c_backend.ml161
-rw-r--r--src/parser.mly2
-rw-r--r--src/specialize.ml11
3 files changed, 155 insertions, 19 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml
index 934bfdff..3aa287f8 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -89,7 +89,7 @@ let rec string_of_fragment ?zencode:(zencode=true) = function
| F_current_exception -> "(*current_exception)"
and string_of_fragment' ?zencode:(zencode=true) f =
match f with
- | F_op _ -> "(" ^ string_of_fragment ~zencode:zencode f ^ ")"
+ | F_op _ | F_unary _ -> "(" ^ string_of_fragment ~zencode:zencode f ^ ")"
| _ -> string_of_fragment ~zencode:zencode f
(**************************************************************************)
@@ -144,6 +144,8 @@ and apat =
| AP_id of id
| AP_global of id * typ
| AP_app of id * apat
+ | AP_cons of apat * apat
+ | AP_nil
| AP_wild
and aval =
@@ -331,6 +333,8 @@ let rec anf_pat ?global:(global=false) (P_aux (p_aux, (l, _)) as pat) =
| P_app (id, pats) -> AP_app (id, AP_tup (List.map (fun pat -> anf_pat ~global:global pat) pats))
| P_typ (_, pat) -> anf_pat ~global:global pat
| P_var (pat, _) -> anf_pat ~global:global pat
+ | P_cons (hd_pat, tl_pat) -> AP_cons (anf_pat ~global:global hd_pat, anf_pat ~global:global tl_pat)
+ | P_list pats -> List.fold_right (fun pat apat -> AP_cons (anf_pat ~global:global pat, apat)) pats AP_nil
| _ -> c_error ~loc:l ("Could not convert pattern to ANF: " ^ string_of_pat pat)
let rec apat_globals = function
@@ -368,7 +372,8 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) =
let alast = anf last in
AE_block (aexps, alast, typ_of exp)
- | E_assign (LEXP_aux (LEXP_id id, _), exp) ->
+ | E_assign (LEXP_aux (LEXP_id id, _), exp)
+ | E_assign (LEXP_aux (LEXP_cast (_, id), _), exp) ->
let aexp = anf exp in
AE_assign (id, lvar_typ (Env.lookup_id id (env_of exp)), aexp)
@@ -571,9 +576,12 @@ let rec ctyp_equal ctyp1 ctyp2 =
| CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0
| CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0
| CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0
- | CT_tup ctyps1, CT_tup ctyps2 -> List.for_all2 ctyp_equal ctyps1 ctyps2
+ | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 ->
+ List.for_all2 ctyp_equal ctyps1 ctyps2
| CT_string, CT_string -> true
| CT_real, CT_real -> true
+ | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2
+ | CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2
| _, _ -> false
(* String representation of ctyps here is only for debugging and
@@ -592,6 +600,9 @@ let rec string_of_ctyp = function
| CT_tup ctyps -> "(" ^ Util.string_of_list ", " string_of_ctyp ctyps ^ ")"
| CT_struct (id, _) | CT_enum (id, _) | CT_variant (id, _) -> string_of_id id
| CT_string -> "string"
+ | CT_vector (true, ctyp) -> "vector<dec, " ^ string_of_ctyp ctyp ^ ">"
+ | CT_vector (false, ctyp) -> "vector<inc, " ^ string_of_ctyp ctyp ^ ">"
+ | CT_list ctyp -> "list<" ^ string_of_ctyp ctyp ^ ">"
(** Convert a sail type into a C-type **)
let rec ctyp_of_typ ctx typ =
@@ -612,6 +623,9 @@ let rec ctyp_of_typ ctx typ =
| _ -> CT_mpz
end
+ | Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "list" ->
+ CT_list (ctyp_of_typ ctx typ)
+
| Typ_app (id, [Typ_arg_aux (Typ_arg_nexp n, _);
Typ_arg_aux (Typ_arg_order ord, _);
Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id vtyp_id, _)), _)])
@@ -645,7 +659,7 @@ let rec ctyp_of_typ ctx typ =
let rec is_stack_ctyp ctyp = match ctyp with
| CT_uint64 _ | CT_int64 | CT_bit | CT_unit | CT_bool | CT_enum _ -> true
- | CT_bv _ | CT_mpz | CT_real | CT_string -> false
+ | CT_bv _ | CT_mpz | CT_real | CT_string | CT_list _ | CT_vector _ -> false
| CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields
| CT_variant (_, ctors) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors
| CT_tup ctyps -> List.for_all is_stack_ctyp ctyps
@@ -894,9 +908,27 @@ let rec instr_ctyps (I_aux (instr, aux)) =
| I_throw cval | I_jump (cval, _) | I_return cval -> [cval_ctyp cval]
| I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure -> []
-let cdef_ctyps = function
+let cdef_ctyps ctx = function
| CDEF_reg_dec (_, ctyp) -> [ctyp]
- | CDEF_fundef (_, _, _, instrs) -> List.concat (List.map instr_ctyps instrs)
+ | CDEF_fundef (id, _, _, instrs) ->
+ (* TODO: Move this code to DEF_fundef -> CDEF_fundef translation, and modify bytecode.ott *)
+ let _, Typ_aux (fn_typ, _) =
+ try Env.get_val_spec id ctx.tc_env with
+ | Type_error _ ->
+ (* If we can't find the function type, then it must be a nullary union constructor. *)
+ begin match Env.lookup_id id ctx.tc_env with
+ | Union (typq, typ) -> typq, function_typ unit_typ typ no_effect
+ | _ -> failwith ("Got function identifier " ^ string_of_id id ^ " which is neither a function nor a constructor.")
+ end
+ in
+ let arg_typs, ret_typ = match fn_typ with
+ | Typ_fn (Typ_aux (Typ_tup arg_typs, _), ret_typ, _) -> arg_typs, ret_typ
+ | Typ_fn (arg_typ, ret_typ, _) -> [arg_typ], ret_typ
+ | _ -> assert false
+ in
+ let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in
+ ret_ctyp :: arg_ctyps @ List.concat (List.map instr_ctyps instrs)
+
| CDEF_type tdef -> ctype_def_ctyps tdef
| CDEF_let (_, bindings, instrs, cleanup) ->
List.map snd bindings
@@ -1016,6 +1048,10 @@ let is_ct_tup = function
| CT_tup _ -> true
| _ -> false
+let is_ct_list = function
+ | CT_list _ -> true
+ | _ -> false
+
let rec is_bitvector = function
| [] -> true
| AV_lit (L_aux (L_zero, _), _) :: avals -> is_bitvector avals
@@ -1113,6 +1149,25 @@ let rec compile_aval ctx = function
| AV_vector _ ->
c_error "Have AV_vector"
+ | AV_list (avals, Typ_aux (typ, _)) ->
+ let ctyp = match typ with
+ | Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "list" -> ctyp_of_typ ctx typ
+ | _ -> c_error "Invalid list type"
+ in
+ let gs = gensym () in
+ let mk_cons aval =
+ let setup, cval, cleanup = compile_aval ctx aval in
+ setup @ [ifuncall (CL_id gs) (mk_id ("cons#" ^ string_of_ctyp ctyp)) [cval; (F_id gs, CT_list ctyp)] (CT_list ctyp)] @ cleanup
+ in
+ [idecl (CT_list ctyp) gs;
+ ialloc (CT_list ctyp) gs]
+ @ List.concat (List.map mk_cons (List.rev avals)),
+ (F_id gs, CT_list ctyp),
+ [iclear (CT_list ctyp) gs]
+
+ | AV_ref _ ->
+ c_error "Have AV_ref"
+
let compile_funcall ctx id args typ =
let setup = ref [] in
let cleanup = ref [] in
@@ -1207,6 +1262,14 @@ let rec compile_match ctx apat cval case_label =
end
| AP_wild, _ -> [], []
+ | AP_cons (hd_apat, tl_apat), (frag, CT_list ctyp) ->
+ let hd_setup, hd_cleanup = compile_match ctx hd_apat (F_field (F_unary ("*", frag), "hd"), ctyp) case_label in
+ let tl_setup, tl_cleanup = compile_match ctx tl_apat (F_field (F_unary ("*", frag), "tl"), CT_list ctyp) case_label in
+ [ijump (F_op (frag, "==", F_lit "NULL"), CT_bool) case_label] @ hd_setup @ tl_setup, tl_cleanup @ hd_cleanup
+ | AP_cons _, (_, _) -> c_error "Tried to pattern match cons on non list type"
+
+ | AP_nil, (frag, _) -> [ijump (F_op (frag, "!=", F_lit "NULL"), CT_bool) case_label], []
+
let unit_fragment = (F_lit "UNIT", CT_unit)
(** GLOBAL: label_counter is used to make sure all labels have unique
@@ -1269,13 +1332,14 @@ let rec compile_aexp ctx = function
[iblock case_instrs; ilabel case_label]
in
[icomment "begin match"]
- @ aval_setup @ [idecl ctyp case_return_id]
+ @ aval_setup @ [idecl ctyp case_return_id] @ (if is_stack_ctyp ctyp then [] else [ialloc ctyp case_return_id])
@ List.concat (List.map compile_case cases)
@ [imatch_failure ()]
@ [ilabel finish_match_label],
ctyp,
(fun clexp -> icopy clexp (F_id case_return_id, ctyp)),
- aval_cleanup
+ (if is_stack_ctyp ctyp then [] else [iclear ctyp case_return_id])
+ @ aval_cleanup
@ [icomment "end match"]
(* Compile try statement *)
@@ -1904,6 +1968,7 @@ let sgen_ctyp = function
| CT_struct (id, _) -> "struct " ^ sgen_id id
| CT_enum (id, _) -> "enum " ^ sgen_id id
| CT_variant (id, _) -> "struct " ^ sgen_id id
+ | CT_list _ as l -> Util.zencode_string (string_of_ctyp l)
| CT_string -> "sail_string"
let sgen_ctyp_name = function
@@ -1918,6 +1983,7 @@ let sgen_ctyp_name = function
| CT_struct (id, _) -> sgen_id id
| CT_enum (id, _) -> sgen_id id
| CT_variant (id, _) -> sgen_id id
+ | CT_list _ as l -> Util.zencode_string (string_of_ctyp l)
| CT_string -> "sail_string"
let sgen_cval_param (frag, ctyp) =
@@ -2205,19 +2271,73 @@ let codegen_type_def ctx = function
This variable should be reset to empty only when the entire AST has
been translated to C. **)
-let generated_tuples = ref IdSet.empty
+let generated = ref IdSet.empty
let codegen_tup ctx ctyps =
let id = mk_id ("tuple_" ^ string_of_ctyp (CT_tup ctyps)) in
- if IdSet.mem id !generated_tuples then
+ if IdSet.mem id !generated then
empty
else
- let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields)
- (0, Bindings.empty)
- ctyps
- in
- generated_tuples := IdSet.add id !generated_tuples;
- codegen_type_def ctx (CTD_struct (id, Bindings.bindings fields)) ^^ twice hardline
+ begin
+ let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields)
+ (0, Bindings.empty)
+ ctyps
+ in
+ generated := IdSet.add id !generated;
+ codegen_type_def ctx (CTD_struct (id, Bindings.bindings fields)) ^^ twice hardline
+ end
+
+let codegen_node id ctyp =
+ string (Printf.sprintf "struct node_%s {\n %s hd;\n struct node_%s *tl;\n};\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id))
+ ^^ string (Printf.sprintf "typedef struct node_%s *%s;" (sgen_id id) (sgen_id id))
+
+let codegen_list_init id =
+ string (Printf.sprintf "void init_%s(%s *rop) { *rop = NULL; }" (sgen_id id) (sgen_id id))
+
+let codegen_list_clear id ctyp =
+ string (Printf.sprintf "void clear_%s(%s *rop) {\n" (sgen_id id) (sgen_id id))
+ ^^ string (Printf.sprintf " if (*rop == NULL) return;")
+ ^^ string (Printf.sprintf " clear_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))
+ ^^ string (Printf.sprintf " clear_%s(&(*rop)->tl);\n" (sgen_id id))
+ ^^ string " free(*rop);"
+ ^^ string "}"
+
+let codegen_list_set id ctyp =
+ string (Printf.sprintf "void internal_set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
+ ^^ string " if (op == NULL) { *rop = NULL; return; };\n"
+ ^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id))
+ ^^ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))
+ ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, op->hd);\n" (sgen_ctyp_name ctyp))
+ ^^ string (Printf.sprintf " internal_set_%s(&(*rop)->tl, op->tl);\n" (sgen_id id))
+ ^^ string "}"
+ ^^ twice hardline
+ ^^ string (Printf.sprintf "void set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
+ ^^ string (Printf.sprintf " clear_%s(rop);\n" (sgen_id id))
+ ^^ string (Printf.sprintf " internal_set_%s(rop, op);\n" (sgen_id id))
+ ^^ string "}"
+
+let codegen_cons id ctyp =
+ let cons_id = mk_id ("cons#" ^ string_of_ctyp ctyp) in
+ string (Printf.sprintf "void %s(%s *rop, const %s x, const %s xs) {\n" (sgen_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id))
+ ^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id))
+ ^^ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))
+ ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, x);\n" (sgen_ctyp_name ctyp))
+ ^^ string " (*rop)->tl = xs;\n"
+ ^^ string "}"
+
+let codegen_list ctx ctyp =
+ let id = mk_id (string_of_ctyp (CT_list ctyp)) in
+ if IdSet.mem id !generated then
+ empty
+ else
+ begin
+ generated := IdSet.add id !generated;
+ codegen_node id ctyp ^^ twice hardline
+ ^^ codegen_list_init id ^^ twice hardline
+ ^^ codegen_list_clear id ctyp ^^ twice hardline
+ ^^ codegen_list_set id ctyp ^^ twice hardline
+ ^^ codegen_cons id ctyp ^^ twice hardline
+ end
let codegen_def' ctx = function
| CDEF_reg_dec (id, ctyp) ->
@@ -2271,10 +2391,17 @@ let codegen_def ctx def =
| CT_tup ctyps -> ctyps
| _ -> assert false
in
- let tups = List.filter is_ct_tup (cdef_ctyps def) in
+ let unlist = function
+ | CT_list ctyp -> ctyp
+ | _ -> assert false
+ in
+ let tups = List.filter is_ct_tup (cdef_ctyps ctx def) in
let tups = List.map (fun ctyp -> codegen_tup ctx (untup ctyp)) tups in
+ let lists = List.filter is_ct_list (cdef_ctyps ctx def) in
+ let lists = List.map (fun ctyp -> codegen_list ctx (unlist ctyp)) lists in
prerr_endline (Pretty_print_sail.to_string (pp_cdef def));
concat tups
+ ^^ concat lists
^^ codegen_def' ctx def
let compile_ast ctx (Defs defs) =
diff --git a/src/parser.mly b/src/parser.mly
index b781ea1f..c8cc49a3 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -658,6 +658,8 @@ pat:
{ $1 }
| pat1 As typ
{ mk_pat (P_var ($1, $3)) $startpos $endpos }
+ | pat1 Match typ
+ { mk_pat (P_var ($1, $3)) $startpos $endpos }
pat_list:
| pat
diff --git a/src/specialize.ml b/src/specialize.ml
index 9344e661..efa8783e 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -176,9 +176,16 @@ let specialize_id_fundef instantiations id ast =
match split_defs (is_fundef id) ast with
| None -> ast
| Some (pre_ast, DEF_fundef fundef, post_ast) ->
- let fundefs =
- List.map (fun i -> DEF_fundef (rename_fundef (id_of_instantiation id i) fundef)) instantiations
+ let spec_ids = ref IdSet.empty in
+ let specialize_fundef instantiation =
+ let spec_id = id_of_instantiation id instantiation in
+ if IdSet.mem spec_id !spec_ids then [] else
+ begin
+ spec_ids := IdSet.add spec_id !spec_ids;
+ [DEF_fundef (rename_fundef spec_id fundef)]
+ end
in
+ let fundefs = List.map specialize_fundef instantiations |> List.concat in
append_ast pre_ast (append_ast (Defs fundefs) post_ast)
| Some _ -> assert false (* unreachable *)