summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/flow.sail5
-rw-r--r--src/c_backend.ml367
2 files changed, 306 insertions, 66 deletions
diff --git a/lib/flow.sail b/lib/flow.sail
index 2ca0e1a8..10badc93 100644
--- a/lib/flow.sail
+++ b/lib/flow.sail
@@ -5,7 +5,7 @@ val not_bool = "not" : bool -> bool
val and_bool = "and_bool" : (bool, bool) -> bool
val or_bool = "or_bool" : (bool, bool) -> bool
-val eq_atom = {ocaml: "eq_int", lem: "eq"} : forall 'n 'm. (atom('n), atom('m)) -> bool
+val eq_atom = {ocaml: "eq_int", lem: "eq", c: "eq_int"} : forall 'n 'm. (atom('n), atom('m)) -> bool
val neq_atom = {lem: "neq"} : forall 'n 'm. (atom('n), atom('m)) -> bool
@@ -25,9 +25,10 @@ val lteq_atom_range = "lteq" : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> boo
val gt_atom_range = "gt" : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool
val gteq_atom_range = "gteq" : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool
+val eq_range = {ocaml: "eq_int", lem: "eq"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> bool
val eq_int = {ocaml: "eq_int", lem: "eq"} : (int, int) -> bool
-overload operator == = {eq_atom, eq_int}
+overload operator == = {eq_atom, eq_range, eq_int}
$ifdef TEST
diff --git a/src/c_backend.ml b/src/c_backend.ml
index 2ebd28c8..25c1669b 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -61,6 +61,7 @@ let zencode_id = function
let lvar_typ = function
| Local (_, typ) -> typ
| Register typ -> typ
+ | Enum typ -> typ
| _ -> assert false
(**************************************************************************)
@@ -103,10 +104,16 @@ type aexp =
| AE_return of aval * typ
| AE_if of aval * aexp * aexp * typ
| AE_field of aval * id * typ
+ | AE_case of aval * (apat * aexp * aexp) list * typ
| AE_record_update of aval * aval Bindings.t * typ
| AE_for of id * aexp * aexp * aexp * order * aexp
| AE_loop of loop * aexp * aexp
+and apat =
+ | AP_tup of apat list
+ | AP_id of id
+ | AP_wild
+
and aval =
| AV_lit of lit * typ
| AV_id of id * lvar
@@ -133,6 +140,10 @@ let rec map_aval f = function
AE_for (id, map_aval f aexp1, map_aval f aexp2, map_aval f aexp3, order, map_aval f aexp4)
| AE_record_update (aval, updates, typ) ->
AE_record_update (f aval, Bindings.map f updates, typ)
+ | AE_field (aval, id, typ) ->
+ AE_field (f aval, id, typ)
+ | AE_case (aval, cases, typ) ->
+ AE_case (f aval, List.map (fun (pat, aexp1, aexp2) -> pat, map_aval f aexp1, map_aval f aexp2) cases, typ)
(* Map over all the functions in an aexp. *)
let rec map_functions f = function
@@ -146,6 +157,8 @@ let rec map_functions f = function
| AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, map_functions f aexp1, map_functions f aexp2)
| AE_for (id, aexp1, aexp2, aexp3, order, aexp4) ->
AE_for (id, map_functions f aexp1, map_functions f aexp2, map_functions f aexp3, order, map_functions f aexp4)
+ | AE_case (aval, cases, typ) ->
+ AE_case (aval, List.map (fun (pat, aexp1, aexp2) -> pat, map_functions f aexp1, map_functions f aexp2) cases, typ)
| AE_field _ | AE_record_update _ | AE_val _ | AE_return _ as v -> v
(* For debugging we provide a pretty printer for ANF expressions. *)
@@ -217,8 +230,20 @@ let rec pp_aexp = function
in
header ^//^ pp_aexp aexp4
| AE_field _ -> string "FIELD"
+ | AE_case (aval, cases, typ) ->
+ pp_annot typ (separate space [string "match"; pp_aval aval; pp_cases cases])
| AE_record_update (_, _, typ) -> pp_annot typ (string "RECORD UPDATE")
+and pp_apat = function
+ | AP_wild -> string "_"
+ | AP_id id -> pp_id id
+ | AP_tup apats -> parens (separate_map (comma ^^ space) pp_apat apats)
+
+and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace
+
+and pp_case (apat, guard, body) =
+ separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body]
+
and pp_block = function
| [] -> string "()"
| [aexp] -> pp_aexp aexp
@@ -251,6 +276,12 @@ let rec split_block = function
exp :: exps, last
| [] -> failwith "empty block"
+let anf_pat (P_aux (p_aux, _) as pat) =
+ match p_aux with
+ | P_id id -> AP_id id
+ | P_wild -> AP_wild
+ | _ -> assert false
+
let rec anf (E_aux (e_aux, exp_annot) as exp) =
let to_aval = function
| AE_val v -> (v, fun x -> x)
@@ -260,6 +291,8 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) =
| AE_cast (_, typ)
| AE_if (_, _, _, typ)
| AE_field (_, _, typ)
+ | AE_case (_, _, typ)
+ | AE_record_update (_, _, typ)
as aexp ->
let id = gensym () in
(AV_id (id, Local (Immutable, typ)), fun x -> AE_let (id, typ, aexp, x, typ_of exp))
@@ -375,6 +408,17 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) =
let aval, wrap = to_aval (anf exp) in
wrap (AE_return (aval, typ_of exp))
+ | E_case (match_exp, pexps) ->
+ let match_aval, match_wrap = to_aval (anf match_exp) in
+ let anf_pexp (Pat_aux (pat_aux, _)) =
+ match pat_aux with
+ | Pat_when (pat, guard, body) ->
+ (anf_pat pat, anf guard, anf body)
+ | Pat_exp (pat, body) ->
+ (anf_pat pat, AE_val (AV_lit (mk_lit (L_true), bool_typ)), anf body)
+ in
+ match_wrap (AE_case (match_aval, List.map anf_pexp pexps, typ_of exp))
+
| E_var (LEXP_aux (LEXP_id id, _), binding, body)
| E_let (LB_aux (LB_val (P_aux (P_id id, _), binding), _), body) ->
let env = env_of body in
@@ -456,7 +500,7 @@ let initial_ctx env =
tc_env = env
}
-let ctyp_equal ctyp1 ctyp2 =
+let rec ctyp_equal ctyp1 ctyp2 =
match ctyp1, ctyp2 with
| CT_mpz, CT_mpz -> true
| CT_bv d1, CT_bv d2 -> d1 = d2
@@ -466,9 +510,12 @@ let ctyp_equal ctyp1 ctyp2 =
| CT_unit, CT_unit -> true
| CT_bool, CT_bool -> true
| CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0
+ | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0
+ | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0
+ | CT_tup ctyps1, CT_tup ctyps2 -> List.for_all2 ctyp_equal ctyps1 ctyps2
| _, _ -> false
-let string_of_ctyp = function
+let rec string_of_ctyp = function
| CT_mpz -> "mpz_t"
| CT_bv true -> "bv_t<dec>"
| CT_bv false -> "bv_t<inc>"
@@ -478,10 +525,11 @@ let string_of_ctyp = function
| CT_int -> "int"
| CT_unit -> "unit"
| CT_bool -> "bool"
- | CT_struct (id, _) -> string_of_id id
+ | CT_tup ctyps -> "(" ^ Util.string_of_list ", " string_of_ctyp ctyps ^ ")"
+ | CT_struct (id, _) | CT_enum (id, _) | CT_variant (id, _) -> string_of_id id
(* Convert a sail type into a C-type *)
-let ctyp_of_typ ctx (Typ_aux (typ_aux, _) as typ) =
+let rec ctyp_of_typ ctx (Typ_aux (typ_aux, _) as typ) =
match typ_aux with
| Typ_id id when string_of_id id = "bit" -> CT_int
| Typ_id id when string_of_id id = "bool" -> CT_bool
@@ -508,14 +556,21 @@ let ctyp_of_typ ctx (Typ_aux (typ_aux, _) as typ) =
| _ -> CT_bv direction
end
| Typ_id id when string_of_id id = "unit" -> CT_unit
+
| Typ_id id when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records)
+ | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums)
+ | Typ_id id when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants)
+
+ | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs)
| _ -> failwith ("No C-type for type " ^ string_of_typ typ)
let rec is_stack_ctyp ctyp = match ctyp with
- | CT_uint64 _ | CT_int64 | CT_int | CT_unit | CT_bool -> true
+ | CT_uint64 _ | CT_int64 | CT_int | CT_unit | CT_bool | CT_enum _ -> true
| CT_bv _ | CT_mpz -> false
| CT_struct (_, fields) -> Bindings.for_all (fun _ ctyp -> is_stack_ctyp ctyp) fields
+ | CT_variant (_, ctors) -> Bindings.for_all (fun _ ctyp -> is_stack_ctyp ctyp) ctors
+ | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps
let is_stack_typ ctx typ = is_stack_ctyp (ctyp_of_typ ctx typ)
@@ -569,7 +624,7 @@ let c_aval ctx = function
begin
match lvar with
| Local (_, typ) when is_stack_typ ctx typ ->
- AV_C_fragment (string_of_id id, typ)
+ AV_C_fragment (Util.zencode_string (string_of_id id), typ)
| _ -> v
end
| AV_tuple avals -> AV_tuple avals
@@ -605,6 +660,15 @@ let analyze_primop' ctx id args typ =
| _ -> no_change
end
+ else if string_of_id id = "eq_range" || string_of_id id = "eq_atom" then
+ begin
+ match List.map (c_aval ctx) args with
+ | [x; y] when is_c_fragment x && is_c_fragment y ->
+ AE_val (AV_C_fragment ("(" ^ c_fragment_string x ^ " == " ^ c_fragment_string y ^ ")", typ))
+ | _ ->
+ no_change
+ end
+
else if string_of_id id = "xor_vec" then
begin
let n, x, y = match typ, args with
@@ -656,6 +720,11 @@ type ctype_def =
| CTD_record of id * ctyp Bindings.t
| CTD_variant of id * ctyp Bindings.t
+let ctype_def_ctyps = function
+ | CTD_enum _ -> []
+ | CTD_record (_, fields) -> List.map snd (Bindings.bindings fields)
+ | CTD_variant (_, ctors) -> List.map snd (Bindings.bindings ctors)
+
type cval =
| CV_id of id * ctyp
| CV_C_fragment of string * ctyp
@@ -664,6 +733,10 @@ let cval_ctyp = function
| CV_id (_, ctyp) -> ctyp
| CV_C_fragment (_, ctyp) -> ctyp
+type clexp =
+ | CL_id of id
+ | CL_field of id * id
+
type instr =
| I_decl of ctyp * id
| I_alloc of ctyp * id
@@ -672,26 +745,51 @@ type instr =
| I_funcall of id * id * cval list * ctyp
| I_convert of id * ctyp * id * ctyp
| I_assign of id * cval
- | I_copy of id * cval
+ | I_copy of clexp * cval
| I_clear of ctyp * id
| I_return of id
+ | I_block of instr list
| I_comment of string
+ | I_label of string
+ | I_goto of string
+ | I_raw of string
type cdef =
| CDEF_reg_dec of ctyp * id
- | CDEF_fundef of id * id list * instr list
+ | CDEF_fundef of id * id option * id list * instr list
| CDEF_type of ctype_def
+let rec instr_ctyps = function
+ | I_decl (ctyp, _) | I_alloc (ctyp, _) | I_clear (ctyp, _) -> [ctyp]
+ | I_init (ctyp, _, cval) -> [ctyp; cval_ctyp cval]
+ | I_if (cval, instrs1, instrs2, ctyp) ->
+ ctyp :: cval_ctyp cval :: List.concat (List.map instr_ctyps instrs1 @ List.map instr_ctyps instrs2)
+ | I_funcall (_, _, cvals, ctyp) ->
+ ctyp :: List.map cval_ctyp cvals
+ | I_convert (_, ctyp1, _, ctyp2) -> [ctyp1; ctyp2]
+ | I_assign (_, cval) | I_copy (_, cval) -> [cval_ctyp cval]
+ | I_block instrs -> List.concat (List.map instr_ctyps instrs)
+ | I_return _ | I_comment _ | I_label _ | I_goto _ | I_raw _ -> []
+
+let cdef_ctyps = function
+ | CDEF_reg_dec (ctyp, _) -> [ctyp]
+ | CDEF_fundef (_, _, _, instrs) -> List.concat (List.map instr_ctyps instrs)
+ | CDEF_type tdef -> ctype_def_ctyps tdef
+
let pp_ctyp ctyp =
string (string_of_ctyp ctyp |> Util.yellow |> Util.clear)
let pp_keyword str =
string ((str |> Util.red |> Util.clear) ^ "$")
-and pp_cval = function
+let pp_cval = function
| CV_id (id, ctyp) -> parens (pp_ctyp ctyp) ^^ (pp_id id)
| CV_C_fragment (str, ctyp) -> parens (pp_ctyp ctyp) ^^ (string (str |> Util.cyan |> Util.clear))
+let pp_clexp = function
+ | CL_id id -> pp_id id
+ | CL_field (id, field) -> pp_id id ^^ string "." ^^ pp_id field
+
let rec pp_instr = function
| I_decl (ctyp, id) ->
parens (pp_ctyp ctyp) ^^ space ^^ pp_id id
@@ -701,6 +799,8 @@ let rec pp_instr = function
^^ pp_keyword "IF" ^^ pp_cval cval
^^ pp_keyword "THEN" ^^ pp_if_block then_instrs
^^ pp_keyword "ELSE" ^^ pp_if_block else_instrs
+ | I_block instrs ->
+ surround 2 0 lbrace (separate_map hardline pp_instr instrs) rbrace
| I_alloc (ctyp, id) ->
pp_keyword "ALLOC" ^^ parens (pp_ctyp ctyp) ^^ space ^^ pp_id id
| I_init (ctyp, id, cval) ->
@@ -715,14 +815,20 @@ let rec pp_instr = function
string "->"; pp_ctyp ctyp1 ]
| I_assign (id, cval) ->
separate space [pp_id id; string ":="; pp_cval cval]
- | I_copy (id, cval) ->
- separate space [string "let"; pp_id id; string "="; pp_cval cval]
+ | I_copy (clexp, cval) ->
+ separate space [string "let"; pp_clexp clexp; string "="; pp_cval cval]
| I_clear (ctyp, id) ->
pp_keyword "CLEAR" ^^ pp_ctyp ctyp ^^ parens (pp_id id)
| I_return id ->
pp_keyword "RETURN" ^^ pp_id id
| I_comment str ->
string ("// " ^ str)
+ | I_label str ->
+ string (str ^ ":")
+ | I_goto str ->
+ pp_keyword "GOTO" ^^ string str
+ | I_raw str ->
+ pp_keyword "C" ^^ string str
let compile_funcall ctx id args typ =
let setup = ref [] in
@@ -780,6 +886,59 @@ let compile_funcall ctx id args typ =
(List.rev !setup, final_ctyp, call, !cleanup)
+let is_ct_enum = function
+ | CT_enum _ -> true
+ | _ -> false
+
+let is_ct_tup = function
+ | CT_tup _ -> true
+ | _ -> false
+
+let rec compile_aval ctx = function
+ | AV_C_fragment (code, typ) ->
+ [], CV_C_fragment (code, ctyp_of_typ ctx typ), []
+
+ | AV_id (id, typ) ->
+ begin
+ match ctyp_of_typ ctx (lvar_typ typ) with
+ | CT_enum (_, elems) when IdSet.mem id elems ->
+ [], CV_C_fragment (Util.zencode_upper_string (string_of_id id), ctyp_of_typ ctx (lvar_typ typ)), []
+ | _ ->
+ [], CV_id (id, ctyp_of_typ ctx (lvar_typ typ)), []
+ end
+
+ | AV_tuple avals ->
+ let elements = List.map (compile_aval ctx) avals in
+ let cvals = List.map (fun (_, cval, _) -> cval) elements in
+ let setup = List.concat (List.map (fun (setup, _, _) -> setup) elements) in
+ let cleanup = List.concat (List.rev (List.map (fun (_, _, cleanup) -> cleanup) elements)) in
+ let tup_ctyp = CT_tup (List.map cval_ctyp cvals) in
+ let gs = gensym () in
+ setup
+ @ [I_decl (tup_ctyp, gs)]
+ @ List.mapi (fun n cval -> I_copy (CL_field (gs, mk_id ("tup" ^ string_of_int n)), cval)) cvals,
+ CV_id (gs, CT_tup (List.map cval_ctyp cvals)),
+ cleanup
+
+let rec compile_match ctx apat cval case_label =
+ match apat, cval with
+ | AP_id pid, CV_C_fragment (code, ctyp) when is_ct_enum ctyp ->
+ [ I_if (CV_C_fragment (Util.zencode_upper_string (string_of_id pid) ^ " != " ^ code, CT_bool), [I_goto case_label], [], CT_unit) ]
+ | AP_id pid, CV_id (id, ctyp) when is_ct_enum ctyp ->
+ [ I_if (CV_C_fragment (Util.zencode_upper_string (string_of_id pid) ^ " != " ^ Util.zencode_string (string_of_id id), CT_bool), [I_goto case_label], [], CT_unit) ]
+ | AP_id pid, CV_C_fragment (code, ctyp) ->
+ [ I_init (ctyp, pid, cval) ]
+ | AP_id pid, CV_id _ ->
+ [ I_decl (cval_ctyp cval, pid); I_copy (CL_id pid, cval) ]
+ | _, _ -> []
+
+let label_counter = ref 0
+
+let label str =
+ let str = str ^ string_of_int !label_counter in
+ incr label_counter;
+ str
+
let rec compile_aexp ctx = function
| AE_let (id, _, binding, body, typ) ->
let setup, ctyp, call, cleanup = compile_aexp ctx binding in
@@ -796,24 +955,48 @@ let rec compile_aexp ctx = function
| AE_app (id, vs, typ) ->
compile_funcall ctx id vs typ
- | AE_val (AV_C_fragment (c, typ)) ->
- let ctyp = ctyp_of_typ ctx typ in
- [], ctyp, (fun id -> I_copy (id, CV_C_fragment (c, ctyp))), []
+ | AE_val aval ->
+ let setup, cval, cleanup = compile_aval ctx aval in
+ setup, cval_ctyp cval, (fun id -> I_copy (CL_id id, cval)), cleanup
- | AE_val (AV_id (id, lvar)) ->
- let ctyp = ctyp_of_typ ctx (lvar_typ lvar) in
- [], ctyp, (fun id' -> I_copy (id', CV_id (id, ctyp))), []
-
- | AE_val (AV_lit (lit, typ)) ->
+ (* Compile case statements *)
+ | AE_case (aval, cases, typ) ->
let ctyp = ctyp_of_typ ctx typ in
- if is_stack_ctyp ctyp then
- assert false
- else
+ let aval_setup, cval, aval_cleanup = compile_aval ctx aval in
+ let case_return_id = gensym () in
+ let finish_match_label = label "finish_match_" in
+ let compile_case (apat, guard, body) =
+ let trivial_guard = match guard with
+ | AE_val (AV_lit (L_aux (L_true, _), _))
+ | AE_val (AV_C_fragment ("true", _)) -> true
+ | _ -> false
+ in
+ let case_label = label "case_" in
+ let destructure = compile_match ctx apat cval case_label in
+ let guard_setup, _, guard_call, guard_cleanup = compile_aexp ctx guard in
+ let body_setup, _, body_call, body_cleanup = compile_aexp ctx body in
let gs = gensym () in
- [I_alloc (ctyp, gs); I_comment "fix literal init"],
- ctyp,
- (fun id -> I_copy (id, CV_id (gs, ctyp))),
- [I_clear (ctyp, gs)]
+ let case_instrs =
+ destructure @ [I_comment "end destructuring"]
+ @ (if not trivial_guard then
+ guard_setup @ [I_decl (CT_bool, gs); guard_call gs] @ guard_cleanup
+ @ [I_if (CV_C_fragment (Printf.sprintf "!%s" (Util.zencode_string (string_of_id gs)), CT_bool), [I_goto case_label], [], CT_unit)]
+ @ [I_comment "end guard"]
+ else [])
+ @ body_setup @ [body_call case_return_id] @ body_cleanup
+ @ [I_goto finish_match_label]
+ in
+ [I_block case_instrs; I_label case_label]
+ in
+ [I_comment "begin match"]
+ @ aval_setup @ [I_decl (ctyp, case_return_id)]
+ @ List.concat (List.map compile_case cases)
+ @ [I_raw "sail_match_failure();"]
+ @ [I_label finish_match_label],
+ ctyp,
+ (fun id -> I_copy (CL_id id, CV_id (case_return_id, ctyp))),
+ aval_cleanup
+ @ [I_comment "end match"]
| AE_if (aval, then_aexp, else_aexp, if_typ) ->
let if_ctyp = ctyp_of_typ ctx if_typ in
@@ -841,7 +1024,7 @@ let rec compile_aexp ctx = function
let gs = gensym () in
[I_alloc (ctyp, gs)] @ setup @ List.map (fun call -> call gs) calls,
ctyp,
- (fun id -> I_copy (id, CV_id (gs, ctyp))),
+ (fun id -> I_copy (CL_id id, CV_id (gs, ctyp))),
cleanup @ [I_clear (ctyp, gs)]
| AE_assign (id, assign_typ, aexp) ->
@@ -853,7 +1036,7 @@ let rec compile_aexp ctx = function
let unit_fragment = CV_C_fragment ("UNIT", CT_unit) in
let comment = "assign " ^ string_of_ctyp assign_ctyp ^ " := " ^ string_of_ctyp ctyp in
if ctyp_equal assign_ctyp ctyp then
- setup @ [call id], CT_unit, (fun id -> I_copy (id, unit_fragment)), cleanup
+ setup @ [call id], CT_unit, (fun id -> I_copy (CL_id id, unit_fragment)), cleanup
else if not (is_stack_ctyp assign_ctyp) && is_stack_ctyp ctyp then
let gs = gensym () in
setup @ [ I_comment comment;
@@ -862,7 +1045,7 @@ let rec compile_aexp ctx = function
I_convert (id, assign_ctyp, gs, ctyp)
],
CT_unit,
- (fun id -> I_copy (id, unit_fragment)),
+ (fun id -> I_copy (CL_id id, unit_fragment)),
cleanup
else
failwith comment
@@ -911,8 +1094,8 @@ let compile_type_def ctx (TD_aux (type_def, _)) =
(* Will be re-written before here, see bitfield.ml *)
| TD_bitfield _ -> failwith "Cannot compile TD_bitfield"
- (* All type abbreviations will be removed before now. TODO: point to where this is done. *)
- | TD_abbrev _ -> failwith "Cannot compile TD_abbrev"
+ (* All type abbreviations are filtered out in compile_def *)
+ | TD_abbrev _ -> assert false
let compile_def ctx = function
| DEF_reg_dec (DEC_aux (DEC_reg (typ, id), _)) ->
@@ -926,25 +1109,27 @@ let compile_def ctx = function
match pexp with
| Pat_aux (Pat_exp (pat, exp), _) ->
let aexp = map_functions (analyze_primop ctx) (c_literals ctx (anf exp)) in
- print_endline (Pretty_print_sail.to_string (pp_aexp aexp));
+ prerr_endline (Pretty_print_sail.to_string (pp_aexp aexp));
let setup, ctyp, call, cleanup = compile_aexp ctx aexp in
let gs = gensym () in
- let instrs =
- if is_stack_ctyp ctyp then
- setup @ [I_decl (ctyp, gs); call gs] @ cleanup @ [I_return gs]
- else
- assert false
- in
- [CDEF_fundef (id, pat_ids pat, instrs)], ctx
+ if is_stack_ctyp ctyp then
+ let instrs = [I_decl (ctyp, gs)] @ setup @ [call gs] @ cleanup @ [I_return gs] in
+ [CDEF_fundef (id, None, pat_ids pat, instrs)], ctx
+ else
+ let instrs = setup @ [call gs] @ cleanup in
+ [CDEF_fundef (id, Some gs, pat_ids pat, instrs)], ctx
| _ -> assert false
end
+ | DEF_type (TD_aux (TD_abbrev _, _)) -> [], ctx
| DEF_type type_def ->
let tdef, ctx = compile_type_def ctx type_def in
[CDEF_type tdef], ctx
| DEF_default _ -> [], ctx
+ | DEF_overload _ -> [], ctx
+
| _ -> assert false
(**************************************************************************)
@@ -965,6 +1150,7 @@ let sgen_ctyp = function
| CT_int64 -> "int64_t"
| CT_mpz -> "mpz_t"
| CT_bv _ -> "bv_t"
+ | CT_tup _ as tup -> "struct " ^ Util.zencode_string ("tuple_" ^ string_of_ctyp tup)
| CT_struct (id, _) -> "struct " ^ sgen_id id
| CT_enum (id, _) -> "enum " ^ sgen_id id
| CT_variant (id, _) -> "struct " ^ sgen_id id
@@ -977,6 +1163,7 @@ let sgen_ctyp_name = function
| CT_int64 -> "int64_t"
| CT_mpz -> "mpz_t"
| CT_bv _ -> "bv_t"
+ | CT_tup _ as tup -> Util.zencode_string ("tuple_" ^ string_of_ctyp tup)
| CT_struct (id, _) -> sgen_id id
| CT_enum (id, _) -> sgen_id id
| CT_variant (id, _) -> sgen_id id
@@ -986,55 +1173,73 @@ let sgen_cval = function
| CV_id (id, _) -> sgen_id id
| _ -> "CVAL??"
-let rec codegen_instr = function
+let sgen_clexp = function
+ | CL_id id -> sgen_id id
+ | CL_field (id, field) -> sgen_id id ^ "." ^ sgen_id field
+
+let rec codegen_instr ctx = function
| I_decl (ctyp, id) ->
- string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))
- | I_copy (id, cval) ->
+ string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_id id))
+ | I_copy (clexp, cval) ->
let ctyp = cval_ctyp cval in
if is_stack_ctyp ctyp then
- string (Printf.sprintf "%s = %s;" (sgen_id id) (sgen_cval cval))
+ string (Printf.sprintf " %s = %s;" (sgen_clexp clexp) (sgen_cval cval))
else
- string (Printf.sprintf "set_%s(%s, %s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_cval cval))
+ string (Printf.sprintf " set_%s(%s, %s);" (sgen_ctyp_name ctyp) (sgen_clexp clexp) (sgen_cval cval))
+ | I_if (cval, [then_instr], [], ctyp) ->
+ string (Printf.sprintf " if (%s)" (sgen_cval cval)) ^^ hardline
+ ^^ twice space ^^ codegen_instr ctx then_instr
| I_if (cval, then_instrs, else_instrs, ctyp) ->
- string "if" ^^ space ^^ parens (string (sgen_cval cval)) ^^ space
- ^^ surround 2 0 lbrace (separate_map hardline codegen_instr then_instrs) rbrace
+ string " if" ^^ space ^^ parens (string (sgen_cval cval)) ^^ space
+ ^^ surround 2 0 lbrace (separate_map hardline (codegen_instr ctx) then_instrs) rbrace
^^ space ^^ string "else" ^^ space
- ^^ surround 2 0 lbrace (separate_map hardline codegen_instr else_instrs) rbrace
+ ^^ surround 2 0 lbrace (separate_map hardline (codegen_instr ctx) else_instrs) rbrace
+ | I_block instrs ->
+ string " {"
+ ^^ jump 2 2 (separate_map hardline (codegen_instr ctx) instrs) ^^ hardline
+ ^^ string " }"
| I_funcall (x, f, args, ctyp) ->
let args = Util.string_of_list ", " sgen_cval args in
+ let fname = if Env.is_extern f ctx.tc_env "c" then Env.get_extern f ctx.tc_env "c" else sgen_id f in
if is_stack_ctyp ctyp then
- string (Printf.sprintf "%s = %s(%s);" (sgen_id x) (sgen_id f) args)
+ string (Printf.sprintf " %s = %s(%s);" (sgen_id x) fname args)
else
- string (Printf.sprintf "%s(%s, %s);" (sgen_id f) (sgen_id x) args)
+ string (Printf.sprintf " %s(%s, %s);" fname (sgen_id x) args)
| I_clear (ctyp, id) ->
- string (Printf.sprintf "clear_%s(%s);" (sgen_ctyp_name ctyp) (sgen_id id))
+ string (Printf.sprintf " clear_%s(%s);" (sgen_ctyp_name ctyp) (sgen_id id))
| I_init (ctyp, id, cval) ->
- string (Printf.sprintf "init_%s_of_%s(%s, %s);"
+ string (Printf.sprintf " init_%s_of_%s(%s, %s);"
(sgen_ctyp_name ctyp)
(sgen_ctyp_name (cval_ctyp cval))
(sgen_id id)
(sgen_cval cval))
| I_alloc (ctyp, id) ->
- string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))
+ string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_id id))
^^ hardline
- ^^ string (Printf.sprintf "init_%s(%s);" (sgen_ctyp_name ctyp) (sgen_id id))
+ ^^ string (Printf.sprintf " init_%s(%s);" (sgen_ctyp_name ctyp) (sgen_id id))
| I_convert (x, ctyp1, y, ctyp2) ->
if is_stack_ctyp ctyp1 then
- string (Printf.sprintf "%s = convert_%s_of_%s(%s);"
+ string (Printf.sprintf " %s = convert_%s_of_%s(%s);"
(sgen_id x)
(sgen_ctyp_name ctyp1)
(sgen_ctyp_name ctyp2)
(sgen_id y))
else
- string (Printf.sprintf "convert_%s_of_%s(%s, %s);"
+ string (Printf.sprintf " convert_%s_of_%s(%s, %s);"
(sgen_ctyp_name ctyp1)
(sgen_ctyp_name ctyp2)
(sgen_id x)
(sgen_id y))
| I_return id ->
- string (Printf.sprintf "return %s;" (sgen_id id))
+ string (Printf.sprintf " return %s;" (sgen_id id))
| I_comment str ->
- string ("/* " ^ str ^ " */")
+ string (" /* " ^ str ^ " */")
+ | I_label str ->
+ string (str ^ ": ;")
+ | I_goto str ->
+ string (Printf.sprintf " goto %s;" str)
+ | I_raw str ->
+ string (" " ^ str)
let codegen_type_def ctx = function
| CTD_enum (id, ids) ->
@@ -1104,12 +1309,27 @@ let codegen_type_def ctx = function
rbrace
^^ semi
-let codegen_def ctx = function
+let generated_tuples = ref IdSet.empty
+
+let codegen_tup ctx ctyps =
+ let id = mk_id ("tuple_" ^ string_of_ctyp (CT_tup ctyps)) in
+ if IdSet.mem id !generated_tuples then
+ empty
+ else
+ let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields)
+ (0, Bindings.empty)
+ ctyps
+ in
+ generated_tuples := IdSet.add id !generated_tuples;
+ codegen_type_def ctx (CTD_record (id, fields)) ^^ twice hardline
+
+let codegen_def' ctx = function
| CDEF_reg_dec (ctyp, id) ->
string (Printf.sprintf "// register %s" (string_of_id id)) ^^ hardline
^^ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))
- | CDEF_fundef (id, args, instrs) ->
- List.iter (fun instr -> print_endline (Pretty_print_sail.to_string (pp_instr instr))) instrs;
+
+ | CDEF_fundef (id, ret_arg, args, instrs) ->
+ List.iter (fun instr -> prerr_endline (Pretty_print_sail.to_string (pp_instr instr))) instrs;
let _, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in
let arg_typs, ret_typ = match fn_typ with
| Typ_fn (Typ_aux (Typ_tup arg_typs, _), ret_typ, _) -> arg_typs, ret_typ
@@ -1117,17 +1337,36 @@ let codegen_def ctx = function
| _ -> assert false
in
let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in
-
let args = Util.string_of_list ", " (fun x -> x) (List.map2 (fun ctyp arg -> sgen_ctyp ctyp ^ " " ^ sgen_id arg) arg_ctyps args) in
-
- string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline
+ let function_header =
+ match ret_arg with
+ | None ->
+ assert (is_stack_ctyp ret_ctyp);
+ string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline
+ | Some gs ->
+ assert (not (is_stack_ctyp ret_ctyp));
+ string "void" ^^ space ^^ codegen_id id
+ ^^ parens (string (sgen_ctyp ret_ctyp ^ " " ^ sgen_id gs ^ ", ") ^^ string args)
+ ^^ hardline
+ in
+ function_header
^^ string "{"
- ^^ jump 2 2 (separate_map hardline codegen_instr instrs) ^^ hardline
+ ^^ jump 0 2 (separate_map hardline (codegen_instr ctx) instrs) ^^ hardline
^^ string "}"
| CDEF_type ctype_def ->
codegen_type_def ctx ctype_def
+let codegen_def ctx def =
+ let untup = function
+ | CT_tup ctyps -> ctyps
+ | _ -> assert false
+ in
+ let tups = List.filter is_ct_tup (cdef_ctyps def) in
+ let tups = List.map (fun ctyp -> codegen_tup ctx (untup ctyp)) tups in
+ concat tups
+ ^^ codegen_def' ctx def
+
let compile_ast ctx (Defs defs) =
let chunks, ctx = List.fold_left (fun (chunks, ctx) def -> let defs, ctx = compile_def ctx def in defs :: chunks, ctx) ([], ctx) defs in
let cdefs = List.concat (List.rev chunks) in