diff options
| -rw-r--r-- | lib/flow.sail | 5 | ||||
| -rw-r--r-- | src/c_backend.ml | 367 |
2 files changed, 306 insertions, 66 deletions
diff --git a/lib/flow.sail b/lib/flow.sail index 2ca0e1a8..10badc93 100644 --- a/lib/flow.sail +++ b/lib/flow.sail @@ -5,7 +5,7 @@ val not_bool = "not" : bool -> bool val and_bool = "and_bool" : (bool, bool) -> bool val or_bool = "or_bool" : (bool, bool) -> bool -val eq_atom = {ocaml: "eq_int", lem: "eq"} : forall 'n 'm. (atom('n), atom('m)) -> bool +val eq_atom = {ocaml: "eq_int", lem: "eq", c: "eq_int"} : forall 'n 'm. (atom('n), atom('m)) -> bool val neq_atom = {lem: "neq"} : forall 'n 'm. (atom('n), atom('m)) -> bool @@ -25,9 +25,10 @@ val lteq_atom_range = "lteq" : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> boo val gt_atom_range = "gt" : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool val gteq_atom_range = "gteq" : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool +val eq_range = {ocaml: "eq_int", lem: "eq"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> bool val eq_int = {ocaml: "eq_int", lem: "eq"} : (int, int) -> bool -overload operator == = {eq_atom, eq_int} +overload operator == = {eq_atom, eq_range, eq_int} $ifdef TEST diff --git a/src/c_backend.ml b/src/c_backend.ml index 2ebd28c8..25c1669b 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -61,6 +61,7 @@ let zencode_id = function let lvar_typ = function | Local (_, typ) -> typ | Register typ -> typ + | Enum typ -> typ | _ -> assert false (**************************************************************************) @@ -103,10 +104,16 @@ type aexp = | AE_return of aval * typ | AE_if of aval * aexp * aexp * typ | AE_field of aval * id * typ + | AE_case of aval * (apat * aexp * aexp) list * typ | AE_record_update of aval * aval Bindings.t * typ | AE_for of id * aexp * aexp * aexp * order * aexp | AE_loop of loop * aexp * aexp +and apat = + | AP_tup of apat list + | AP_id of id + | AP_wild + and aval = | AV_lit of lit * typ | AV_id of id * lvar @@ -133,6 +140,10 @@ let rec map_aval f = function AE_for (id, map_aval f aexp1, map_aval f aexp2, map_aval f aexp3, order, map_aval f aexp4) | AE_record_update (aval, updates, typ) -> AE_record_update (f aval, Bindings.map f updates, typ) + | AE_field (aval, id, typ) -> + AE_field (f aval, id, typ) + | AE_case (aval, cases, typ) -> + AE_case (f aval, List.map (fun (pat, aexp1, aexp2) -> pat, map_aval f aexp1, map_aval f aexp2) cases, typ) (* Map over all the functions in an aexp. *) let rec map_functions f = function @@ -146,6 +157,8 @@ let rec map_functions f = function | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, map_functions f aexp1, map_functions f aexp2) | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> AE_for (id, map_functions f aexp1, map_functions f aexp2, map_functions f aexp3, order, map_functions f aexp4) + | AE_case (aval, cases, typ) -> + AE_case (aval, List.map (fun (pat, aexp1, aexp2) -> pat, map_functions f aexp1, map_functions f aexp2) cases, typ) | AE_field _ | AE_record_update _ | AE_val _ | AE_return _ as v -> v (* For debugging we provide a pretty printer for ANF expressions. *) @@ -217,8 +230,20 @@ let rec pp_aexp = function in header ^//^ pp_aexp aexp4 | AE_field _ -> string "FIELD" + | AE_case (aval, cases, typ) -> + pp_annot typ (separate space [string "match"; pp_aval aval; pp_cases cases]) | AE_record_update (_, _, typ) -> pp_annot typ (string "RECORD UPDATE") +and pp_apat = function + | AP_wild -> string "_" + | AP_id id -> pp_id id + | AP_tup apats -> parens (separate_map (comma ^^ space) pp_apat apats) + +and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace + +and pp_case (apat, guard, body) = + separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body] + and pp_block = function | [] -> string "()" | [aexp] -> pp_aexp aexp @@ -251,6 +276,12 @@ let rec split_block = function exp :: exps, last | [] -> failwith "empty block" +let anf_pat (P_aux (p_aux, _) as pat) = + match p_aux with + | P_id id -> AP_id id + | P_wild -> AP_wild + | _ -> assert false + let rec anf (E_aux (e_aux, exp_annot) as exp) = let to_aval = function | AE_val v -> (v, fun x -> x) @@ -260,6 +291,8 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = | AE_cast (_, typ) | AE_if (_, _, _, typ) | AE_field (_, _, typ) + | AE_case (_, _, typ) + | AE_record_update (_, _, typ) as aexp -> let id = gensym () in (AV_id (id, Local (Immutable, typ)), fun x -> AE_let (id, typ, aexp, x, typ_of exp)) @@ -375,6 +408,17 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = let aval, wrap = to_aval (anf exp) in wrap (AE_return (aval, typ_of exp)) + | E_case (match_exp, pexps) -> + let match_aval, match_wrap = to_aval (anf match_exp) in + let anf_pexp (Pat_aux (pat_aux, _)) = + match pat_aux with + | Pat_when (pat, guard, body) -> + (anf_pat pat, anf guard, anf body) + | Pat_exp (pat, body) -> + (anf_pat pat, AE_val (AV_lit (mk_lit (L_true), bool_typ)), anf body) + in + match_wrap (AE_case (match_aval, List.map anf_pexp pexps, typ_of exp)) + | E_var (LEXP_aux (LEXP_id id, _), binding, body) | E_let (LB_aux (LB_val (P_aux (P_id id, _), binding), _), body) -> let env = env_of body in @@ -456,7 +500,7 @@ let initial_ctx env = tc_env = env } -let ctyp_equal ctyp1 ctyp2 = +let rec ctyp_equal ctyp1 ctyp2 = match ctyp1, ctyp2 with | CT_mpz, CT_mpz -> true | CT_bv d1, CT_bv d2 -> d1 = d2 @@ -466,9 +510,12 @@ let ctyp_equal ctyp1 ctyp2 = | CT_unit, CT_unit -> true | CT_bool, CT_bool -> true | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0 + | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0 + | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0 + | CT_tup ctyps1, CT_tup ctyps2 -> List.for_all2 ctyp_equal ctyps1 ctyps2 | _, _ -> false -let string_of_ctyp = function +let rec string_of_ctyp = function | CT_mpz -> "mpz_t" | CT_bv true -> "bv_t<dec>" | CT_bv false -> "bv_t<inc>" @@ -478,10 +525,11 @@ let string_of_ctyp = function | CT_int -> "int" | CT_unit -> "unit" | CT_bool -> "bool" - | CT_struct (id, _) -> string_of_id id + | CT_tup ctyps -> "(" ^ Util.string_of_list ", " string_of_ctyp ctyps ^ ")" + | CT_struct (id, _) | CT_enum (id, _) | CT_variant (id, _) -> string_of_id id (* Convert a sail type into a C-type *) -let ctyp_of_typ ctx (Typ_aux (typ_aux, _) as typ) = +let rec ctyp_of_typ ctx (Typ_aux (typ_aux, _) as typ) = match typ_aux with | Typ_id id when string_of_id id = "bit" -> CT_int | Typ_id id when string_of_id id = "bool" -> CT_bool @@ -508,14 +556,21 @@ let ctyp_of_typ ctx (Typ_aux (typ_aux, _) as typ) = | _ -> CT_bv direction end | Typ_id id when string_of_id id = "unit" -> CT_unit + | Typ_id id when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records) + | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums) + | Typ_id id when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants) + + | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) | _ -> failwith ("No C-type for type " ^ string_of_typ typ) let rec is_stack_ctyp ctyp = match ctyp with - | CT_uint64 _ | CT_int64 | CT_int | CT_unit | CT_bool -> true + | CT_uint64 _ | CT_int64 | CT_int | CT_unit | CT_bool | CT_enum _ -> true | CT_bv _ | CT_mpz -> false | CT_struct (_, fields) -> Bindings.for_all (fun _ ctyp -> is_stack_ctyp ctyp) fields + | CT_variant (_, ctors) -> Bindings.for_all (fun _ ctyp -> is_stack_ctyp ctyp) ctors + | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps let is_stack_typ ctx typ = is_stack_ctyp (ctyp_of_typ ctx typ) @@ -569,7 +624,7 @@ let c_aval ctx = function begin match lvar with | Local (_, typ) when is_stack_typ ctx typ -> - AV_C_fragment (string_of_id id, typ) + AV_C_fragment (Util.zencode_string (string_of_id id), typ) | _ -> v end | AV_tuple avals -> AV_tuple avals @@ -605,6 +660,15 @@ let analyze_primop' ctx id args typ = | _ -> no_change end + else if string_of_id id = "eq_range" || string_of_id id = "eq_atom" then + begin + match List.map (c_aval ctx) args with + | [x; y] when is_c_fragment x && is_c_fragment y -> + AE_val (AV_C_fragment ("(" ^ c_fragment_string x ^ " == " ^ c_fragment_string y ^ ")", typ)) + | _ -> + no_change + end + else if string_of_id id = "xor_vec" then begin let n, x, y = match typ, args with @@ -656,6 +720,11 @@ type ctype_def = | CTD_record of id * ctyp Bindings.t | CTD_variant of id * ctyp Bindings.t +let ctype_def_ctyps = function + | CTD_enum _ -> [] + | CTD_record (_, fields) -> List.map snd (Bindings.bindings fields) + | CTD_variant (_, ctors) -> List.map snd (Bindings.bindings ctors) + type cval = | CV_id of id * ctyp | CV_C_fragment of string * ctyp @@ -664,6 +733,10 @@ let cval_ctyp = function | CV_id (_, ctyp) -> ctyp | CV_C_fragment (_, ctyp) -> ctyp +type clexp = + | CL_id of id + | CL_field of id * id + type instr = | I_decl of ctyp * id | I_alloc of ctyp * id @@ -672,26 +745,51 @@ type instr = | I_funcall of id * id * cval list * ctyp | I_convert of id * ctyp * id * ctyp | I_assign of id * cval - | I_copy of id * cval + | I_copy of clexp * cval | I_clear of ctyp * id | I_return of id + | I_block of instr list | I_comment of string + | I_label of string + | I_goto of string + | I_raw of string type cdef = | CDEF_reg_dec of ctyp * id - | CDEF_fundef of id * id list * instr list + | CDEF_fundef of id * id option * id list * instr list | CDEF_type of ctype_def +let rec instr_ctyps = function + | I_decl (ctyp, _) | I_alloc (ctyp, _) | I_clear (ctyp, _) -> [ctyp] + | I_init (ctyp, _, cval) -> [ctyp; cval_ctyp cval] + | I_if (cval, instrs1, instrs2, ctyp) -> + ctyp :: cval_ctyp cval :: List.concat (List.map instr_ctyps instrs1 @ List.map instr_ctyps instrs2) + | I_funcall (_, _, cvals, ctyp) -> + ctyp :: List.map cval_ctyp cvals + | I_convert (_, ctyp1, _, ctyp2) -> [ctyp1; ctyp2] + | I_assign (_, cval) | I_copy (_, cval) -> [cval_ctyp cval] + | I_block instrs -> List.concat (List.map instr_ctyps instrs) + | I_return _ | I_comment _ | I_label _ | I_goto _ | I_raw _ -> [] + +let cdef_ctyps = function + | CDEF_reg_dec (ctyp, _) -> [ctyp] + | CDEF_fundef (_, _, _, instrs) -> List.concat (List.map instr_ctyps instrs) + | CDEF_type tdef -> ctype_def_ctyps tdef + let pp_ctyp ctyp = string (string_of_ctyp ctyp |> Util.yellow |> Util.clear) let pp_keyword str = string ((str |> Util.red |> Util.clear) ^ "$") -and pp_cval = function +let pp_cval = function | CV_id (id, ctyp) -> parens (pp_ctyp ctyp) ^^ (pp_id id) | CV_C_fragment (str, ctyp) -> parens (pp_ctyp ctyp) ^^ (string (str |> Util.cyan |> Util.clear)) +let pp_clexp = function + | CL_id id -> pp_id id + | CL_field (id, field) -> pp_id id ^^ string "." ^^ pp_id field + let rec pp_instr = function | I_decl (ctyp, id) -> parens (pp_ctyp ctyp) ^^ space ^^ pp_id id @@ -701,6 +799,8 @@ let rec pp_instr = function ^^ pp_keyword "IF" ^^ pp_cval cval ^^ pp_keyword "THEN" ^^ pp_if_block then_instrs ^^ pp_keyword "ELSE" ^^ pp_if_block else_instrs + | I_block instrs -> + surround 2 0 lbrace (separate_map hardline pp_instr instrs) rbrace | I_alloc (ctyp, id) -> pp_keyword "ALLOC" ^^ parens (pp_ctyp ctyp) ^^ space ^^ pp_id id | I_init (ctyp, id, cval) -> @@ -715,14 +815,20 @@ let rec pp_instr = function string "->"; pp_ctyp ctyp1 ] | I_assign (id, cval) -> separate space [pp_id id; string ":="; pp_cval cval] - | I_copy (id, cval) -> - separate space [string "let"; pp_id id; string "="; pp_cval cval] + | I_copy (clexp, cval) -> + separate space [string "let"; pp_clexp clexp; string "="; pp_cval cval] | I_clear (ctyp, id) -> pp_keyword "CLEAR" ^^ pp_ctyp ctyp ^^ parens (pp_id id) | I_return id -> pp_keyword "RETURN" ^^ pp_id id | I_comment str -> string ("// " ^ str) + | I_label str -> + string (str ^ ":") + | I_goto str -> + pp_keyword "GOTO" ^^ string str + | I_raw str -> + pp_keyword "C" ^^ string str let compile_funcall ctx id args typ = let setup = ref [] in @@ -780,6 +886,59 @@ let compile_funcall ctx id args typ = (List.rev !setup, final_ctyp, call, !cleanup) +let is_ct_enum = function + | CT_enum _ -> true + | _ -> false + +let is_ct_tup = function + | CT_tup _ -> true + | _ -> false + +let rec compile_aval ctx = function + | AV_C_fragment (code, typ) -> + [], CV_C_fragment (code, ctyp_of_typ ctx typ), [] + + | AV_id (id, typ) -> + begin + match ctyp_of_typ ctx (lvar_typ typ) with + | CT_enum (_, elems) when IdSet.mem id elems -> + [], CV_C_fragment (Util.zencode_upper_string (string_of_id id), ctyp_of_typ ctx (lvar_typ typ)), [] + | _ -> + [], CV_id (id, ctyp_of_typ ctx (lvar_typ typ)), [] + end + + | AV_tuple avals -> + let elements = List.map (compile_aval ctx) avals in + let cvals = List.map (fun (_, cval, _) -> cval) elements in + let setup = List.concat (List.map (fun (setup, _, _) -> setup) elements) in + let cleanup = List.concat (List.rev (List.map (fun (_, _, cleanup) -> cleanup) elements)) in + let tup_ctyp = CT_tup (List.map cval_ctyp cvals) in + let gs = gensym () in + setup + @ [I_decl (tup_ctyp, gs)] + @ List.mapi (fun n cval -> I_copy (CL_field (gs, mk_id ("tup" ^ string_of_int n)), cval)) cvals, + CV_id (gs, CT_tup (List.map cval_ctyp cvals)), + cleanup + +let rec compile_match ctx apat cval case_label = + match apat, cval with + | AP_id pid, CV_C_fragment (code, ctyp) when is_ct_enum ctyp -> + [ I_if (CV_C_fragment (Util.zencode_upper_string (string_of_id pid) ^ " != " ^ code, CT_bool), [I_goto case_label], [], CT_unit) ] + | AP_id pid, CV_id (id, ctyp) when is_ct_enum ctyp -> + [ I_if (CV_C_fragment (Util.zencode_upper_string (string_of_id pid) ^ " != " ^ Util.zencode_string (string_of_id id), CT_bool), [I_goto case_label], [], CT_unit) ] + | AP_id pid, CV_C_fragment (code, ctyp) -> + [ I_init (ctyp, pid, cval) ] + | AP_id pid, CV_id _ -> + [ I_decl (cval_ctyp cval, pid); I_copy (CL_id pid, cval) ] + | _, _ -> [] + +let label_counter = ref 0 + +let label str = + let str = str ^ string_of_int !label_counter in + incr label_counter; + str + let rec compile_aexp ctx = function | AE_let (id, _, binding, body, typ) -> let setup, ctyp, call, cleanup = compile_aexp ctx binding in @@ -796,24 +955,48 @@ let rec compile_aexp ctx = function | AE_app (id, vs, typ) -> compile_funcall ctx id vs typ - | AE_val (AV_C_fragment (c, typ)) -> - let ctyp = ctyp_of_typ ctx typ in - [], ctyp, (fun id -> I_copy (id, CV_C_fragment (c, ctyp))), [] + | AE_val aval -> + let setup, cval, cleanup = compile_aval ctx aval in + setup, cval_ctyp cval, (fun id -> I_copy (CL_id id, cval)), cleanup - | AE_val (AV_id (id, lvar)) -> - let ctyp = ctyp_of_typ ctx (lvar_typ lvar) in - [], ctyp, (fun id' -> I_copy (id', CV_id (id, ctyp))), [] - - | AE_val (AV_lit (lit, typ)) -> + (* Compile case statements *) + | AE_case (aval, cases, typ) -> let ctyp = ctyp_of_typ ctx typ in - if is_stack_ctyp ctyp then - assert false - else + let aval_setup, cval, aval_cleanup = compile_aval ctx aval in + let case_return_id = gensym () in + let finish_match_label = label "finish_match_" in + let compile_case (apat, guard, body) = + let trivial_guard = match guard with + | AE_val (AV_lit (L_aux (L_true, _), _)) + | AE_val (AV_C_fragment ("true", _)) -> true + | _ -> false + in + let case_label = label "case_" in + let destructure = compile_match ctx apat cval case_label in + let guard_setup, _, guard_call, guard_cleanup = compile_aexp ctx guard in + let body_setup, _, body_call, body_cleanup = compile_aexp ctx body in let gs = gensym () in - [I_alloc (ctyp, gs); I_comment "fix literal init"], - ctyp, - (fun id -> I_copy (id, CV_id (gs, ctyp))), - [I_clear (ctyp, gs)] + let case_instrs = + destructure @ [I_comment "end destructuring"] + @ (if not trivial_guard then + guard_setup @ [I_decl (CT_bool, gs); guard_call gs] @ guard_cleanup + @ [I_if (CV_C_fragment (Printf.sprintf "!%s" (Util.zencode_string (string_of_id gs)), CT_bool), [I_goto case_label], [], CT_unit)] + @ [I_comment "end guard"] + else []) + @ body_setup @ [body_call case_return_id] @ body_cleanup + @ [I_goto finish_match_label] + in + [I_block case_instrs; I_label case_label] + in + [I_comment "begin match"] + @ aval_setup @ [I_decl (ctyp, case_return_id)] + @ List.concat (List.map compile_case cases) + @ [I_raw "sail_match_failure();"] + @ [I_label finish_match_label], + ctyp, + (fun id -> I_copy (CL_id id, CV_id (case_return_id, ctyp))), + aval_cleanup + @ [I_comment "end match"] | AE_if (aval, then_aexp, else_aexp, if_typ) -> let if_ctyp = ctyp_of_typ ctx if_typ in @@ -841,7 +1024,7 @@ let rec compile_aexp ctx = function let gs = gensym () in [I_alloc (ctyp, gs)] @ setup @ List.map (fun call -> call gs) calls, ctyp, - (fun id -> I_copy (id, CV_id (gs, ctyp))), + (fun id -> I_copy (CL_id id, CV_id (gs, ctyp))), cleanup @ [I_clear (ctyp, gs)] | AE_assign (id, assign_typ, aexp) -> @@ -853,7 +1036,7 @@ let rec compile_aexp ctx = function let unit_fragment = CV_C_fragment ("UNIT", CT_unit) in let comment = "assign " ^ string_of_ctyp assign_ctyp ^ " := " ^ string_of_ctyp ctyp in if ctyp_equal assign_ctyp ctyp then - setup @ [call id], CT_unit, (fun id -> I_copy (id, unit_fragment)), cleanup + setup @ [call id], CT_unit, (fun id -> I_copy (CL_id id, unit_fragment)), cleanup else if not (is_stack_ctyp assign_ctyp) && is_stack_ctyp ctyp then let gs = gensym () in setup @ [ I_comment comment; @@ -862,7 +1045,7 @@ let rec compile_aexp ctx = function I_convert (id, assign_ctyp, gs, ctyp) ], CT_unit, - (fun id -> I_copy (id, unit_fragment)), + (fun id -> I_copy (CL_id id, unit_fragment)), cleanup else failwith comment @@ -911,8 +1094,8 @@ let compile_type_def ctx (TD_aux (type_def, _)) = (* Will be re-written before here, see bitfield.ml *) | TD_bitfield _ -> failwith "Cannot compile TD_bitfield" - (* All type abbreviations will be removed before now. TODO: point to where this is done. *) - | TD_abbrev _ -> failwith "Cannot compile TD_abbrev" + (* All type abbreviations are filtered out in compile_def *) + | TD_abbrev _ -> assert false let compile_def ctx = function | DEF_reg_dec (DEC_aux (DEC_reg (typ, id), _)) -> @@ -926,25 +1109,27 @@ let compile_def ctx = function match pexp with | Pat_aux (Pat_exp (pat, exp), _) -> let aexp = map_functions (analyze_primop ctx) (c_literals ctx (anf exp)) in - print_endline (Pretty_print_sail.to_string (pp_aexp aexp)); + prerr_endline (Pretty_print_sail.to_string (pp_aexp aexp)); let setup, ctyp, call, cleanup = compile_aexp ctx aexp in let gs = gensym () in - let instrs = - if is_stack_ctyp ctyp then - setup @ [I_decl (ctyp, gs); call gs] @ cleanup @ [I_return gs] - else - assert false - in - [CDEF_fundef (id, pat_ids pat, instrs)], ctx + if is_stack_ctyp ctyp then + let instrs = [I_decl (ctyp, gs)] @ setup @ [call gs] @ cleanup @ [I_return gs] in + [CDEF_fundef (id, None, pat_ids pat, instrs)], ctx + else + let instrs = setup @ [call gs] @ cleanup in + [CDEF_fundef (id, Some gs, pat_ids pat, instrs)], ctx | _ -> assert false end + | DEF_type (TD_aux (TD_abbrev _, _)) -> [], ctx | DEF_type type_def -> let tdef, ctx = compile_type_def ctx type_def in [CDEF_type tdef], ctx | DEF_default _ -> [], ctx + | DEF_overload _ -> [], ctx + | _ -> assert false (**************************************************************************) @@ -965,6 +1150,7 @@ let sgen_ctyp = function | CT_int64 -> "int64_t" | CT_mpz -> "mpz_t" | CT_bv _ -> "bv_t" + | CT_tup _ as tup -> "struct " ^ Util.zencode_string ("tuple_" ^ string_of_ctyp tup) | CT_struct (id, _) -> "struct " ^ sgen_id id | CT_enum (id, _) -> "enum " ^ sgen_id id | CT_variant (id, _) -> "struct " ^ sgen_id id @@ -977,6 +1163,7 @@ let sgen_ctyp_name = function | CT_int64 -> "int64_t" | CT_mpz -> "mpz_t" | CT_bv _ -> "bv_t" + | CT_tup _ as tup -> Util.zencode_string ("tuple_" ^ string_of_ctyp tup) | CT_struct (id, _) -> sgen_id id | CT_enum (id, _) -> sgen_id id | CT_variant (id, _) -> sgen_id id @@ -986,55 +1173,73 @@ let sgen_cval = function | CV_id (id, _) -> sgen_id id | _ -> "CVAL??" -let rec codegen_instr = function +let sgen_clexp = function + | CL_id id -> sgen_id id + | CL_field (id, field) -> sgen_id id ^ "." ^ sgen_id field + +let rec codegen_instr ctx = function | I_decl (ctyp, id) -> - string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id)) - | I_copy (id, cval) -> + string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_id id)) + | I_copy (clexp, cval) -> let ctyp = cval_ctyp cval in if is_stack_ctyp ctyp then - string (Printf.sprintf "%s = %s;" (sgen_id id) (sgen_cval cval)) + string (Printf.sprintf " %s = %s;" (sgen_clexp clexp) (sgen_cval cval)) else - string (Printf.sprintf "set_%s(%s, %s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_cval cval)) + string (Printf.sprintf " set_%s(%s, %s);" (sgen_ctyp_name ctyp) (sgen_clexp clexp) (sgen_cval cval)) + | I_if (cval, [then_instr], [], ctyp) -> + string (Printf.sprintf " if (%s)" (sgen_cval cval)) ^^ hardline + ^^ twice space ^^ codegen_instr ctx then_instr | I_if (cval, then_instrs, else_instrs, ctyp) -> - string "if" ^^ space ^^ parens (string (sgen_cval cval)) ^^ space - ^^ surround 2 0 lbrace (separate_map hardline codegen_instr then_instrs) rbrace + string " if" ^^ space ^^ parens (string (sgen_cval cval)) ^^ space + ^^ surround 2 0 lbrace (separate_map hardline (codegen_instr ctx) then_instrs) rbrace ^^ space ^^ string "else" ^^ space - ^^ surround 2 0 lbrace (separate_map hardline codegen_instr else_instrs) rbrace + ^^ surround 2 0 lbrace (separate_map hardline (codegen_instr ctx) else_instrs) rbrace + | I_block instrs -> + string " {" + ^^ jump 2 2 (separate_map hardline (codegen_instr ctx) instrs) ^^ hardline + ^^ string " }" | I_funcall (x, f, args, ctyp) -> let args = Util.string_of_list ", " sgen_cval args in + let fname = if Env.is_extern f ctx.tc_env "c" then Env.get_extern f ctx.tc_env "c" else sgen_id f in if is_stack_ctyp ctyp then - string (Printf.sprintf "%s = %s(%s);" (sgen_id x) (sgen_id f) args) + string (Printf.sprintf " %s = %s(%s);" (sgen_id x) fname args) else - string (Printf.sprintf "%s(%s, %s);" (sgen_id f) (sgen_id x) args) + string (Printf.sprintf " %s(%s, %s);" fname (sgen_id x) args) | I_clear (ctyp, id) -> - string (Printf.sprintf "clear_%s(%s);" (sgen_ctyp_name ctyp) (sgen_id id)) + string (Printf.sprintf " clear_%s(%s);" (sgen_ctyp_name ctyp) (sgen_id id)) | I_init (ctyp, id, cval) -> - string (Printf.sprintf "init_%s_of_%s(%s, %s);" + string (Printf.sprintf " init_%s_of_%s(%s, %s);" (sgen_ctyp_name ctyp) (sgen_ctyp_name (cval_ctyp cval)) (sgen_id id) (sgen_cval cval)) | I_alloc (ctyp, id) -> - string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id)) + string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_id id)) ^^ hardline - ^^ string (Printf.sprintf "init_%s(%s);" (sgen_ctyp_name ctyp) (sgen_id id)) + ^^ string (Printf.sprintf " init_%s(%s);" (sgen_ctyp_name ctyp) (sgen_id id)) | I_convert (x, ctyp1, y, ctyp2) -> if is_stack_ctyp ctyp1 then - string (Printf.sprintf "%s = convert_%s_of_%s(%s);" + string (Printf.sprintf " %s = convert_%s_of_%s(%s);" (sgen_id x) (sgen_ctyp_name ctyp1) (sgen_ctyp_name ctyp2) (sgen_id y)) else - string (Printf.sprintf "convert_%s_of_%s(%s, %s);" + string (Printf.sprintf " convert_%s_of_%s(%s, %s);" (sgen_ctyp_name ctyp1) (sgen_ctyp_name ctyp2) (sgen_id x) (sgen_id y)) | I_return id -> - string (Printf.sprintf "return %s;" (sgen_id id)) + string (Printf.sprintf " return %s;" (sgen_id id)) | I_comment str -> - string ("/* " ^ str ^ " */") + string (" /* " ^ str ^ " */") + | I_label str -> + string (str ^ ": ;") + | I_goto str -> + string (Printf.sprintf " goto %s;" str) + | I_raw str -> + string (" " ^ str) let codegen_type_def ctx = function | CTD_enum (id, ids) -> @@ -1104,12 +1309,27 @@ let codegen_type_def ctx = function rbrace ^^ semi -let codegen_def ctx = function +let generated_tuples = ref IdSet.empty + +let codegen_tup ctx ctyps = + let id = mk_id ("tuple_" ^ string_of_ctyp (CT_tup ctyps)) in + if IdSet.mem id !generated_tuples then + empty + else + let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields) + (0, Bindings.empty) + ctyps + in + generated_tuples := IdSet.add id !generated_tuples; + codegen_type_def ctx (CTD_record (id, fields)) ^^ twice hardline + +let codegen_def' ctx = function | CDEF_reg_dec (ctyp, id) -> string (Printf.sprintf "// register %s" (string_of_id id)) ^^ hardline ^^ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id)) - | CDEF_fundef (id, args, instrs) -> - List.iter (fun instr -> print_endline (Pretty_print_sail.to_string (pp_instr instr))) instrs; + + | CDEF_fundef (id, ret_arg, args, instrs) -> + List.iter (fun instr -> prerr_endline (Pretty_print_sail.to_string (pp_instr instr))) instrs; let _, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in let arg_typs, ret_typ = match fn_typ with | Typ_fn (Typ_aux (Typ_tup arg_typs, _), ret_typ, _) -> arg_typs, ret_typ @@ -1117,17 +1337,36 @@ let codegen_def ctx = function | _ -> assert false in let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in - let args = Util.string_of_list ", " (fun x -> x) (List.map2 (fun ctyp arg -> sgen_ctyp ctyp ^ " " ^ sgen_id arg) arg_ctyps args) in - - string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline + let function_header = + match ret_arg with + | None -> + assert (is_stack_ctyp ret_ctyp); + string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline + | Some gs -> + assert (not (is_stack_ctyp ret_ctyp)); + string "void" ^^ space ^^ codegen_id id + ^^ parens (string (sgen_ctyp ret_ctyp ^ " " ^ sgen_id gs ^ ", ") ^^ string args) + ^^ hardline + in + function_header ^^ string "{" - ^^ jump 2 2 (separate_map hardline codegen_instr instrs) ^^ hardline + ^^ jump 0 2 (separate_map hardline (codegen_instr ctx) instrs) ^^ hardline ^^ string "}" | CDEF_type ctype_def -> codegen_type_def ctx ctype_def +let codegen_def ctx def = + let untup = function + | CT_tup ctyps -> ctyps + | _ -> assert false + in + let tups = List.filter is_ct_tup (cdef_ctyps def) in + let tups = List.map (fun ctyp -> codegen_tup ctx (untup ctyp)) tups in + concat tups + ^^ codegen_def' ctx def + let compile_ast ctx (Defs defs) = let chunks, ctx = List.fold_left (fun (chunks, ctx) def -> let defs, ctx = compile_def ctx def in defs :: chunks, ctx) ([], ctx) defs in let cdefs = List.concat (List.rev chunks) in |
