diff options
Diffstat (limited to 'src')
33 files changed, 2591 insertions, 2388 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 3f13e6ad..546faf14 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -193,6 +193,10 @@ let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l) and nexp_simp_aux = function + (* (n - (n - m)) often appears in foreach loops *) + | Nexp_minus (nexp1, Nexp_aux (Nexp_minus (nexp2, Nexp_aux (n3,_)),_)) + when nexp_identical nexp1 nexp2 -> + nexp_simp_aux n3 | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), nexp2), _), nexp3) when nexp_identical nexp2 nexp3 -> nexp_simp_aux n1 diff --git a/src/c_backend.ml b/src/c_backend.ml index b06cd950..42932285 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -61,28 +61,9 @@ let zencode_id = function let lvar_typ = function | Local (_, typ) -> typ | Register typ -> typ + | Enum typ -> typ | _ -> assert false -(* - -1) Conversion to ANF - - tannot defs -> (typ, aexp) cdefs - -2) Primitive operation optimizations - -3) Lowering to low-level imperative language - - (typ, aexp) cdefs -> (ctyp, instr list) cdefs - -4) Low level optimizations (e.g. reducing allocations) - -5) Generation of C code - - (ctyp, instr list) -> string - -*) - (**************************************************************************) (* 1. Conversion to A-normal form (ANF) *) (**************************************************************************) @@ -121,16 +102,25 @@ type aexp = | AE_let of id * typ * aexp * aexp * typ | AE_block of aexp list * aexp * typ | AE_return of aval * typ - | AE_throw of aval | AE_if of aval * aexp * aexp * typ + | AE_field of aval * id * typ + | AE_case of aval * (apat * aexp * aexp) list * typ + | AE_record_update of aval * aval Bindings.t * typ | AE_for of id * aexp * aexp * aexp * order * aexp | AE_loop of loop * aexp * aexp +and apat = + | AP_tup of apat list + | AP_id of id + | AP_wild + and aval = | AV_lit of lit * typ | AV_id of id * lvar | AV_ref of id * lvar | AV_tuple of aval list + | AV_list of aval list * typ + | AV_vector of aval list * typ | AV_C_fragment of string * typ (* Map over all the avals in an aexp. *) @@ -145,6 +135,15 @@ let rec map_aval f = function | AE_return (aval, typ) -> AE_return (f aval, typ) | AE_if (aval, aexp1, aexp2, typ2) -> AE_if (f aval, map_aval f aexp1, map_aval f aexp2, typ2) + | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, map_aval f aexp1, map_aval f aexp2) + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> + AE_for (id, map_aval f aexp1, map_aval f aexp2, map_aval f aexp3, order, map_aval f aexp4) + | AE_record_update (aval, updates, typ) -> + AE_record_update (f aval, Bindings.map f updates, typ) + | AE_field (aval, id, typ) -> + AE_field (f aval, id, typ) + | AE_case (aval, cases, typ) -> + AE_case (f aval, List.map (fun (pat, aexp1, aexp2) -> pat, map_aval f aexp1, map_aval f aexp2) cases, typ) (* Map over all the functions in an aexp. *) let rec map_functions f = function @@ -155,10 +154,15 @@ let rec map_functions f = function | AE_block (aexps, aexp, typ) -> AE_block (List.map (map_functions f) aexps, map_functions f aexp, typ) | AE_if (aval, aexp1, aexp2, typ) -> AE_if (aval, map_functions f aexp1, map_functions f aexp2, typ) - | AE_val _ | AE_return _ as v -> v + | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, map_functions f aexp1, map_functions f aexp2) + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> + AE_for (id, map_functions f aexp1, map_functions f aexp2, map_functions f aexp3, order, map_functions f aexp4) + | AE_case (aval, cases, typ) -> + AE_case (aval, List.map (fun (pat, aexp1, aexp2) -> pat, map_functions f aexp1, map_functions f aexp2) cases, typ) + | AE_field _ | AE_record_update _ | AE_val _ | AE_return _ as v -> v (* For debugging we provide a pretty printer for ANF expressions. *) - + let pp_id ?color:(color=Util.green) id = string (string_of_id id |> color |> Util.clear) @@ -179,6 +183,11 @@ let pp_lvar lvar doc = let pp_annot typ doc = string "[" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc +let pp_order = function + | Ord_aux (Ord_inc, _) -> string "inc" + | Ord_aux (Ord_dec, _) -> string "dec" + | _ -> assert false (* Order types have been specialised, so no polymorphism in C backend. *) + let rec pp_aexp = function | AE_val v -> pp_aval v | AE_cast (aexp, typ) -> @@ -205,6 +214,35 @@ let rec pp_aexp = function | AE_block (aexps, aexp, typ) -> pp_annot typ (surround 2 0 lbrace (pp_block (aexps @ [aexp])) rbrace) | AE_return (v, typ) -> pp_annot typ (string "return" ^^ parens (pp_aval v)) + | AE_loop (While, aexp1, aexp2) -> + separate space [string "while"; pp_aexp aexp1; string "do"; pp_aexp aexp2] + | AE_loop (Until, aexp1, aexp2) -> + separate space [string "repeat"; pp_aexp aexp2; string "until"; pp_aexp aexp1] + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> + let header = + string "foreach" ^^ space ^^ + group (parens (separate (break 1) + [ pp_id id; + string "from " ^^ pp_aexp aexp1; + string "to " ^^ pp_aexp aexp2; + string "by " ^^ pp_aexp aexp3; + string "in " ^^ pp_order order ])) + in + header ^//^ pp_aexp aexp4 + | AE_field _ -> string "FIELD" + | AE_case (aval, cases, typ) -> + pp_annot typ (separate space [string "match"; pp_aval aval; pp_cases cases]) + | AE_record_update (_, _, typ) -> pp_annot typ (string "RECORD UPDATE") + +and pp_apat = function + | AP_wild -> string "_" + | AP_id id -> pp_id id + | AP_tup apats -> parens (separate_map (comma ^^ space) pp_apat apats) + +and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace + +and pp_case (apat, guard, body) = + separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body] and pp_block = function | [] -> string "()" @@ -215,14 +253,22 @@ and pp_aval = function | AV_lit (lit, typ) -> pp_annot typ (string (string_of_lit lit)) | AV_id (id, lvar) -> pp_lvar lvar (pp_id id) | AV_tuple avals -> parens (separate_map (comma ^^ space) pp_aval avals) + | AV_ref (id, lvar) -> string "ref" ^^ space ^^ pp_lvar lvar (pp_id id) | AV_C_fragment (str, typ) -> pp_annot typ (string (str |> Util.cyan |> Util.clear)) + | AV_vector (avals, typ) -> + pp_annot typ (string "[" ^^ separate_map (comma ^^ space) pp_aval avals ^^ string "]") + | AV_list (avals, typ) -> + pp_annot typ (string "[|" ^^ separate_map (comma ^^ space) pp_aval avals ^^ string "|]") let ae_lit lit typ = AE_val (AV_lit (lit, typ)) +(** GLOBAL: gensym_counter is used to generate fresh identifiers where + needed. It should be safe to reset between top level + definitions. **) let gensym_counter = ref 0 let gensym () = - let id = mk_id ("v" ^ string_of_int !gensym_counter) in + let id = mk_id ("gs#" ^ string_of_int !gensym_counter) in incr gensym_counter; id @@ -233,12 +279,30 @@ let rec split_block = function exp :: exps, last | [] -> failwith "empty block" +let rec anf_pat (P_aux (p_aux, _) as pat) = + match p_aux with + | P_id id -> AP_id id + | P_wild -> AP_wild + | P_tup pats -> AP_tup (List.map anf_pat pats) + | _ -> assert false + let rec anf (E_aux (e_aux, exp_annot) as exp) = let to_aval = function | AE_val v -> (v, fun x -> x) - | AE_app (_, _, typ) | AE_let (_, _, _, _, typ) | AE_return (_, typ) | AE_cast (_, typ) as aexp -> + | AE_app (_, _, typ) + | AE_let (_, _, _, _, typ) + | AE_return (_, typ) + | AE_cast (_, typ) + | AE_if (_, _, _, typ) + | AE_field (_, _, typ) + | AE_case (_, _, typ) + | AE_record_update (_, _, typ) + as aexp -> let id = gensym () in (AV_id (id, Local (Immutable, typ)), fun x -> AE_let (id, typ, aexp, x, typ_of exp)) + | AE_assign _ | AE_block _ | AE_for _ | AE_loop _ as aexp -> + let id = gensym () in + (AV_id (id, Local (Immutable, unit_typ)), fun x -> AE_let (id, unit_typ, aexp, x, typ_of exp)) in match e_aux with | E_lit lit -> ae_lit lit (typ_of exp) @@ -253,6 +317,15 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = let aexp = anf exp in AE_assign (id, lvar_typ (Env.lookup_id id (env_of exp)), aexp) + | E_loop (loop_typ, cond, exp) -> + let acond = anf cond in + let aexp = anf exp in + AE_loop (loop_typ, acond, aexp) + + | E_for (id, exp1, exp2, exp3, order, body) -> + let aexp1, aexp2, aexp3, abody = anf exp1, anf exp2, anf exp3, anf body in + AE_for (id, aexp1, aexp2, aexp3, order, abody) + | E_if (cond, then_exp, else_exp) -> let cond_val, wrap = to_aval (anf cond) in let then_aexp = anf then_exp in @@ -263,7 +336,34 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = anf (E_aux (E_app (Id_aux (DeIid op, l), [x; y]), exp_annot)) | E_app_infix (x, Id_aux (DeIid op, l), y) -> anf (E_aux (E_app (Id_aux (Id op, l), [x; y]), exp_annot)) - + + | E_vector exps -> + let aexps = List.map anf exps in + let avals = List.map to_aval aexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in + wrap (AE_val (AV_vector (List.map fst avals, typ_of exp))) + + | E_list exps -> + let aexps = List.map anf exps in + let avals = List.map to_aval aexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in + wrap (AE_val (AV_list (List.map fst avals, typ_of exp))) + + | E_field (exp, id) -> + let aval, wrap = to_aval (anf exp) in + wrap (AE_field (aval, id, typ_of exp)) + + | E_record_update (exp, FES_aux (FES_Fexps (fexps, _), _)) -> + let anf_fexp (FE_aux (FE_Fexp (id, exp), _)) = + let aval, wrap = to_aval (anf exp) in + (id, aval), wrap + in + let aval, exp_wrap = to_aval (anf exp) in + let fexps = List.map anf_fexp fexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd fexps) in + let record = List.fold_left (fun r (id, aval) -> Bindings.add id aval r) Bindings.empty (List.map fst fexps) in + exp_wrap (wrap (AE_record_update (aval, record, typ_of exp))) + | E_app (id, exps) -> let aexps = List.map anf exps in let avals = List.map to_aval aexps in @@ -301,17 +401,36 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = | E_id id -> let lvar = Env.lookup_id id (env_of exp) in - AE_val (AV_id (zencode_id id, lvar)) + AE_val (AV_id (id, lvar)) + + | E_ref id -> + let lvar = Env.lookup_id id (env_of exp) in + AE_val (AV_ref (id, lvar)) | E_return exp -> let aval, wrap = to_aval (anf exp) in wrap (AE_return (aval, typ_of exp)) + | E_case (match_exp, pexps) -> + let match_aval, match_wrap = to_aval (anf match_exp) in + let anf_pexp (Pat_aux (pat_aux, _)) = + match pat_aux with + | Pat_when (pat, guard, body) -> + (anf_pat pat, anf guard, anf body) + | Pat_exp (pat, body) -> + (anf_pat pat, AE_val (AV_lit (mk_lit (L_true), bool_typ)), anf body) + in + match_wrap (AE_case (match_aval, List.map anf_pexp pexps, typ_of exp)) + | E_var (LEXP_aux (LEXP_id id, _), binding, body) + | E_var (LEXP_aux (LEXP_cast (_, id), _), binding, body) | E_let (LB_aux (LB_val (P_aux (P_id id, _), binding), _), body) -> let env = env_of body in let lvar = Env.lookup_id id env in - AE_let (zencode_id id, lvar_typ lvar, anf binding, anf body, typ_of exp) + AE_let (id, lvar_typ lvar, anf binding, anf body, typ_of exp) + + | E_let (LB_aux (LB_val (pat, binding), _), body) -> + anf (E_aux (E_case (binding, [Pat_aux (Pat_exp (pat, body), (Parse_ast.Unknown, None))]), exp_annot)) | E_tuple exps -> let aexps = List.map anf exps in @@ -339,10 +458,8 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = | E_nondet _ -> (* We don't compile E_nondet nodes *) failwith "encountered E_nondet node when converting to ANF" - - (* + | _ -> failwith ("Cannot convert to ANF: " ^ string_of_exp exp) - *) (**************************************************************************) (* 2. Converting sail types to C types *) @@ -375,8 +492,23 @@ type ctyp = | CT_struct of id * ctyp Bindings.t | CT_enum of id * IdSet.t | CT_variant of id * ctyp Bindings.t - -let ctyp_equal ctyp1 ctyp2 = + | CT_string + +type ctx = + { records : (ctyp Bindings.t) Bindings.t; + enums : IdSet.t Bindings.t; + variants : (ctyp Bindings.t) Bindings.t; + tc_env : Env.t + } + +let initial_ctx env = + { records = Bindings.empty; + enums = Bindings.empty; + variants = Bindings.empty; + tc_env = env + } + +let rec ctyp_equal ctyp1 ctyp2 = match ctyp1, ctyp2 with | CT_mpz, CT_mpz -> true | CT_bv d1, CT_bv d2 -> d1 = d2 @@ -385,9 +517,14 @@ let ctyp_equal ctyp1 ctyp2 = | CT_int64, CT_int64 -> true | CT_unit, CT_unit -> true | CT_bool, CT_bool -> true + | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0 + | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0 + | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0 + | CT_tup ctyps1, CT_tup ctyps2 -> List.for_all2 ctyp_equal ctyps1 ctyps2 + | CT_string, CT_string -> true | _, _ -> false -let string_of_ctyp = function +let rec string_of_ctyp = function | CT_mpz -> "mpz_t" | CT_bv true -> "bv_t<dec>" | CT_bv false -> "bv_t<inc>" @@ -397,9 +534,12 @@ let string_of_ctyp = function | CT_int -> "int" | CT_unit -> "unit" | CT_bool -> "bool" + | CT_tup ctyps -> "(" ^ Util.string_of_list ", " string_of_ctyp ctyps ^ ")" + | CT_struct (id, _) | CT_enum (id, _) | CT_variant (id, _) -> string_of_id id + | CT_string -> "string" (* Convert a sail type into a C-type *) -let ctyp_of_typ (Typ_aux (typ_aux, _) as typ) = +let rec ctyp_of_typ ctx (Typ_aux (typ_aux, _) as typ) = match typ_aux with | Typ_id id when string_of_id id = "bit" -> CT_int | Typ_id id when string_of_id id = "bool" -> CT_bool @@ -426,13 +566,24 @@ let ctyp_of_typ (Typ_aux (typ_aux, _) as typ) = | _ -> CT_bv direction end | Typ_id id when string_of_id id = "unit" -> CT_unit + | Typ_id id when string_of_id id = "string" -> CT_string + + | Typ_id id when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records) + | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums) + | Typ_id id when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants) + + | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) + | _ -> failwith ("No C-type for type " ^ string_of_typ typ) -let is_stack_ctyp ctyp = match ctyp with - | CT_uint64 _ | CT_int64 | CT_int | CT_unit | CT_bool -> true - | CT_bv _ | CT_mpz -> false +let rec is_stack_ctyp ctyp = match ctyp with + | CT_uint64 _ | CT_int64 | CT_int | CT_unit | CT_bool | CT_enum _ -> true + | CT_bv _ | CT_mpz | CT_string _ -> false + | CT_struct (_, fields) -> Bindings.for_all (fun _ ctyp -> is_stack_ctyp ctyp) fields + | CT_variant (_, ctors) -> Bindings.for_all (fun _ ctyp -> is_stack_ctyp ctyp) ctors + | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps -let is_stack_typ typ = is_stack_ctyp (ctyp_of_typ typ) +let is_stack_typ ctx typ = is_stack_ctyp (ctyp_of_typ ctx typ) (**************************************************************************) (* 3. Optimization of primitives and literals *) @@ -445,19 +596,20 @@ let literal_to_cstring (L_aux (l_aux, _) as lit) = | L_hex str when String.length str <= 16 -> let padding = 16 - String.length str in Some ("0x" ^ String.make padding '0' ^ str ^ "ul") - | L_unit -> Some "0" + | L_unit -> Some "UNIT" | L_true -> Some "true" | L_false -> Some "false" | _ -> None -let c_literals = - let c_literal = function - | AV_lit (lit, typ) as v when is_stack_ctyp (ctyp_of_typ typ) -> +let c_literals ctx = + let rec c_literal = function + | AV_lit (lit, typ) as v when is_stack_ctyp (ctyp_of_typ ctx typ) -> begin match literal_to_cstring lit with | Some str -> AV_C_fragment (str, typ) | None -> v end + | AV_tuple avals -> AV_tuple (List.map c_literal avals) | v -> v in map_aval c_literal @@ -471,7 +623,7 @@ let mask m = else failwith "Tried to create a mask literal for a vector greater than 64 bits." -let c_aval = function +let rec c_aval ctx = function | AV_lit (lit, typ) as v -> begin match literal_to_cstring lit with @@ -483,11 +635,11 @@ let c_aval = function | AV_id (id, lvar) as v -> begin match lvar with - | Local (_, typ) when is_stack_typ typ -> - AV_C_fragment (string_of_id id, typ) + | Local (_, typ) when is_stack_typ ctx typ -> + AV_C_fragment (Util.zencode_string (string_of_id id), typ) | _ -> v end - | AV_tuple avals -> AV_tuple avals + | AV_tuple avals -> AV_tuple (List.map (c_aval ctx) avals) let is_c_fragment = function | AV_C_fragment _ -> true @@ -497,12 +649,13 @@ let c_fragment_string = function | AV_C_fragment (str, _) -> str | _ -> assert false -let analyze_primop' id args typ = +let analyze_primop' ctx id args typ = let no_change = AE_app (id, args, typ) in (* primops add_range and add_atom *) if string_of_id id = "add_range" || string_of_id id = "add_atom" then begin + prerr_endline "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"; let n, m, x, y = match destruct_range typ, args with | Some (n, m), [x; y] -> n, m, x, y | _ -> failwith ("add_range has incorrect return type or arity ^ " ^ string_of_typ typ) @@ -510,16 +663,27 @@ let analyze_primop' id args typ = match nexp_simp n, nexp_simp m with | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) -> if Big_int.less_equal min_int64 n && Big_int.less_equal m max_int64 then - let x, y = c_aval x, c_aval y in + let x, y = c_aval ctx x, c_aval ctx y in if is_c_fragment x && is_c_fragment y then AE_val (AV_C_fragment (c_fragment_string x ^ " + " ^ c_fragment_string y, typ)) else - no_change + (print_endline "QQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ"; + no_change) else - no_change + (print_endline "YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY"; + no_change) | _ -> no_change end + else if string_of_id id = "eq_range" || string_of_id id = "eq_atom" then + begin + match List.map (c_aval ctx) args with + | [x; y] when is_c_fragment x && is_c_fragment y -> + AE_val (AV_C_fragment ("(" ^ c_fragment_string x ^ " == " ^ c_fragment_string y ^ ")", typ)) + | _ -> + no_change + end + else if string_of_id id = "xor_vec" then begin let n, x, y = match typ, args with @@ -529,7 +693,7 @@ let analyze_primop' id args typ = in match nexp_simp n with | Nexp_aux (Nexp_constant n, _) when Big_int.less_equal n (Big_int.of_int 64) -> - let x, y = c_aval x, c_aval y in + let x, y = c_aval ctx x, c_aval ctx y in if is_c_fragment x && is_c_fragment y then AE_val (AV_C_fragment (c_fragment_string x ^ " ^ " ^ c_fragment_string y, typ)) else @@ -546,7 +710,7 @@ let analyze_primop' id args typ = in match nexp_simp n with | Nexp_aux (Nexp_constant n, _) when Big_int.less_equal n (Big_int.of_int 64) -> - let x, y = c_aval x, c_aval y in + let x, y = c_aval ctx x, c_aval ctx y in if is_c_fragment x && is_c_fragment y then AE_val (AV_C_fragment ("(" ^ c_fragment_string x ^ " + " ^ c_fragment_string y ^ ") & " ^ mask n, typ)) else @@ -557,15 +721,46 @@ let analyze_primop' id args typ = else no_change -let analyze_primop id args typ = +let analyze_primop ctx id args typ = let no_change = AE_app (id, args, typ) in - try analyze_primop' id args typ with + try analyze_primop' ctx id args typ with | Failure _ -> no_change (**************************************************************************) (* 4. Conversion to low-level AST *) (**************************************************************************) +(** We now define a low-level AST that is only slightly abstracted + away from C. To be succint in comments we usually refer to this as + LLcode rather than low-level AST repeatedly. + + The general idea is ANF expressions are converted into lists of + instructions (type instr) where allocations and deallocations are + now made explicit. ANF values (aval) are mapped to the cval type, + which is even simpler still. Some things are still more abstract + than in C, so the type definitions follow the sail type definition + structure, just with typ (from ast.ml) replaced with + ctyp. Top-level declarations that have no meaning for the backend + are not included at this level. + + The convention used here is that functions of the form compile_X + compile the type X into types in this AST, so compile_aval maps + avals into cvals. Note that the return types for these functions + are often quite complex, and they usually return some tuple + containing setup instructions (to allocate memory for the + expression), cleanup instructions (to deallocate that memory) and + possibly typing information about what has been translated. **) + +type ctype_def = + | CTD_enum of id * IdSet.t + | CTD_record of id * ctyp Bindings.t + | CTD_variant of id * ctyp Bindings.t + +let ctype_def_ctyps = function + | CTD_enum _ -> [] + | CTD_record (_, fields) -> List.map snd (Bindings.bindings fields) + | CTD_variant (_, ctors) -> List.map snd (Bindings.bindings ctors) + type cval = | CV_id of id * ctyp | CV_C_fragment of string * ctyp @@ -574,21 +769,61 @@ let cval_ctyp = function | CV_id (_, ctyp) -> ctyp | CV_C_fragment (_, ctyp) -> ctyp +type clexp = + | CL_id of id + | CL_field of id * id + | CL_addr of clexp + type instr = | I_decl of ctyp * id | I_alloc of ctyp * id | I_init of ctyp * id * cval | I_if of cval * instr list * instr list * ctyp - | I_funcall of id * id * cval list * ctyp - | I_convert of id * ctyp * id * ctyp + | I_funcall of clexp * id * cval list * ctyp + | I_convert of clexp * ctyp * id * ctyp | I_assign of id * cval + | I_copy of clexp * cval | I_clear of ctyp * id - | I_return of id + | I_return of cval + | I_block of instr list | I_comment of string + | I_label of string + | I_goto of string + | I_raw of string + +let rec map_instrs f instr = + match instr with + | I_decl _ | I_alloc _ | I_init _ -> instr + | I_if (cval, instrs1, instrs2, ctyp) -> + I_if (cval, f (List.map (map_instrs f) instrs1), f (List.map (map_instrs f) instrs2), ctyp) + | I_funcall _ | I_convert _ | I_assign _ | I_copy _ | I_clear _ | I_return _ -> instr + | I_block instrs -> I_block (f (List.map (map_instrs f) instrs)) + | I_comment _ | I_label _ | I_goto _ | I_raw _ -> instr type cdef = | CDEF_reg_dec of ctyp * id - | CDEF_fundef of id * id list * instr list + | CDEF_fundef of id * id option * id list * instr list + | CDEF_type of ctype_def + +let rec instr_ctyps = function + | I_decl (ctyp, _) | I_alloc (ctyp, _) | I_clear (ctyp, _) -> [ctyp] + | I_init (ctyp, _, cval) -> [ctyp; cval_ctyp cval] + | I_if (cval, instrs1, instrs2, ctyp) -> + ctyp :: cval_ctyp cval :: List.concat (List.map instr_ctyps instrs1 @ List.map instr_ctyps instrs2) + | I_funcall (_, _, cvals, ctyp) -> + ctyp :: List.map cval_ctyp cvals + | I_convert (_, ctyp1, _, ctyp2) -> [ctyp1; ctyp2] + | I_assign (_, cval) | I_copy (_, cval) -> [cval_ctyp cval] + | I_block instrs -> List.concat (List.map instr_ctyps instrs) + | I_return cval -> [cval_ctyp cval] + | I_comment _ | I_label _ | I_goto _ | I_raw _ -> [] + +let cdef_ctyps = function + | CDEF_reg_dec (ctyp, _) -> [ctyp] + | CDEF_fundef (_, _, _, instrs) -> List.concat (List.map instr_ctyps instrs) + | CDEF_type tdef -> ctype_def_ctyps tdef + +(* For debugging we define a pretty printer for LLcode instructions *) let pp_ctyp ctyp = string (string_of_ctyp ctyp |> Util.yellow |> Util.clear) @@ -596,10 +831,15 @@ let pp_ctyp ctyp = let pp_keyword str = string ((str |> Util.red |> Util.clear) ^ "$") -and pp_cval = function +let pp_cval = function | CV_id (id, ctyp) -> parens (pp_ctyp ctyp) ^^ (pp_id id) | CV_C_fragment (str, ctyp) -> parens (pp_ctyp ctyp) ^^ (string (str |> Util.cyan |> Util.clear)) +let rec pp_clexp = function + | CL_id id -> pp_id id + | CL_field (id, field) -> pp_id id ^^ string "." ^^ pp_id field + | CL_addr clexp -> string "*" ^^ pp_clexp clexp + let rec pp_instr = function | I_decl (ctyp, id) -> parens (pp_ctyp ctyp) ^^ space ^^ pp_id id @@ -609,64 +849,116 @@ let rec pp_instr = function ^^ pp_keyword "IF" ^^ pp_cval cval ^^ pp_keyword "THEN" ^^ pp_if_block then_instrs ^^ pp_keyword "ELSE" ^^ pp_if_block else_instrs + | I_block instrs -> + surround 2 0 lbrace (separate_map hardline pp_instr instrs) rbrace | I_alloc (ctyp, id) -> pp_keyword "ALLOC" ^^ parens (pp_ctyp ctyp) ^^ space ^^ pp_id id | I_init (ctyp, id, cval) -> pp_keyword "INIT" ^^ pp_ctyp ctyp ^^ parens (pp_id id ^^ string ", " ^^ pp_cval cval) | I_funcall (x, f, args, ctyp2) -> - separate space [ pp_id x; string ":="; + separate space [ pp_clexp x; string ":="; pp_id ~color:Util.red f ^^ parens (separate_map (string ", ") pp_cval args); string "->"; pp_ctyp ctyp2 ] | I_convert (x, ctyp1, y, ctyp2) -> - separate space [ pp_id x; string ":="; + separate space [ pp_clexp x; string ":="; pp_keyword "CONVERT" ^^ pp_ctyp ctyp2 ^^ parens (pp_id y); string "->"; pp_ctyp ctyp1 ] | I_assign (id, cval) -> separate space [pp_id id; string ":="; pp_cval cval] + | I_copy (clexp, cval) -> + separate space [string "let"; pp_clexp clexp; string "="; pp_cval cval] | I_clear (ctyp, id) -> pp_keyword "CLEAR" ^^ pp_ctyp ctyp ^^ parens (pp_id id) - | I_return id -> - pp_keyword "RETURN" ^^ pp_id id + | I_return cval -> + pp_keyword "RETURN" ^^ pp_cval cval | I_comment str -> string ("// " ^ str) - -let compile_funcall env id args typ = + | I_label str -> + string (str ^ ":") + | I_goto str -> + pp_keyword "GOTO" ^^ string str + | I_raw str -> + pp_keyword "C" ^^ string str + +let is_ct_enum = function + | CT_enum _ -> true + | _ -> false + +let is_ct_tup = function + | CT_tup _ -> true + | _ -> false + +let rec compile_aval ctx = function + | AV_C_fragment (code, typ) -> + [], CV_C_fragment (code, ctyp_of_typ ctx typ), [] + + | AV_id (id, typ) -> + begin + match ctyp_of_typ ctx (lvar_typ typ) with + | CT_enum (_, elems) when IdSet.mem id elems -> + [], CV_C_fragment (Util.zencode_upper_string (string_of_id id), ctyp_of_typ ctx (lvar_typ typ)), [] + | _ -> + [], CV_id (id, ctyp_of_typ ctx (lvar_typ typ)), [] + end + + | AV_lit (L_aux (L_string str, _), typ) -> + [], CV_C_fragment ("\"" ^ str ^ "\"", ctyp_of_typ ctx typ), [] + + | AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal min_int64 n && Big_int.less_equal n max_int64 -> + let gs = gensym () in + [I_decl (CT_mpz, gs); + I_init (CT_mpz, gs, CV_C_fragment (Big_int.to_string n ^ "L", CT_int64))], + CV_id (gs, CT_mpz), + [I_clear (CT_mpz, gs)] + + | AV_lit (L_aux (L_num n, _), typ) -> + let gs = gensym () in + [ I_decl (CT_mpz, gs); + I_init (CT_mpz, gs, CV_C_fragment ("\"" ^ Big_int.to_string n ^ "\"", CT_string)) ], + CV_id (gs, CT_mpz), + [I_clear (CT_mpz, gs)] + + | AV_tuple avals -> + let elements = List.map (compile_aval ctx) avals in + let cvals = List.map (fun (_, cval, _) -> cval) elements in + let setup = List.concat (List.map (fun (setup, _, _) -> setup) elements) in + let cleanup = List.concat (List.rev (List.map (fun (_, _, cleanup) -> cleanup) elements)) in + let tup_ctyp = CT_tup (List.map cval_ctyp cvals) in + let gs = gensym () in + setup + @ [I_decl (tup_ctyp, gs)] + @ List.mapi (fun n cval -> I_copy (CL_field (gs, mk_id ("tup" ^ string_of_int n)), cval)) cvals, + CV_id (gs, CT_tup (List.map cval_ctyp cvals)), + cleanup + +let compile_funcall ctx id args typ = let setup = ref [] in let cleanup = ref [] in - let _, Typ_aux (fn_typ, _) = Env.get_val_spec id env in + let _, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in let arg_typs, ret_typ = match fn_typ with | Typ_fn (Typ_aux (Typ_tup arg_typs, _), ret_typ, _) -> arg_typs, ret_typ | Typ_fn (arg_typ, ret_typ, _) -> [arg_typ], ret_typ | _ -> assert false in - let arg_ctyps, ret_ctyp = List.map ctyp_of_typ arg_typs, ctyp_of_typ ret_typ in - let final_ctyp = ctyp_of_typ typ in + let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in + let final_ctyp = ctyp_of_typ ctx typ in let setup_arg ctyp aval = - match aval with - | AV_C_fragment (c, typ) -> - if is_stack_ctyp ctyp then - CV_C_fragment (c, ctyp_of_typ typ) - else - let gs = gensym () in - setup := I_decl (ctyp, gs) :: !setup; - setup := I_init (ctyp, gs, CV_C_fragment (c, ctyp_of_typ typ)) :: !setup; - cleanup := I_clear (ctyp, gs) :: !cleanup; - CV_id (gs, ctyp) - | AV_id (id, lvar) -> - let have_ctyp = ctyp_of_typ (lvar_typ lvar) in - if ctyp_equal ctyp have_ctyp then - CV_id (id, ctyp) - else if is_stack_ctyp have_ctyp && not (is_stack_ctyp ctyp) then - let gs = gensym () in - setup := I_decl (ctyp, gs) :: !setup; - setup := I_init (ctyp, gs, CV_id (id, have_ctyp)) :: !setup; - cleanup := I_clear (ctyp, gs) :: !cleanup; - CV_id (gs, ctyp) - else - CV_id (mk_id ("????" ^ string_of_ctyp (ctyp_of_typ (lvar_typ lvar))), ctyp) - | _ -> CV_id (mk_id "???", ctyp) + let arg_setup, cval, arg_cleanup = compile_aval ctx aval in + setup := List.rev arg_setup @ !setup; + cleanup := arg_cleanup @ !cleanup; + let have_ctyp = cval_ctyp cval in + if ctyp_equal ctyp have_ctyp then + cval + else if is_stack_ctyp have_ctyp && not (is_stack_ctyp ctyp) then + let gs = gensym () in + setup := I_decl (ctyp, gs) :: !setup; + setup := I_init (ctyp, gs, cval) :: !setup; + cleanup := I_clear (ctyp, gs) :: !cleanup; + CV_id (gs, ctyp) + else + assert false in let sargs = List.map2 setup_arg arg_ctyps args in @@ -677,7 +969,7 @@ let compile_funcall env id args typ = else if not (is_stack_ctyp ret_ctyp) && is_stack_ctyp final_ctyp then let gs = gensym () in setup := I_alloc (ret_ctyp, gs) :: !setup; - setup := I_funcall (gs, id, sargs, ret_ctyp) :: !setup; + setup := I_funcall (CL_id gs, id, sargs, ret_ctyp) :: !setup; cleanup := I_clear (ret_ctyp, gs) :: !cleanup; fun ret -> I_convert (ret, final_ctyp, gs, ret_ctyp) else @@ -686,128 +978,373 @@ let compile_funcall env id args typ = (List.rev !setup, final_ctyp, call, !cleanup) -let rec compile_aexp env = function +let rec compile_match ctx apat cval case_label = + match apat, cval with + | AP_id pid, CV_C_fragment (code, ctyp) when is_ct_enum ctyp -> + [ I_if (CV_C_fragment (Util.zencode_upper_string (string_of_id pid) ^ " != " ^ code, CT_bool), [I_goto case_label], [], CT_unit) ] + | AP_id pid, CV_id (id, ctyp) when is_ct_enum ctyp -> + [ I_if (CV_C_fragment (Util.zencode_upper_string (string_of_id pid) ^ " != " ^ Util.zencode_string (string_of_id id), CT_bool), [I_goto case_label], [], CT_unit) ] + | AP_id pid, CV_C_fragment (code, ctyp) -> + [ I_decl (cval_ctyp cval, pid); I_copy (CL_id pid, cval) ] + | AP_id pid, CV_id _ -> + [ I_decl (cval_ctyp cval, pid); I_copy (CL_id pid, cval) ] + | AP_tup apats, CV_id (id, ctyp) -> + begin + let get_tup n ctyp = CV_C_fragment (Util.zencode_string (string_of_id id) ^ ".ztup" ^ string_of_int n, ctyp) in + match ctyp with + | CT_tup ctyps -> + fst (List.fold_left2 (fun (instrs, n) apat ctyp -> instrs @ compile_match ctx apat (get_tup n ctyp) case_label, n + 1) ([], 0) apats ctyps) + | _ -> assert false + end + | _, _ -> [] + +let unit_fragment = CV_C_fragment ("UNIT", CT_unit) + +(** GLOBAL: label_counter is used to make sure all labels have unique + names. Like gensym_counter it should be safe to reset between + top-level definitions. **) +let label_counter = ref 0 + +let label str = + let str = str ^ string_of_int !label_counter in + incr label_counter; + str + +let rec compile_aexp ctx = function | AE_let (id, _, binding, body, typ) -> - let setup, ctyp, call, cleanup = compile_aexp env binding in + let setup, ctyp, call, cleanup = compile_aexp ctx binding in let letb1, letb1c = if is_stack_ctyp ctyp then - [I_decl (ctyp, id); call id], [] + [I_decl (ctyp, id); call (CL_id id)], [] else - [I_alloc (ctyp, id); call id], [I_clear (ctyp, id)] + [I_alloc (ctyp, id); call (CL_id id)], [I_clear (ctyp, id)] in let letb2 = setup @ letb1 @ cleanup in - let setup, ctyp, call, cleanup = compile_aexp env body in + let setup, ctyp, call, cleanup = compile_aexp ctx body in letb2 @ setup, ctyp, call, cleanup @ letb1c | AE_app (id, vs, typ) -> - compile_funcall env id vs typ - - | AE_val (AV_C_fragment (c, typ)) -> - let ctyp = ctyp_of_typ typ in - [], ctyp, (fun id -> I_assign (id, CV_C_fragment (c, ctyp))), [] - - | AE_val (AV_id (id, lvar)) -> - let ctyp = ctyp_of_typ (lvar_typ lvar) in - [], ctyp, (fun id' -> I_assign (id', CV_id (id, ctyp))), [] - - | AE_val (AV_lit (lit, typ)) -> - let ctyp = ctyp_of_typ typ in - if is_stack_ctyp ctyp then - assert false - else + compile_funcall ctx id vs typ + + | AE_val aval -> + let setup, cval, cleanup = compile_aval ctx aval in + setup, cval_ctyp cval, (fun clexp -> I_copy (clexp, cval)), cleanup + + (* Compile case statements *) + | AE_case (aval, cases, typ) -> + let ctyp = ctyp_of_typ ctx typ in + let aval_setup, cval, aval_cleanup = compile_aval ctx aval in + let case_return_id = gensym () in + let finish_match_label = label "finish_match_" in + let compile_case (apat, guard, body) = + let trivial_guard = match guard with + | AE_val (AV_lit (L_aux (L_true, _), _)) + | AE_val (AV_C_fragment ("true", _)) -> true + | _ -> false + in + let case_label = label "case_" in + let destructure = compile_match ctx apat cval case_label in + let guard_setup, _, guard_call, guard_cleanup = compile_aexp ctx guard in + let body_setup, _, body_call, body_cleanup = compile_aexp ctx body in let gs = gensym () in - [I_alloc (ctyp, gs); I_comment "fix literal init"], - ctyp, - (fun id -> I_assign (id, CV_id (gs, ctyp))), - [I_clear (ctyp, gs)] + let case_instrs = + destructure @ [I_comment "end destructuring"] + @ (if not trivial_guard then + guard_setup @ [I_decl (CT_bool, gs); guard_call (CL_id gs)] @ guard_cleanup + @ [I_if (CV_C_fragment (Printf.sprintf "!%s" (Util.zencode_string (string_of_id gs)), CT_bool), [I_goto case_label], [], CT_unit)] + @ [I_comment "end guard"] + else []) + @ body_setup @ [body_call (CL_id case_return_id)] @ body_cleanup + @ [I_goto finish_match_label] + in + [I_block case_instrs; I_label case_label] + in + [I_comment "begin match"] + @ aval_setup @ [I_decl (ctyp, case_return_id)] + @ List.concat (List.map compile_case cases) + @ [I_raw "sail_match_failure();"] + @ [I_label finish_match_label], + ctyp, + (fun clexp -> I_copy (clexp, CV_id (case_return_id, ctyp))), + aval_cleanup + @ [I_comment "end match"] | AE_if (aval, then_aexp, else_aexp, if_typ) -> - let if_ctyp = ctyp_of_typ if_typ in + let if_ctyp = ctyp_of_typ ctx if_typ in let compile_branch aexp = - let setup, ctyp, call, cleanup = compile_aexp env aexp in - fun id -> setup @ [call id] @ cleanup + let setup, ctyp, call, cleanup = compile_aexp ctx aexp in + fun clexp -> setup @ [call clexp] @ cleanup in - let setup, ctyp, call, cleanup = compile_aexp env (AE_val aval) in + let setup, ctyp, call, cleanup = compile_aexp ctx (AE_val aval) in let gs = gensym () in - setup @ [I_decl (ctyp, gs); call gs], + setup @ [I_decl (ctyp, gs); call (CL_id gs)], if_ctyp, - (fun id -> I_if (CV_id (gs, ctyp), - compile_branch then_aexp id, - compile_branch else_aexp id, - if_ctyp)), + (fun clexp -> I_if (CV_id (gs, ctyp), + compile_branch then_aexp clexp, + compile_branch else_aexp clexp, + if_ctyp)), cleanup - + + | AE_record_update (aval, fields, typ) -> + let update_field (prev_setup, prev_calls, prev_cleanup) (field, aval) = + let setup, _, call, cleanup = compile_aexp ctx (AE_val aval) in + prev_setup @ setup, call :: prev_calls, cleanup @ prev_cleanup + in + let setup, calls, cleanup = List.fold_left update_field ([], [], []) (Bindings.bindings fields) in + let ctyp = ctyp_of_typ ctx typ in + let gs = gensym () in + [I_alloc (ctyp, gs)] @ setup @ List.map (fun call -> call (CL_id gs)) calls, + ctyp, + (fun clexp -> I_copy (clexp, CV_id (gs, ctyp))), + cleanup @ [I_clear (ctyp, gs)] + | AE_assign (id, assign_typ, aexp) -> (* assign_ctyp is the type of the C variable we are assigning to, ctyp is the type of the C expression being assigned. These may be different. *) - let assign_ctyp = ctyp_of_typ assign_typ in - let setup, ctyp, call, cleanup = compile_aexp env aexp in - let unit_fragment = CV_C_fragment ("0", CT_unit) in + let assign_ctyp = ctyp_of_typ ctx assign_typ in + let setup, ctyp, call, cleanup = compile_aexp ctx aexp in let comment = "assign " ^ string_of_ctyp assign_ctyp ^ " := " ^ string_of_ctyp ctyp in if ctyp_equal assign_ctyp ctyp then - setup @ [call id], CT_unit, (fun id -> I_assign (id, unit_fragment)), cleanup + setup @ [call (CL_id id)], CT_unit, (fun clexp -> I_copy (clexp, unit_fragment)), cleanup else if not (is_stack_ctyp assign_ctyp) && is_stack_ctyp ctyp then let gs = gensym () in setup @ [ I_comment comment; I_decl (ctyp, gs); - call gs; - I_convert (id, assign_ctyp, gs, ctyp) + call (CL_id gs); + I_convert (CL_id id, assign_ctyp, gs, ctyp) ], CT_unit, - (fun id -> I_assign (id, unit_fragment)), + (fun clexp -> I_copy (clexp, unit_fragment)), cleanup else - failwith ("Failure: " ^ comment) - + failwith comment + | AE_block (aexps, aexp, _) -> - let block = compile_block env aexps in - let setup, ctyp, call, cleanup = compile_aexp env aexp in + let block = compile_block ctx aexps in + let setup, ctyp, call, cleanup = compile_aexp ctx aexp in block @ setup, ctyp, call, cleanup - | AE_cast (aexp, typ) -> compile_aexp env aexp - -and compile_block env = function + | AE_loop (While, cond, body) -> + let loop_start_label = label "while_" in + let loop_end_label = label "wend_" in + let cond_setup, _, cond_call, cond_cleanup = compile_aexp ctx cond in + let body_setup, _, body_call, body_cleanup = compile_aexp ctx body in + let gs = gensym () in + let unit_gs = gensym () in + let loop_test = CV_C_fragment (Printf.sprintf "!%s" (Util.zencode_string (string_of_id gs)), CT_bool) in + cond_setup @ [I_decl (CT_bool, gs); I_decl (CT_unit, unit_gs)] + @ [I_label loop_start_label] + @ [I_block ([cond_call (CL_id gs); I_if (loop_test, [I_goto loop_end_label], [], CT_unit)] + @ body_setup + @ [body_call (CL_id unit_gs)] + @ body_cleanup + @ [I_goto loop_start_label])] + @ [I_label loop_end_label], + CT_unit, + (fun clexp -> I_copy (clexp, unit_fragment)), + cond_cleanup + + | AE_cast (aexp, typ) -> compile_aexp ctx aexp + + | AE_return (aval, typ) -> + (* Cleanup info will be re-added by fix_early_return *) + let return_setup, cval, _ = compile_aval ctx aval in + return_setup @ [I_return cval], + CT_unit, + (fun clexp -> I_copy (clexp, unit_fragment)), + [] + + | aexp -> failwith ("Cannot compile ANF expression: " ^ Pretty_print_sail.to_string (pp_aexp aexp)) + +and compile_block ctx = function | [] -> [] | exp :: exps -> - let setup, _, call, cleanup = compile_aexp env exp in - let rest = compile_block env exps in + let setup, _, call, cleanup = compile_aexp ctx exp in + let rest = compile_block ctx exps in let gs = gensym () in - setup @ [I_decl (CT_unit, gs); call gs] @ cleanup @ rest + setup @ [I_decl (CT_unit, gs); call (CL_id gs)] @ cleanup @ rest -let rec pat_ids (P_aux (p_aux, _)) = +let rec pat_ids (P_aux (p_aux, _) as pat) = match p_aux with | P_id id -> [id] | P_tup pats -> List.concat (List.map pat_ids pats) - | _ -> failwith "Bad pattern" + | P_lit (L_aux (L_unit, _)) -> let gs = gensym () in [gs] + | P_wild -> let gs = gensym () in [gs] + | _ -> failwith ("Bad pattern " ^ string_of_pat pat) + +(** Compile a sail type definition into a LLcode one. Most of the + actual work of translating the typedefs into C is done by the code + generator, as it's easy to keep track of structs, tuples and unions + in their sail form at this level, and leave the fiddly details of + how they get mapped to C in the next stage. This function also adds + details of the types it compiles to the context, ctx, which is why + it returns a ctypdef * ctx pair. **) +let compile_type_def ctx (TD_aux (type_def, _)) = + match type_def with + | TD_enum (id, _, ids, _) -> + CTD_enum (id, IdSet.of_list ids), + { ctx with enums = Bindings.add id (IdSet.of_list ids) ctx.enums } + + | TD_record (id, _, _, ctors, _) -> + let ctors = List.fold_left (fun ctors (typ, id) -> Bindings.add id (ctyp_of_typ ctx typ) ctors) Bindings.empty ctors in + CTD_record (id, ctors), + { ctx with records = Bindings.add id ctors ctx.records } + + | TD_variant (id, _, _, tus, _) -> + let compile_tu (Tu_aux (tu_aux, _)) = + match tu_aux with + | Tu_id id -> CT_unit, id + | Tu_ty_id (typ, id) -> ctyp_of_typ ctx typ, id + in + let ctus = List.fold_left (fun ctus (ctyp, id) -> Bindings.add id ctyp ctus) Bindings.empty (List.map compile_tu tus) in + CTD_variant (id, ctus), + { ctx with variants = Bindings.add id ctus ctx.variants } + + (* Will be re-written before here, see bitfield.ml *) + | TD_bitfield _ -> failwith "Cannot compile TD_bitfield" + (* All type abbreviations are filtered out in compile_def *) + | TD_abbrev _ -> assert false + +let instr_split_at f = + let rec instr_split_at' f before = function + | [] -> (List.rev before, []) + | instr :: instrs when f instr -> (List.rev before, instr :: instrs) + | instr :: instrs -> instr_split_at' f (instr :: before) instrs + in + instr_split_at' f [] -let compile_def env = function +let generate_cleanup instrs = + let generate_cleanup' = function + | I_decl (ctyp, id) when not (is_stack_ctyp ctyp) -> [(id, I_clear (ctyp, id))] + | I_alloc (ctyp, id) when not (is_stack_ctyp ctyp) -> [(id, I_clear (ctyp, id))] + | _ -> [] + in + let is_clear ids = function + | I_clear (_, id) -> IdSet.add id ids + | _ -> ids + in + let cleaned = List.fold_left is_clear IdSet.empty instrs in + instrs + |> List.map generate_cleanup' + |> List.concat + |> List.filter (fun (id, _) -> not (IdSet.mem id cleaned)) + |> List.map snd + +(** Functions that have heap-allocated return types are implemented by + passing a pointer a location where the return value should be + stored. The ANF -> LLcode pass for expressions simply outputs an + I_return instruction for any return value, so this function walks + over the LLcode ast for expressions and modifies the return + statements into code that sets that pointer, as well as adds extra + control flow to cleanup heap-allocated variables correctly when a + function terminates early. See the generate_cleanup function for + how this is done. *) +let fix_early_return ret ctx instrs = + let end_function_label = label "end_function_" in + let is_return_recur = function + | I_return _ | I_if _ | I_block _ -> true + | _ -> false + in + let rec rewrite_return pre_cleanup instrs = + match instr_split_at is_return_recur instrs with + | instrs, [] -> instrs + | before, I_block instrs :: after -> + before + @ [I_block (rewrite_return (pre_cleanup @ generate_cleanup before) instrs)] + @ rewrite_return pre_cleanup after + | before, I_if (cval, then_instrs, else_instrs, ctyp) :: after -> + let cleanup = pre_cleanup @ generate_cleanup before in + before + @ [I_if (cval, rewrite_return cleanup then_instrs, rewrite_return cleanup else_instrs, ctyp)] + @ rewrite_return pre_cleanup after + | before, I_return cval :: after -> + let cleanup_label = label "cleanup_" in + let end_cleanup_label = label "end_cleanup_" in + before + @ [I_copy (ret, cval); + I_goto cleanup_label] + (* This is probably dead code until cleanup_label, but how can we be sure there are no jumps into it? *) + @ rewrite_return pre_cleanup after + @ [I_goto end_cleanup_label] + @ [I_label cleanup_label] + @ pre_cleanup + @ generate_cleanup before + @ [I_goto end_function_label] + @ [I_label end_cleanup_label] + | _, _ -> assert false + in + rewrite_return [] instrs + @ [I_label end_function_label] + +(** Compile a Sail toplevel definition into an LLcode definition **) +let compile_def ctx = function | DEF_reg_dec (DEC_aux (DEC_reg (typ, id), _)) -> - [CDEF_reg_dec (ctyp_of_typ typ, id)] + [CDEF_reg_dec (ctyp_of_typ ctx typ, id)], ctx | DEF_reg_dec _ -> failwith "Unsupported register declaration" (* FIXME *) - | DEF_spec _ -> [] + | DEF_spec _ -> [], ctx | DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, pexp), _)]), _)) -> begin match pexp with | Pat_aux (Pat_exp (pat, exp), _) -> - let aexp = map_functions analyze_primop (c_literals (anf exp)) in - print_endline (Pretty_print_sail.to_string (pp_aexp aexp)); - let setup, ctyp, call, cleanup = compile_aexp env aexp in + let aexp = map_functions (analyze_primop ctx) (c_literals ctx (anf exp)) in + prerr_endline (Pretty_print_sail.to_string (pp_aexp aexp)); + let setup, ctyp, call, cleanup = compile_aexp ctx aexp in let gs = gensym () in - let instrs = - if is_stack_ctyp ctyp then - setup @ [I_decl (ctyp, gs); call gs] @ cleanup @ [I_return gs] - else - assert false - in - [CDEF_fundef (id, pat_ids pat, instrs)] + if is_stack_ctyp ctyp then + let instrs = [I_decl (ctyp, gs)] @ setup @ [call (CL_id gs)] @ cleanup @ [I_return (CV_id (gs, ctyp))] in + [CDEF_fundef (id, None, pat_ids pat, instrs)], ctx + else + let instrs = setup @ [call (CL_addr (CL_id gs))] @ cleanup in + let instrs = fix_early_return (CL_addr (CL_id gs)) ctx instrs in + [CDEF_fundef (id, Some gs, pat_ids pat, instrs)], ctx | _ -> assert false end - | DEF_default _ -> [] - + (* All abbreviations should expanded by the typechecker, so we don't + need to translate type abbreviations into C typedefs. *) + | DEF_type (TD_aux (TD_abbrev _, _)) -> [], ctx + + | DEF_type type_def -> + let tdef, ctx = compile_type_def ctx type_def in + [CDEF_type tdef], ctx + + (* Only DEF_default that matters is default Order, but all order + polymorphism is specialised by this point. *) + | DEF_default _ -> [], ctx + + (* Overloading resolved by type checker *) + | DEF_overload _ -> [], ctx + + (* Only the parser and sail pretty printer care about this. *) + | DEF_fixity _ -> [], ctx + + | _ -> assert false + +(** To keep things neat we use GCC's local labels extension to limit + the scope of labels. We do this by iterating over all the blocks + and adding a __label__ declaration with all the labels local to + that block. The add_local_labels function is called by the code + generator just before it outputs C. + + See https://gcc.gnu.org/onlinedocs/gcc/Local-Labels.html **) +let add_local_labels' instrs = + let is_label = function + | I_label str -> [str] + | _ -> [] + in + let labels = List.concat (List.map is_label instrs) in + let local_label_decl = I_raw ("__label__ " ^ String.concat ", " labels ^ ";\n") in + if labels = [] then + instrs + else + local_label_decl :: instrs + +let add_local_labels instrs = + match map_instrs add_local_labels' (I_block instrs) with + | I_block instrs -> instrs | _ -> assert false (**************************************************************************) @@ -817,114 +1354,291 @@ let compile_def env = function let sgen_id id = Util.zencode_string (string_of_id id) let codegen_id id = string (sgen_id id) +let upper_sgen_id id = Util.zencode_upper_string (string_of_id id) +let upper_codegen_id id = string (upper_sgen_id id) + let sgen_ctyp = function - | CT_unit -> "int" + | CT_unit -> "unit" + | CT_int -> "int" + | CT_bool -> "bool" + | CT_uint64 _ -> "uint64_t" + | CT_int64 -> "int64_t" + | CT_mpz -> "mpz_t" + | CT_bv _ -> "bv_t" + | CT_tup _ as tup -> "struct " ^ Util.zencode_string ("tuple_" ^ string_of_ctyp tup) + | CT_struct (id, _) -> "struct " ^ sgen_id id + | CT_enum (id, _) -> "enum " ^ sgen_id id + | CT_variant (id, _) -> "struct " ^ sgen_id id + | CT_string -> "sail_string" + +let sgen_ctyp_name = function + | CT_unit -> "unit" | CT_int -> "int" | CT_bool -> "bool" | CT_uint64 _ -> "uint64_t" | CT_int64 -> "int64_t" | CT_mpz -> "mpz_t" | CT_bv _ -> "bv_t" + | CT_tup _ as tup -> Util.zencode_string ("tuple_" ^ string_of_ctyp tup) + | CT_struct (id, _) -> sgen_id id + | CT_enum (id, _) -> sgen_id id + | CT_variant (id, _) -> sgen_id id + | CT_string -> "sail_string" let sgen_cval = function | CV_C_fragment (c, _) -> c - | CV_id (id, _) -> string_of_id id + | CV_id (id, _) -> sgen_id id | _ -> "CVAL??" -let rec codegen_instr = function +let sgen_clexp = function + | CL_id id -> "&" ^ sgen_id id + | CL_field (id, field) -> "&(" ^ sgen_id id ^ "." ^ sgen_id field ^ ")" + | CL_addr (CL_id id) -> sgen_id id + | _ -> assert false + +let sgen_clexp_pure = function + | CL_id id -> sgen_id id + | CL_field (id, field) -> sgen_id id ^ "." ^ sgen_id field + | _ -> assert false + +let rec codegen_instr ctx = function | I_decl (ctyp, id) -> - string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (string_of_id id)) - | I_assign (id, cval) -> + string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_id id)) + | I_copy (clexp, cval) -> let ctyp = cval_ctyp cval in if is_stack_ctyp ctyp then - string (Printf.sprintf "%s = %s;" (string_of_id id) (sgen_cval cval)) + string (Printf.sprintf " %s = %s;" (sgen_clexp_pure clexp) (sgen_cval cval)) else - string (Printf.sprintf "set_%s(%s, %s);" (sgen_ctyp ctyp) (string_of_id id) (sgen_cval cval)) + string (Printf.sprintf " set_%s(%s, %s);" (sgen_ctyp_name ctyp) (sgen_clexp clexp) (sgen_cval cval)) + | I_if (cval, [then_instr], [], ctyp) -> + string (Printf.sprintf " if (%s)" (sgen_cval cval)) ^^ hardline + ^^ twice space ^^ codegen_instr ctx then_instr | I_if (cval, then_instrs, else_instrs, ctyp) -> - string "if" ^^ space ^^ parens (string (sgen_cval cval)) ^^ space - ^^ surround 2 0 lbrace (separate_map hardline codegen_instr then_instrs) rbrace + string " if" ^^ space ^^ parens (string (sgen_cval cval)) ^^ space + ^^ surround 2 0 lbrace (separate_map hardline (codegen_instr ctx) then_instrs) (twice space ^^ rbrace) ^^ space ^^ string "else" ^^ space - ^^ surround 2 0 lbrace (separate_map hardline codegen_instr else_instrs) rbrace + ^^ surround 2 0 lbrace (separate_map hardline (codegen_instr ctx) else_instrs) (twice space ^^ rbrace) + | I_block instrs -> + string " {" + ^^ jump 2 2 (separate_map hardline (codegen_instr ctx) instrs) ^^ hardline + ^^ string " }" | I_funcall (x, f, args, ctyp) -> let args = Util.string_of_list ", " sgen_cval args in + let fname = if Env.is_extern f ctx.tc_env "c" then Env.get_extern f ctx.tc_env "c" else sgen_id f in if is_stack_ctyp ctyp then - string (Printf.sprintf "%s = %s(%s);" (string_of_id x) (sgen_id f) args) + string (Printf.sprintf " %s = %s(%s);" (sgen_clexp_pure x) fname args) else - string (Printf.sprintf "%s(%s, %s);" (sgen_id f) (string_of_id x) args) + string (Printf.sprintf " %s(%s, %s);" fname (sgen_clexp x) args) | I_clear (ctyp, id) -> - string (Printf.sprintf "clear_%s(%s);" (sgen_ctyp ctyp) (string_of_id id)) + string (Printf.sprintf " clear_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id)) | I_init (ctyp, id, cval) -> - string (Printf.sprintf "init_%s_of_%s(%s, %s);" - (sgen_ctyp ctyp) - (sgen_ctyp (cval_ctyp cval)) - (string_of_id id) + string (Printf.sprintf " init_%s_of_%s(&%s, %s);" + (sgen_ctyp_name ctyp) + (sgen_ctyp_name (cval_ctyp cval)) + (sgen_id id) (sgen_cval cval)) | I_alloc (ctyp, id) -> - string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (string_of_id id)) + string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_id id)) ^^ hardline - ^^ string (Printf.sprintf "init_%s(%s);" (sgen_ctyp ctyp) (string_of_id id)) + ^^ string (Printf.sprintf " init_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id)) | I_convert (x, ctyp1, y, ctyp2) -> if is_stack_ctyp ctyp1 then - string (Printf.sprintf "%s = convert_%s_of_%s(%s);" - (string_of_id x) - (sgen_ctyp ctyp1) - (sgen_ctyp ctyp2) - (string_of_id y)) + string (Printf.sprintf " %s = convert_%s_of_%s(%s);" + (sgen_clexp_pure x) + (sgen_ctyp_name ctyp1) + (sgen_ctyp_name ctyp2) + (sgen_id y)) else - string (Printf.sprintf "convert_%s_of_%s(%s, %s);" - (sgen_ctyp ctyp1) - (sgen_ctyp ctyp2) - (string_of_id x) - (string_of_id y)) - | I_return id -> - string (Printf.sprintf "return %s;" (string_of_id id)) + string (Printf.sprintf " convert_%s_of_%s(%s, %s);" + (sgen_ctyp_name ctyp1) + (sgen_ctyp_name ctyp2) + (sgen_clexp x) + (sgen_id y)) + | I_return cval -> + string (Printf.sprintf " return %s;" (sgen_cval cval)) | I_comment str -> - string ("/* " ^ str ^ " */") - -let codegen_def env = function + string (" /* " ^ str ^ " */") + | I_label str -> + string (str ^ ": ;") + | I_goto str -> + string (Printf.sprintf " goto %s;" str) + | I_raw str -> + string (" " ^ str) + +let codegen_type_def ctx = function + | CTD_enum (id, ids) -> + string (Printf.sprintf "// enum %s" (string_of_id id)) ^^ hardline + ^^ separate space [string "enum"; codegen_id id; lbrace; separate_map (comma ^^ space) upper_codegen_id (IdSet.elements ids); rbrace ^^ semi] + + | CTD_record (id, ctors) -> + (* Generate a set_T function for every struct T *) + let codegen_set (id, ctyp) = + if is_stack_ctyp ctyp then + string (Printf.sprintf "rop->%s = op.%s;" (sgen_id id) (sgen_id id)) + else + string (Printf.sprintf "set_%s(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) + in + let codegen_setter id ctors = + string (let n = sgen_id id in Printf.sprintf "void set_%s(struct %s *rop, const struct %s op)" n n n) ^^ space + ^^ surround 2 0 lbrace + (separate_map hardline codegen_set (Bindings.bindings ctors)) + rbrace + in + (* Generate an init/clear_T function for every struct T *) + let codegen_field_init f (id, ctyp) = + if not (is_stack_ctyp ctyp) then + [string (Printf.sprintf "%s_%s(&op->%s);" f (sgen_ctyp_name ctyp) (sgen_id id))] + else [] + in + let codegen_init f id ctors = + string (let n = sgen_id id in Printf.sprintf "void %s_%s(struct %s *op)" f n n) ^^ space + ^^ surround 2 0 lbrace + (separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) + rbrace + in + (* Generate the struct and add the generated functions *) + let codegen_ctor (id, ctyp) = + string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id + in + string (Printf.sprintf "// struct %s" (string_of_id id)) ^^ hardline + ^^ string "struct" ^^ space ^^ codegen_id id ^^ space + ^^ surround 2 0 lbrace + (separate_map (semi ^^ hardline) codegen_ctor (Bindings.bindings ctors) ^^ semi) + rbrace + ^^ semi ^^ twice hardline + ^^ codegen_setter id ctors + ^^ twice hardline + ^^ codegen_init "init" id ctors + ^^ twice hardline + ^^ codegen_init "clear" id ctors + + | CTD_variant (id, tus) -> + let codegen_tu (id, ctyp) = + separate space [string "struct"; lbrace; string (sgen_ctyp ctyp); codegen_id id ^^ semi; rbrace] + in + string (Printf.sprintf "// union %s" (string_of_id id)) ^^ hardline + ^^ string "enum" ^^ space + ^^ string ("kind_" ^ sgen_id id) ^^ space + ^^ separate space [lbrace; separate_map (comma ^^ space) (fun id -> string ("Kind_" ^ sgen_id id)) (List.map fst (Bindings.bindings tus)); rbrace ^^ semi] + ^^ hardline ^^ hardline + ^^ string "struct" ^^ space ^^ codegen_id id ^^ space + ^^ surround 2 0 lbrace + (separate space [string "enum"; string ("kind_" ^ sgen_id id); string "kind" ^^ semi] + ^^ hardline + ^^ string "union" ^^ space + ^^ surround 2 0 lbrace + (separate_map (semi ^^ hardline) codegen_tu (Bindings.bindings tus) ^^ semi) + rbrace + ^^ semi) + rbrace + ^^ semi + +(** GLOBAL: because C doesn't have real anonymous tuple types + (anonymous structs don't quite work the way we need) every tuple + type in the spec becomes some generated named struct in C. This is + done in such a way that every possible tuple type has a unique name + associated with it. This global variable keeps track of these + generated struct names, so we never generate two copies of the + struct that is used to represent them in C. + + The way this works is that codegen_def scans each definition's type + annotations for tuple types and generates the required structs + using codegen_type_def before the actual definition is generated by + codegen_def'. + + This variable should be reset to empty only when the entire AST has + been translated to C. **) +let generated_tuples = ref IdSet.empty + +let codegen_tup ctx ctyps = + let id = mk_id ("tuple_" ^ string_of_ctyp (CT_tup ctyps)) in + if IdSet.mem id !generated_tuples then + empty + else + let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields) + (0, Bindings.empty) + ctyps + in + generated_tuples := IdSet.add id !generated_tuples; + codegen_type_def ctx (CTD_record (id, fields)) ^^ twice hardline + +let codegen_def' ctx = function | CDEF_reg_dec (ctyp, id) -> - string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id)) - | CDEF_fundef (id, args, instrs) -> - List.iter (fun instr -> print_endline (Pretty_print_sail.to_string (pp_instr instr))) instrs; - let _, Typ_aux (fn_typ, _) = Env.get_val_spec id env in + string (Printf.sprintf "// register %s" (string_of_id id)) ^^ hardline + ^^ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id)) + + | CDEF_fundef (id, ret_arg, args, instrs) -> + let instrs = add_local_labels instrs in + List.iter (fun instr -> prerr_endline (Pretty_print_sail.to_string (pp_instr instr))) instrs; + let _, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in let arg_typs, ret_typ = match fn_typ with | Typ_fn (Typ_aux (Typ_tup arg_typs, _), ret_typ, _) -> arg_typs, ret_typ | Typ_fn (arg_typ, ret_typ, _) -> [arg_typ], ret_typ | _ -> assert false in - let arg_ctyps, ret_ctyp = List.map ctyp_of_typ arg_typs, ctyp_of_typ ret_typ in - + let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in let args = Util.string_of_list ", " (fun x -> x) (List.map2 (fun ctyp arg -> sgen_ctyp ctyp ^ " " ^ sgen_id arg) arg_ctyps args) in - - string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline + let function_header = + match ret_arg with + | None -> + assert (is_stack_ctyp ret_ctyp); + string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline + | Some gs -> + assert (not (is_stack_ctyp ret_ctyp)); + string "void" ^^ space ^^ codegen_id id + ^^ parens (string (sgen_ctyp ret_ctyp ^ " *" ^ sgen_id gs ^ ", ") ^^ string args) + ^^ hardline + in + function_header ^^ string "{" - ^^ jump 2 2 (separate_map hardline codegen_instr instrs) ^^ hardline + ^^ jump 0 2 (separate_map hardline (codegen_instr ctx) instrs) ^^ hardline ^^ string "}" -let compile_ast env (Defs defs) = - let cdefs = List.concat (List.map (compile_def env) defs) in - let docs = List.map (codegen_def env) cdefs in + | CDEF_type ctype_def -> + codegen_type_def ctx ctype_def + +let codegen_def ctx def = + let untup = function + | CT_tup ctyps -> ctyps + | _ -> assert false + in + let tups = List.filter is_ct_tup (cdef_ctyps def) in + let tups = List.map (fun ctyp -> codegen_tup ctx (untup ctyp)) tups in + concat tups + ^^ codegen_def' ctx def + +let compile_ast ctx (Defs defs) = + let chunks, ctx = List.fold_left (fun (chunks, ctx) def -> let defs, ctx = compile_def ctx def in defs :: chunks, ctx) ([], ctx) defs in + let cdefs = List.concat (List.rev chunks) in + let docs = List.map (codegen_def ctx) cdefs in let preamble = separate hardline [ string "#include \"sail.h\"" ] in + let postamble = separate hardline + [ string "int main(void)"; + string "{"; + string " zmain(UNIT);"; + string "}" ] + in + let hlhl = hardline ^^ hardline in - Pretty_print_sail.to_string (preamble ^^ hlhl ^^ separate hlhl docs) + Pretty_print_sail.to_string (preamble ^^ hlhl ^^ separate hlhl docs ^^ hlhl ^^ postamble) |> print_endline let print_compiled (setup, ctyp, call, cleanup) = List.iter (fun instr -> print_endline (Pretty_print_sail.to_string (pp_instr instr))) setup; - print_endline (Pretty_print_sail.to_string (pp_instr (call (mk_id ("?" ^ string_of_ctyp ctyp))))); + print_endline (Pretty_print_sail.to_string (pp_instr (call (CL_id (mk_id ("?" ^ string_of_ctyp ctyp)))))); List.iter (fun instr -> print_endline (Pretty_print_sail.to_string (pp_instr instr))) cleanup -let compile_exp env exp = +let compile_exp ctx exp = let aexp = anf exp in - let aexp = c_literals aexp in - let aexp = map_functions analyze_primop aexp in + let aexp = c_literals ctx aexp in + let aexp = map_functions (analyze_primop ctx) aexp in print_endline "\n###################### COMPILED ######################\n"; - print_compiled (compile_aexp env aexp); + print_compiled (compile_aexp ctx aexp); print_endline "\n###################### ANF ######################\n"; aexp diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem index 5019c2f7..d398ab52 100644 --- a/src/gen_lib/prompt.lem +++ b/src/gen_lib/prompt.lem @@ -1,172 +1,8 @@ open import Pervasives_extra open import Sail_impl_base open import Sail_values - -type M 'a 'e = outcome 'a 'e - -val return : forall 'a 'e. 'a -> M 'a 'e -let return a = Done a - -val bind : forall 'a 'b 'e. M 'a 'e -> ('a -> M 'b 'e) -> M 'b 'e -let rec bind m f = match m with - | Done a -> f a - | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (bind o f,opt)) - | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (bind o f,opt)) - | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (bind o f,opt)) - | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (bind o f,opt)) - | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (bind o f,opt)) - | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (bind o f,opt)) - | Footprint o_s -> Footprint (let (o,opt) = o_s in (bind o f,opt)) - | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (bind o f,opt)) - | Escape descr -> Escape descr - | Fail descr -> Fail descr - | Error descr -> Error descr - | Exception e -> Exception e - | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (bind o f ,opt)) -end - -let inline (>>=) = bind -val (>>) : forall 'b 'e. M unit 'e -> M 'b 'e -> M 'b 'e -let inline (>>) m n = m >>= fun (_ : unit) -> n - -val exit : forall 'a 'e. unit -> M 'a 'e -let exit () = Fail Nothing - -val assert_exp : forall 'e. bool -> string -> M unit 'e -let assert_exp exp msg = if exp then Done () else Fail (Just msg) - -val throw : forall 'a 'e. 'e -> M 'a 'e -let throw e = Exception e - -val try_catch : forall 'a 'e1 'e2. M 'a 'e1 -> ('e1 -> M 'a 'e2) -> M 'a 'e2 -let rec try_catch m h = match m with - | Done a -> Done a - | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) - | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) - | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) - | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (try_catch o h,opt)) - | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (try_catch o h,opt)) - | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (try_catch o h,opt)) - | Footprint o_s -> Footprint (let (o,opt) = o_s in (try_catch o h,opt)) - | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (try_catch o h,opt)) - | Escape descr -> Escape descr - | Fail descr -> Fail descr - | Error descr -> Error descr - | Exception e -> h e - | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (try_catch o h ,opt)) -end - -(* For early return, we abuse exceptions by throwing and catching - the return value. The exception type is "either 'r 'e", where "Right e" - represents a proper exception and "Left r" an early return of value "r". *) -type MR 'a 'r 'e = M 'a (either 'r 'e) - -val early_return : forall 'a 'r 'e. 'r -> MR 'a 'r 'e -let early_return r = throw (Left r) - -val catch_early_return : forall 'a 'e. MR 'a 'a 'e -> M 'a 'e -let catch_early_return m = - try_catch m - (function - | Left a -> return a - | Right e -> throw e - end) - -(* Lift to monad with early return by wrapping exceptions *) -val liftR : forall 'a 'r 'e. M 'a 'e -> MR 'a 'r 'e -let liftR m = try_catch m (fun e -> throw (Right e)) - -(* Catch exceptions in the presence of early returns *) -val try_catchR : forall 'a 'r 'e1 'e2. MR 'a 'r 'e1 -> ('e1 -> MR 'a 'r 'e2) -> MR 'a 'r 'e2 -let try_catchR m h = - try_catch m - (function - | Left r -> throw (Left r) - | Right e -> h e - end) - - -val read_mem : forall 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> M 'b 'e -let read_mem rk addr sz = - let addr = address_lifted_of_bitv (bits_of addr) in - let sz = natFromInteger sz in - let k memory_value = - let bitv = of_bits (internal_mem_value memory_value) in - (Done bitv,Nothing) in - Read_mem (rk,addr,sz) k - -val excl_result : forall 'e. unit -> M bool 'e -let excl_result () = - let k successful = (return successful,Nothing) in - Excl_res k - -val write_mem_ea : forall 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M unit 'e -let write_mem_ea wk addr sz = - let addr = address_lifted_of_bitv (bits_of addr) in - let sz = natFromInteger sz in - Write_ea (wk,addr,sz) (Done (),Nothing) - -val write_mem_val : forall 'a 'e. Bitvector 'a => 'a -> M bool 'e -let write_mem_val v = - let v = external_mem_value (bits_of v) in - let k successful = (return successful,Nothing) in - Write_memv v k - -val read_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> M 'a 'e -let read_reg_aux reg = - let k reg_value = - let v = of_bits (internal_reg_value reg_value) in - (Done v,Nothing) in - Read_reg reg k - -let read_reg reg = - read_reg_aux (external_reg_whole reg) -let read_reg_range reg i j = - read_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger j)) -let read_reg_bit reg i = - read_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger i)) >>= fun v -> - return (extract_only_element v) -let read_reg_field reg regfield = - read_reg_aux (external_reg_field_whole reg regfield.field_name) -let read_reg_bitfield reg regfield = - read_reg_aux (external_reg_field_whole reg regfield.field_name) >>= fun v -> - return (extract_only_element v) - -let reg_deref = read_reg - -val write_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> 'a -> M unit 'e -let write_reg_aux reg_name v = - let regval = external_reg_value reg_name (bits_of v) in - Write_reg (reg_name,regval) (Done (), Nothing) - -let write_reg reg v = - write_reg_aux (external_reg_whole reg) v -let write_reg_range reg i j v = - write_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger j)) v -let write_reg_pos reg i v = - let iN = natFromInteger i in - write_reg_aux (external_reg_slice reg (iN,iN)) [v] -let write_reg_bit = write_reg_pos -let write_reg_field reg regfield v = - write_reg_aux (external_reg_field_whole reg regfield.field_name) v -(*let write_reg_field_bit reg regfield bit = - write_reg_aux (external_reg_field_whole reg regfield.field_name) - (Vector [bit] 0 (is_inc_of_reg reg))*) -let write_reg_field_range reg regfield i j v = - write_reg_aux (external_reg_field_slice reg regfield.field_name (natFromInteger i,natFromInteger j)) v -let write_reg_field_pos reg regfield i v = - write_reg_field_range reg regfield i i [v] -let write_reg_field_bit = write_reg_field_pos - -let write_reg_ref (reg, v) = write_reg reg v - -val barrier : forall 'e. barrier_kind -> M unit 'e -let barrier bk = Barrier bk (Done (), Nothing) - - -val footprint : forall 'e. M unit 'e -let footprint = Footprint (Done (),Nothing) - +open import Prompt_monad +open import {isabelle} `Prompt_monad_extras` val iter_aux : forall 'a 'e. integer -> (integer -> 'a -> M unit 'e) -> list 'a -> M unit 'e let rec iter_aux i f xs = match xs with diff --git a/src/gen_lib/prompt_monad.lem b/src/gen_lib/prompt_monad.lem new file mode 100644 index 00000000..45733caa --- /dev/null +++ b/src/gen_lib/prompt_monad.lem @@ -0,0 +1,168 @@ +open import Pervasives_extra +open import Sail_impl_base +open import Sail_values + +type M 'a 'e = outcome 'a 'e + +val return : forall 'a 'e. 'a -> M 'a 'e +let return a = Done a + +val bind : forall 'a 'b 'e. M 'a 'e -> ('a -> M 'b 'e) -> M 'b 'e +let rec bind m f = match m with + | Done a -> f a + | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (bind o f,opt)) + | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (bind o f,opt)) + | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (bind o f,opt)) + | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (bind o f,opt)) + | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (bind o f,opt)) + | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (bind o f,opt)) + | Footprint o_s -> Footprint (let (o,opt) = o_s in (bind o f,opt)) + | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (bind o f,opt)) + | Escape descr -> Escape descr + | Fail descr -> Fail descr + | Error descr -> Error descr + | Exception e -> Exception e + | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (bind o f ,opt)) +end + +let inline (>>=) = bind +val (>>) : forall 'b 'e. M unit 'e -> M 'b 'e -> M 'b 'e +let inline (>>) m n = m >>= fun (_ : unit) -> n + +val exit : forall 'a 'e. unit -> M 'a 'e +let exit () = Fail Nothing + +val assert_exp : forall 'e. bool -> string -> M unit 'e +let assert_exp exp msg = if exp then Done () else Fail (Just msg) + +val throw : forall 'a 'e. 'e -> M 'a 'e +let throw e = Exception e + +val try_catch : forall 'a 'e1 'e2. M 'a 'e1 -> ('e1 -> M 'a 'e2) -> M 'a 'e2 +let rec try_catch m h = match m with + | Done a -> Done a + | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) + | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) + | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) + | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (try_catch o h,opt)) + | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (try_catch o h,opt)) + | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (try_catch o h,opt)) + | Footprint o_s -> Footprint (let (o,opt) = o_s in (try_catch o h,opt)) + | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (try_catch o h,opt)) + | Escape descr -> Escape descr + | Fail descr -> Fail descr + | Error descr -> Error descr + | Exception e -> h e + | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (try_catch o h ,opt)) +end + +(* For early return, we abuse exceptions by throwing and catching + the return value. The exception type is "either 'r 'e", where "Right e" + represents a proper exception and "Left r" an early return of value "r". *) +type MR 'a 'r 'e = M 'a (either 'r 'e) + +val early_return : forall 'a 'r 'e. 'r -> MR 'a 'r 'e +let early_return r = throw (Left r) + +val catch_early_return : forall 'a 'e. MR 'a 'a 'e -> M 'a 'e +let catch_early_return m = + try_catch m + (function + | Left a -> return a + | Right e -> throw e + end) + +(* Lift to monad with early return by wrapping exceptions *) +val liftR : forall 'a 'r 'e. M 'a 'e -> MR 'a 'r 'e +let liftR m = try_catch m (fun e -> throw (Right e)) + +(* Catch exceptions in the presence of early returns *) +val try_catchR : forall 'a 'r 'e1 'e2. MR 'a 'r 'e1 -> ('e1 -> MR 'a 'r 'e2) -> MR 'a 'r 'e2 +let try_catchR m h = + try_catch m + (function + | Left r -> throw (Left r) + | Right e -> h e + end) + + +val read_mem : forall 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> M 'b 'e +let read_mem rk addr sz = + let addr = address_lifted_of_bitv (bits_of addr) in + let sz = natFromInteger sz in + let k memory_value = + let bitv = of_bits (internal_mem_value memory_value) in + (Done bitv,Nothing) in + Read_mem (rk,addr,sz) k + +val excl_result : forall 'e. unit -> M bool 'e +let excl_result () = + let k successful = (return successful,Nothing) in + Excl_res k + +val write_mem_ea : forall 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M unit 'e +let write_mem_ea wk addr sz = + let addr = address_lifted_of_bitv (bits_of addr) in + let sz = natFromInteger sz in + Write_ea (wk,addr,sz) (Done (),Nothing) + +val write_mem_val : forall 'a 'e. Bitvector 'a => 'a -> M bool 'e +let write_mem_val v = + let v = external_mem_value (bits_of v) in + let k successful = (return successful,Nothing) in + Write_memv v k + +val read_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> M 'a 'e +let read_reg_aux reg = + let k reg_value = + let v = of_bits (internal_reg_value reg_value) in + (Done v,Nothing) in + Read_reg reg k + +let read_reg reg = + read_reg_aux (external_reg_whole reg) +let read_reg_range reg i j = + read_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger j)) +let read_reg_bit reg i = + read_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger i)) >>= fun v -> + return (extract_only_element v) +let read_reg_field reg regfield = + read_reg_aux (external_reg_field_whole reg regfield.field_name) +let read_reg_bitfield reg regfield = + read_reg_aux (external_reg_field_whole reg regfield.field_name) >>= fun v -> + return (extract_only_element v) + +let reg_deref = read_reg + +val write_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> 'a -> M unit 'e +let write_reg_aux reg_name v = + let regval = external_reg_value reg_name (bits_of v) in + Write_reg (reg_name,regval) (Done (), Nothing) + +let write_reg reg v = + write_reg_aux (external_reg_whole reg) v +let write_reg_range reg i j v = + write_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger j)) v +let write_reg_pos reg i v = + let iN = natFromInteger i in + write_reg_aux (external_reg_slice reg (iN,iN)) [v] +let write_reg_bit = write_reg_pos +let write_reg_field reg regfield v = + write_reg_aux (external_reg_field_whole reg regfield.field_name) v +(*let write_reg_field_bit reg regfield bit = + write_reg_aux (external_reg_field_whole reg regfield.field_name) + (Vector [bit] 0 (is_inc_of_reg reg))*) +let write_reg_field_range reg regfield i j v = + write_reg_aux (external_reg_field_slice reg regfield.field_name (natFromInteger i,natFromInteger j)) v +let write_reg_field_pos reg regfield i v = + write_reg_field_range reg regfield i i [v] +let write_reg_field_bit = write_reg_field_pos + +let write_reg_ref (reg, v) = write_reg reg v + +val barrier : forall 'e. barrier_kind -> M unit 'e +let barrier bk = Barrier bk (Done (), Nothing) + + +val footprint : forall 'e. M unit 'e +let footprint = Footprint (Done (),Nothing) diff --git a/src/gen_lib/sail_operators.lem b/src/gen_lib/sail_operators.lem index 230ab84e..ada91bd0 100644 --- a/src/gen_lib/sail_operators.lem +++ b/src/gen_lib/sail_operators.lem @@ -5,53 +5,22 @@ open import Sail_values (*** Bit vector operations *) -val concat_vec : forall 'a 'b 'c. Bitvector 'a, Bitvector 'b, Bitvector 'c => 'a -> 'b -> 'c -let concat_vec l r = of_bits (bits_of l ++ bits_of r) +val concat_bv : forall 'a 'b 'c. Bitvector 'a, Bitvector 'b, Bitvector 'c => 'a -> 'b -> 'c +let concat_bv l r = of_bits (bits_of l ++ bits_of r) -val cons_vec : forall 'a 'b 'c. Bitvector 'a, Bitvector 'b => bitU -> 'a -> 'b -let cons_vec b v = of_bits (b :: bits_of v) +val cons_bv : forall 'a 'b 'c. Bitvector 'a, Bitvector 'b => bitU -> 'a -> 'b +let cons_bv b v = of_bits (b :: bits_of v) -let bool_of_vec v = extract_only_element (bits_of v) -let vec_of_bit len b = of_bits (extz_bits len [b]) -let cast_unit_vec b = of_bits [b] +let bool_of_bv v = extract_only_element (bits_of v) +let cast_unit_bv b = of_bits [b] +let bv_of_bit len b = of_bits (extz_bits len [b]) +let int_of_bv sign = if sign then signed else unsigned let most_significant v = match bits_of v with | b :: _ -> b | _ -> failwith "most_significant applied to empty vector" end -let hardware_mod (a: integer) (b:integer) : integer = - if a < 0 && b < 0 - then (abs a) mod (abs b) - else if (a < 0 && b >= 0) - then (a mod b) - b - else a mod b - -(* There are different possible answers for integer divide regarding -rounding behaviour on negative operands. Positive operands always -round down so derive the one we want (trucation towards zero) from -that *) -let hardware_quot (a:integer) (b:integer) : integer = - let q = (abs a) / (abs b) in - if ((a<0) = (b<0)) then - q (* same sign -- result positive *) - else - ~q (* different sign -- result negative *) - -let int_of_vec sign = if sign then signed else unsigned -let vec_of_int len n = of_bits (bits_of_int len n) - -let max_64u = (integerPow 2 64) - 1 -let max_64 = (integerPow 2 63) - 1 -let min_64 = 0 - (integerPow 2 63) -let max_32u = (4294967295 : integer) -let max_32 = (2147483647 : integer) -let min_32 = (0 - 2147483648 : integer) -let max_8 = (127 : integer) -let min_8 = (0 - 128 : integer) -let max_5 = (31 : integer) -let min_5 = (0 - 32 : integer) - let get_max_representable_in sign (n : integer) : integer = if (n = 64) then match sign with | true -> max_64 | false -> max_64u end else if (n=32) then match sign with | true -> max_32 | false -> max_32u end @@ -68,106 +37,106 @@ let get_min_representable_in _ (n : integer) : integer = else if n = 5 then min_5 else 0 - (integerPow 2 (natFromInteger n)) -val bitwise_binop_vec : forall 'a. Bitvector 'a => +val bitwise_binop_bv : forall 'a. Bitvector 'a => (bool -> bool -> bool) -> 'a -> 'a -> 'a -let bitwise_binop_vec op l r = of_bits (binop_bits op (bits_of l) (bits_of r)) +let bitwise_binop_bv op l r = of_bits (binop_bits op (bits_of l) (bits_of r)) -let and_vec = bitwise_binop_vec (&&) -let or_vec = bitwise_binop_vec (||) -let xor_vec = bitwise_binop_vec xor -let not_vec v = of_bits (not_bits (bits_of v)) +let and_bv = bitwise_binop_bv (&&) +let or_bv = bitwise_binop_bv (||) +let xor_bv = bitwise_binop_bv xor +let not_bv v = of_bits (not_bits (bits_of v)) -val arith_op_vec : forall 'a 'b. Bitvector 'a, Bitvector 'b => +val arith_op_bv : forall 'a 'b. Bitvector 'a, Bitvector 'b => (integer -> integer -> integer) -> bool -> integer -> 'a -> 'a -> 'b -let arith_op_vec op sign size l r = - let (l',r') = (int_of_vec sign l, int_of_vec sign r) in +let arith_op_bv op sign size l r = + let (l',r') = (int_of_bv sign l, int_of_bv sign r) in let n = op l' r' in - of_bits (bits_of_int (size * length l) n) + of_int (size * length l) n -let add_vec = arith_op_vec integerAdd false 1 -let addS_vec = arith_op_vec integerAdd true 1 -let sub_vec = arith_op_vec integerMinus false 1 -let mult_vec = arith_op_vec integerMult false 2 -let multS_vec = arith_op_vec integerMult true 2 +let add_bv = arith_op_bv integerAdd false 1 +let addS_bv = arith_op_bv integerAdd true 1 +let sub_bv = arith_op_bv integerMinus false 1 +let mult_bv = arith_op_bv integerMult false 2 +let multS_bv = arith_op_bv integerMult true 2 let inline add_mword = Machine_word.plus let inline sub_mword = Machine_word.minus val mult_mword : forall 'a 'b. Size 'b => mword 'a -> mword 'a -> mword 'b let mult_mword l r = times (zeroExtend l) (zeroExtend r) -val arith_op_vec_int : forall 'a 'b. Bitvector 'a, Bitvector 'b => +val arith_op_bv_int : forall 'a 'b. Bitvector 'a, Bitvector 'b => (integer -> integer -> integer) -> bool -> integer -> 'a -> integer -> 'b -let arith_op_vec_int op sign size l r = - let l' = int_of_vec sign l in +let arith_op_bv_int op sign size l r = + let l' = int_of_bv sign l in let n = op l' r in - of_bits (bits_of_int (size * length l) n) + of_int (size * length l) n -let add_vec_int = arith_op_vec_int integerAdd false 1 -let addS_vec_int = arith_op_vec_int integerAdd true 1 -let sub_vec_int = arith_op_vec_int integerMinus false 1 -let mult_vec_int = arith_op_vec_int integerMult false 2 -let multS_vec_int = arith_op_vec_int integerMult true 2 +let add_bv_int = arith_op_bv_int integerAdd false 1 +let addS_bv_int = arith_op_bv_int integerAdd true 1 +let sub_bv_int = arith_op_bv_int integerMinus false 1 +let mult_bv_int = arith_op_bv_int integerMult false 2 +let multS_bv_int = arith_op_bv_int integerMult true 2 -val arith_op_int_vec : forall 'a 'b. Bitvector 'a, Bitvector 'b => +val arith_op_int_bv : forall 'a 'b. Bitvector 'a, Bitvector 'b => (integer -> integer -> integer) -> bool -> integer -> integer -> 'a -> 'b -let arith_op_int_vec op sign size l r = - let r' = int_of_vec sign r in +let arith_op_int_bv op sign size l r = + let r' = int_of_bv sign r in let n = op l r' in - of_bits (bits_of_int (size * length r) n) + of_int (size * length r) n -let add_int_vec = arith_op_int_vec integerAdd false 1 -let addS_int_vec = arith_op_int_vec integerAdd true 1 -let sub_int_vec = arith_op_int_vec integerMinus false 1 -let mult_int_vec = arith_op_int_vec integerMult false 2 -let multS_int_vec = arith_op_int_vec integerMult true 2 +let add_int_bv = arith_op_int_bv integerAdd false 1 +let addS_int_bv = arith_op_int_bv integerAdd true 1 +let sub_int_bv = arith_op_int_bv integerMinus false 1 +let mult_int_bv = arith_op_int_bv integerMult false 2 +let multS_int_bv = arith_op_int_bv integerMult true 2 -let arith_op_vec_bit op sign (size : integer) l r = - let l' = int_of_vec sign l in +let arith_op_bv_bit op sign (size : integer) l r = + let l' = int_of_bv sign l in let n = op l' (match r with | B1 -> (1 : integer) | _ -> 0 end) in - of_bits (bits_of_int (size * length l) n) + of_int (size * length l) n -let add_vec_bit = arith_op_vec_bit integerAdd false 1 -let addS_vec_bit = arith_op_vec_bit integerAdd true 1 -let sub_vec_bit = arith_op_vec_bit integerMinus true 1 +let add_bv_bit = arith_op_bv_bit integerAdd false 1 +let addS_bv_bit = arith_op_bv_bit integerAdd true 1 +let sub_bv_bit = arith_op_bv_bit integerMinus true 1 -val arith_op_overflow_vec : forall 'a 'b. Bitvector 'a, Bitvector 'b => +val arith_op_overflow_bv : forall 'a 'b. Bitvector 'a, Bitvector 'b => (integer -> integer -> integer) -> bool -> integer -> 'a -> 'a -> ('b * bitU * bitU) -let arith_op_overflow_vec op sign size l r = +let arith_op_overflow_bv op sign size l r = let len = length l in let act_size = len * size in - let (l_sign,r_sign) = (int_of_vec sign l,int_of_vec sign r) in - let (l_unsign,r_unsign) = (int_of_vec false l,int_of_vec false r) in + let (l_sign,r_sign) = (int_of_bv sign l,int_of_bv sign r) in + let (l_unsign,r_unsign) = (int_of_bv false l,int_of_bv false r) in let n = op l_sign r_sign in let n_unsign = op l_unsign r_unsign in - let correct_size = bits_of_int act_size n in + let correct_size = of_int act_size n in let one_more_size_u = bits_of_int (act_size + 1) n_unsign in let overflow = if n <= get_max_representable_in sign len && n >= get_min_representable_in sign len then B0 else B1 in let c_out = most_significant one_more_size_u in - (of_bits correct_size,overflow,c_out) + (correct_size,overflow,c_out) -let add_overflow_vec = arith_op_overflow_vec integerAdd false 1 -let add_overflow_vec_signed = arith_op_overflow_vec integerAdd true 1 -let sub_overflow_vec = arith_op_overflow_vec integerMinus false 1 -let sub_overflow_vec_signed = arith_op_overflow_vec integerMinus true 1 -let mult_overflow_vec = arith_op_overflow_vec integerMult false 2 -let mult_overflow_vec_signed = arith_op_overflow_vec integerMult true 2 +let add_overflow_bv = arith_op_overflow_bv integerAdd false 1 +let add_overflow_bv_signed = arith_op_overflow_bv integerAdd true 1 +let sub_overflow_bv = arith_op_overflow_bv integerMinus false 1 +let sub_overflow_bv_signed = arith_op_overflow_bv integerMinus true 1 +let mult_overflow_bv = arith_op_overflow_bv integerMult false 2 +let mult_overflow_bv_signed = arith_op_overflow_bv integerMult true 2 -val arith_op_overflow_vec_bit : forall 'a 'b. Bitvector 'a, Bitvector 'b => +val arith_op_overflow_bv_bit : forall 'a 'b. Bitvector 'a, Bitvector 'b => (integer -> integer -> integer) -> bool -> integer -> 'a -> bitU -> ('b * bitU * bitU) -let arith_op_overflow_vec_bit op sign size l r_bit = +let arith_op_overflow_bv_bit op sign size l r_bit = let act_size = length l * size in - let l' = int_of_vec sign l in - let l_u = int_of_vec false l in + let l' = int_of_bv sign l in + let l_u = int_of_bv false l in let (n,nu,changed) = match r_bit with | B1 -> (op l' 1, op l_u 1, true) | B0 -> (l',l_u,false) - | BU -> failwith "arith_op_overflow_vec_bit applied to undefined bit" + | BU -> failwith "arith_op_overflow_bv_bit applied to undefined bit" end in - let correct_size = bits_of_int act_size n in + let correct_size = of_int act_size n in let one_larger = bits_of_int (act_size + 1) nu in let overflow = if changed @@ -175,32 +144,35 @@ let arith_op_overflow_vec_bit op sign size l r_bit = if n <= get_max_representable_in sign act_size && n >= get_min_representable_in sign act_size then B0 else B1 else B0 in - (of_bits correct_size,overflow,most_significant one_larger) + (correct_size,overflow,most_significant one_larger) -let add_overflow_vec_bit = arith_op_overflow_vec_bit integerAdd false 1 -let add_overflow_vec_bit_signed = arith_op_overflow_vec_bit integerAdd true 1 -let sub_overflow_vec_bit = arith_op_overflow_vec_bit integerMinus false 1 -let sub_overflow_vec_bit_signed = arith_op_overflow_vec_bit integerMinus true 1 +let add_overflow_bv_bit = arith_op_overflow_bv_bit integerAdd false 1 +let add_overflow_bv_bit_signed = arith_op_overflow_bv_bit integerAdd true 1 +let sub_overflow_bv_bit = arith_op_overflow_bv_bit integerMinus false 1 +let sub_overflow_bv_bit_signed = arith_op_overflow_bv_bit integerMinus true 1 -type shift = LL_shift | RR_shift | LL_rot | RR_rot +type shift = LL_shift | RR_shift | RR_shift_arith | LL_rot | RR_rot -val shift_op_vec : forall 'a. Bitvector 'a => shift -> 'a -> integer -> 'a -let shift_op_vec op v n = +val shift_op_bv : forall 'a. Bitvector 'a => shift -> 'a -> integer -> 'a +let shift_op_bv op v n = match op with | LL_shift -> of_bits (get_bits true v n (length v - 1) ++ repeat [B0] n) | RR_shift -> of_bits (repeat [B0] n ++ get_bits true v 0 (length v - n - 1)) + | RR_shift_arith -> + of_bits (repeat [most_significant v] n ++ get_bits true v 0 (length v - n - 1)) | LL_rot -> of_bits (get_bits true v n (length v - 1) ++ get_bits true v 0 (n - 1)) | RR_rot -> of_bits (get_bits false v 0 (n - 1) ++ get_bits false v n (length v - 1)) end -let shiftl = shift_op_vec LL_shift (*"<<"*) -let shiftr = shift_op_vec RR_shift (*">>"*) -let rotl = shift_op_vec LL_rot (*"<<<"*) -let rotr = shift_op_vec LL_rot (*">>>"*) +let shiftl_bv = shift_op_bv LL_shift (*"<<"*) +let shiftr_bv = shift_op_bv RR_shift (*">>"*) +let arith_shiftr_bv = shift_op_bv RR_shift_arith +let rotl_bv = shift_op_bv LL_rot (*"<<<"*) +let rotr_bv = shift_op_bv LL_rot (*">>>"*) let shiftl_mword w n = Machine_word.shiftLeft w (natFromInteger n) let shiftr_mword w n = Machine_word.shiftRight w (natFromInteger n) @@ -212,11 +184,11 @@ let rec arith_op_no0 (op : integer -> integer -> integer) l r = then Nothing else Just (op l r) -val arith_op_vec_no0 : forall 'a 'b. Bitvector 'a, Bitvector 'b => +val arith_op_bv_no0 : forall 'a 'b. Bitvector 'a, Bitvector 'b => (integer -> integer -> integer) -> bool -> integer -> 'a -> 'a -> 'b -let arith_op_vec_no0 op sign size l r = +let arith_op_bv_no0 op sign size l r = let act_size = length l * size in - let (l',r') = (int_of_vec sign l,int_of_vec sign r) in + let (l',r') = (int_of_bv sign l,int_of_bv sign r) in let n = arith_op_no0 op l' r' in let (representable,n') = match n with @@ -225,80 +197,42 @@ let arith_op_vec_no0 op sign size l r = n' >= get_min_representable_in sign act_size, n') | _ -> (false,0) end in - of_bits (if representable then bits_of_int act_size n' else repeat [BU] act_size) + if representable then (of_int act_size n') else (of_bits (repeat [BU] act_size)) -let mod_vec = arith_op_vec_no0 hardware_mod false 1 -let quot_vec = arith_op_vec_no0 hardware_quot false 1 -let quot_vec_signed = arith_op_vec_no0 hardware_quot true 1 +let mod_bv = arith_op_bv_no0 hardware_mod false 1 +let quot_bv = arith_op_bv_no0 hardware_quot false 1 +let quot_bv_signed = arith_op_bv_no0 hardware_quot true 1 let mod_mword = Machine_word.modulo let quot_mword = Machine_word.unsignedDivide let quot_mword_signed = Machine_word.signedDivide -val arith_op_overflow_vec_no0 : forall 'a 'b. Bitvector 'a, Bitvector 'b => - (integer -> integer -> integer) -> bool -> integer -> 'a -> 'a -> ('b * bitU * bitU) -let arith_op_overflow_vec_no0 op sign size l r = - let rep_size = length r * size in - let act_size = length l * size in - let (l',r') = (int_of_vec sign l,int_of_vec sign r) in - let (l_u,r_u) = (int_of_vec false l,int_of_vec false r) in - let n = arith_op_no0 op l' r' in - let n_u = arith_op_no0 op l_u r_u in - let (representable,n',n_u') = - match (n, n_u) with - | (Just n',Just n_u') -> - ((n' <= get_max_representable_in sign rep_size && - n' >= (get_min_representable_in sign rep_size)), n', n_u') - | _ -> (true,0,0) - end in - let (correct_size,one_more) = - if representable then - (bits_of_int act_size n', bits_of_int (act_size + 1) n_u') - else - (repeat [BU] act_size, repeat [BU] (act_size + 1)) in - let overflow = if representable then B0 else B1 in - (of_bits correct_size,overflow,most_significant one_more) - -let quot_overflow_vec = arith_op_overflow_vec_no0 hardware_quot false 1 -let quot_overflow_vec_signed = arith_op_overflow_vec_no0 hardware_quot true 1 - -let arith_op_vec_int_no0 op sign size l r = - arith_op_vec_no0 op sign size l (of_bits (bits_of_int (length l) r)) - -let quot_vec_int = arith_op_vec_int_no0 hardware_quot false 1 -let mod_vec_int = arith_op_vec_int_no0 hardware_mod false 1 - -let replicate_bits v count = of_bits (repeat v count) -let duplicate bit len = replicate_bits [bit] len - -let lt = (<) -let gt = (>) -let lteq = (<=) -let gteq = (>=) +let arith_op_bv_int_no0 op sign size l r = + arith_op_bv_no0 op sign size l (of_int (length l) r) -val eq : forall 'a. Eq 'a => 'a -> 'a -> bool -let eq l r = (l = r) +let quot_bv_int = arith_op_bv_int_no0 hardware_quot false 1 +let mod_bv_int = arith_op_bv_int_no0 hardware_mod false 1 -val eq_vec : forall 'a. Bitvector 'a => 'a -> 'a -> bool -let eq_vec l r = (unsigned l = unsigned r) +let replicate_bits_bv v count = of_bits (repeat (bits_of v) count) +let duplicate_bit_bv bit len = replicate_bits_bv [bit] len -val neq : forall 'a. Eq 'a => 'a -> 'a -> bool -let neq l r = (l <> r) +val eq_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool +let eq_bv l r = (unsigned l = unsigned r) -val neq_vec : forall 'a. Bitvector 'a => 'a -> 'a -> bool -let neq_vec l r = (unsigned l <> unsigned r) +val neq_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool +let neq_bv l r = (unsigned l <> unsigned r) -val ucmp_vec : forall 'a. Bitvector 'a => (integer -> integer -> bool) -> 'a -> 'a -> bool -let ucmp_vec cmp l r = cmp (unsigned l) (unsigned r) +val ucmp_bv : forall 'a. Bitvector 'a => (integer -> integer -> bool) -> 'a -> 'a -> bool +let ucmp_bv cmp l r = cmp (unsigned l) (unsigned r) -val scmp_vec : forall 'a. Bitvector 'a => (integer -> integer -> bool) -> 'a -> 'a -> bool -let scmp_vec cmp l r = cmp (signed l) (signed r) +val scmp_bv : forall 'a. Bitvector 'a => (integer -> integer -> bool) -> 'a -> 'a -> bool +let scmp_bv cmp l r = cmp (signed l) (signed r) -let ult_vec = ucmp_vec (<) -let slt_vec = scmp_vec (<) -let ugt_vec = ucmp_vec (>) -let sgt_vec = scmp_vec (>) -let ulteq_vec = ucmp_vec (<=) -let slteq_vec = scmp_vec (<=) -let ugteq_vec = ucmp_vec (>=) -let sgteq_vec = scmp_vec (>=) +let ult_bv = ucmp_bv (<) +let slt_bv = scmp_bv (<) +let ugt_bv = ucmp_bv (>) +let sgt_bv = scmp_bv (>) +let ulteq_bv = ucmp_bv (<=) +let slteq_bv = scmp_bv (<=) +let ugteq_bv = ucmp_bv (>=) +let sgteq_bv = scmp_bv (>=) diff --git a/src/gen_lib/sail_operators_bitlists.lem b/src/gen_lib/sail_operators_bitlists.lem new file mode 100644 index 00000000..374628a4 --- /dev/null +++ b/src/gen_lib/sail_operators_bitlists.lem @@ -0,0 +1,179 @@ +open import Pervasives_extra +open import Machine_word +open import Sail_impl_base +open import Sail_values +open import Sail_operators + +(* Specialisation of operators to bit lists *) + +val access_vec_inc : list bitU -> integer -> bitU +let access_vec_inc = access_bv_inc + +val access_vec_dec : list bitU -> integer -> bitU +let access_vec_dec = access_bv_dec + +val update_vec_inc : list bitU -> integer -> bitU -> list bitU +let update_vec_inc = update_bv_inc + +val update_vec_dec : list bitU -> integer -> bitU -> list bitU +let update_vec_dec = update_bv_dec + +val subrange_vec_inc : list bitU -> integer -> integer -> list bitU +let subrange_vec_inc = subrange_bv_inc + +val subrange_vec_dec : list bitU -> integer -> integer -> list bitU +let subrange_vec_dec = subrange_bv_dec + +val update_subrange_vec_inc : list bitU -> integer -> integer -> list bitU -> list bitU +let update_subrange_vec_inc = update_subrange_bv_inc + +val update_subrange_vec_dec : list bitU -> integer -> integer -> list bitU -> list bitU +let update_subrange_vec_dec = update_subrange_bv_dec + +val extz_vec : integer -> list bitU -> list bitU +let extz_vec = extz_bv + +val exts_vec : integer -> list bitU -> list bitU +let exts_vec = exts_bv + +val concat_vec : list bitU -> list bitU -> list bitU +let concat_vec = concat_bv + +val cons_vec : bitU -> list bitU -> list bitU +let cons_vec = cons_bv + +val bool_of_vec : mword ty1 -> bitU +let bool_of_vec = bool_of_bv + +val cast_unit_vec : bitU -> mword ty1 +let cast_unit_vec = cast_unit_bv + +val vec_of_bit : integer -> bitU -> list bitU +let vec_of_bit = bv_of_bit + +val msb : list bitU -> bitU +let msb = most_significant + +val int_of_vec : bool -> list bitU -> integer +let int_of_vec = int_of_bv + +val and_vec : list bitU -> list bitU -> list bitU +val or_vec : list bitU -> list bitU -> list bitU +val xor_vec : list bitU -> list bitU -> list bitU +val not_vec : list bitU -> list bitU +let and_vec = and_bv +let or_vec = or_bv +let xor_vec = xor_bv +let not_vec = not_bv + +val add_vec : list bitU -> list bitU -> list bitU +val addS_vec : list bitU -> list bitU -> list bitU +val sub_vec : list bitU -> list bitU -> list bitU +val mult_vec : list bitU -> list bitU -> list bitU +val multS_vec : list bitU -> list bitU -> list bitU +let add_vec = add_bv +let addS_vec = addS_bv +let sub_vec = sub_bv +let mult_vec = mult_bv +let multS_vec = multS_bv + +val add_vec_int : list bitU -> integer -> list bitU +val addS_vec_int : list bitU -> integer -> list bitU +val sub_vec_int : list bitU -> integer -> list bitU +val mult_vec_int : list bitU -> integer -> list bitU +val multS_vec_int : list bitU -> integer -> list bitU +let add_vec_int = add_bv_int +let addS_vec_int = addS_bv_int +let sub_vec_int = sub_bv_int +let mult_vec_int = mult_bv_int +let multS_vec_int = multS_bv_int + +val add_int_vec : integer -> list bitU -> list bitU +val addS_int_vec : integer -> list bitU -> list bitU +val sub_int_vec : integer -> list bitU -> list bitU +val mult_int_vec : integer -> list bitU -> list bitU +val multS_int_vec : integer -> list bitU -> list bitU +let add_int_vec = add_int_bv +let addS_int_vec = addS_int_bv +let sub_int_vec = sub_int_bv +let mult_int_vec = mult_int_bv +let multS_int_vec = multS_int_bv + +val add_vec_bit : list bitU -> bitU -> list bitU +val addS_vec_bit : list bitU -> bitU -> list bitU +val sub_vec_bit : list bitU -> bitU -> list bitU +let add_vec_bit = add_bv_bit +let addS_vec_bit = addS_bv_bit +let sub_vec_bit = sub_bv_bit + +val add_overflow_vec : list bitU -> list bitU -> (list bitU * bitU * bitU) +val add_overflow_vec_signed : list bitU -> list bitU -> (list bitU * bitU * bitU) +val sub_overflow_vec : list bitU -> list bitU -> (list bitU * bitU * bitU) +val sub_overflow_vec_signed : list bitU -> list bitU -> (list bitU * bitU * bitU) +val mult_overflow_vec : list bitU -> list bitU -> (list bitU * bitU * bitU) +val mult_overflow_vec_signed : list bitU -> list bitU -> (list bitU * bitU * bitU) +let add_overflow_vec = add_overflow_bv +let add_overflow_vec_signed = add_overflow_bv_signed +let sub_overflow_vec = sub_overflow_bv +let sub_overflow_vec_signed = sub_overflow_bv_signed +let mult_overflow_vec = mult_overflow_bv +let mult_overflow_vec_signed = mult_overflow_bv_signed + +val add_overflow_vec_bit : list bitU -> bitU -> (list bitU * bitU * bitU) +val add_overflow_vec_bit_signed : list bitU -> bitU -> (list bitU * bitU * bitU) +val sub_overflow_vec_bit : list bitU -> bitU -> (list bitU * bitU * bitU) +val sub_overflow_vec_bit_signed : list bitU -> bitU -> (list bitU * bitU * bitU) +let add_overflow_vec_bit = add_overflow_bv_bit +let add_overflow_vec_bit_signed = add_overflow_bv_bit_signed +let sub_overflow_vec_bit = sub_overflow_bv_bit +let sub_overflow_vec_bit_signed = sub_overflow_bv_bit_signed + +val shiftl : list bitU -> integer -> list bitU +val shiftr : list bitU -> integer -> list bitU +val arith_shiftr : list bitU -> integer -> list bitU +val rotl : list bitU -> integer -> list bitU +val rotr : list bitU -> integer -> list bitU +let shiftl = shiftl_bv +let shiftr = shiftr_bv +let arith_shiftr = arith_shiftr_bv +let rotl = rotl_bv +let rotr = rotr_bv + +val mod_vec : list bitU -> list bitU -> list bitU +val quot_vec : list bitU -> list bitU -> list bitU +val quot_vec_signed : list bitU -> list bitU -> list bitU +let mod_vec = mod_bv +let quot_vec = quot_bv +let quot_vec_signed = quot_bv_signed + +val mod_vec_int : list bitU -> integer -> list bitU +val quot_vec_int : list bitU -> integer -> list bitU +let mod_vec_int = mod_bv_int +let quot_vec_int = quot_bv_int + +val replicate_bits : list bitU -> integer -> list bitU +let replicate_bits = replicate_bits_bv + +val duplicate : bitU -> integer -> list bitU +let duplicate = duplicate_bit_bv + +val eq_vec : list bitU -> list bitU -> bool +val neq_vec : list bitU -> list bitU -> bool +val ult_vec : list bitU -> list bitU -> bool +val slt_vec : list bitU -> list bitU -> bool +val ugt_vec : list bitU -> list bitU -> bool +val sgt_vec : list bitU -> list bitU -> bool +val ulteq_vec : list bitU -> list bitU -> bool +val slteq_vec : list bitU -> list bitU -> bool +val ugteq_vec : list bitU -> list bitU -> bool +val sgteq_vec : list bitU -> list bitU -> bool +let eq_vec = eq_bv +let neq_vec = neq_bv +let ult_vec = ult_bv +let slt_vec = slt_bv +let ugt_vec = ugt_bv +let sgt_vec = sgt_bv +let ulteq_vec = ulteq_bv +let slteq_vec = slteq_bv +let ugteq_vec = ugteq_bv +let sgteq_vec = sgteq_bv diff --git a/src/gen_lib/sail_operators_mwords.lem b/src/gen_lib/sail_operators_mwords.lem index ff25c37b..7fa09b9b 100644 --- a/src/gen_lib/sail_operators_mwords.lem +++ b/src/gen_lib/sail_operators_mwords.lem @@ -2,600 +2,178 @@ open import Pervasives_extra open import Machine_word open import Sail_impl_base open import Sail_values - -(* Translating between a type level number (itself 'n) and an integer *) - -let size_itself_int x = integerFromNat (size_itself x) - -(* NB: the corresponding sail type is forall 'n. atom('n) -> itself('n), - the actual integer is ignored. *) - -val make_the_value : forall 'n. integer -> itself 'n -let inline make_the_value x = the_value - -(*** Bit vector operations *) - -let bitvector_length bs = integerFromNat (word_length bs) - -(*val set_bitvector_start : forall 'a. (integer * bitvector 'a) -> bitvector 'a -let set_bitvector_start (new_start, Bitvector bs _ is_inc) = - Bitvector bs new_start is_inc - -let reset_bitvector_start v = - set_bitvector_start (if (bvget_dir v) then 0 else (bvlength v - 1), v) - -let set_bitvector_start_to_length v = - set_bitvector_start (bvlength v - 1, v) - -let bitvector_concat (Bitvector bs start is_inc, Bitvector bs' _ _) = - Bitvector (word_concat bs bs') start is_inc*) - -let bitvector_concat (bs, bs') = word_concat bs bs' - -let inline (^^^) = bitvector_concat - -val bvslice : forall 'a 'b. Size 'a => bool -> integer -> bitvector 'a -> integer -> integer -> bitvector 'b -let bvslice is_inc start bs i j = - let iN = natFromInteger i in - let jN = natFromInteger j in - let startN = natFromInteger start in - let top = word_length bs - 1 in - let (hi,lo) = if is_inc then (top+startN-iN,top+startN-jN) else (top-startN+iN,top-startN+jN) in - word_extract lo hi bs - -let bitvector_subrange_inc (start, v, i, j) = bvslice true start v i j -let bitvector_subrange_dec (start, v, i, j) = bvslice false start v i j - -let vector_subrange_bl_dec (start, v, i, j) = - let v' = slice (bvec_to_vec false start v) i j in - get_elems v' - -(* this is for the vector slicing introduced in vector-concat patterns: i and j -index into the "raw data", the list of bits. Therefore getting the bit list is -easy, but the start index has to be transformed to match the old vector start -and the direction. *) -val bvslice_raw : forall 'a 'b. Size 'b => bitvector 'a -> integer -> integer -> bitvector 'b -let bvslice_raw bs i j = - let iN = natFromInteger i in - let jN = natFromInteger j in - (*let bits =*) word_extract iN jN bs (*in - let len = integerFromNat (word_length bits) in - Bitvector bits (if is_inc then 0 else len - 1) is_inc*) - -val bvupdate_aux : forall 'a 'b. Size 'a => bool -> integer -> bitvector 'a -> integer -> integer -> list bitU -> bitvector 'a -let bvupdate_aux is_inc start bs i j bs' = - let bits = update_aux is_inc start (List.map to_bitU (bitlistFromWord bs)) i j bs' in - wordFromBitlist (List.map of_bitU bits) - (*let iN = natFromInteger i in - let jN = natFromInteger j in - let startN = natFromInteger start in - let top = word_length bs - 1 in - let (hi,lo) = if is_inc then (top+startN-iN,top+startN-jN) else (top-startN+iN,top-startN+jN) in - word_update bs lo hi bs'*) - -val bvupdate : forall 'a 'b. Size 'a => bool -> integer -> bitvector 'a -> integer -> integer -> bitvector 'b -> bitvector 'a -let bvupdate is_inc start bs i j bs' = - bvupdate_aux is_inc start bs i j (List.map to_bitU (bitlistFromWord bs')) - -val bvaccess : forall 'a. Size 'a => bool -> integer -> bitvector 'a -> integer -> bitU -let bvaccess is_inc start bs n = bool_to_bitU ( - let top = integerFromNat (word_length bs) - 1 in - if is_inc then getBit bs (natFromInteger (top + start - n)) - else getBit bs (natFromInteger (top + n - start))) - -val bvupdate_pos : forall 'a. Size 'a => bool -> integer -> bitvector 'a -> integer -> bitU -> bitvector 'a -let bvupdate_pos is_inc start v n b = - bvupdate_aux is_inc start v n n [b] - -let bitvector_access_inc (start, v, i) = bvaccess true start v i -let bitvector_access_dec (start, v, i) = bvaccess false start v i -let bitvector_update_pos_dec (start, v, i, b) = bvupdate_pos false start v i b -let bitvector_update_subrange_dec (start, v, i, j, v') = bvupdate false start v i j v' - -val extract_only_bit : bitvector ty1 -> bitU -let extract_only_bit elems = - let l = word_length elems in - if l = 1 then - bool_to_bitU (msb elems) - else if l = 0 then - failwith "extract_single_bit called for empty vector" - else - failwith "extract_single_bit called for vector with more bits" - - -let norm_dec v = v (*reset_bitvector_start*) -let adjust_start_index (start, v) = v (*set_bitvector_start (start, v)*) - -let cast_vec_bool v = bitU_to_bool (extract_only_bit v) -let cast_bit_vec_basic (start, len, b) = vec_to_bvec (Vector [b] start false) -let cast_boolvec_bitvec (Vector bs start inc) = - vec_to_bvec (Vector (List.map bool_to_bitU bs) start inc) -let cast_vec_bl v = List.map bool_to_bitU (bitlistFromWord v) -let cast_int_vec n = wordFromInteger n -let cast_bl_vec (start, len, bs) = wordFromBitlist (List.map bitU_to_bool bs) -let cast_bl_svec (start, len, bs) = cast_int_vec (bitlist_to_signed bs) - -let pp_bitu_vector (Vector elems start inc) = - let elems_pp = List.foldl (fun acc elem -> acc ^ showBitU elem) "" elems in - "Vector [" ^ elems_pp ^ "] " ^ show start ^ " " ^ show inc - - -let most_significant v = - if word_length v = 0 then - failwith "most_significant applied to empty vector" - else - bool_to_bitU (msb v) - -let bitwise_not_bitlist = List.map bitwise_not_bit - -let bitwise_not bs = lNot bs - -let bitwise_binop op (bsl, bsr) = (op bsl bsr) - -let bitwise_and x = bitwise_binop lAnd x -let bitwise_or x = bitwise_binop lOr x -let bitwise_xor x = bitwise_binop lXor x - -(*let unsigned bs : integer = unsignedIntegerFromWord bs*) -let unsigned_big = unsigned - -let signed v : integer = signedIntegerFromWord v - -let hardware_mod (a: integer) (b:integer) : integer = - if a < 0 && b < 0 - then (abs a) mod (abs b) - else if (a < 0 && b >= 0) - then (a mod b) - b - else a mod b - -(* There are different possible answers for integer divide regarding -rounding behaviour on negative operands. Positive operands always -round down so derive the one we want (trucation towards zero) from -that *) -let hardware_quot (a:integer) (b:integer) : integer = - let q = (abs a) / (abs b) in - if ((a<0) = (b<0)) then - q (* same sign -- result positive *) - else - ~q (* different sign -- result negative *) - -let quot_signed = hardware_quot - - -let signed_big = signed - -let to_num sign = if sign then signed else unsigned - -let max_64u = (integerPow 2 64) - 1 -let max_64 = (integerPow 2 63) - 1 -let min_64 = 0 - (integerPow 2 63) -let max_32u = (4294967295 : integer) -let max_32 = (2147483647 : integer) -let min_32 = (0 - 2147483648 : integer) -let max_8 = (127 : integer) -let min_8 = (0 - 128 : integer) -let max_5 = (31 : integer) -let min_5 = (0 - 32 : integer) - -let get_max_representable_in sign (n : integer) : integer = - if (n = 64) then match sign with | true -> max_64 | false -> max_64u end - else if (n=32) then match sign with | true -> max_32 | false -> max_32u end - else if (n=8) then max_8 - else if (n=5) then max_5 - else match sign with | true -> integerPow 2 ((natFromInteger n) -1) - | false -> integerPow 2 (natFromInteger n) - end - -let get_min_representable_in _ (n : integer) : integer = - if n = 64 then min_64 - else if n = 32 then min_32 - else if n = 8 then min_8 - else if n = 5 then min_5 - else 0 - (integerPow 2 (natFromInteger n)) - -val to_bin_aux : natural -> list bitU -let rec to_bin_aux x = - if x = 0 then [] - else (if x mod 2 = 1 then B1 else B0) :: to_bin_aux (x / 2) -let to_bin n = List.reverse (to_bin_aux n) - -val pad_zero : list bitU -> integer -> list bitU -let rec pad_zero bits n = - if n = 0 then bits else pad_zero (B0 :: bits) (n -1) - - -let rec add_one_bit_ignore_overflow_aux bits = match bits with - | [] -> [] - | B0 :: bits -> B1 :: bits - | B1 :: bits -> B0 :: add_one_bit_ignore_overflow_aux bits - | BU :: _ -> failwith "add_one_bit_ignore_overflow: undefined bit" -end - -let add_one_bit_ignore_overflow bits = - List.reverse (add_one_bit_ignore_overflow_aux (List.reverse bits)) - -val to_norm_vec : forall 'a. Size 'a => integer -> bitvector 'a -let to_norm_vec (n : integer) = wordFromInteger n -(* - (* Bitvector length is determined by return type *) - let bits = wordFromInteger n in - let len = integerFromNat (word_length bits) in - let start = if is_inc then 0 else len - 1 in - (*if integerFromNat (word_length bits) = len then*) - Bitvector bits start is_inc - (*else - failwith "Vector length mismatch in to_vec"*) -*) - -let to_vec_big = to_norm_vec - -let to_vec_inc (start, len, n) = to_norm_vec n -let to_vec_norm_inc (len, n) = to_norm_vec n -let to_vec_dec (start, len, n) = to_norm_vec n -let to_vec_norm_dec (len, n) = to_norm_vec n - -(* TODO: Think about undefined bit(vector)s *) -let to_vec_undef is_inc (len : integer) = - (* Bitvector *) - (failwith "undefined bitvector") - (* (if is_inc then 0 else len-1) is_inc *) - -let to_vec_inc_undef = to_vec_undef true -let to_vec_dec_undef = to_vec_undef false - -let exts (start, len, vec) = to_norm_vec (signed vec) -val extz : forall 'a 'b. Size 'a, Size 'b => (integer * integer * bitvector 'a) -> bitvector 'b -let extz (start, len, vec) = to_norm_vec (unsigned vec) - -let exts_big (start, len, vec) = to_vec_big (signed_big vec) -let extz_big (start, len, vec) = to_vec_big (unsigned_big vec) - -let quot = hardware_quot -let modulo (l,r) = hardware_mod l r - -(* TODO: this, and the definitions that use it, currently require Size for - to_vec, which I'd rather avoid in favour of library versions; the - double-size results for multiplication may be a problem *) -let arith_op_vec op sign (size : integer) l r = - let (l',r') = (to_num sign l, to_num sign r) in - let n = op l' r' in - to_norm_vec n - - -(* add_vec - * add_vec_signed - * minus_vec - * multiply_vec - * multiply_vec_signed - *) -let add_VVV = arith_op_vec integerAdd false 1 -let addS_VVV = arith_op_vec integerAdd true 1 -let minus_VVV = arith_op_vec integerMinus false 1 -let mult_VVV = arith_op_vec integerMult false 2 -let multS_VVV = arith_op_vec integerMult true 2 - -let mult_vec (l, r) = mult_VVV l r -let mult_svec (l, r) = multS_VVV l r - -let add_vec (l, r) = add_VVV l r -let sub_vec (l, r) = minus_VVV l r - -val arith_op_vec_range : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> bitvector 'a -> integer -> bitvector 'b -let arith_op_vec_range op sign size l r = - arith_op_vec op sign size l ((to_norm_vec r) : bitvector 'a) - -(* add_vec_range - * add_vec_range_signed - * minus_vec_range - * mult_vec_range - * mult_vec_range_signed - *) -let add_VIV = arith_op_vec_range integerAdd false 1 -let addS_VIV = arith_op_vec_range integerAdd true 1 -let minus_VIV = arith_op_vec_range integerMinus false 1 -let mult_VIV = arith_op_vec_range integerMult false 2 -let multS_VIV = arith_op_vec_range integerMult true 2 - -let add_vec_int (l, r) = add_VIV l r -let sub_vec_int (l, r) = minus_VIV l r - -val arith_op_range_vec : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> integer -> bitvector 'a -> bitvector 'b -let arith_op_range_vec op sign size l r = - arith_op_vec op sign size ((to_norm_vec l) : bitvector 'a) r - -(* add_range_vec - * add_range_vec_signed - * minus_range_vec - * mult_range_vec - * mult_range_vec_signed - *) -let add_IVV = arith_op_range_vec integerAdd false 1 -let addS_IVV = arith_op_range_vec integerAdd true 1 -let minus_IVV = arith_op_range_vec integerMinus false 1 -let mult_IVV = arith_op_range_vec integerMult false 2 -let multS_IVV = arith_op_range_vec integerMult true 2 - -let arith_op_range_vec_range op sign l r = op l (to_num sign r) - -(* add_range_vec_range - * add_range_vec_range_signed - * minus_range_vec_range - *) -let add_IVI x = arith_op_range_vec_range integerAdd false x -let addS_IVI x = arith_op_range_vec_range integerAdd true x -let minus_IVI x = arith_op_range_vec_range integerMinus false x - -let arith_op_vec_range_range op sign l r = op (to_num sign l) r - -(* add_vec_range_range - * add_vec_range_range_signed - * minus_vec_range_range - *) -let add_VII x = arith_op_vec_range_range integerAdd false x -let addS_VII x = arith_op_vec_range_range integerAdd true x -let minus_VII x = arith_op_vec_range_range integerMinus false x - - - -let arith_op_vec_vec_range op sign l r = - let (l',r') = (to_num sign l,to_num sign r) in - op l' r' - -(* add_vec_vec_range - * add_vec_vec_range_signed - *) -let add_VVI x = arith_op_vec_vec_range integerAdd false x -let addS_VVI x = arith_op_vec_vec_range integerAdd true x - -let arith_op_vec_bit op sign (size : integer) l r = - let l' = to_num sign l in - let n = op l' (match r with | B1 -> (1 : integer) | _ -> 0 end) in - to_norm_vec n - -(* add_vec_bit - * add_vec_bit_signed - * minus_vec_bit_signed - *) -let add_VBV x = arith_op_vec_bit integerAdd false 1 x -let addS_VBV x = arith_op_vec_bit integerAdd true 1 x -let minus_VBV x = arith_op_vec_bit integerMinus true 1 x - -(* TODO: these can't be done directly in Lem because of the one_more size calculation -val arith_op_overflow_vec : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> bitvector 'a -> bitvector 'a -> bitvector 'b * bitU * bool -let rec arith_op_overflow_vec op sign size (Bitvector _ _ is_inc as l) r = - let len = bvlength l in - let act_size = len * size in - let (l_sign,r_sign) = (to_num sign l,to_num sign r) in - let (l_unsign,r_unsign) = (to_num false l,to_num false r) in - let n = op l_sign r_sign in - let n_unsign = op l_unsign r_unsign in - let correct_size_num = to_vec_ord is_inc (act_size,n) in - let one_more_size_u = to_vec_ord is_inc (act_size + 1,n_unsign) in - let overflow = - if n <= get_max_representable_in sign len && - n >= get_min_representable_in sign len - then B0 else B1 in - let c_out = most_significant one_more_size_u in - (correct_size_num,overflow,c_out) - -(* add_overflow_vec - * add_overflow_vec_signed - * minus_overflow_vec - * minus_overflow_vec_signed - * mult_overflow_vec - * mult_overflow_vec_signed - *) -let addO_VVV = arith_op_overflow_vec integerAdd false 1 -let addSO_VVV = arith_op_overflow_vec integerAdd true 1 -let minusO_VVV = arith_op_overflow_vec integerMinus false 1 -let minusSO_VVV = arith_op_overflow_vec integerMinus true 1 -let multO_VVV = arith_op_overflow_vec integerMult false 2 -let multSO_VVV = arith_op_overflow_vec integerMult true 2 - -val arith_op_overflow_vec_bit : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> - bitvector 'a -> bitU -> bitvector 'b * bitU * bool -let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (size : integer) - (Bitvector _ _ is_inc as l) r_bit = - let act_size = bvlength l * size in - let l' = to_num sign l in - let l_u = to_num false l in - let (n,nu,changed) = match r_bit with - | B1 -> (op l' 1, op l_u 1, true) - | B0 -> (l',l_u,false) - | BU -> failwith "arith_op_overflow_vec_bit applied to undefined bit" - end in -(* | _ -> assert false *) - let correct_size_num = to_vec_ord is_inc (act_size,n) in - let one_larger = to_vec_ord is_inc (act_size + 1,nu) in - let overflow = - if changed - then - if n <= get_max_representable_in sign act_size && n >= get_min_representable_in sign act_size - then B0 else B1 - else B0 in - (correct_size_num,overflow,most_significant one_larger) - -(* add_overflow_vec_bit_signed - * minus_overflow_vec_bit - * minus_overflow_vec_bit_signed - *) -let addSO_VBV = arith_op_overflow_vec_bit integerAdd true 1 -let minusO_VBV = arith_op_overflow_vec_bit integerMinus false 1 -let minusSO_VBV = arith_op_overflow_vec_bit integerMinus true 1 -*) -type shift = LL_shift | RR_shift | LLL_shift - -let shift_op_vec op (bs, (n : integer)) = - let n = natFromInteger n in - match op with - | LL_shift (*"<<"*) -> - shiftLeft bs n - | RR_shift (*">>"*) -> - shiftRight bs n - | LLL_shift (*"<<<"*) -> - rotateLeft n bs - end - -let bitwise_leftshift x = shift_op_vec LL_shift x (*"<<"*) -let bitwise_rightshift x = shift_op_vec RR_shift x (*">>"*) -let bitwise_rotate x = shift_op_vec LLL_shift x (*"<<<"*) - -let shiftl = bitwise_leftshift -let shiftr = bitwise_rightshift - -let rec arith_op_no0 (op : integer -> integer -> integer) l r = - if r = 0 - then Nothing - else Just (op l r) -(* TODO -let rec arith_op_vec_no0 (op : integer -> integer -> integer) sign size ((Bitvector _ start is_inc) as l) r = - let act_size = bvlength l * size in - let (l',r') = (to_num sign l,to_num sign r) in - let n = arith_op_no0 op l' r' in - let (representable,n') = - match n with - | Just n' -> - (n' <= get_max_representable_in sign act_size && - n' >= get_min_representable_in sign act_size, n') - | _ -> (false,0) - end in - if representable - then to_vec_ord is_inc (act_size,n') - else Vector (List.replicate (natFromInteger act_size) BU) start is_inc - -let mod_VVV = arith_op_vec_no0 hardware_mod false 1 -let quot_VVV = arith_op_vec_no0 hardware_quot false 1 -let quotS_VVV = arith_op_vec_no0 hardware_quot true 1 - -let arith_op_overflow_no0_vec op sign size ((Vector _ start is_inc) as l) r = - let rep_size = length r * size in - let act_size = length l * size in - let (l',r') = (to_num sign l,to_num sign r) in - let (l_u,r_u) = (to_num false l,to_num false r) in - let n = arith_op_no0 op l' r' in - let n_u = arith_op_no0 op l_u r_u in - let (representable,n',n_u') = - match (n, n_u) with - | (Just n',Just n_u') -> - ((n' <= get_max_representable_in sign rep_size && - n' >= (get_min_representable_in sign rep_size)), n', n_u') - | _ -> (true,0,0) - end in - let (correct_size_num,one_more) = - if representable then - (to_vec_ord is_inc (act_size,n'),to_vec_ord is_inc (act_size + 1,n_u')) - else - (Vector (List.replicate (natFromInteger act_size) BU) start is_inc, - Vector (List.replicate (natFromInteger (act_size + 1)) BU) start is_inc) in - let overflow = if representable then B0 else B1 in - (correct_size_num,overflow,most_significant one_more) - -let quotO_VVV = arith_op_overflow_no0_vec hardware_quot false 1 -let quotSO_VVV = arith_op_overflow_no0_vec hardware_quot true 1 - -let arith_op_vec_range_no0 op sign size (Vector _ _ is_inc as l) r = - arith_op_vec_no0 op sign size l (to_vec_ord is_inc (length l,r)) - -let mod_VIV = arith_op_vec_range_no0 hardware_mod false 1 -*) - -let duplicate (bit, length) = - vec_to_bvec (Vector (repeat [bit] length) (length - 1) false) - -(* TODO: replace with better native versions *) -let replicate_bits (v, count) = - let v = bvec_to_vec true 0 v in - vec_to_bvec (Vector (repeat (get_elems v) count) ((length v * count) - 1) false) - -let compare_op op (l,r) = (op l r) - -let lt = compare_op (<) -let gt = compare_op (>) -let lteq = compare_op (<=) -let gteq = compare_op (>=) - -let compare_op_vec op sign (l,r) = - let (l',r') = (to_num sign l, to_num sign r) in - compare_op op (l',r') - -let lt_vec x = compare_op_vec (<) true x -let gt_vec x = compare_op_vec (>) true x -let lteq_vec x = compare_op_vec (<=) true x -let gteq_vec x = compare_op_vec (>=) true x - -let lt_vec_signed x = compare_op_vec (<) true x -let gt_vec_signed x = compare_op_vec (>) true x -let lteq_vec_signed x = compare_op_vec (<=) true x -let gteq_vec_signed x = compare_op_vec (>=) true x -let lt_vec_unsigned x = compare_op_vec (<) false x -let gt_vec_unsigned x = compare_op_vec (>) false x -let lteq_vec_unsigned x = compare_op_vec (<=) false x -let gteq_vec_unsigned x = compare_op_vec (>=) false x - -let lt_svec = lt_vec_signed - -let compare_op_vec_range op sign (l,r) = - compare_op op ((to_num sign l),r) - -let lt_vec_range x = compare_op_vec_range (<) true x -let gt_vec_range x = compare_op_vec_range (>) true x -let lteq_vec_range x = compare_op_vec_range (<=) true x -let gteq_vec_range x = compare_op_vec_range (>=) true x - -let compare_op_range_vec op sign (l,r) = - compare_op op (l, (to_num sign r)) - -let lt_range_vec x = compare_op_range_vec (<) true x -let gt_range_vec x = compare_op_range_vec (>) true x -let lteq_range_vec x = compare_op_range_vec (<=) true x -let gteq_range_vec x = compare_op_range_vec (>=) true x - -val eq : forall 'a. Eq 'a => 'a * 'a -> bool -let eq (l,r) = (l = r) -let eq_range (l,r) = (l = r) - -val eq_vec : forall 'a. Size 'a => bitvector 'a * bitvector 'a -> bool -let eq_vec (l,r) = eq (to_num false l, to_num false r) -let eq_bit (l,r) = eq (l, r) -let eq_vec_range (l,r) = eq (to_num false l,r) -let eq_range_vec (l,r) = eq (l, to_num false r) -(*let eq_vec_vec (l,r) = eq (to_num true l, to_num true r)*) - -let neq (l,r) = not (eq (l,r)) -let neq_bit (l,r) = not (eq_bit (l,r)) -let neq_range (l,r) = not (eq_range (l,r)) -let neq_vec (l,r) = not (eq_vec (l,r)) -(*let neq_vec_vec (l,r) = not (eq_vec_vec (l,r))*) -let neq_vec_range (l,r) = not (eq_vec_range (l,r)) -let neq_range_vec (l,r) = not (eq_range_vec (l,r)) - - -val make_indexed_vector : forall 'a. list (integer * 'a) -> 'a -> integer -> integer -> bool -> vector 'a -let make_indexed_vector entries default start length dir = - let length = natFromInteger length in - Vector (List.foldl replace (replicate length default) entries) start dir - -(* -val make_bit_vector_undef : integer -> vector bitU -let make_bitvector_undef length = - Vector (replicate (natFromInteger length) BU) 0 true - *) - -(* let bitwise_not_range_bit n = bitwise_not (to_vec_ord defaultDir n) *) - -(* TODO *) -val mask : forall 'a 'b. Size 'b => (integer * integer * bitvector 'a) -> bitvector 'b -let mask (start, _, w) = (zeroExtend w) - -(* Register operations *) - -(*let update_reg_range reg i j reg_val new_val = bvupdate (reg.reg_is_inc) (reg.reg_start) reg_val i j new_val -let update_reg_pos reg i reg_val bit = bvupdate_pos (reg.reg_is_inc) (reg.reg_start) reg_val i bit -let update_reg_field_range regfield i j reg_val new_val = - let current_field_value = regfield.get_field reg_val in - let new_field_value = bvupdate (regfield.field_is_inc) (regfield.field_start) current_field_value i j new_val in - regfield.set_field reg_val new_field_value -(*let write_reg_field_pos regfield i reg_val bit = - let current_field_value = regfield.get_field reg_val in - let new_field_value = bvupdate_pos (regfield.field_is_inc) (regfield.field_start) current_field_value i bit in - regfield.set_field reg_val new_field_value*)*) +open import Sail_operators + +(* Specialisation of operators to machine words *) + +val access_vec_inc : forall 'a. Size 'a => mword 'a -> integer -> bitU +let access_vec_inc = access_bv_inc + +val access_vec_dec : forall 'a. Size 'a => mword 'a -> integer -> bitU +let access_vec_dec = access_bv_dec + +val update_vec_inc : forall 'a. Size 'a => mword 'a -> integer -> bitU -> mword 'a +let update_vec_inc = update_bv_inc + +val update_vec_dec : forall 'a. Size 'a => mword 'a -> integer -> bitU -> mword 'a +let update_vec_dec = update_bv_dec + +val subrange_vec_inc : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> integer -> mword 'b +let subrange_vec_inc = subrange_bv_inc + +val subrange_vec_dec : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> integer -> mword 'b +let subrange_vec_dec = subrange_bv_dec + +val update_subrange_vec_inc : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> integer -> mword 'b -> mword 'a +let update_subrange_vec_inc = update_subrange_bv_inc + +val update_subrange_vec_dec : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> integer -> mword 'b -> mword 'a +let update_subrange_vec_dec = update_subrange_bv_dec + +val extz_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b +let extz_vec = extz_bv + +val exts_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b +let exts_vec = exts_bv + +val concat_vec : forall 'a 'b 'c. Size 'a, Size 'b, Size 'c => mword 'a -> mword 'b -> mword 'c +let concat_vec = concat_bv + +val cons_vec : forall 'a 'b 'c. Size 'a, Size 'b => bitU -> mword 'a -> mword 'b +let cons_vec = cons_bv + +val bool_of_vec : mword ty1 -> bitU +let bool_of_vec = bool_of_bv + +val cast_unit_vec : bitU -> mword ty1 +let cast_unit_vec = cast_unit_bv + +val vec_of_bit : forall 'a. Size 'a => integer -> bitU -> mword 'a +let vec_of_bit = bv_of_bit + +val msb : forall 'a. Size 'a => mword 'a -> bitU +let msb = most_significant + +val int_of_vec : forall 'a. Size 'a => bool -> mword 'a -> integer +let int_of_vec = int_of_bv + +val and_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val or_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val xor_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val not_vec : forall 'a. Size 'a => mword 'a -> mword 'a +let and_vec = and_bv +let or_vec = or_bv +let xor_vec = xor_bv +let not_vec = not_bv + +val add_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val addS_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val sub_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val mult_vec : forall 'a 'b. Size 'a, Size 'b => mword 'a -> mword 'a -> mword 'b +val multS_vec : forall 'a 'b. Size 'a, Size 'b => mword 'a -> mword 'a -> mword 'b +let add_vec = add_bv +let addS_vec = addS_bv +let sub_vec = sub_bv +let mult_vec = mult_bv +let multS_vec = multS_bv + +val add_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val addS_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val sub_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val mult_vec_int : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b +val multS_vec_int : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b +let add_vec_int = add_bv_int +let addS_vec_int = addS_bv_int +let sub_vec_int = sub_bv_int +let mult_vec_int = mult_bv_int +let multS_vec_int = multS_bv_int + +val add_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a +val addS_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a +val sub_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a +val mult_int_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b +val multS_int_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b +let add_int_vec = add_int_bv +let addS_int_vec = addS_int_bv +let sub_int_vec = sub_int_bv +let mult_int_vec = mult_int_bv +let multS_int_vec = multS_int_bv + +val add_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> mword 'a +val addS_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> mword 'a +val sub_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> mword 'a +let add_vec_bit = add_bv_bit +let addS_vec_bit = addS_bv_bit +let sub_vec_bit = sub_bv_bit + +val add_overflow_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val add_overflow_vec_signed : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val sub_overflow_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val sub_overflow_vec_signed : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val mult_overflow_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val mult_overflow_vec_signed : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +let add_overflow_vec = add_overflow_bv +let add_overflow_vec_signed = add_overflow_bv_signed +let sub_overflow_vec = sub_overflow_bv +let sub_overflow_vec_signed = sub_overflow_bv_signed +let mult_overflow_vec = mult_overflow_bv +let mult_overflow_vec_signed = mult_overflow_bv_signed + +val add_overflow_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> (mword 'a * bitU * bitU) +val add_overflow_vec_bit_signed : forall 'a. Size 'a => mword 'a -> bitU -> (mword 'a * bitU * bitU) +val sub_overflow_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> (mword 'a * bitU * bitU) +val sub_overflow_vec_bit_signed : forall 'a. Size 'a => mword 'a -> bitU -> (mword 'a * bitU * bitU) +let add_overflow_vec_bit = add_overflow_bv_bit +let add_overflow_vec_bit_signed = add_overflow_bv_bit_signed +let sub_overflow_vec_bit = sub_overflow_bv_bit +let sub_overflow_vec_bit_signed = sub_overflow_bv_bit_signed + +val shiftl : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val shiftr : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val arith_shiftr : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val rotl : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val rotr : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +let shiftl = shiftl_bv +let shiftr = shiftr_bv +let arith_shiftr = arith_shiftr_bv +let rotl = rotl_bv +let rotr = rotr_bv + +val mod_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val quot_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val quot_vec_signed : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +let mod_vec = mod_bv +let quot_vec = quot_bv +let quot_vec_signed = quot_bv_signed + +val mod_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val quot_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +let mod_vec_int = mod_bv_int +let quot_vec_int = quot_bv_int + +val replicate_bits : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b +let replicate_bits = replicate_bits_bv + +val duplicate : forall 'a. Size 'a => bitU -> integer -> mword 'a +let duplicate = duplicate_bit_bv + +val eq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val neq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val ult_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val slt_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val ugt_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val sgt_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val ulteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val slteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val ugteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val sgteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +let eq_vec = eq_bv +let neq_vec = neq_bv +let ult_vec = ult_bv +let slt_vec = slt_bv +let ugt_vec = ugt_bv +let sgt_vec = sgt_bv +let ulteq_vec = ulteq_bv +let slteq_vec = slteq_bv +let ugteq_vec = ugteq_bv +let sgteq_vec = sgteq_bv diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index 8aee556d..50dacf5e 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -13,6 +13,17 @@ let pow m n = m ** (natFromInteger n) let pow2 n = pow 2 n +let inline lt = (<) +let inline gt = (>) +let inline lteq = (<=) +let inline gteq = (>=) + +val eq : forall 'a. Eq 'a => 'a -> 'a -> bool +let inline eq l r = (l = r) + +val neq : forall 'a. Eq 'a => 'a -> 'a -> bool +let inline neq l r = (l <> r) + (*let add_int l r = integerAdd l r let add_signed l r = integerAdd l r let sub_int l r = integerMinus l r @@ -58,6 +69,34 @@ let rec replace bs (n : integer) b' = match bs with let upper n = n +(* Modulus operation corresponding to quot below -- result + has sign of dividend. *) +let hardware_mod (a: integer) (b:integer) : integer = + let m = (abs a) mod (abs b) in + if a < 0 then ~m else m + +(* There are different possible answers for integer divide regarding +rounding behaviour on negative operands. Positive operands always +round down so derive the one we want (trucation towards zero) from +that *) +let hardware_quot (a:integer) (b:integer) : integer = + let q = (abs a) / (abs b) in + if ((a<0) = (b<0)) then + q (* same sign -- result positive *) + else + ~q (* different sign -- result negative *) + +let max_64u = (integerPow 2 64) - 1 +let max_64 = (integerPow 2 63) - 1 +let min_64 = 0 - (integerPow 2 63) +let max_32u = (4294967295 : integer) +let max_32 = (2147483647 : integer) +let min_32 = (0 - 2147483648 : integer) +let max_8 = (127 : integer) +let min_8 = (0 - 128 : integer) +let max_5 = (31 : integer) +let min_5 = (0 - 32 : integer) + (*** Bits *) type bitU = B0 | B1 | BU @@ -83,7 +122,7 @@ end let bool_of_bitU = function | B0 -> false - | B1 -> true + | B1 -> true | BU -> failwith "bool_of_bitU applied to BU" end @@ -272,34 +311,34 @@ let show_bitlist bs = let inline (^^) = append_list -val slice_list_inc : forall 'a. list 'a -> integer -> integer -> list 'a -let slice_list_inc xs i j = +val subrange_list_inc : forall 'a. list 'a -> integer -> integer -> list 'a +let subrange_list_inc xs i j = let (toJ,_suffix) = List.splitAt (natFromInteger j + 1) xs in let (_prefix,fromItoJ) = List.splitAt (natFromInteger i) toJ in fromItoJ -val slice_list_dec : forall 'a. list 'a -> integer -> integer -> list 'a -let slice_list_dec xs i j = +val subrange_list_dec : forall 'a. list 'a -> integer -> integer -> list 'a +let subrange_list_dec xs i j = let top = (length_list xs) - 1 in - slice_list_inc xs (top - i) (top - j) + subrange_list_inc xs (top - i) (top - j) -val slice_list : forall 'a. bool -> list 'a -> integer -> integer -> list 'a -let slice_list is_inc xs i j = if is_inc then slice_list_inc xs i j else slice_list_dec xs i j +val subrange_list : forall 'a. bool -> list 'a -> integer -> integer -> list 'a +let subrange_list is_inc xs i j = if is_inc then subrange_list_inc xs i j else subrange_list_dec xs i j -val update_slice_list_inc : forall 'a. list 'a -> integer -> integer -> list 'a -> list 'a -let update_slice_list_inc xs i j xs' = +val update_subrange_list_inc : forall 'a. list 'a -> integer -> integer -> list 'a -> list 'a +let update_subrange_list_inc xs i j xs' = let (toJ,suffix) = List.splitAt (natFromInteger j + 1) xs in let (prefix,_fromItoJ) = List.splitAt (natFromInteger i) toJ in prefix ++ xs' ++ suffix -val update_slice_list_dec : forall 'a. list 'a -> integer -> integer -> list 'a -> list 'a -let update_slice_list_dec xs i j xs' = +val update_subrange_list_dec : forall 'a. list 'a -> integer -> integer -> list 'a -> list 'a +let update_subrange_list_dec xs i j xs' = let top = (length_list xs) - 1 in - update_slice_list_inc xs (top - i) (top - j) xs' + update_subrange_list_inc xs (top - i) (top - j) xs' -val update_slice_list : forall 'a. bool -> list 'a -> integer -> integer -> list 'a -> list 'a -let update_slice_list is_inc xs i j xs' = - if is_inc then update_slice_list_inc xs i j xs' else update_slice_list_dec xs i j xs' +val update_subrange_list : forall 'a. bool -> list 'a -> integer -> integer -> list 'a -> list 'a +let update_subrange_list is_inc xs i j xs' = + if is_inc then update_subrange_list_inc xs i j xs' else update_subrange_list_dec xs i j xs' val access_list_inc : forall 'a. list 'a -> integer -> 'a let access_list_inc xs n = List_extra.nth xs (natFromInteger n) @@ -383,11 +422,28 @@ val update_mword : forall 'a. bool -> mword 'a -> integer -> bitU -> mword 'a let update_mword is_inc w n b = if is_inc then update_mword_inc w n b else update_mword_dec w n b +val mword_of_int : forall 'a. Size 'a => integer -> integer -> mword 'a +let mword_of_int len n = + let w = wordFromInteger n in + if (length_mword w = len) then w else failwith "unexpected word length" + +(* Translating between a type level number (itself 'n) and an integer *) + +let size_itself_int x = integerFromNat (size_itself x) + +(* NB: the corresponding sail type is forall 'n. atom('n) -> itself('n), + the actual integer is ignored. *) + +val make_the_value : forall 'n. integer -> itself 'n +let inline make_the_value x = the_value + (*** Bitvectors *) class (Bitvector 'a) val bits_of : 'a -> list bitU val of_bits : list bitU -> 'a + (* The first parameter specifies the desired length of the bitvector *) + val of_int : integer -> integer -> 'a val length : 'a -> integer val unsigned : 'a -> integer val signed : 'a -> integer @@ -401,44 +457,46 @@ end instance forall 'a. BitU 'a => (Bitvector (list 'a)) let bits_of v = List.map to_bitU v let of_bits v = List.map of_bitU v - let length v = length_list v + let of_int len n = List.map of_bitU (bits_of_int len n) + let length = length_list let unsigned v = unsigned_of_bits (List.map to_bitU v) let signed v = signed_of_bits (List.map to_bitU v) let get_bit is_inc v n = to_bitU (access_list is_inc v n) let set_bit is_inc v n b = update_list is_inc v n (of_bitU b) - let get_bits is_inc v i j = List.map to_bitU (slice_list is_inc v i j) - let set_bits is_inc v i j v' = update_slice_list is_inc v i j (List.map of_bitU v') + let get_bits is_inc v i j = List.map to_bitU (subrange_list is_inc v i j) + let set_bits is_inc v i j v' = update_subrange_list is_inc v i j (List.map of_bitU v') end instance forall 'a. Size 'a => (Bitvector (mword 'a)) let bits_of v = List.map to_bitU (bitlistFromWord v) let of_bits v = wordFromBitlist (List.map of_bitU v) + let of_int = mword_of_int let length v = integerFromNat (word_length v) - let unsigned v = unsignedIntegerFromWord v - let signed v = signedIntegerFromWord v + let unsigned = unsignedIntegerFromWord + let signed = signedIntegerFromWord let get_bit = access_mword let set_bit = update_mword let get_bits is_inc v i j = get_bits is_inc (bitlistFromWord v) i j let set_bits is_inc v i j v' = wordFromBitlist (set_bits is_inc (bitlistFromWord v) i j v') end -let access_vec_inc v n = get_bit true v n -let access_vec_dec v n = get_bit false v n +let access_bv_inc v n = get_bit true v n +let access_bv_dec v n = get_bit false v n -let update_vec_inc v n b = set_bit true v n b -let update_vec_dec v n b = set_bit false v n b +let update_bv_inc v n b = set_bit true v n b +let update_bv_dec v n b = set_bit false v n b -let subrange_vec_inc v i j = of_bits (get_bits true v i j) -let subrange_vec_dec v i j = of_bits (get_bits false v i j) +let subrange_bv_inc v i j = of_bits (get_bits true v i j) +let subrange_bv_dec v i j = of_bits (get_bits false v i j) -let update_subrange_vec_inc v i j v' = set_bits true v i j (bits_of v') -let update_subrange_vec_dec v i j v' = set_bits false v i j (bits_of v') +let update_subrange_bv_inc v i j v' = set_bits true v i j (bits_of v') +let update_subrange_bv_dec v i j v' = set_bits false v i j (bits_of v') -val extz_vec : forall 'a 'b. Bitvector 'a, Bitvector 'b => integer -> 'a -> 'b -let extz_vec n v = of_bits (extz_bits n (bits_of v)) +val extz_bv : forall 'a 'b. Bitvector 'a, Bitvector 'b => integer -> 'a -> 'b +let extz_bv n v = of_bits (extz_bits n (bits_of v)) -val exts_vec : forall 'a 'b. Bitvector 'a, Bitvector 'b => integer -> 'a -> 'b -let exts_vec n v = of_bits (exts_bits n (bits_of v)) +val exts_bv : forall 'a 'b. Bitvector 'a, Bitvector 'b => integer -> 'a -> 'b +let exts_bv n v = of_bits (exts_bits n (bits_of v)) (*** Bytes and addresses *) @@ -584,13 +642,13 @@ let rec external_reg_value reg_name v = match reg_name with | Reg _ start size dir -> (start, (if dir = D_increasing then start else (start - (size +1))), dir) - | Reg_slice _ reg_start dir (slice_start, slice_end) -> + | Reg_slice _ reg_start dir (slice_start, _) -> ((if dir = D_increasing then slice_start else (reg_start - slice_start)), slice_start, dir) - | Reg_field _ reg_start dir _ (slice_start, slice_end) -> + | Reg_field _ reg_start dir _ (slice_start, _) -> ((if dir = D_increasing then slice_start else (reg_start - slice_start)), slice_start, dir) - | Reg_f_slice _ reg_start dir _ _ (slice_start, slice_end) -> + | Reg_f_slice _ reg_start dir _ _ (slice_start, _) -> ((if dir = D_increasing then slice_start else (reg_start - slice_start)), slice_start, dir) end in diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index b6852aaf..ac6d55b5 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -1,253 +1,8 @@ open import Pervasives_extra open import Sail_impl_base open import Sail_values - -(* 'a is result type *) - -type memstate = map integer memory_byte -type tagstate = map integer bitU -(* type regstate = map string (vector bitU) *) - -type sequential_state 'regs = - <| regstate : 'regs; - memstate : memstate; - tagstate : tagstate; - write_ea : maybe (write_kind * integer * integer); - last_exclusive_operation_was_load : bool|> - -val init_state : forall 'regs. 'regs -> sequential_state 'regs -let init_state regs = - <| regstate = regs; - memstate = Map.empty; - tagstate = Map.empty; - write_ea = Nothing; - last_exclusive_operation_was_load = false |> - -type ex 'e = - | Exit - | Assert of string - | Throw of 'e - -type result 'a 'e = - | Value of 'a - | Exception of (ex 'e) - -(* State, nondeterminism and exception monad with result value type 'a - and exception type 'e. *) -type M 'regs 'a 'e = sequential_state 'regs -> list (result 'a 'e * sequential_state 'regs) - -val return : forall 'regs 'a 'e. 'a -> M 'regs 'a 'e -let return a s = [(Value a,s)] - -val bind : forall 'regs 'a 'b 'e. M 'regs 'a 'e -> ('a -> M 'regs 'b 'e) -> M 'regs 'b 'e -let bind m f (s : sequential_state 'regs) = - List.concatMap (function - | (Value a, s') -> f a s' - | (Exception e, s') -> [(Exception e, s')] - end) (m s) - -let inline (>>=) = bind -val (>>): forall 'regs 'b 'e. M 'regs unit 'e -> M 'regs 'b 'e -> M 'regs 'b 'e -let inline (>>) m n = m >>= fun (_ : unit) -> n - -val throw : forall 'regs 'a 'e. 'e -> M 'regs 'a 'e -let throw e s = [(Exception (Throw e), s)] - -val try_catch : forall 'regs 'a 'e1 'e2. M 'regs 'a 'e1 -> ('e1 -> M 'regs 'a 'e2) -> M 'regs 'a 'e2 -let try_catch m h s = - List.concatMap (function - | (Value a, s') -> return a s' - | (Exception (Throw e), s') -> h e s' - | (Exception Exit, s') -> [(Exception Exit, s')] - | (Exception (Assert msg), s') -> [(Exception (Assert msg), s')] - end) (m s) - -val exit : forall 'regs 'e 'a. unit -> M 'regs 'a 'e -let exit () s = [(Exception Exit, s)] - -val assert_exp : forall 'regs 'e. bool -> string -> M 'regs unit 'e -let assert_exp exp msg s = if exp then [(Value (), s)] else [(Exception (Assert msg), s)] - -(* For early return, we abuse exceptions by throwing and catching - the return value. The exception type is "either 'r 'e", where "Right e" - represents a proper exception and "Left r" an early return of value "r". *) -type MR 'regs 'a 'r 'e = M 'regs 'a (either 'r 'e) - -val early_return : forall 'regs 'a 'r 'e. 'r -> MR 'regs 'a 'r 'e -let early_return r = throw (Left r) - -val catch_early_return : forall 'regs 'a 'e. MR 'regs 'a 'a 'e -> M 'regs 'a 'e -let catch_early_return m = - try_catch m - (function - | Left a -> return a - | Right e -> throw e - end) - -(* Lift to monad with early return by wrapping exceptions *) -val liftR : forall 'a 'r 'regs 'e. M 'regs 'a 'e -> MR 'regs 'a 'r 'e -let liftR m = try_catch m (fun e -> throw (Right e)) - -(* Catch exceptions in the presence of early returns *) -val try_catchR : forall 'regs 'a 'r 'e1 'e2. MR 'regs 'a 'r 'e1 -> ('e1 -> MR 'regs 'a 'r 'e2) -> MR 'regs 'a 'r 'e2 -let try_catchR m h = - try_catch m - (function - | Left r -> throw (Left r) - | Right e -> h e - end) - -val range : integer -> integer -> list integer -let rec range i j = - if j < i then [] - else if i = j then [i] - else i :: range (i+1) j - -val get_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a -let get_reg state reg = reg.read_from state.regstate - -val set_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a -> sequential_state 'regs -let set_reg state reg v = - <| state with regstate = reg.write_to state.regstate v |> - - -let is_exclusive = function - | Sail_impl_base.Read_plain -> false - | Sail_impl_base.Read_reserve -> true - | Sail_impl_base.Read_acquire -> false - | Sail_impl_base.Read_exclusive -> true - | Sail_impl_base.Read_exclusive_acquire -> true - | Sail_impl_base.Read_stream -> false - | Sail_impl_base.Read_RISCV_acquire -> false - | Sail_impl_base.Read_RISCV_strong_acquire -> false - | Sail_impl_base.Read_RISCV_reserved -> true - | Sail_impl_base.Read_RISCV_reserved_acquire -> true - | Sail_impl_base.Read_RISCV_reserved_strong_acquire -> true - | Sail_impl_base.Read_X86_locked -> true -end - - -val read_mem : forall 'regs 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> M 'regs 'b 'e -let read_mem read_kind addr sz state = - let addr = unsigned addr in - let addrs = range addr (addr+sz-1) in - let memory_value = List.map (fun addr -> Map_extra.find addr state.memstate) addrs in - let value = of_bits (Sail_values.internal_mem_value memory_value) in - if is_exclusive read_kind - then [(Value value, <| state with last_exclusive_operation_was_load = true |>)] - else [(Value value, state)] - -(* caps are aligned at 32 bytes *) -let cap_alignment = (32 : integer) - -val read_tag : forall 'regs 'a 'e. Bitvector 'a => read_kind -> 'a -> M 'regs bitU 'e -let read_tag read_kind addr state = - let addr = (unsigned addr) / cap_alignment in - let tag = match (Map.lookup addr state.tagstate) with - | Just t -> t - | Nothing -> B0 - end in - if is_exclusive read_kind - then [(Value tag, <| state with last_exclusive_operation_was_load = true |>)] - else [(Value tag, state)] - -val excl_result : forall 'regs 'e. unit -> M 'regs bool 'e -let excl_result () state = - let success = - (Value true, <| state with last_exclusive_operation_was_load = false |>) in - (Value false, state) :: if state.last_exclusive_operation_was_load then [success] else [] - -val write_mem_ea : forall 'regs 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit 'e -let write_mem_ea write_kind addr sz state = - [(Value (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)] - -val write_mem_val : forall 'a 'regs 'b 'e. Bitvector 'a => 'a -> M 'regs bool 'e -let write_mem_val v state = - let (write_kind,addr,sz) = match state.write_ea with - | Nothing -> failwith "write ea has not been announced yet" - | Just write_ea -> write_ea end in - let addrs = range addr (addr+sz-1) in - let v = external_mem_value (bits_of v) in - let addresses_with_value = List.zip addrs v in - let memstate = List.foldl (fun mem (addr,v) -> Map.insert addr v mem) - state.memstate addresses_with_value in - [(Value true, <| state with memstate = memstate |>)] - -val write_tag : forall 'regs 'e. bitU -> M 'regs bool 'e -let write_tag t state = - let (write_kind,addr,sz) = match state.write_ea with - | Nothing -> failwith "write ea has not been announced yet" - | Just write_ea -> write_ea end in - let taddr = addr / cap_alignment in - let tagstate = Map.insert taddr t state.tagstate in - [(Value true, <| state with tagstate = tagstate |>)] - -val read_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> M 'regs 'a 'e -let read_reg reg state = - let v = reg.read_from state.regstate in - [(Value v,state)] -(*let read_reg_range reg i j state = - let v = slice (get_reg state (name_of_reg reg)) i j in - [(Value (vec_to_bvec v),state)] -let read_reg_bit reg i state = - let v = access (get_reg state (name_of_reg reg)) i in - [(Value v,state)] -let read_reg_field reg regfield = - let (i,j) = register_field_indices reg regfield in - read_reg_range reg i j -let read_reg_bitfield reg regfield = - let (i,_) = register_field_indices reg regfield in - read_reg_bit reg i *) - -let reg_deref = read_reg - -val write_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> 'a -> M 'regs unit 'e -let write_reg reg v state = - [(Value (), <| state with regstate = reg.write_to state.regstate v |>)] - -let write_reg_ref (reg, v) = write_reg reg v - -val update_reg : forall 'regs 'a 'b 'e. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit 'e -let update_reg reg f v state = - let current_value = get_reg state reg in - let new_value = f current_value v in - [(Value (), set_reg state reg new_value)] - -let write_reg_field reg regfield = update_reg reg regfield.set_field - -val update_reg_range : forall 'regs 'a 'b. Bitvector 'a, Bitvector 'b => register_ref 'regs 'a -> integer -> integer -> 'a -> 'b -> 'a -let update_reg_range reg i j reg_val new_val = set_bits (reg.reg_is_inc) reg_val i j (bits_of new_val) -let write_reg_range reg i j = update_reg reg (update_reg_range reg i j) - -let update_reg_pos reg i reg_val x = update_list reg.reg_is_inc reg_val i x -let write_reg_pos reg i = update_reg reg (update_reg_pos reg i) - -let update_reg_bit reg i reg_val bit = set_bit (reg.reg_is_inc) reg_val i (to_bitU bit) -let write_reg_bit reg i = update_reg reg (update_reg_bit reg i) - -let update_reg_field_range regfield i j reg_val new_val = - let current_field_value = regfield.get_field reg_val in - let new_field_value = set_bits (regfield.field_is_inc) current_field_value i j (bits_of new_val) in - regfield.set_field reg_val new_field_value -let write_reg_field_range reg regfield i j = update_reg reg (update_reg_field_range regfield i j) - -let update_reg_field_pos regfield i reg_val x = - let current_field_value = regfield.get_field reg_val in - let new_field_value = update_list regfield.field_is_inc current_field_value i x in - regfield.set_field reg_val new_field_value -let write_reg_field_pos reg regfield i = update_reg reg (update_reg_field_pos regfield i) - -let update_reg_field_bit regfield i reg_val bit = - let current_field_value = regfield.get_field reg_val in - let new_field_value = set_bit (regfield.field_is_inc) current_field_value i (to_bitU bit) in - regfield.set_field reg_val new_field_value -let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i) - -val barrier : forall 'regs 'e. barrier_kind -> M 'regs unit 'e -let barrier _ = return () - -val footprint : forall 'regs 'e. M 'regs unit 'e -let footprint s = return () s +open import State_monad +open import {isabelle} `State_monad_extras` val iter_aux : forall 'regs 'e 'a. integer -> (integer -> 'a -> M 'regs unit 'e) -> list 'a -> M 'regs unit 'e let rec iter_aux i f xs = match xs with @@ -286,24 +41,30 @@ let rec while_PP vars cond body = val while_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) -> ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e -let rec while_PM vars cond body = +let rec while_PM vars cond body s = if cond vars then - body vars >>= fun vars -> while_PM vars cond body - else return vars + bind (body vars) (fun vars s' -> while_PM vars cond body s') s + else return vars s val while_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) -> ('vars -> 'vars) -> M 'regs 'vars 'e -let rec while_MP vars cond body = - cond vars >>= fun cond_val -> - if cond_val then while_MP (body vars) cond body else return vars +let rec while_MP vars cond body s = + bind + (cond vars) + (fun cond_val s' -> + if cond_val then while_MP (body vars) cond body s' else return vars s') s val while_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) -> ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e -let rec while_MM vars cond body = - cond vars >>= fun cond_val -> - if cond_val then - body vars >>= fun vars -> while_MM vars cond body - else return vars +let rec while_MM vars cond body s = + bind + (cond vars) + (fun cond_val s' -> + if cond_val then + bind + (body vars) + (fun vars s'' -> while_MM vars cond body s'') s' + else return vars s') s val until_PP : forall 'vars. 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars let rec until_PP vars cond body = @@ -312,44 +73,28 @@ let rec until_PP vars cond body = val until_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) -> ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e -let rec until_PM vars cond body = - body vars >>= fun vars -> - if (cond vars) then return vars else until_PM vars cond body +let rec until_PM vars cond body s = + bind + (body vars) + (fun vars s' -> + if (cond vars) then return vars s' else until_PM vars cond body s') s val until_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) -> ('vars -> 'vars) -> M 'regs 'vars 'e -let rec until_MP vars cond body = +let rec until_MP vars cond body s = let vars = body vars in - cond vars >>= fun cond_val -> - if cond_val then return vars else until_MP vars cond body + bind + (cond vars) + (fun cond_val s' -> + if cond_val then return vars s' else until_MP vars cond body s') s val until_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) -> ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e -let rec until_MM vars cond body = - body vars >>= fun vars -> - cond vars >>= fun cond_val -> - if cond_val then return vars else until_MM vars cond body - -(*let write_two_regs r1 r2 bvec state = - let vec = bvec_to_vec bvec in - let is_inc = - let is_inc_r1 = is_inc_of_reg r1 in - let is_inc_r2 = is_inc_of_reg r2 in - let () = ensure (is_inc_r1 = is_inc_r2) - "write_two_regs called with vectors of different direction" in - is_inc_r1 in - - let (size_r1 : integer) = size_of_reg r1 in - let (start_vec : integer) = get_start vec in - let size_vec = length vec in - let r1_v = - if is_inc - then slice vec start_vec (size_r1 - start_vec - 1) - else slice vec start_vec (start_vec - size_r1 - 1) in - let r2_v = - if is_inc - then slice vec (size_r1 - start_vec) (size_vec - start_vec) - else slice vec (start_vec - size_r1) (start_vec - size_vec) in - let state1 = set_reg state (name_of_reg r1) r1_v in - let state2 = set_reg state1 (name_of_reg r2) r2_v in - [(Left (), state2)]*) +let rec until_MM vars cond body s = + bind + (body vars) + (fun vars s' -> + bind + (cond vars) + (fun cond_val s''-> + if cond_val then return vars s'' else until_MM vars cond body s'') s') s diff --git a/src/gen_lib/state_monad.lem b/src/gen_lib/state_monad.lem new file mode 100644 index 00000000..2d8e412e --- /dev/null +++ b/src/gen_lib/state_monad.lem @@ -0,0 +1,250 @@ +open import Pervasives_extra +open import Sail_impl_base +open import Sail_values + +(* 'a is result type *) + +type memstate = map integer memory_byte +type tagstate = map integer bitU +(* type regstate = map string (vector bitU) *) + +type sequential_state 'regs = + <| regstate : 'regs; + memstate : memstate; + tagstate : tagstate; + write_ea : maybe (write_kind * integer * integer); + last_exclusive_operation_was_load : bool|> + +val init_state : forall 'regs. 'regs -> sequential_state 'regs +let init_state regs = + <| regstate = regs; + memstate = Map.empty; + tagstate = Map.empty; + write_ea = Nothing; + last_exclusive_operation_was_load = false |> + +type ex 'e = + | Exit + | Assert of string + | Throw of 'e + +type result 'a 'e = + | Value of 'a + | Exception of (ex 'e) + +(* State, nondeterminism and exception monad with result value type 'a + and exception type 'e. *) +type M 'regs 'a 'e = sequential_state 'regs -> list (result 'a 'e * sequential_state 'regs) + +val return : forall 'regs 'a 'e. 'a -> M 'regs 'a 'e +let return a s = [(Value a,s)] + +val bind : forall 'regs 'a 'b 'e. M 'regs 'a 'e -> ('a -> M 'regs 'b 'e) -> M 'regs 'b 'e +let bind m f (s : sequential_state 'regs) = + List.concatMap (function + | (Value a, s') -> f a s' + | (Exception e, s') -> [(Exception e, s')] + end) (m s) + +let inline (>>=) = bind +val (>>): forall 'regs 'b 'e. M 'regs unit 'e -> M 'regs 'b 'e -> M 'regs 'b 'e +let inline (>>) m n = m >>= fun (_ : unit) -> n + +val throw : forall 'regs 'a 'e. 'e -> M 'regs 'a 'e +let throw e s = [(Exception (Throw e), s)] + +val try_catch : forall 'regs 'a 'e1 'e2. M 'regs 'a 'e1 -> ('e1 -> M 'regs 'a 'e2) -> M 'regs 'a 'e2 +let try_catch m h s = + List.concatMap (function + | (Value a, s') -> return a s' + | (Exception (Throw e), s') -> h e s' + | (Exception Exit, s') -> [(Exception Exit, s')] + | (Exception (Assert msg), s') -> [(Exception (Assert msg), s')] + end) (m s) + +val exit : forall 'regs 'e 'a. unit -> M 'regs 'a 'e +let exit () s = [(Exception Exit, s)] + +val assert_exp : forall 'regs 'e. bool -> string -> M 'regs unit 'e +let assert_exp exp msg s = if exp then [(Value (), s)] else [(Exception (Assert msg), s)] + +(* For early return, we abuse exceptions by throwing and catching + the return value. The exception type is "either 'r 'e", where "Right e" + represents a proper exception and "Left r" an early return of value "r". *) +type MR 'regs 'a 'r 'e = M 'regs 'a (either 'r 'e) + +val early_return : forall 'regs 'a 'r 'e. 'r -> MR 'regs 'a 'r 'e +let early_return r = throw (Left r) + +val catch_early_return : forall 'regs 'a 'e. MR 'regs 'a 'a 'e -> M 'regs 'a 'e +let catch_early_return m = + try_catch m + (function + | Left a -> return a + | Right e -> throw e + end) + +(* Lift to monad with early return by wrapping exceptions *) +val liftR : forall 'a 'r 'regs 'e. M 'regs 'a 'e -> MR 'regs 'a 'r 'e +let liftR m = try_catch m (fun e -> throw (Right e)) + +(* Catch exceptions in the presence of early returns *) +val try_catchR : forall 'regs 'a 'r 'e1 'e2. MR 'regs 'a 'r 'e1 -> ('e1 -> MR 'regs 'a 'r 'e2) -> MR 'regs 'a 'r 'e2 +let try_catchR m h = + try_catch m + (function + | Left r -> throw (Left r) + | Right e -> h e + end) + +val range : integer -> integer -> list integer +let rec range i j = + if j < i then [] + else if i = j then [i] + else i :: range (i+1) j + +val get_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a +let get_reg state reg = reg.read_from state.regstate + +val set_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a -> sequential_state 'regs +let set_reg state reg v = + <| state with regstate = reg.write_to state.regstate v |> + + +let is_exclusive = function + | Sail_impl_base.Read_plain -> false + | Sail_impl_base.Read_reserve -> true + | Sail_impl_base.Read_acquire -> false + | Sail_impl_base.Read_exclusive -> true + | Sail_impl_base.Read_exclusive_acquire -> true + | Sail_impl_base.Read_stream -> false + | Sail_impl_base.Read_RISCV_acquire -> false + | Sail_impl_base.Read_RISCV_strong_acquire -> false + | Sail_impl_base.Read_RISCV_reserved -> true + | Sail_impl_base.Read_RISCV_reserved_acquire -> true + | Sail_impl_base.Read_RISCV_reserved_strong_acquire -> true + | Sail_impl_base.Read_X86_locked -> true +end + + +val read_mem : forall 'regs 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> M 'regs 'b 'e +let read_mem read_kind addr sz state = + let addr = unsigned addr in + let addrs = range addr (addr+sz-1) in + let memory_value = List.map (fun addr -> Map_extra.find addr state.memstate) addrs in + let value = of_bits (Sail_values.internal_mem_value memory_value) in + if is_exclusive read_kind + then [(Value value, <| state with last_exclusive_operation_was_load = true |>)] + else [(Value value, state)] + +(* caps are aligned at 32 bytes *) +let cap_alignment = (32 : integer) + +val read_tag : forall 'regs 'a 'e. Bitvector 'a => read_kind -> 'a -> M 'regs bitU 'e +let read_tag read_kind addr state = + let addr = (unsigned addr) / cap_alignment in + let tag = match (Map.lookup addr state.tagstate) with + | Just t -> t + | Nothing -> B0 + end in + if is_exclusive read_kind + then [(Value tag, <| state with last_exclusive_operation_was_load = true |>)] + else [(Value tag, state)] + +val excl_result : forall 'regs 'e. unit -> M 'regs bool 'e +let excl_result () state = + let success = + (Value true, <| state with last_exclusive_operation_was_load = false |>) in + (Value false, state) :: if state.last_exclusive_operation_was_load then [success] else [] + +val write_mem_ea : forall 'regs 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit 'e +let write_mem_ea write_kind addr sz state = + [(Value (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)] + +val write_mem_val : forall 'a 'regs 'b 'e. Bitvector 'a => 'a -> M 'regs bool 'e +let write_mem_val v state = + let (_,addr,sz) = match state.write_ea with + | Nothing -> failwith "write ea has not been announced yet" + | Just write_ea -> write_ea end in + let addrs = range addr (addr+sz-1) in + let v = external_mem_value (bits_of v) in + let addresses_with_value = List.zip addrs v in + let memstate = List.foldl (fun mem (addr,v) -> Map.insert addr v mem) + state.memstate addresses_with_value in + [(Value true, <| state with memstate = memstate |>)] + +val write_tag : forall 'regs 'e. bitU -> M 'regs bool 'e +let write_tag t state = + let (_,addr,_) = match state.write_ea with + | Nothing -> failwith "write ea has not been announced yet" + | Just write_ea -> write_ea end in + let taddr = addr / cap_alignment in + let tagstate = Map.insert taddr t state.tagstate in + [(Value true, <| state with tagstate = tagstate |>)] + +val read_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> M 'regs 'a 'e +let read_reg reg state = + let v = reg.read_from state.regstate in + [(Value v,state)] +(*let read_reg_range reg i j state = + let v = slice (get_reg state (name_of_reg reg)) i j in + [(Value (vec_to_bvec v),state)] +let read_reg_bit reg i state = + let v = access (get_reg state (name_of_reg reg)) i in + [(Value v,state)] +let read_reg_field reg regfield = + let (i,j) = register_field_indices reg regfield in + read_reg_range reg i j +let read_reg_bitfield reg regfield = + let (i,_) = register_field_indices reg regfield in + read_reg_bit reg i *) + +let reg_deref = read_reg + +val write_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> 'a -> M 'regs unit 'e +let write_reg reg v state = + [(Value (), <| state with regstate = reg.write_to state.regstate v |>)] + +let write_reg_ref (reg, v) = write_reg reg v + +val update_reg : forall 'regs 'a 'b 'e. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit 'e +let update_reg reg f v state = + let current_value = get_reg state reg in + let new_value = f current_value v in + [(Value (), set_reg state reg new_value)] + +let write_reg_field reg regfield = update_reg reg regfield.set_field + +val update_reg_range : forall 'regs 'a 'b. Bitvector 'a, Bitvector 'b => register_ref 'regs 'a -> integer -> integer -> 'a -> 'b -> 'a +let update_reg_range reg i j reg_val new_val = set_bits (reg.reg_is_inc) reg_val i j (bits_of new_val) +let write_reg_range reg i j = update_reg reg (update_reg_range reg i j) + +let update_reg_pos reg i reg_val x = update_list reg.reg_is_inc reg_val i x +let write_reg_pos reg i = update_reg reg (update_reg_pos reg i) + +let update_reg_bit reg i reg_val bit = set_bit (reg.reg_is_inc) reg_val i (to_bitU bit) +let write_reg_bit reg i = update_reg reg (update_reg_bit reg i) + +let update_reg_field_range regfield i j reg_val new_val = + let current_field_value = regfield.get_field reg_val in + let new_field_value = set_bits (regfield.field_is_inc) current_field_value i j (bits_of new_val) in + regfield.set_field reg_val new_field_value +let write_reg_field_range reg regfield i j = update_reg reg (update_reg_field_range regfield i j) + +let update_reg_field_pos regfield i reg_val x = + let current_field_value = regfield.get_field reg_val in + let new_field_value = update_list regfield.field_is_inc current_field_value i x in + regfield.set_field reg_val new_field_value +let write_reg_field_pos reg regfield i = update_reg reg (update_reg_field_pos regfield i) + +let update_reg_field_bit regfield i reg_val bit = + let current_field_value = regfield.get_field reg_val in + let new_field_value = set_bit (regfield.field_is_inc) current_field_value i (to_bitU bit) in + regfield.set_field reg_val new_field_value +let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i) + +val barrier : forall 'regs 'e. barrier_kind -> M 'regs unit 'e +let barrier _ = return () + +val footprint : forall 'regs 'e. M 'regs unit 'e +let footprint s = return () s diff --git a/src/isail.ml b/src/isail.ml index 07a72bd2..ffeb1442 100644 --- a/src/isail.ml +++ b/src/isail.ml @@ -296,7 +296,7 @@ let handle_input' input = vs_ids := Initial_check.val_spec_ids !interactive_ast | ":compile" -> let exp = Type_check.infer_exp !interactive_env (Initial_check.exp_of_string Ast_util.dec_ord arg) in - let anf = C_backend.compile_exp !interactive_env exp in + let anf = C_backend.compile_exp (C_backend.initial_ctx !interactive_env) exp in print_endline (Pretty_print_sail.to_string (C_backend.pp_aexp anf)) | ":u" | ":unload" -> interactive_ast := Ast.Defs []; diff --git a/src/lexer.mll b/src/lexer.mll index 77fba70b..3538d5cb 100644 --- a/src/lexer.mll +++ b/src/lexer.mll @@ -93,7 +93,14 @@ let mk_operator prec n op = | InfixR, 9 -> Op9r op | _, _ -> assert false -let operators = ref M.empty +let operators = ref + (List.fold_left + (fun r (x, y) -> M.add x y r) + M.empty + [ ("==", mk_operator Infix 4 "=="); + ("/", mk_operator InfixL 7 "/"); + ("%", mk_operator InfixL 7 "%"); + ]) let kw_table = List.fold_left diff --git a/src/monomorphise.ml b/src/monomorphise.ml index a4d404e1..3b8a5073 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -520,7 +520,12 @@ let nexp_subst_fns substs = | E_id _ | E_lit _ | E_comment _ -> re e - | E_sizeof ne -> re (E_sizeof ne) (* TODO: does this need done? does it appear in type checked code? *) + | E_sizeof ne -> begin + let ne' = subst_nexp substs ne in + match ne' with + | Nexp_aux (Nexp_constant i,l) -> re (E_lit (L_aux (L_num i,l))) + | _ -> re (E_sizeof ne') + end | E_constraint nc -> re (E_constraint (subst_nc substs nc)) | E_internal_exp (l,annot) -> re (E_internal_exp (l, s_tannot annot)) | E_sizeof_internal (l,annot) -> re (E_sizeof_internal (l, s_tannot annot)) @@ -789,14 +794,17 @@ let construct_lit_vector args = | _ -> None in aux [] args +type pat_choice = Parse_ast.l * (int * int * (id * tannot exp) list) + (* We may need to split up a pattern match if (1) we've been told to case split on a variable by the user or analysis, or (2) we monomorphised a constructor that's used in the pattern. *) type split = | NoSplit - | VarSplit of (tannot pat * (* pattern for this case *) - (id * tannot Ast.exp) list * (* substitutions for arguments *) - (Parse_ast.l * (int * (id * tannot exp) list)) list) (* optional locations of case expressions to reduce *) + | VarSplit of (tannot pat * (* pattern for this case *) + (id * tannot Ast.exp) list * (* substitutions for arguments *) + pat_choice list * (* optional locations of constraints/case expressions to reduce *) + (kid * nexp) list) (* substitutions for type variables *) list | ConstrSplit of (tannot pat * nexp KBindings.t) list @@ -896,23 +904,96 @@ let rec freshen_pat_bindings p = FP_aux (FP_Fpat (id, p),(Generated l,annot)), vs in aux p +(* This cuts off function bodies at false assertions that we may have produced + in a wildcard pattern match. It should handle the same assertions that + find_set_assertions does. *) +let stop_at_false_assertions e = + let dummy_value_of_typ typ = + let l = Generated Unknown in + E_aux (E_exit (E_aux (E_lit (L_aux (L_unit,l)),(l,None))),(l,None)) + in + let rec exp (E_aux (e,ann) as ea) = + match e with + | E_block es -> + let rec aux = function + | [] -> [], None + | e::es -> let e,stop = exp e in + match stop with + | Some _ -> [e],stop + | None -> + let es',stop = aux es in + e::es',stop + in let es,stop = aux es in begin + match stop with + | None -> E_aux (E_block es,ann), stop + | Some typ -> + let typ' = typ_of_annot ann in + if Type_check.alpha_equivalent (env_of_annot ann) typ typ' + then E_aux (E_block es,ann), stop + else E_aux (E_block (es@[dummy_value_of_typ typ']),ann), Some typ' + end + | E_nondet es -> + let es,stops = List.split (List.map exp es) in + let stop = List.exists (function Some _ -> true | _ -> false) stops in + let stop = if stop then Some (typ_of_annot ann) else None in + E_aux (E_nondet es,ann), stop + | E_cast (typ,e) -> let e,stop = exp e in + let stop = match stop with Some _ -> Some typ | None -> None in + E_aux (E_cast (typ,e),ann),stop + | E_let (LB_aux (LB_val (p,e1),lbann),e2) -> + let e1,stop = exp e1 in begin + match stop with + | Some _ -> e1,stop + | None -> + let e2,stop = exp e2 in + E_aux (E_let (LB_aux (LB_val (p,e1),lbann),e2),ann), stop + end + | E_assert (E_aux (E_constraint (NC_aux (NC_false,_)),_),_) -> + ea, Some (typ_of_annot ann) + | E_assert (E_aux (E_lit (L_aux (L_false,_)),_),_) -> + ea, Some (typ_of_annot ann) + | _ -> ea, None + in fst (exp e) + (* Use the location pairs in choices to reduce case expressions at the first location to the given case at the second. *) let apply_pat_choices choices = - let rewrite_constraint (NC_aux (nc,l) as nconstr) = E_constraint nconstr (* - Not right now - false cases may not type check + let rec rewrite_ncs (NC_aux (nc,l) as nconstr) = + match nc with + | NC_set (kid,is) -> begin + match List.assoc l choices with + | choice,max,_ -> + NC_aux ((if choice < max then NC_true else NC_false), Generated l) + | exception Not_found -> nconstr + end + | NC_and (nc1,nc2) -> begin + match rewrite_ncs nc1, rewrite_ncs nc2 with + | NC_aux (NC_false,l), _ + | _, NC_aux (NC_false,l) -> NC_aux (NC_false,l) + | nc1,nc2 -> NC_aux (NC_and (nc1,nc2),l) + end + | _ -> nconstr + in + let rec rewrite_assert_cond (E_aux (e,(l,ann)) as exp) = match List.assoc l choices with - | choice,_ -> begin - match nc with - | NC_set (kid,is) -> - E_constraint (NC_aux ((if choice < List.length is then NC_true else NC_false), Generated l)) - | _ -> E_constraint nconstr - end - | exception Not_found -> E_constraint nconstr*) + | choice,max,_ -> + E_aux (E_lit (L_aux ((if choice < max then L_true else L_false (* wildcard *)), + Generated l)),(Generated l,ann)) + | exception Not_found -> + match e with + | E_constraint nc -> E_aux (E_constraint (rewrite_ncs nc),(l,ann)) + | E_app (Id_aux (Id "and_bool",andl), [e1;e2]) -> + E_aux (E_app (Id_aux (Id "and_bool",andl), + [rewrite_assert_cond e1; + rewrite_assert_cond e2]),(l,ann)) + | _ -> exp + in + let rewrite_assert (e1,e2) = + E_assert (rewrite_assert_cond e1, e2) in let rewrite_case (e,cases) = match List.assoc (exp_loc e) choices with - | choice,subst -> + | choice,max,subst -> (match List.nth cases choice with | Pat_aux (Pat_exp (p,E_aux (e,_)),_) -> let dummyannot = (Generated Unknown,None) in @@ -929,10 +1010,11 @@ let apply_pat_choices choices = in let open Rewriter in fold_exp { id_exp_alg with - e_constraint = rewrite_constraint; + e_assert = rewrite_assert; e_case = rewrite_case } -let split_defs continue_anyway splits defs = +let split_defs all_errors splits defs = + let no_errors_happened = ref true in let split_constructors (Defs defs) = let sc_type_union q (Tu_aux (tu,l) as tua) = match tu with @@ -1374,27 +1456,29 @@ let split_defs continue_anyway splits defs = let error = Err_general (pat_l, ("Cannot split type " ^ string_of_typ typ ^ " for variable " ^ v ^ ": " ^ msg)) - in if continue_anyway - then (print_error error; [P_aux (P_id var,(pat_l,annot)),[],[]]) + in if all_errors + then (no_errors_happened := false; + print_error error; + [P_aux (P_id var,(pat_l,annot)),[],[],[]]) else raise (Fatal_error error) in match ty with | Typ_id (Id_aux (Id "bool",_)) -> - [P_aux (P_lit (L_aux (L_true,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_true,new_l)),(new_l,annot))],[]; - P_aux (P_lit (L_aux (L_false,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_false,new_l)),(new_l,annot))],[]] + [P_aux (P_lit (L_aux (L_true,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_true,new_l)),(new_l,annot))],[],[]; + P_aux (P_lit (L_aux (L_false,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_false,new_l)),(new_l,annot))],[],[]] | Typ_id id -> (try (* enumerations *) let ns = Env.get_enum id env in List.map (fun n -> (P_aux (P_id (renew_id n),(l,annot)), - [var,E_aux (E_id (renew_id n),(new_l,annot))],[])) ns + [var,E_aux (E_id (renew_id n),(new_l,annot))],[],[])) ns with Type_error _ -> match id with | Id_aux (Id "bit",_) -> List.map (fun b -> P_aux (P_lit (L_aux (b,new_l)),(l,annot)), - [var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))],[]) + [var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))],[],[]) [L_zero; L_one] | _ -> cannot ("don't know about type " ^ string_of_id id)) @@ -1404,25 +1488,26 @@ let split_defs continue_anyway splits defs = let lits = make_vectors (Big_int.to_int sz) in List.map (fun lit -> P_aux (P_lit lit,(l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))],[]) lits + [var,E_aux (E_lit lit,(new_l,annot))],[],[]) lits | _ -> cannot ("length not constant, " ^ string_of_nexp len) ) (* set constrained numbers *) | Typ_app (Id_aux (Id "atom",_), [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (value,_) as nexp),_)]) -> begin - let mk_lit i = + let mk_lit kid i = let lit = L_aux (L_num i,new_l) in P_aux (P_lit lit,(l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))],[] + [var,E_aux (E_lit lit,(new_l,annot))],[], + match kid with None -> [] | Some k -> [(k,nconstant i)] in match value with - | Nexp_constant i -> [mk_lit i] + | Nexp_constant i -> [mk_lit None i] | Nexp_var kvar -> let ncs = Env.get_constraints env in let nc = List.fold_left nc_and nc_true ncs in (match extract_set_nc l kvar nc with - | (is,_) -> List.map mk_lit is + | (is,_) -> List.map (mk_lit (Some kvar)) is | exception Reporting_basic.Fatal_error (Reporting_basic.Err_general (_,msg)) -> cannot msg) | _ -> cannot ("unsupport atom nexp " ^ string_of_nexp nexp) end @@ -1456,32 +1541,32 @@ let split_defs continue_anyway splits defs = | h::t -> let t' = match list f t with - | None -> [t,[],[]] + | None -> [t,[],[],[]] | Some t' -> t' in let h' = match f h with - | None -> [h,[],[]] + | None -> [h,[],[],[]] | Some ps -> ps in Some (List.concat - (List.map (fun (h,hsubs,hpchoices) -> - List.map (fun (t,tsubs,tpchoices) -> - (h::t, hsubs@tsubs, hpchoices@tpchoices)) t') h')) + (List.map (fun (h,hsubs,hpchoices,hksubs) -> + List.map (fun (t,tsubs,tpchoices,tksubs) -> + (h::t, hsubs@tsubs, hpchoices@tpchoices, hksubs@tksubs)) t') h')) in let rec spl (P_aux (p,(l,annot))) = let relist f ctx ps = optmap (list f ps) (fun ps -> - List.map (fun (ps,sub,pchoices) -> P_aux (ctx ps,(l,annot)),sub,pchoices) ps) + List.map (fun (ps,sub,pchoices,ksub) -> P_aux (ctx ps,(l,annot)),sub,pchoices,ksub) ps) in let re f p = optmap (spl p) - (fun ps -> List.map (fun (p,sub,pchoices) -> (P_aux (f p,(l,annot)), sub, pchoices)) ps) + (fun ps -> List.map (fun (p,sub,pchoices,ksub) -> (P_aux (f p,(l,annot)), sub, pchoices, ksub)) ps) in let fpat (FP_aux ((FP_Fpat (id,p),annot))) = optmap (spl p) - (fun ps -> List.map (fun (p,sub,pchoices) -> FP_aux (FP_Fpat (id,p), annot), sub, pchoices) ps) + (fun ps -> List.map (fun (p,sub,pchoices,ksub) -> FP_aux (FP_Fpat (id,p), annot), sub, pchoices, ksub) ps) in match p with | P_lit _ @@ -1503,18 +1588,29 @@ let split_defs continue_anyway splits defs = literal as normal, but perform a more careful transformation otherwise *) | Some (Some (pats,l)) -> + let max = List.length pats - 1 in Some (List.mapi (fun i p -> match p with | P_aux (P_lit lit,(pl,pannot)) when (match lit with L_aux (L_undef,_) -> false | _ -> true) -> - p,[id,E_aux (E_lit lit,(Generated pl,pannot))],[l,(i,[])] + let orig_typ = Env.base_typ_of (env_of_annot (l,annot)) (typ_of_annot (l,annot)) in + let kid_subst = match lit, orig_typ with + | L_aux (L_num i,_), + Typ_aux + (Typ_app (Id_aux (Id "atom",_), + [Typ_arg_aux (Typ_arg_nexp + (Nexp_aux (Nexp_var var,_)),_)]),_) -> + [var,nconstant i] + | _ -> [] + in + p,[id,E_aux (E_lit lit,(Generated pl,pannot))],[l,(i,max,[])],kid_subst | _ -> let p',subst = freshen_pat_bindings p in match p' with | P_aux (P_wild,_) -> - P_aux (P_id id,(l,annot)),[],[l,(i,subst)] + P_aux (P_id id,(l,annot)),[],[l,(i,max,subst)],[] | _ -> - P_aux (P_as (p',id),(l,annot)),[],[l,(i,subst)]) + P_aux (P_as (p',id),(l,annot)),[],[l,(i,max,subst)],[]) pats) ) | P_app (id,ps) -> @@ -1533,10 +1629,10 @@ let split_defs continue_anyway splits defs = match spl p1, spl p2 with | None, None -> None | p1', p2' -> - let p1' = match p1' with None -> [p1,[],[]] | Some p1' -> p1' in - let p2' = match p2' with None -> [p2,[],[]] | Some p2' -> p2' in - let ps = List.map (fun (p1',subs1,pchoices1) -> List.map (fun (p2',subs2,pchoices2) -> - P_aux (P_cons (p1',p2'),(l,annot)),subs1@subs2,pchoices1@pchoices2) p2') p1' in + let p1' = match p1' with None -> [p1,[],[],[]] | Some p1' -> p1' in + let p2' = match p2' with None -> [p2,[],[],[]] | Some p2' -> p2' in + let ps = List.map (fun (p1',subs1,pchoices1,ksub1) -> List.map (fun (p2',subs2,pchoices2,ksub2) -> + P_aux (P_cons (p1',p2'),(l,annot)),subs1@subs2,pchoices1@pchoices2,ksub1@ksub2) p2') p1' in Some (List.concat ps) in spl p in @@ -1604,8 +1700,9 @@ let split_defs continue_anyway splits defs = let error = Err_general (l, "Case split is too large (" ^ string_of_int size ^ " > limit " ^ string_of_int size_set_limit ^ ")") - in if continue_anyway - then (print_error error; false) + in if all_errors + then (no_errors_happened := false; + print_error error; false) else raise (Fatal_error error) else true in @@ -1677,9 +1774,11 @@ let split_defs continue_anyway splits defs = | NoSplit -> nosplit | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then - List.map (fun (pat',substs,pchoices) -> - let exp' = subst_exp substs e in + List.map (fun (pat',substs,pchoices,ksubsts) -> + let exp' = nexp_subst_exp (kbindings_from_list ksubsts) e in + let exp' = subst_exp substs exp' in let exp' = apply_pat_choices pchoices exp' in + let exp' = stop_at_false_assertions exp' in Pat_aux (Pat_exp (pat', map_exp exp'),l)) patsubsts else nosplit @@ -1695,11 +1794,14 @@ let split_defs continue_anyway splits defs = | NoSplit -> nosplit | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then - List.map (fun (pat',substs,pchoices) -> - let exp1' = subst_exp substs e1 in + List.map (fun (pat',substs,pchoices,ksubsts) -> + let exp1' = nexp_subst_exp (kbindings_from_list ksubsts) e1 in + let exp1' = subst_exp substs exp1' in let exp1' = apply_pat_choices pchoices exp1' in - let exp2' = subst_exp substs e2 in + let exp2' = nexp_subst_exp (kbindings_from_list ksubsts) e2 in + let exp2' = subst_exp substs exp2' in let exp2' = apply_pat_choices pchoices exp2' in + let exp2' = stop_at_false_assertions exp2' in Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) patsubsts else nosplit @@ -1757,7 +1859,8 @@ let split_defs continue_anyway splits defs = in Defs (List.concat (List.map map_def defs)) in - map_locs splits defs' + let defs'' = map_locs splits defs' in + !no_errors_happened, defs'' @@ -1902,7 +2005,7 @@ let rewrite_size_parameters env (Defs defs) = pat_exp = (fun ((sp,pat),(s,e)) -> KidSet.diff s (tyvars_bound_in_pat pat), Pat_exp (pat,e))} pexp) in - let sizes_funcl fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = + let exposed_sizes_funcl fnsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = let sizes = size_vars pexp in let pat,guard,exp,pannot = destruct_pexp pexp in let visible_tyvars = @@ -1911,6 +2014,10 @@ let rewrite_size_parameters env (Defs defs) = (Pretty_print_lem.lem_tyvars_of_typ (typ_of exp)) in let expose_tyvars = KidSet.diff sizes visible_tyvars in + KidSet.union fnsizes expose_tyvars + in + let sizes_funcl expose_tyvars fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = + let pat,guard,exp,pannot = destruct_pexp pexp in let parameters = match pat with | P_aux (P_tup ps,_) -> ps | _ -> [pat] @@ -1942,13 +2049,16 @@ let rewrite_size_parameters env (Defs defs) = let to_change = List.sort ik_compare to_change in match Bindings.find id fsizes with | old -> if List.for_all2 (fun x y -> ik_compare x y = 0) old to_change then fsizes else + let str l = String.concat "," (List.map (fun (i,k) -> string_of_int i ^ "." ^ string_of_kid k) l) in raise (Reporting_basic.err_general l - ("Different size type variables in different clauses of " ^ string_of_id id)) + ("Different size type variables in different clauses of " ^ string_of_id id ^ + " old: " ^ str old ^ " new: " ^ str to_change)) | exception Not_found -> Bindings.add id to_change fsizes in let sizes_def fsizes = function | DEF_fundef (FD_aux (FD_function (_,_,_,funcls),_)) -> - List.fold_left sizes_funcl fsizes funcls + let expose_tyvars = List.fold_left exposed_sizes_funcl KidSet.empty funcls in + List.fold_left (sizes_funcl expose_tyvars) fsizes funcls | _ -> fsizes in let fn_sizes = List.fold_left sizes_def Bindings.empty defs in @@ -2098,6 +2208,17 @@ module ArgSplits = Map.Make (struct end) type arg_splits = match_detail ArgSplits.t +(* Function id, funcl loc for adding splits on sizes in the body when + there's no corresponding argument *) +module ExtraSplits = Map.Make (struct + type t = id * Parse_ast.l + let compare (id,l) (id',l') = + let x = Id.compare id id' in + if x <> 0 then x else + compare l l' +end) +type extra_splits = (match_detail KBindings.t) ExtraSplits.t + (* Arguments that we should look at in callers *) module CallerArgSet = Set.Make (struct type t = id * int @@ -2124,8 +2245,7 @@ module StringSet = Set.Make (struct end) type dependencies = - | Have of arg_splits * CallerArgSet.t * CallerKidSet.t - (* args to split inside fn * caller args to split * caller kids that are bitvector parameters *) + | Have of arg_splits * extra_splits | Unknown of Parse_ast.l * string let string_of_match_detail = function @@ -2138,6 +2258,26 @@ let string_of_argsplits s = string_of_id id ^ "." ^ string_of_loc l ^ string_of_match_detail detail) (ArgSplits.bindings s)) +let string_of_lx lx = + let open Lexing in + Printf.sprintf "%s,%d,%d,%d" lx.pos_fname lx.pos_lnum lx.pos_bol lx.pos_cnum + +let rec simple_string_of_loc = function + | Parse_ast.Unknown -> "Unknown" + | Parse_ast.Int (s,None) -> "Int(" ^ s ^ ",None)" + | Parse_ast.Int (s,Some l) -> "Int(" ^ s ^ ",Some("^simple_string_of_loc l^"))" + | Parse_ast.Generated l -> "Generated(" ^ simple_string_of_loc l ^ ")" + | Parse_ast.Range (lx1,lx2) -> "Range(" ^ string_of_lx lx1 ^ "->" ^ string_of_lx lx2 ^ ")" + +let string_of_extra_splits s = + String.concat ", " + (List.map (fun ((id,l),ks) -> + string_of_id id ^ "." ^ simple_string_of_loc l ^ ":" ^ + (String.concat "," (List.map (fun (kid,detail) -> + string_of_kid kid ^ "." ^ string_of_match_detail detail) + (KBindings.bindings ks)))) + (ExtraSplits.bindings s)) + let string_of_callerset s = String.concat ", " (List.map (fun (id,arg) -> string_of_id id ^ "." ^ string_of_int arg) (CallerArgSet.elements s)) @@ -2147,31 +2287,40 @@ let string_of_callerkidset s = (CallerKidSet.elements s)) let string_of_dep = function - | Have (args,callset,kidset) -> - "Have (" ^ string_of_argsplits args ^ "; " ^ string_of_callerset callset ^ "; " ^ - string_of_callerkidset kidset ^ ")" + | Have (args,extras) -> + "Have (" ^ string_of_argsplits args ^ ";" ^ string_of_extra_splits extras ^ ")" | Unknown (l,msg) -> "Unknown " ^ msg ^ " at " ^ Reporting_basic.loc_to_string l +(* If a callee uses a type variable as a size, does it need to be split in the + current function, or is it also a parameter? (Note that there may be multiple + calls, so more than one parameter can be involved) *) +type call_dep = + | InFun of dependencies + | Parents of CallerKidSet.t + (* Result of analysing the body of a function. The split field gives - the arguments to split based on the body alone, and the failures - field where we couldn't do anything. The other fields are used at - the end for the interprocedural phase. *) + the arguments to split based on the body alone, the extra_splits + field where we want to case split on a size type variable but + there's no corresponding argument so we introduce a case + expression, and the failures field where we couldn't do anything. + The other fields are used at the end for the interprocedural + phase. *) type result = { split : arg_splits; + extra_splits : extra_splits; failures : StringSet.t Failures.t; - (* Dependencies for arguments and type variables of each fn called, so that + (* Dependencies for type variables of each fn called, so that if the fn uses one for a bitvector size we can track it back *) - split_on_call : (dependencies list * dependencies KBindings.t) Bindings.t; (* (arguments, kids) per fn *) - split_in_caller : CallerArgSet.t; + split_on_call : (call_dep KBindings.t) Bindings.t; (* kids per fn *) kid_in_caller : CallerKidSet.t } let empty = { split = ArgSplits.empty; + extra_splits = ExtraSplits.empty; failures = Failures.empty; split_on_call = Bindings.empty; - split_in_caller = CallerArgSet.empty; kid_in_caller = CallerKidSet.empty } @@ -2183,39 +2332,48 @@ let merge_detail _ x y = when l1 = l2 && forall2 pat_eq ps1 ps2 -> x | _ -> Some Total +let opt_merge f _ x y = + match x,y with + | None, _ -> y + | _, None -> x + | Some x, Some y -> Some (f x y) + +let merge_extras = ExtraSplits.merge (opt_merge (KBindings.merge merge_detail)) + let dmerge x y = match x,y with | Unknown (l,s), _ -> Unknown (l,s) | _, Unknown (l,s) -> Unknown (l,s) - | Have (a,c,k), Have (a',c',k') -> - Have (ArgSplits.merge merge_detail a a', CallerArgSet.union c c', CallerKidSet.union k k') - -let dempty = Have (ArgSplits.empty, CallerArgSet.empty, CallerKidSet.empty) + | Have (args,extras), Have (args',extras') -> + Have (ArgSplits.merge merge_detail args args', + merge_extras extras extras') -let dopt_merge k x y = - match x, y with - | None, _ -> y - | _, None -> x - | Some x, Some y -> Some (dmerge x y) +let dempty = Have (ArgSplits.empty, ExtraSplits.empty) let dep_bindings_merge a1 a2 = - Bindings.merge dopt_merge a1 a2 + Bindings.merge (opt_merge dmerge) a1 a2 let dep_kbindings_merge a1 a2 = - KBindings.merge dopt_merge a1 a2 + KBindings.merge (opt_merge dmerge) a1 a2 let call_kid_merge k x y = match x, y with | None, x -> x | x, None -> x - | Some d, Some d' -> Some (dmerge d d') + | Some (InFun deps), Some (Parents _) + | Some (Parents _), Some (InFun deps) + -> Some (InFun deps) + | Some (InFun deps), Some (InFun deps') + -> Some (InFun (dmerge deps deps')) + | Some (Parents fns), Some (Parents fns') + -> Some (Parents (CallerKidSet.union fns fns')) let call_arg_merge k args args' = match args, args' with | None, x -> x | x, None -> x - | Some (args,kdep), Some (args',kdep') - -> Some (List.map2 dmerge args args', KBindings.merge call_kid_merge kdep kdep') + | Some kdep, Some kdep' + -> Some (KBindings.merge call_kid_merge kdep kdep') let failure_merge _ x y = match x, y with @@ -2225,9 +2383,9 @@ let failure_merge _ x y = let merge rs rs' = { split = ArgSplits.merge merge_detail rs.split rs'.split; + extra_splits = merge_extras rs.extra_splits rs'.extra_splits; failures = Failures.merge failure_merge rs.failures rs'.failures; split_on_call = Bindings.merge call_arg_merge rs.split_on_call rs'.split_on_call; - split_in_caller = CallerArgSet.union rs.split_in_caller rs'.split_in_caller; kid_in_caller = CallerKidSet.union rs.kid_in_caller rs'.kid_in_caller } @@ -2316,11 +2474,14 @@ let rec deps_of_nc kid_deps (NC_aux (nc,l)) = let deps_of_typ kid_deps arg_deps typ = deps_of_tyvars kid_deps arg_deps (tyvars_of_typ typ) -let deps_of_uvar kid_deps arg_deps = function - | U_nexp nexp -> deps_of_nexp kid_deps arg_deps nexp +let deps_of_uvar fn_id env arg_deps = function + | U_nexp (Nexp_aux (Nexp_var kid,_)) + when List.exists (fun k -> Kid.compare kid k == 0) env.top_kids -> + Parents (CallerKidSet.singleton (fn_id,kid)) + | U_nexp nexp -> InFun (deps_of_nexp env.kid_deps arg_deps nexp) | U_order _ - | U_effect _ -> dempty - | U_typ typ -> deps_of_typ kid_deps arg_deps typ + | U_effect _ -> InFun dempty + | U_typ typ -> InFun (deps_of_typ env.kid_deps arg_deps typ) let mk_subrange_pattern vannot vstart vend = let (_,len,ord,typ) = vector_typ_args_of (Env.base_typ_of (env_of_annot vannot) (typ_of_annot vannot)) in @@ -2358,27 +2519,27 @@ let mk_subrange_pattern vannot vstart vend = let refine_dependency env (E_aux (e,(l,annot)) as exp) pexps = let check_dep id ctx = match Bindings.find id env.var_deps with - | Have (args,callargs,callkids) -> - if CallerArgSet.is_empty callargs && CallerKidSet.is_empty callkids then - match ArgSplits.bindings args with - | [(id',loc),Total] when Id.compare id id' == 0 -> - (match Util.map_all (function - | Pat_aux (Pat_exp (pat,_),_) -> Some (ctx pat) - | Pat_aux (Pat_when (_,_,_),_) -> None) pexps - with - | Some pats -> - if l = Parse_ast.Unknown then - (Reporting_basic.print_error - (Reporting_basic.Err_general - (l, "No location for pattern match: " ^ string_of_exp exp)); - None) - else - Some (Have (ArgSplits.singleton (id,loc) (Partial (pats,l)),callargs,callkids)) - | None -> None) - | _ -> None - else None - | Unknown _ -> None - | exception Not_found -> None + | Have (args,extras) -> begin + match ArgSplits.bindings args, ExtraSplits.bindings extras with + | [(id',loc),Total], [] when Id.compare id id' == 0 -> + (match Util.map_all (function + | Pat_aux (Pat_exp (pat,_),_) -> Some (ctx pat) + | Pat_aux (Pat_when (_,_,_),_) -> None) pexps + with + | Some pats -> + if l = Parse_ast.Unknown then + (Reporting_basic.print_error + (Reporting_basic.Err_general + (l, "No location for pattern match: " ^ string_of_exp exp)); + None) + else + Some (Have (ArgSplits.singleton (id,loc) (Partial (pats,l)), + ExtraSplits.empty)) + | None -> None) + | _ -> None + end + | Unknown _ -> None + | exception Not_found -> None in match e with | E_id id -> check_dep id (fun x -> x) @@ -2466,15 +2627,14 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) = let kid_inst = instantiation_of exp in (* Change kids in instantiation to the canonical ones from the type signature *) let kid_inst = KBindings.fold (fun kid -> KBindings.add (orig_kid kid)) kid_inst KBindings.empty in - let kid_deps = KBindings.map (deps_of_uvar env.kid_deps deps) kid_inst in + let kid_deps = KBindings.map (deps_of_uvar fn_id env deps) kid_inst in let rdep,r' = if Id.compare fn_id id == 0 then let bad = Unknown (l,"Recursive call of " ^ string_of_id id) in - let deps = List.map (fun _ -> bad) deps in - let kid_deps = KBindings.map (fun _ -> bad) kid_deps in - bad, { empty with split_on_call = Bindings.singleton id (deps, kid_deps) } + let kid_deps = KBindings.map (fun _ -> InFun bad) kid_deps in + bad, { empty with split_on_call = Bindings.singleton id kid_deps } else - dempty, { empty with split_on_call = Bindings.singleton id (deps, kid_deps) } in + dempty, { empty with split_on_call = Bindings.singleton id kid_deps } in (merge_deps (rdep::eff_dep::deps), assigns, merge r r') | E_tuple es | E_list es -> @@ -2619,19 +2779,26 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) = let typ = Env.expand_synonyms tenv typ in if is_bitvector_typ typ then let _,size,_,_ = vector_typ_args_of typ in - let size = simplify_size_nexp env tenv size in - match deps_of_nexp env.kid_deps [] size with - | Have (args,caller,caller_kids) -> - { r with - split = ArgSplits.merge merge_detail r.split args; - split_in_caller = CallerArgSet.union r.split_in_caller caller; - kid_in_caller = CallerKidSet.union r.kid_in_caller caller_kids - } - | Unknown (l,msg) -> - { r with - failures = - Failures.add l (StringSet.singleton ("Unable to monomorphise " ^ string_of_nexp size ^ ": " ^ msg)) - r.failures } + let Nexp_aux (size,_) as size_nexp = simplify_size_nexp env tenv size in + let is_tyvar_parameter v = + List.exists (fun k -> Kid.compare k v == 0) env.top_kids + in + match size with + | Nexp_constant _ -> r + | Nexp_var v when is_tyvar_parameter v -> + { r with kid_in_caller = CallerKidSet.add (fn_id,v) r.kid_in_caller } + | _ -> + match deps_of_nexp env.kid_deps [] size_nexp with + | Have (args,extras) -> + { r with + split = ArgSplits.merge merge_detail r.split args; + extra_splits = merge_extras r.extra_splits extras + } + | Unknown (l,msg) -> + { r with + failures = + Failures.add l (StringSet.singleton ("Unable to monomorphise " ^ string_of_nexp size_nexp ^ ": " ^ msg)) + r.failures } else r in (deps, assigns, r) @@ -2662,33 +2829,36 @@ and analyse_lexp fn_id env assigns deps (LEXP_aux (lexp,_)) = | LEXP_field (lexp,_) -> analyse_lexp fn_id env assigns deps lexp -let translate_id (Id_aux (_,l) as id) = - let rec aux l = - match l with - | Range (pos,_) -> Some (id,(pos.Lexing.pos_fname,pos.Lexing.pos_lnum)) - | Generated l -> aux l - | _ -> None - in aux l +let rec translate_loc l = + match l with + | Range (pos,_) -> Some (pos.Lexing.pos_fname,pos.Lexing.pos_lnum) + | Generated l -> translate_loc l + | _ -> None -let initial_env fn_id (TypQ_aux (tq,_)) pat set_assertions = +let initial_env fn_id fn_l (TypQ_aux (tq,_)) pat set_assertions = let pats = match pat with | P_aux (P_tup pats,_) -> pats | _ -> [pat] in - let default_split annot = + (* For the type in an annotation, produce the corresponding tyvar (if any), + and a default case split (a set if there's one, a full case split if not). *) + let kid_of_annot annot = let env = env_of_annot annot in let Typ_aux (typ,_) = Env.base_typ_of env (typ_of_annot annot) in match typ with | Typ_app (Id_aux (Id "atom",_),[Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]) -> - (match KBindings.find kid set_assertions with - | (l,is) -> - let l' = Generated l in - let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n,l')),(l',snd annot))) is in - let pats = pats @ [P_aux (P_wild,(l',snd annot))] in - Partial (pats,l) - | exception Not_found -> Total) - | _ -> Total + Some kid + | _ -> None + in + let default_split annot kid = + match KBindings.find kid set_assertions with + | (l,is) -> + let l' = Generated l in + let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n,l')),(l',annot))) is in + let pats = pats @ [P_aux (P_wild,(l',annot))] in + Partial (pats,l) + | exception Not_found -> Total in let arg i pat = let rec aux (P_aux (p,(l,annot))) = @@ -2706,10 +2876,10 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat set_assertions = | P_as (pat,id) -> begin let s,v,k = aux pat in - match translate_id id with - | Some id' -> - ArgSplits.add id' Total s, - Bindings.add id (Have (ArgSplits.singleton id' Total,CallerArgSet.empty,CallerKidSet.empty)) v, + match translate_loc (id_loc id) with + | Some loc -> + ArgSplits.add (id,loc) Total s, + Bindings.add id (Have (ArgSplits.singleton (id,loc) Total, ExtraSplits.empty)) v, k | None -> s, @@ -2719,12 +2889,16 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat set_assertions = | P_typ (_,pat) -> aux pat | P_id id -> begin - match translate_id id with - | Some id' -> - let s = ArgSplits.singleton id' (default_split (l,annot)) in + match translate_loc (id_loc id) with + | Some loc -> + let kid_opt = kid_of_annot (l,annot) in + let split = Util.option_cases kid_opt (default_split annot) (fun () -> Total) in + let s = ArgSplits.singleton (id,loc) split in s, - Bindings.singleton id (Have (s,CallerArgSet.empty,CallerKidSet.empty)), - KBindings.empty + Bindings.singleton id (Have (s, ExtraSplits.empty)), + (match kid_opt with + | None -> KBindings.empty + | Some kid -> KBindings.singleton kid (Have (s, ExtraSplits.empty))) | None -> ArgSplits.empty, Bindings.singleton id (Unknown (l, ("Unable to give location for " ^ string_of_id id))), @@ -2732,7 +2906,7 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat set_assertions = end | P_var (pat, TP_var kid) -> let s,v,k = aux pat in - s,v,KBindings.add kid (Have (ArgSplits.empty,CallerArgSet.singleton (fn_id,i),CallerKidSet.empty)) k + s,v,KBindings.add kid (Have (s, ExtraSplits.empty)) k | P_app (_,pats) -> of_list pats | P_record (fpats,_) -> of_list (List.map (fun (FP_aux (FP_Fpat (_,p),_)) -> p) fpats) | P_vector pats @@ -2743,20 +2917,31 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat set_assertions = | P_cons (p1,p2) -> of_list [p1;p2] in aux pat in - let quant k = function + let quant = function | QI_aux (QI_id (KOpt_aux ((KOpt_none kid | KOpt_kind (_,kid)),_)),_) -> - KBindings.add kid (Have (ArgSplits.empty,CallerArgSet.empty,CallerKidSet.singleton (fn_id,kid))) k - | QI_aux (QI_const _,_) -> k + Some kid + | QI_aux (QI_const _,_) -> None in - let kid_quant_deps = + let top_kids = match tq with - | TypQ_no_forall -> KBindings.empty - | TypQ_tq qs -> List.fold_left quant KBindings.empty qs + | TypQ_no_forall -> [] + | TypQ_tq qs -> Util.map_filter quant qs in let _,var_deps,kid_deps = split3 (List.mapi arg pats) in let var_deps = List.fold_left dep_bindings_merge Bindings.empty var_deps in - let kid_deps = List.fold_left dep_kbindings_merge kid_quant_deps kid_deps in - let top_kids = List.map fst (KBindings.bindings kid_quant_deps) in + let kid_deps = List.fold_left dep_kbindings_merge KBindings.empty kid_deps in + let note_no_arg kid_deps kid = + if KBindings.mem kid kid_deps then kid_deps + else + (* When there's no argument to case split on for a kid, we'll add a + case expression instead *) + let env = pat_env_of pat in + let split = default_split (Some (env,int_typ,no_effect)) kid in + let extra_splits = ExtraSplits.singleton (fn_id, fn_l) + (KBindings.singleton kid split) in + KBindings.add kid (Have (ArgSplits.empty, extra_splits)) kid_deps + in + let kid_deps = List.fold_left note_no_arg kid_deps top_kids in { top_kids = top_kids; var_deps = var_deps; kid_deps = kid_deps } (* When there's more than one pick the first *) @@ -2767,8 +2952,48 @@ let merge_set_asserts _ x y = let merge_set_asserts_by_kid sets1 sets2 = KBindings.merge merge_set_asserts sets1 sets2 +(* Set constraints in assertions don't always use the set syntax, so we also + handle assert('N == 1 | ...) style set constraints *) +let rec sets_from_assert e = + let set_from_or_exps (E_aux (_,(l,_)) as e) = + let mykid = ref None in + let check_kid kid = + match !mykid with + | None -> mykid := Some kid + | Some kid' -> if Kid.compare kid kid' == 0 then () + else raise Not_found + in + let rec aux (E_aux (e,_)) = + match e with + | E_app (Id_aux (Id "or_bool",_),[e1;e2]) -> + aux e1 @ aux e2 + | E_app (Id_aux (Id "eq_atom",_), + [E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); + E_aux (E_lit (L_aux (L_num i,_)),_)]) -> + (check_kid kid; [i]) + | _ -> raise Not_found + in try + let is = aux e in + match !mykid with + | None -> KBindings.empty + | Some kid -> KBindings.singleton kid (l,is) + with Not_found -> KBindings.empty + in + let rec sets_from_nc (NC_aux (nc,l)) = + match nc with + | NC_and (nc1,nc2) -> merge_set_asserts_by_kid (sets_from_nc nc1) (sets_from_nc nc2) + | NC_set (kid,is) -> KBindings.singleton kid (l,is) + | _ -> KBindings.empty + in + match e with + | E_aux (E_app (Id_aux (Id "and_bool",_),[e1;e2]),_) -> + merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2) + | E_aux (E_constraint nc,_) -> sets_from_nc nc + | _ -> set_from_or_exps e + (* Find all the easily reached set assertions in a function body, to use as - case splits *) + case splits. Note that this should be mirrored in stop_at_false_assertions, + above. *) let rec find_set_assertions (E_aux (e,_)) = match e with | E_block es @@ -2781,11 +3006,7 @@ let rec find_set_assertions (E_aux (e,_)) = let kbound = kids_bound_by_pat p in let sets2 = KBindings.filter (fun kid _ -> not (KidSet.mem kid kbound)) sets2 in merge_set_asserts_by_kid sets1 sets2 - | E_assert (E_aux (e1,_),_) -> begin - match e1 with - | E_constraint (NC_aux (NC_set (kid,is),l)) -> KBindings.singleton kid (l,is) - | _ -> KBindings.empty - end + | E_assert (exp1,_) -> sets_from_assert exp1 | _ -> KBindings.empty let print_set_assertions set_assertions = @@ -2802,18 +3023,20 @@ let print_set_assertions set_assertions = let print_result r = let _ = print_endline (" splits: " ^ string_of_argsplits r.split) in let print_kbinding kid dep = - let _ = print_endline (" " ^ string_of_kid kid ^ ": " ^ string_of_dep dep) in + let s = match dep with + | InFun dep -> "InFun " ^ string_of_dep dep + | Parents cks -> string_of_callerkidset cks + in + let _ = print_endline (" " ^ string_of_kid kid ^ ": " ^ s) in () in - let print_binding id (deps,kdep) = + let print_binding id kdep = let _ = print_endline (" " ^ string_of_id id ^ ":") in - let _ = List.iter (fun dep -> print_endline (" " ^ string_of_dep dep)) deps in let _ = KBindings.iter print_kbinding kdep in () in let _ = print_endline " split_on_call: " in let _ = Bindings.iter print_binding r.split_on_call in - let _ = print_endline (" split_in_caller: " ^ string_of_callerset r.split_in_caller) in let _ = print_endline (" kid_in_caller: " ^ string_of_callerkidset r.kid_in_caller) in let _ = print_endline (" failures: \n " ^ (String.concat "\n " @@ -2822,17 +3045,24 @@ let print_result r = (Failures.bindings r.failures)))) in () -let analyse_funcl debug tenv (FCL_aux (FCL_Funcl (id,pexp),_)) = +let analyse_funcl debug tenv (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = let _ = if debug > 2 then print_endline (string_of_id id) else () in let pat,guard,body,_ = destruct_pexp pexp in let (tq,_) = Env.get_val_spec id tenv in let set_assertions = find_set_assertions body in let _ = if debug > 2 then print_set_assertions set_assertions in - let aenv = initial_env id tq pat set_assertions in + let aenv = initial_env id l tq pat set_assertions in let _,_,r = analyse_exp id aenv Bindings.empty body in let r = match guard with | None -> r | Some exp -> let _,_,r' = analyse_exp id aenv Bindings.empty exp in + let r' = + if ExtraSplits.is_empty r'.extra_splits + then r' + else merge r' { empty with failures = + Failures.singleton l (StringSet.singleton + "Case splitting size tyvars in guards not supported") } + in merge r r' in let _ = if debug > 2 then print_result r else () @@ -2844,65 +3074,110 @@ let analyse_def debug env = function | _ -> empty +let detail_to_split = function + | Total -> None + | Partial (pats,l) -> Some (pats,l) + +let argset_to_list splits = + let l = ArgSplits.bindings splits in + let argelt = function + | ((id,(file,loc)),detail) -> ((file,loc),string_of_id id,detail_to_split detail) + in + List.map argelt l + let analyse_defs debug env (Defs defs) = let r = List.fold_left (fun r d -> merge r (analyse_def debug env d)) empty defs in (* Resolve the interprocedural dependencies *) - let rec chase_deps = function - | Have (splits, caller_args, caller_kids) -> - let splits,fails = CallerArgSet.fold add_arg caller_args (splits,Failures.empty) in - let splits,fails = CallerKidSet.fold add_kid caller_kids (splits,fails) in - splits, fails + let rec separate_deps = function + | Have (splits, extras) -> + splits, extras, Failures.empty | Unknown (l,msg) -> - ArgSplits.empty , Failures.singleton l (StringSet.singleton ("Unable to monomorphise dependency: " ^ msg)) + ArgSplits.empty, ExtraSplits.empty, + Failures.singleton l (StringSet.singleton ("Unable to monomorphise dependency: " ^ msg)) and chase_kid_caller (id,kid) = match Bindings.find id r.split_on_call with - | (_,kid_deps) -> begin + | kid_deps -> begin match KBindings.find kid kid_deps with - | deps -> chase_deps deps - | exception Not_found -> ArgSplits.empty,Failures.empty + | InFun deps -> separate_deps deps + | Parents fns -> CallerKidSet.fold add_kid fns (ArgSplits.empty, ExtraSplits.empty, Failures.empty) + | exception Not_found -> ArgSplits.empty,ExtraSplits.empty,Failures.empty end - | exception Not_found -> ArgSplits.empty,Failures.empty - and chase_arg_caller (id,i) = - match Bindings.find id r.split_on_call with - | (arg_deps,_) -> chase_deps (List.nth arg_deps i) - | exception Not_found -> ArgSplits.empty,Failures.empty - and add_arg arg (splits,fails) = - let splits',fails' = chase_arg_caller arg in - ArgSplits.merge merge_detail splits splits', Failures.merge failure_merge fails fails' - and add_kid k (splits,fails) = - let splits',fails' = chase_kid_caller k in - ArgSplits.merge merge_detail splits splits', Failures.merge failure_merge fails fails' + | exception Not_found -> ArgSplits.empty,ExtraSplits.empty,Failures.empty + and add_kid k (splits,extras,fails) = + let splits',extras',fails' = chase_kid_caller k in + ArgSplits.merge merge_detail splits splits', + merge_extras extras extras', + Failures.merge failure_merge fails fails' in let _ = if debug > 1 then print_result r else () in - let splits,fails = CallerArgSet.fold add_arg r.split_in_caller (r.split,r.failures) in - let splits,fails = CallerKidSet.fold add_kid r.kid_in_caller (splits,fails) in + let splits,extras,fails = CallerKidSet.fold add_kid r.kid_in_caller (r.split,r.extra_splits,r.failures) in let _ = if debug > 0 then (print_endline "Final splits:"; - print_endline (string_of_argsplits splits)) + print_endline (string_of_argsplits splits); + print_endline (string_of_extra_splits extras)) else () in - let _ = - if Failures.is_empty fails then () else - begin - Failures.iter (fun l msgs -> - Reporting_basic.print_err false false l "Monomorphisation" (String.concat "\n" (StringSet.elements msgs))) - fails; - raise (Reporting_basic.err_general Unknown "Unable to monomorphise program") - end - in splits + let splits = argset_to_list splits in + if Failures.is_empty fails + then (true,splits,extras) else + begin + Failures.iter (fun l msgs -> + Reporting_basic.print_err false false l "Monomorphisation" (String.concat "\n" (StringSet.elements msgs))) + fails; + (false, splits,extras) + end -let argset_to_list splits = - let l = ArgSplits.bindings splits in - let argelt = function - | ((id,(file,loc)),Total) -> ((file,loc),string_of_id id,None) - | ((id,(file,loc)),Partial (pats,l)) -> ((file,loc),string_of_id id,Some (pats,l)) - in - List.map argelt l end +let fresh_sz_var = + let counter = ref 0 in + fun () -> + let n = !counter in + let () = counter := n+1 in + mk_id ("sz#" ^ string_of_int n) + +let add_extra_splits extras (Defs defs) = + let success = ref true in + let add_to_body extras (E_aux (_,(l,annot)) as e) = + let l' = Generated l in + KBindings.fold (fun kid detail (exp,split_list) -> + let nexp = Nexp_aux (Nexp_var kid,l) in + let var = fresh_sz_var () in + let size_annot = Some (env_of e,atom_typ nexp,no_effect) in + let loc = match Analysis.translate_loc l with + | Some l -> l + | None -> + (Reporting_basic.print_err false false l "Monomorphisation" + "Internal error: bad location for added case"; + ("",0)) + in + let pexps = [Pat_aux (Pat_exp (P_aux (P_id var,(l,size_annot)),exp),(l',annot))] in + E_aux (E_case (E_aux (E_sizeof nexp, (l',size_annot)), pexps),(l',annot)), + ((loc, string_of_id var, Analysis.detail_to_split detail)::split_list) + ) extras (e,[]) + in + let add_to_funcl (FCL_aux (FCL_Funcl (id,Pat_aux (pexp,peannot)),(l,annot))) = + let pexp, splits = + match Analysis.ExtraSplits.find (id,l) extras with + | extras -> + (match pexp with + | Pat_exp (p,e) -> let e',sp = add_to_body extras e in Pat_exp (p,e'), sp + | Pat_when (p,g,e) -> let e',sp = add_to_body extras e in Pat_when (p,g,e'), sp) + | exception Not_found -> pexp, [] + in FCL_aux (FCL_Funcl (id,Pat_aux (pexp,peannot)),(l,annot)), splits + in + let add_to_def = function + | DEF_fundef (FD_aux (FD_function (re,ta,ef,funcls),annot)) -> + let funcls,splits = List.split (List.map add_to_funcl funcls) in + DEF_fundef (FD_aux (FD_function (re,ta,ef,funcls),annot)), List.concat splits + | d -> d, [] + in + let defs, splits = List.split (List.map add_to_def defs) in + !success, Defs defs, List.concat splits + module MonoRewrites = struct @@ -3125,6 +3400,7 @@ type options = { rewrites : bool; rewrite_size_parameters : bool; all_split_errors : bool; + continue_anyway : bool; dump_raw: bool } @@ -3144,13 +3420,26 @@ let monomorphise opts splits env defs = else (defs,env) in (*let _ = Pretty_print.pp_defs stdout defs in*) - let new_splits = + let ok_analysis, new_splits, extra_splits = if opts.auto - then Analysis.argset_to_list (Analysis.analyse_defs opts.debug_analysis env defs) - else [] in + then + let f,r,ex = Analysis.analyse_defs opts.debug_analysis env defs in + if f || opts.all_split_errors || opts.continue_anyway + then f, r, ex + else raise (Reporting_basic.err_general Unknown "Unable to monomorphise program") + else true, [], Analysis.ExtraSplits.empty in let splits = new_splits @ (List.map (fun (loc,id) -> (loc,id,None)) splits) in - let defs = split_defs opts.all_split_errors splits defs in - (* TODO: stop if opts.all_split_errors && something went wrong *) + let ok_extras, defs, extra_splits = add_extra_splits extra_splits defs in + let splits = splits @ extra_splits in + let () = if ok_extras || opts.all_split_errors || opts.continue_anyway + then () + else raise (Reporting_basic.err_general Unknown "Unable to monomorphise program") + in + let ok_split, defs = split_defs opts.all_split_errors splits defs in + let () = if (ok_analysis && ok_extras && ok_split) || opts.continue_anyway + then () + else raise (Reporting_basic.err_general Unknown "Unable to monomorphise program") + in (* TODO: currently doing this because constant propagation leaves numeric literals as int, try to avoid this later; also use final env for DEF_spec case above, because the type checker doesn't store the env at that point :( *) diff --git a/src/monomorphise.mli b/src/monomorphise.mli index 11713511..3e561e32 100644 --- a/src/monomorphise.mli +++ b/src/monomorphise.mli @@ -54,6 +54,7 @@ type options = { rewrites : bool; (* Experimental rewrites for variable-sized operations *) rewrite_size_parameters : bool; (* Make implicit type parameters explicit for (e.g.) lem *) all_split_errors : bool; + continue_anyway : bool; dump_raw: bool } diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 40d373b2..350b5388 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -1210,7 +1210,8 @@ let doc_mutrec_lem = function let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) = match fcls with | [] -> failwith "FD_function with empty function list" - | FCL_aux (FCL_Funcl(id,_),_) :: _ + (* TODO: Move splitting of execute function to the rewriter *) + (*| FCL_aux (FCL_Funcl(id,_),_) :: _ when string_of_id id = "execute" (*|| string_of_id id = "initial_analysis"*) -> let (_,auxiliary_functions,clauses) = List.fold_left @@ -1273,7 +1274,7 @@ let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) = auxiliary_functions ^^ hardline ^^ hardline ^^ (prefix 2 1) ((separate space) [string "let" ^^ doc_rec_lem false r ^^ doc_id_lem id;equals;string "function"]) - (clauses ^/^ string "end") + (clauses ^/^ string "end")*) | FCL_aux (FCL_Funcl(id,_),annot) :: _ when not (Env.is_extern id (env_of_annot annot) "lem") -> string "let" ^^ (doc_rec_lem (List.length fcls > 1) r) ^^ (doc_fundef_rhs_lem fd) @@ -1495,15 +1496,15 @@ let pp_defs_lem (types_file,types_modules) (defs_file,defs_modules) d top_line = if !opt_sequential then concat [regstate_def; hardline; hardline; - string ("type MR 'a 'r = State.MR regstate 'a 'r " ^ exc_typ); hardline; - string ("type M 'a = State.M regstate 'a " ^ exc_typ); hardline; + string ("type MR 'a 'r = State_monad.MR regstate 'a 'r " ^ exc_typ); hardline; + string ("type M 'a = State_monad.M regstate 'a " ^ exc_typ); hardline; hardline; register_refs ] else concat [ - string ("type MR 'a 'r = Prompt.MR 'a 'r " ^ exc_typ); hardline; - string ("type M 'a = Prompt.M 'a " ^ exc_typ); hardline + string ("type MR 'a 'r = Prompt_monad.MR 'a 'r " ^ exc_typ); hardline; + string ("type M 'a = Prompt_monad.M 'a " ^ exc_typ); hardline ] ]); (print defs_file) diff --git a/src/process_file.ml b/src/process_file.ml index 1ba8069f..1da893c3 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -93,7 +93,7 @@ let cond_pragma defs = else else_defs := (def :: !else_defs) in - + let rec scan = function | Parse_ast.DEF_pragma ("endif", _, _) :: defs when !depth = 0 -> (List.rev !then_defs, List.rev !else_defs, defs) @@ -108,13 +108,13 @@ let cond_pragma defs = | [] -> failwith "$ifdef or $ifndef never ended" in scan defs - + let rec preprocess = function | [] -> [] | Parse_ast.DEF_pragma ("define", symbol, _) :: defs -> symbols := StringSet.add symbol !symbols; preprocess defs - + | Parse_ast.DEF_pragma ("ifndef", symbol, _) :: defs -> let then_defs, else_defs, defs = cond_pragma defs in if not (StringSet.mem symbol !symbols) then @@ -128,7 +128,7 @@ let rec preprocess = function preprocess (then_defs @ defs) else preprocess (else_defs @ defs) - + | Parse_ast.DEF_pragma ("include", file, l) :: defs -> let len = String.length file in if len = 0 then @@ -151,18 +151,18 @@ let rec preprocess = function let file = Filename.concat sail_dir ("lib/" ^ file) in let (Parse_ast.Defs include_defs) = parse_file file in let include_defs = preprocess include_defs in - include_defs @ preprocess defs + include_defs @ preprocess defs else let help = "Make sure the filename is surrounded by quotes or angle brackets" in (Util.warn ("Skipping bad $include " ^ file ^ ". " ^ help); preprocess defs) | Parse_ast.DEF_pragma (p, arg, _) :: defs -> - (Util.warn ("Bad pragma $" ^ p ^ " " ^ arg); preprocess defs) - + (Util.warn ("Bad pragma $" ^ p ^ " " ^ arg); preprocess defs) + | def :: defs -> def :: preprocess defs let preprocess_ast (Parse_ast.Defs defs) = Parse_ast.Defs (preprocess defs) - + let convert_ast (order : Ast.order) (defs : Parse_ast.defs) : unit Ast.defs = Initial_check.process_ast order defs let load_file_no_check order f = convert_ast order (preprocess_ast (parse_file f)) @@ -188,6 +188,7 @@ let opt_dmono_analysis = ref 0 let opt_auto_mono = ref false let opt_mono_rewrites = ref false let opt_dall_split_errors = ref false +let opt_dmono_continue = ref false let monomorphise_ast locs type_env ast = let open Monomorphise in @@ -197,6 +198,7 @@ let monomorphise_ast locs type_env ast = rewrites = !opt_mono_rewrites; rewrite_size_parameters = !Pretty_print_lem.opt_mwords; all_split_errors = !opt_dall_split_errors; + continue_anyway = !opt_dmono_continue; dump_raw = !opt_ddump_raw_mono_ast } in monomorphise opts locs type_env ast @@ -226,16 +228,22 @@ let output_lem filename libs defs = let generated_line = generated_line filename in let seq_suffix = if !Pretty_print_lem.opt_sequential then "_sequential" else "" in let types_module = (filename ^ "_embed_types" ^ seq_suffix) in - let monad_module = if !Pretty_print_lem.opt_sequential then "State" else "Prompt" in - let operators_module = "Sail_operators" (* if !Pretty_print_lem.opt_mwords then "Sail_operators_mwords" else "Sail_operators" *) in + let monad_modules = + if !Pretty_print_lem.opt_sequential + then ["State_monad"; "State"] + else ["Prompt_monad"; "Prompt"] in + let operators_module = + if !Pretty_print_lem.opt_mwords + then "Sail_operators_mwords" + else "Sail_operators_bitlists" in let libs = List.map (fun lib -> lib ^ seq_suffix) libs in let base_imports = [ "Pervasives_extra"; "Sail_impl_base"; "Sail_values"; - operators_module; - monad_module - ] in + operators_module + ] @ monad_modules + in let ((ot,_, _) as ext_ot) = open_output_with_check_unformatted (filename ^ "_embed_types" ^ seq_suffix ^ ".lem") in let ((o,_, _) as ext_o) = diff --git a/src/process_file.mli b/src/process_file.mli index d8094682..54415621 100644 --- a/src/process_file.mli +++ b/src/process_file.mli @@ -72,6 +72,7 @@ val opt_ddump_rewrite_ast : ((string * int) option) ref val opt_dno_cast : bool ref val opt_ddump_raw_mono_ast : bool ref val opt_dmono_analysis : int ref +val opt_dmono_continue : bool ref val opt_auto_mono : bool ref val opt_mono_rewrites : bool ref val opt_dall_split_errors : bool ref diff --git a/src/rewrites.ml b/src/rewrites.ml index e38169cc..6146e73b 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1203,43 +1203,50 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) = let guard_bitvector_pat = let collect_guards_decls ps rootid t = let (start,_,ord,_) = vector_typ_args_of t in - let rec collect current (guards,dls) idx ps = - let idx' = if is_order_inc ord then Big_int.add idx (Big_int.of_int 1) else Big_int.sub idx (Big_int.of_int 1) in - (match ps with - | pat :: ps' -> - (match pat with - | P_aux (P_lit lit, (l,annot)) -> - let e = E_aux (E_lit lit, (gen_loc l, annot)) in - let current' = (match current with - | Some (l,i,j,lits) -> Some (l,i,idx,lits @ [e]) - | None -> Some (l,idx,idx,[e])) in - collect current' (guards, dls) idx' ps' - | P_aux (P_as (pat',id), (l,annot)) -> - let dl = letbind_bit_exp rootid l t idx id in - collect current (guards, dls @ [dl]) idx (pat' :: ps') - | _ -> - let dls' = (match pat with - | P_aux (P_id id, (l,annot)) -> - dls @ [letbind_bit_exp rootid l t idx id] - | _ -> dls) in - let guards' = (match current with - | Some (l,i,j,lits) -> - guards @ [Some (test_subvec_exp rootid l t i j lits)] - | None -> guards) in - collect None (guards', dls') idx' ps') - | [] -> - let guards' = (match current with - | Some (l,i,j,lits) -> - guards @ [Some (test_subvec_exp rootid l t i j lits)] - | None -> guards) in - (guards',dls)) in - let (guards,dls) = match start with - | Nexp_aux (Nexp_constant s, _) -> - collect None ([],[]) s ps + let start_idx = match start with + | Nexp_aux (Nexp_constant s, _) -> s | _ -> - let (P_aux (_, (l,_))) = pat in raise (Reporting_basic.err_unreachable l "guard_bitvector_pat called on pattern with non-constant start index") in + let add_bit_pat (idx, current, guards, dls) pat = + let idx' = + if is_order_inc ord + then Big_int.add idx (Big_int.of_int 1) + else Big_int.sub idx (Big_int.of_int 1) in + let ids = fst (fold_pat + { (compute_pat_alg IdSet.empty IdSet.union) with + p_id = (fun id -> IdSet.singleton id, P_id id); + p_as = (fun ((ids, pat), id) -> IdSet.add id ids, P_as (pat, id)) } + pat) in + let lits = fst (fold_pat + { (compute_pat_alg [] (@)) with + p_aux = (fun ((lits, paux), (l, annot)) -> + let lits = match paux with + | P_lit lit -> E_aux (E_lit lit, (l, annot)) :: lits + | _ -> lits in + lits, P_aux (paux, (l, annot))) } + pat) in + let add_letbind id dls = dls @ [letbind_bit_exp rootid l t idx id] in + let dls' = IdSet.fold add_letbind ids dls in + let current', guards' = + match current with + | Some (l, i, j, lits') -> + if lits = [] + then None, guards @ [Some (test_subvec_exp rootid l t i j lits')] + else Some (l, i, idx, lits' @ lits), guards + | None -> + begin + match lits with + | E_aux (_, (l, _)) :: _ -> Some (l, idx, idx, lits), guards + | [] -> None, guards + end + in + (idx', current', guards', dls') in + let (_, final, guards, dls) = List.fold_left add_bit_pat (start_idx, None, [], []) ps in + let guards = match final with + | Some (l,i,j,lits) -> + guards @ [Some (test_subvec_exp rootid l t i j lits)] + | None -> guards in let (decls,letbinds) = List.split dls in (compose_guards guards, List.fold_right (@@) decls, letbinds) in @@ -2951,10 +2958,10 @@ let rewrite_defs_lem = [ ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats); ("remove_numeral_pats", rewrite_defs_remove_numeral_pats); ("guarded_pats", rewrite_defs_guarded_pats); - ("exp_lift_assign", rewrite_defs_exp_lift_assign); (* ("register_ref_writes", rewrite_register_ref_writes); *) ("fix_val_specs", rewrite_fix_val_specs); ("recheck_defs", recheck_defs); + ("exp_lift_assign", rewrite_defs_exp_lift_assign); (* ("constraint", rewrite_constraint); *) (* ("remove_assert", rewrite_defs_remove_assert); *) ("top_sort_defs", top_sort_defs); @@ -3021,15 +3028,25 @@ let rewrite_check_annot = try prerr_endline ("CHECKING: " ^ string_of_exp exp ^ " : " ^ string_of_typ (typ_of exp)); let _ = check_exp (env_of exp) (strip_exp exp) (typ_of exp) in - (if not (alpha_equivalent (env_of exp) (typ_of exp) (Env.expand_synonyms (env_of exp) (typ_of exp))) - then raise (Reporting_basic.err_typ Parse_ast.Unknown "Found synonym in annotation") + let typ1 = typ_of exp in + let typ2 = Env.expand_synonyms (env_of exp) (typ_of exp) in + (if not (alpha_equivalent (env_of exp) typ1 typ2) + then raise (Reporting_basic.err_typ Parse_ast.Unknown + ("Found synonym in annotation " ^ string_of_typ typ1 ^ " vs " ^ string_of_typ typ2)) else ()); exp with Type_error (l, err) -> raise (Reporting_basic.err_typ l (string_of_type_error err)) in + let check_pat pat = + prerr_endline ("CHECKING PAT: " ^ string_of_pat pat ^ " : " ^ string_of_typ (pat_typ_of pat)); + let _, _ = bind_pat_no_guard (pat_env_of pat) (strip_pat pat) (pat_typ_of pat) in + pat + in + let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> check_annot (E_aux (exp, annot))) } in - rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } + rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp); + rewrite_pat = (fun _ -> check_pat) } let rewrite_defs_check = [ ("check_annotations", rewrite_check_annot); diff --git a/src/sail.ml b/src/sail.ml index bbe26a0d..dac2f841 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -152,6 +152,9 @@ let options = Arg.align ([ ( "-dall_split_errors", Arg.Set Process_file.opt_dall_split_errors, " display all case split errors from monomorphisation, rather than one"); + ( "-dmono_continue", + Arg.Set Process_file.opt_dmono_continue, + " continue despite monomorphisation errors"); ( "-verbose", Arg.Set opt_print_verbose, " (debug) pretty-print the input to standard output"); @@ -256,7 +259,8 @@ let main() = (if !(opt_print_c) then let ast_c = rewrite_ast_c ast in - C_backend.compile_ast type_envs ast_c + let ast_c, type_envs = Specialize.specialize ast_c type_envs in + C_backend.compile_ast (C_backend.initial_ctx type_envs) ast_c else ()); (if !(opt_print_lem) then let ast_lem = rewrite_ast_lem ast in diff --git a/src/sail_lib.ml b/src/sail_lib.ml index 7a8dc88c..4d6e32bc 100644 --- a/src/sail_lib.ml +++ b/src/sail_lib.ml @@ -176,6 +176,31 @@ let add_int (x, y) = Big_int.add x y let sub_int (x, y) = Big_int.sub x y let mult (x, y) = Big_int.mul x y let quotient (x, y) = Big_int.div x y + +(* Big_int does not provide divide with rounding towards zero so roll + our own, assuming that division of positive integers rounds down *) +let quot_round_zero (x, y) = + let posX = Big_int.greater_equal x Big_int.zero in + let posY = Big_int.greater_equal y Big_int.zero in + let absX = Big_int.abs x in + let absY = Big_int.abs y in + let q = Big_int.div absX absY in + if posX != posY then + Big_int.negate q + else + q + +(* The corresponding remainder function for above just respects the sign of x *) +let rem_round_zero (x, y) = + let posX = Big_int.greater_equal x Big_int.zero in + let absX = Big_int.abs x in + let absY = Big_int.abs y in + let r = Big_int.modulus absX absY in + if posX then + r + else + Big_int.negate r + let modulus (x, y) = Big_int.modulus x y let negate x = Big_int.negate x @@ -425,6 +450,7 @@ let round_up x = failwith "round_up" (* Num.big_int_of_num (Num.ceiling_num x) * let quotient_real (x, y) = Rational.div x y let mult_real (x, y) = Rational.mul x y (* Num.mult_num x y *) let real_power (x, y) = failwith "real_power" (* Num.power_num x (Num.num_of_big_int y) *) +let int_power (x, y) = Big_int.pow_int x (Big_int.to_int y) let add_real (x, y) = Rational.add x y let sub_real (x, y) = Rational.sub x y diff --git a/src/trace_viewer/.gitignore b/src/trace_viewer/.gitignore deleted file mode 100644 index c1f9aea6..00000000 --- a/src/trace_viewer/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -*~ -*.js -*.js.map - -# Dependencies -node_modules/ diff --git a/src/trace_viewer/List-add.svg b/src/trace_viewer/List-add.svg deleted file mode 100644 index f8031599..00000000 --- a/src/trace_viewer/List-add.svg +++ /dev/null @@ -1,56 +0,0 @@ -<?xml version="1.0" encoding="UTF-8" standalone="no"?> -<!-- Created with Inkscape (http://www.inkscape.org/) --> -<svg id="svg6431" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" xmlns="http://www.w3.org/2000/svg" sodipodi:docname="list-add.svg" xmlns:sodipodi="http://inkscape.sourceforge.net/DTD/sodipodi-0.dtd" height="48px" sodipodi:version="0.32" width="48px" xmlns:cc="http://web.resource.org/cc/" xmlns:xlink="http://www.w3.org/1999/xlink" sodipodi:docbase="/home/jimmac/src/cvs/tango-icon-theme/scalable/actions" xmlns:dc="http://purl.org/dc/elements/1.1/"> - <defs id="defs6433"> - <linearGradient id="linearGradient4975" y2="48.548" gradientUnits="userSpaceOnUse" x2="45.919" gradientTransform="translate(-18.018 -13.571)" y1="36.423" x1="34.893"> - <stop id="stop1324" stop-color="#729fcf" offset="0"/> - <stop id="stop1326" stop-color="#5187d6" offset="1"/> - </linearGradient> - <linearGradient id="linearGradient7922" y2="34.977" gradientUnits="userSpaceOnUse" x2="27.901" y1="22.852" x1="16.875"> - <stop id="stop7918" stop-color="#fff" offset="0"/> - <stop id="stop7920" stop-color="#fff" stop-opacity=".34021" offset="1"/> - </linearGradient> - <radialGradient id="radialGradient2097" gradientUnits="userSpaceOnUse" cy="35.127" cx="23.071" gradientTransform="matrix(.91481 .012650 -.0082150 .21356 2.2539 27.189)" r="10.319"> - <stop id="stop2093" offset="0"/> - <stop id="stop2095" stop-opacity="0" offset="1"/> - </radialGradient> - </defs> - <sodipodi:namedview id="base" bordercolor="#666666" pagecolor="#ffffff" showgrid="false" borderopacity="0.15686275" showguides="true"/> - <metadata id="metadata6436"> - <rdf:RDF> - <cc:Work rdf:about=""> - <dc:format>image/svg+xml</dc:format> - <dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/> - <dc:title>Add</dc:title> - <dc:date>2006-01-04</dc:date> - <dc:creator> - <cc:Agent> - <dc:title>Andreas Nilsson</dc:title> - </cc:Agent> - </dc:creator> - <dc:source>http://tango-project.org</dc:source> - <dc:subject> - <rdf:Bag> - <rdf:li>add</rdf:li> - <rdf:li>plus</rdf:li> - </rdf:Bag> - </dc:subject> - <cc:license rdf:resource="http://creativecommons.org/licenses/by-sa/2.0/"/> - </cc:Work> - <cc:License rdf:about="http://creativecommons.org/licenses/by-sa/2.0/"> - <cc:permits rdf:resource="http://web.resource.org/cc/Reproduction"/> - <cc:permits rdf:resource="http://web.resource.org/cc/Distribution"/> - <cc:requires rdf:resource="http://web.resource.org/cc/Notice"/> - <cc:requires rdf:resource="http://web.resource.org/cc/Attribution"/> - <cc:permits rdf:resource="http://web.resource.org/cc/DerivativeWorks"/> - <cc:requires rdf:resource="http://web.resource.org/cc/ShareAlike"/> - </cc:License> - </rdf:RDF> - </metadata> - <g id="layer1"> - <path id="path1361" sodipodi:rx="10.319340" sodipodi:ry="2.3201940" sodipodi:type="arc" d="m33.278 34.941a10.319 2.3202 0 1 1 -20.638 0 10.319 2.3202 0 1 1 20.638 0z" opacity=".2" transform="matrix(1.5505 0 0 1.293 -11.597 -8.1782)" sodipodi:cy="34.940620" sodipodi:cx="22.958872" fill="url(#radialGradient2097)"/> - <path id="text1314" d="m27.514 37.543v-9.027l9.979-0.04v-6.996h-9.97l-0.009-9.96-7.016 0.011 0.005 9.931-9.99 0.074-0.036 6.969 10.034-0.029 0.007 9.04 6.996 0.027z" sodipodi:nodetypes="ccccccccccccc" stroke="#3465a4" stroke-width="1px" fill="#75a1d0"/> - <path id="path7076" opacity=".40860" d="m26.499 36.534v-9.034h10.002l-0.006-5.025h-9.987v-9.995l-4.995 0.018 0.009 9.977-10.026 0.018-0.027 4.973 10.064 0.009-0.013 9.028 4.979 0.031z" sodipodi:nodetypes="ccccccccccccc" stroke="url(#linearGradient7922)" stroke-width="1px" fill="url(#linearGradient4975)"/> - <path id="path7914" opacity=".31183" d="m11 25c0 1.938 25.984-0.969 25.984-0.031v-3l-9.984 0.031v-9.965h-6v9.965h-10v3z" fill-rule="evenodd" sodipodi:nodetypes="ccccccccc" fill="#fff"/> - </g> -</svg> diff --git a/src/trace_viewer/List-remove.svg b/src/trace_viewer/List-remove.svg deleted file mode 100644 index 18c9a135..00000000 --- a/src/trace_viewer/List-remove.svg +++ /dev/null @@ -1,117 +0,0 @@ -<?xml version="1.0" encoding="UTF-8" standalone="no"?> -<!-- Created with Inkscape (http://www.inkscape.org/) --> -<svg xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://web.resource.org/cc/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns:sodipodi="http://inkscape.sourceforge.net/DTD/sodipodi-0.dtd" xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" width="48px" height="48px" id="svg6431" sodipodi:version="0.32" inkscape:version="0.43+devel" sodipodi:docbase="/home/jimmac/src/cvs/tango-icon-theme/scalable/actions" sodipodi:docname="list-remove.svg"> - <defs id="defs6433"> - <linearGradient inkscape:collect="always" id="linearGradient2091"> - <stop style="stop-color:#000000;stop-opacity:1;" offset="0" id="stop2093"/> - <stop style="stop-color:#000000;stop-opacity:0;" offset="1" id="stop2095"/> - </linearGradient> - <radialGradient inkscape:collect="always" xlink:href="#linearGradient2091" id="radialGradient2097" cx="23.070683" cy="35.127438" fx="23.070683" fy="35.127438" r="10.319340" gradientTransform="matrix(0.914812,1.265023e-2,-8.21502e-3,0.213562,2.253914,27.18889)" gradientUnits="userSpaceOnUse"/> - <linearGradient id="linearGradient7916"> - <stop style="stop-color:#ffffff;stop-opacity:1;" offset="0" id="stop7918"/> - <stop style="stop-color:#ffffff;stop-opacity:0.34020618;" offset="1.0000000" id="stop7920"/> - </linearGradient> - <linearGradient inkscape:collect="always" id="linearGradient8662"> - <stop style="stop-color:#000000;stop-opacity:1;" offset="0" id="stop8664"/> - <stop style="stop-color:#000000;stop-opacity:0;" offset="1" id="stop8666"/> - </linearGradient> - <radialGradient inkscape:collect="always" xlink:href="#linearGradient8662" id="radialGradient1503" gradientUnits="userSpaceOnUse" gradientTransform="matrix(1.000000,0.000000,0.000000,0.536723,-1.018989e-13,16.87306)" cx="24.837126" cy="36.421127" fx="24.837126" fy="36.421127" r="15.644737"/> - <linearGradient inkscape:collect="always" id="linearGradient2847"> - <stop style="stop-color:#3465a4;stop-opacity:1;" offset="0" id="stop2849"/> - <stop style="stop-color:#3465a4;stop-opacity:0;" offset="1" id="stop2851"/> - </linearGradient> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2847" id="linearGradient1488" gradientUnits="userSpaceOnUse" gradientTransform="matrix(-1.000000,0.000000,0.000000,-1.000000,-1.242480,40.08170)" x1="37.128052" y1="29.729605" x2="37.065414" y2="26.194071"/> - <linearGradient id="linearGradient2831"> - <stop style="stop-color:#3465a4;stop-opacity:1;" offset="0" id="stop2833"/> - <stop id="stop2855" offset="0.33333334" style="stop-color:#5b86be;stop-opacity:1;"/> - <stop style="stop-color:#83a8d8;stop-opacity:0;" offset="1" id="stop2835"/> - </linearGradient> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2831" id="linearGradient1486" gradientUnits="userSpaceOnUse" gradientTransform="translate(-48.30498,-6.043298)" x1="13.478554" y1="10.612206" x2="15.419417" y2="19.115122"/> - <linearGradient id="linearGradient2380"> - <stop style="stop-color:#b9cfe7;stop-opacity:1" offset="0" id="stop2382"/> - <stop style="stop-color:#729fcf;stop-opacity:1" offset="1" id="stop2384"/> - </linearGradient> - <linearGradient id="linearGradient2682"> - <stop style="stop-color:#3977c3;stop-opacity:1;" offset="0" id="stop2684"/> - <stop style="stop-color:#89aedc;stop-opacity:0;" offset="1" id="stop2686"/> - </linearGradient> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2682" id="linearGradient2688" x1="36.713837" y1="31.455952" x2="37.124462" y2="24.842253" gradientUnits="userSpaceOnUse" gradientTransform="translate(-48.77039,-5.765705)"/> - <linearGradient inkscape:collect="always" id="linearGradient2690"> - <stop style="stop-color:#c4d7eb;stop-opacity:1;" offset="0" id="stop2692"/> - <stop style="stop-color:#c4d7eb;stop-opacity:0;" offset="1" id="stop2694"/> - </linearGradient> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2690" id="linearGradient2696" x1="32.647972" y1="30.748846" x2="37.124462" y2="24.842253" gradientUnits="userSpaceOnUse" gradientTransform="translate(-48.77039,-5.765705)"/> - <linearGradient inkscape:collect="always" id="linearGradient2871"> - <stop style="stop-color:#3465a4;stop-opacity:1;" offset="0" id="stop2873"/> - <stop style="stop-color:#3465a4;stop-opacity:1" offset="1" id="stop2875"/> - </linearGradient> - <linearGradient id="linearGradient2402"> - <stop style="stop-color:#729fcf;stop-opacity:1;" offset="0" id="stop2404"/> - <stop style="stop-color:#528ac5;stop-opacity:1;" offset="1" id="stop2406"/> - </linearGradient> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2797" id="linearGradient1493" gradientUnits="userSpaceOnUse" x1="5.9649176" y1="26.048164" x2="52.854097" y2="26.048164"/> - <linearGradient inkscape:collect="always" id="linearGradient2797"> - <stop style="stop-color:#ffffff;stop-opacity:1;" offset="0" id="stop2799"/> - <stop style="stop-color:#ffffff;stop-opacity:0;" offset="1" id="stop2801"/> - </linearGradient> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2797" id="linearGradient1491" gradientUnits="userSpaceOnUse" x1="5.9649176" y1="26.048164" x2="52.854097" y2="26.048164"/> - <linearGradient inkscape:collect="always" id="linearGradient7179"> - <stop style="stop-color:#ffffff;stop-opacity:1;" offset="0" id="stop7181"/> - <stop style="stop-color:#ffffff;stop-opacity:0;" offset="1" id="stop7183"/> - </linearGradient> - <linearGradient id="linearGradient2316"> - <stop style="stop-color:#000000;stop-opacity:1;" offset="0" id="stop2318"/> - <stop style="stop-color:#ffffff;stop-opacity:0.65979379;" offset="1" id="stop2320"/> - </linearGradient> - <linearGradient id="linearGradient1322"> - <stop id="stop1324" offset="0.0000000" style="stop-color:#729fcf"/> - <stop id="stop1326" offset="1.0000000" style="stop-color:#5187d6;stop-opacity:1.0000000;"/> - </linearGradient> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient1322" id="linearGradient4975" x1="34.892849" y1="36.422989" x2="45.918697" y2="48.547989" gradientUnits="userSpaceOnUse" gradientTransform="translate(-18.01785,-13.57119)"/> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient7179" id="linearGradient7185" x1="13.435029" y1="13.604306" x2="22.374878" y2="23.554308" gradientUnits="userSpaceOnUse"/> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient7179" id="linearGradient7189" gradientUnits="userSpaceOnUse" x1="13.435029" y1="13.604306" x2="22.374878" y2="23.554308" gradientTransform="matrix(-1.000000,0.000000,0.000000,-1.000000,47.93934,50.02474)"/> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2380" id="linearGradient7180" gradientUnits="userSpaceOnUse" x1="62.513836" y1="36.061237" x2="15.984863" y2="20.60858"/> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2871" id="linearGradient7182" gradientUnits="userSpaceOnUse" x1="46.834816" y1="45.264122" x2="45.380436" y2="50.939667"/> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2402" id="linearGradient7184" gradientUnits="userSpaceOnUse" x1="18.935766" y1="23.667896" x2="53.588622" y2="26.649362"/> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient2871" id="linearGradient7186" gradientUnits="userSpaceOnUse" x1="46.834816" y1="45.264122" x2="45.380436" y2="50.939667"/> - <linearGradient inkscape:collect="always" xlink:href="#linearGradient7916" id="linearGradient7922" x1="16.874998" y1="22.851799" x2="27.900846" y2="34.976799" gradientUnits="userSpaceOnUse"/> - </defs> - <sodipodi:namedview id="base" pagecolor="#ffffff" bordercolor="#666666" borderopacity="0.10980392" inkscape:pageopacity="0.0" inkscape:pageshadow="2" inkscape:zoom="1" inkscape:cx="38.727739" inkscape:cy="26.974252" inkscape:current-layer="layer1" showgrid="false" inkscape:grid-bbox="true" inkscape:document-units="px" inkscape:window-width="1280" inkscape:window-height="949" inkscape:window-x="380" inkscape:window-y="79" inkscape:showpageshadow="false"/> - <metadata id="metadata6436"> - <rdf:RDF> - <cc:Work rdf:about=""> - <dc:format>image/svg+xml</dc:format> - <dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/> - <dc:title>Remove</dc:title> - <dc:date>2006-01-04</dc:date> - <dc:creator> - <cc:Agent> - <dc:title>Andreas Nilsson</dc:title> - </cc:Agent> - </dc:creator> - <dc:source>http://tango-project.org</dc:source> - <dc:subject> - <rdf:Bag> - <rdf:li>remove</rdf:li> - <rdf:li>delete</rdf:li> - </rdf:Bag> - </dc:subject> - <cc:license rdf:resource="http://creativecommons.org/licenses/by-sa/2.0/"/> - </cc:Work> - <cc:License rdf:about="http://creativecommons.org/licenses/by-sa/2.0/"> - <cc:permits rdf:resource="http://web.resource.org/cc/Reproduction"/> - <cc:permits rdf:resource="http://web.resource.org/cc/Distribution"/> - <cc:requires rdf:resource="http://web.resource.org/cc/Notice"/> - <cc:requires rdf:resource="http://web.resource.org/cc/Attribution"/> - <cc:permits rdf:resource="http://web.resource.org/cc/DerivativeWorks"/> - <cc:requires rdf:resource="http://web.resource.org/cc/ShareAlike"/> - </cc:License> - </rdf:RDF> - </metadata> - <g id="layer1" inkscape:label="Layer 1" inkscape:groupmode="layer"> - <path style="font-size:59.901077px;font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;text-align:start;line-height:125.00000%;writing-mode:lr-tb;text-anchor:start;fill:#75a1d0;fill-opacity:1.0000000;stroke:#3465a4;stroke-width:1.0000004px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1.0000000;font-family:Bitstream Vera Sans" d="M 27.514356,28.359472 L 39.633445,28.475543 L 39.633445,21.480219 L 27.523285,21.480219 L 20.502546,21.462362 L 8.5441705,21.489147 L 8.5084565,28.457686 L 20.511475,28.475543 L 27.514356,28.359472 z " id="text1314" sodipodi:nodetypes="ccccccccc"/> - <path style="font-size:59.901077px;font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;text-align:start;line-height:125.00000%;writing-mode:lr-tb;text-anchor:start;opacity:0.40860215;fill:url(#linearGradient4975);fill-opacity:1.0000000;stroke:url(#linearGradient7922);stroke-width:1.0000006px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1.0000000;font-family:Bitstream Vera Sans" d="M 38.579429,27.484113 L 38.588357,22.475309 L 9.5267863,22.493166 L 9.5000003,27.466256 L 38.579429,27.484113 z " id="path7076" sodipodi:nodetypes="ccccc"/> - <path style="fill:#ffffff;fill-opacity:1.0000000;fill-rule:evenodd;stroke:none;stroke-width:1.0000000px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1.0000000;opacity:0.31182796" d="M 9.0000000,25.000000 C 9.0000000,26.937500 39.125000,24.062500 39.125000,25.000000 L 39.125000,22.000000 L 9.0000000,22.000000 L 9.0000000,25.000000 z " id="path7914" sodipodi:nodetypes="ccccc"/> - <path sodipodi:type="arc" style="opacity:0.10439561;fill:url(#radialGradient2097);fill-opacity:1;stroke:none;stroke-width:3;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" id="path1361" sodipodi:cx="22.958872" sodipodi:cy="34.94062" sodipodi:rx="10.31934" sodipodi:ry="2.320194" d="M 33.278212 34.94062 A 10.31934 2.320194 0 1 1 12.639532,34.94062 A 10.31934 2.320194 0 1 1 33.278212 34.94062 z" transform="matrix(2.32573,0,0,1.293,-29.39613,-8.178198)" inkscape:r_cx="true" inkscape:r_cy="true"/> - </g> -</svg>
\ No newline at end of file diff --git a/src/trace_viewer/README b/src/trace_viewer/README deleted file mode 100644 index 547a1435..00000000 --- a/src/trace_viewer/README +++ /dev/null @@ -1,11 +0,0 @@ - -To use, first make sure node.js and npm are installed (e.g. via the -Ubuntu package manager), then run the following in this directory: - -> npm install - -> npm run tsc - -> ./node_modules/.bin/electron . - -and point the file selector at a trace produced by sail -ocaml_trace
\ No newline at end of file diff --git a/src/trace_viewer/index.css b/src/trace_viewer/index.css deleted file mode 100644 index 35ebcb23..00000000 --- a/src/trace_viewer/index.css +++ /dev/null @@ -1,86 +0,0 @@ - -body { - background-color: #202020; - color: #DCDCCC; - font-family: monospace; - font-size: 14pt; - font-weight: bold; -} - -img { - height: 30px; -} - -#control { - position: fixed; - bottom: 0px; - left:10%; - right:10%; - width:80%; -} - -#command { - font-size: 16pt; - width: 100%; -} - -.call { - background-color: #313131; - border: 1px; - border-left: 5px; - border-color: rgb(118, 173, 160); - border-style: solid; - padding-top: 2px; - padding-bottom: 2px; - margin: 0px; - min-height: 32px; - display: flex; - align-items: center; -} - -.write { - background-color: #313131; - border: 1px; - border-left: 5px; - border-color: rgb(255, 40, 40); - border-style: solid; - padding-top: 2px; - padding-bottom: 2px; - margin: 0px; - min-height: 32px; - display: flex; - align-items: center; -} - -.load { - background-color: #313131; - color: white; - border: 1px; - border-left: 5px; - border-color: #ff9100; - border-style: solid; - padding-top: 2px; - padding-bottom: 2px; - margin: 0px; - min-height: 32px; - display: flex; - align-items: center; -} - -.read { - background-color: #313131; - border: 1px; - border-left: 5px; - border-color: rgb(107, 199, 47); - border-style: solid; - padding-top: 2px; - padding-bottom: 2px; - margin: 0px; - min-height: 32px; - display: flex; - align-items: center; -} - -.tree { - padding-left: 20px; -}
\ No newline at end of file diff --git a/src/trace_viewer/index.html b/src/trace_viewer/index.html deleted file mode 100644 index 9efcca56..00000000 --- a/src/trace_viewer/index.html +++ /dev/null @@ -1,19 +0,0 @@ -<!DOCTYPE html> -<html> - <head> - <meta charset="UTF-8"> - <title>Sail Trace Viewer</title> - <script type="text/javascript"> - var exports = {} - </script> - <script type="text/javascript" src="index.js"></script> - <link rel="stylesheet" type="text/css" href="index.css"> - </head> - <body> - <div id="container"> - </div> - <div id="control"> - <input type="text" id="command"> - </div> - </body> -</html>
\ No newline at end of file diff --git a/src/trace_viewer/index.ts b/src/trace_viewer/index.ts deleted file mode 100644 index f9b5041b..00000000 --- a/src/trace_viewer/index.ts +++ /dev/null @@ -1,287 +0,0 @@ -import {remote} from "electron" -import fs = require("fs") -const dialog = remote.dialog -const app = remote.app - -let topCallDiv = document.createElement("div") - -const max_arg_length = 5000 - -abstract class Event { - caller: Call - - protected div: HTMLDivElement | null = null - - public hide(): void { - if (this.div != null) { - this.div.remove() - this.div = null - } - } - - protected abstract showText(text: HTMLParagraphElement): void - - public show(): HTMLDivElement { - let callerDiv: HTMLDivElement = (this.caller != null) ? this.caller.show() : topCallDiv - - if (this.div != null) { - return this.div - } else { - this.div = document.createElement("div") - this.div.className = "tree" - callerDiv.appendChild(this.div) - let text = document.createElement("p") - this.showText(text) - this.div.appendChild(text) - return this.div - } - } -} - -class Load extends Event { - loc: string - val: string - - constructor(loc: string, val: string) { - super() - this.loc = loc - this.val = val - } - - protected showText(text: HTMLParagraphElement): void { - text.className = "load" - text.insertAdjacentText('beforeend', this.loc + " " + this.val) - } -} - -class Read extends Event { - reg: string - value: string - - constructor(reg: string, value: string) { - super() - this.reg = reg - this.value = value - } - - public showText(text: HTMLParagraphElement): void { - text.className = "read" - text.insertAdjacentText('beforeend', this.reg + " " + this.value) - } -} - -class Write extends Event { - reg: string - value: string - - constructor(reg: string, value: string) { - super() - this.reg = reg - this.value = value - } - - public showText(text: HTMLParagraphElement): void { - text.className = "write" - text.insertAdjacentText('beforeend', this.reg + " " + this.value) - } -} - -class Call { - fn: string - arg: string - ret: string - callees: (Call | Event)[] = [] - caller: Call - - private div: HTMLDivElement | null = null - - private toggle: boolean = false - private toggleImg: HTMLImageElement | null = null - - constructor(fn: string, arg: string, ret: string) { - this.fn = fn - this.arg = arg - this.ret = ret - } - - public expand() { - if (this.caller != undefined) { - this.caller.expand() - } - this.showChildren() - } - - public iter(f: (call: Call) => void): void { - f(this) - this.callees.forEach((callee) => { - if (callee instanceof Call) { callee.iter(f) } - }) - - } - - public show(): HTMLDivElement { - let callerDiv: HTMLDivElement = (this.caller != null) ? this.caller.show() : topCallDiv - - if (this.div != null) { - return this.div - } else { - this.div = document.createElement("div") - this.div.className = "tree" - callerDiv.appendChild(this.div) - let text = document.createElement("p") - text.className = "call" - if (this.callees.length > 0) { - this.toggleImg = document.createElement("img") - this.toggleImg.src = "List-add.svg" - this.toggleImg.addEventListener('click', () => { - if (this.toggle) { - this.hideChildren() - } else { - this.showChildren() - } - }) - text.appendChild(this.toggleImg) - } - this.toggle = false - let display_arg = this.arg - if (this.arg.length > max_arg_length) { - display_arg = this.arg.slice(0, max_arg_length) - } - let display_ret = this.ret - if (this.ret.length > max_arg_length) { - display_ret = this.ret.slice(0, max_arg_length) - } - - text.insertAdjacentText('beforeend', this.fn + " " + display_arg + " -> " + display_ret) - this.div.appendChild(text) - return this.div - } - } - - public hide(): void { - if (this.toggle == true) { - this.hideChildren() - } - - if (this.div != null) { - this.div.remove() - this.div = null - } - if (this.toggleImg != null) { - this.toggleImg.remove() - this.toggleImg = null - } - } - - public hideChildren(): void { - this.callees.forEach(call => { - call.hide() - }) - - if (this.toggleImg != null) { - this.toggleImg.src = "List-add.svg" - this.toggle = false - } else { - alert("this.toggleImg was null!") - } - } - - public showChildren(): void { - this.callees.forEach(call => { - call.show() - }); - - if (this.toggleImg != null) { - this.toggleImg.src = "List-remove.svg" - this.toggle = true - } else { - alert("this.toggleImg was null!") - } - } - - public appendChild(child: Call | Write | Read | Load): void { - child.caller = this - - this.callees.push(child) - } -} - -document.addEventListener('DOMContentLoaded', () => { - let rootCall = new Call("ROOT", "", "") - topCallDiv.id = "root" - document.getElementById("container")!.appendChild(topCallDiv) - - let commandInput = document.getElementById("command") as HTMLInputElement - - commandInput.addEventListener("keydown", (event) => { - if(event.keyCode == 13) { - let cmd = commandInput.value.split(" ") - commandInput.value = "" - - if (cmd[0] == "function") { - rootCall.iter((call) => { - if (call.fn == cmd[1]) { call.caller.expand() } - }) - } - } - }) - - let files = dialog.showOpenDialog(remote.getCurrentWindow(), {title: "Select log file", defaultPath: app.getAppPath()}) - - if (files == [] || files == undefined) { - dialog.showErrorBox("Error", "No file selected") - app.exit(1) - } - - fs.readFile(files[0], 'utf-8', (err, data) => { - if (err) { - dialog.showErrorBox("Error", "An error occurred when reading the log: " + err.message) - app.exit(1) - } - - let lines = data.split("\n") - // let indents = lines.map(line => line.search(/[^\s]/) / 2) - lines = lines.map(line => line.trim()) - - let stack : Call[] = [rootCall] - - lines.forEach(line => { - if (line.match(/^Call:/)) { - let words = line.slice(6).split(" ") - let call = new Call(words[0], words.slice(1).join(" "), "") - if (stack.length > 0) { - stack[stack.length - 1].appendChild(call) - } - stack.push(call) - } else if (line.match(/^Return:/)) { - let call = stack.pop() - if (call == undefined) { - alert("Unbalanced return") - app.exit(1) - } else { - call.ret = line.slice(8) - } - } else if (line.match(/^Write:/)) { - let words = line.slice(7).split(" ") - let write = new Write(words[0], words.slice(1).join(" ")) - if (stack.length > 0) { - stack[stack.length - 1].appendChild(write) - } - } else if (line.match(/^Read:/)) { - let words = line.slice(6).split(" ") - let read = new Read(words[0], words.slice(1).join(" ")) - if (stack.length > 0) { - stack[stack.length - 1].appendChild(read) - } - } else if (line.match(/^Load:/)) { - let words = line.slice(6).split(" ") - let load = new Load(words[0], words[1]) - if (stack.length > 0) { - stack[stack.length - 1].appendChild(load) - } - } - }) - - rootCall.show() - }) -})
\ No newline at end of file diff --git a/src/trace_viewer/main.ts b/src/trace_viewer/main.ts deleted file mode 100644 index 5cc33452..00000000 --- a/src/trace_viewer/main.ts +++ /dev/null @@ -1,12 +0,0 @@ -import {app, BrowserWindow} from 'electron' - -let win : BrowserWindow | null = null - -app.on('ready', () => { - win = new BrowserWindow({width: 1920, height: 1200}) - win.loadURL('file://' + __dirname + '/index.html') - //win.webContents.openDevTools() - win.on('close', () => { - win = null - }) -})
\ No newline at end of file diff --git a/src/trace_viewer/package.json b/src/trace_viewer/package.json deleted file mode 100644 index e3a88d30..00000000 --- a/src/trace_viewer/package.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "name": "trace_viewer", - "version": "1.0.0", - "description": "", - "main": "main.js", - "scripts": { - "test": "echo \"Error: no test specified\" && exit 1", - "tsc": "./node_modules/typescript/bin/tsc" - }, - "devDependencies": { - "@types/node": "^8.0.46", - "electron": "1.7.9", - "typescript": "^2.5.3" - } -} diff --git a/src/trace_viewer/tsconfig.json b/src/trace_viewer/tsconfig.json deleted file mode 100644 index e66156b3..00000000 --- a/src/trace_viewer/tsconfig.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "compileOnSave": true, - "compilerOptions": { - "target": "es5", - "moduleResolution": "node", - "pretty": true, - "newLine": "LF", - "allowSyntheticDefaultImports": true, - "strict": true, - "noUnusedLocals": true, - "noUnusedParameters": true, - "sourceMap": true, - "strictNullChecks": true, - "skipLibCheck": true, - "allowJs": true, - "jsx": "preserve" - } -}
\ No newline at end of file diff --git a/src/type_check.ml b/src/type_check.ml index 9b704a90..2fcfb309 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -746,6 +746,9 @@ end = struct let add_local id mtyp env = begin wf_typ env (snd mtyp); + if Bindings.mem id env.top_val_specs then + typ_error (id_loc id) ("Local variable " ^ string_of_id id ^ " is already bound as a function name") + else (); typ_print ("Adding local binding " ^ string_of_id id ^ " :: " ^ string_of_mtyp mtyp); { env with locals = Bindings.add id mtyp env.locals } end @@ -1283,6 +1286,8 @@ let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b | Nexp_exp n1, Nexp_exp n2 -> nexp_identical n1 n2 | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2 + | Nexp_app (f1, args1), Nexp_app (f2, args2) when List.length args1 = List.length args2 -> + Id.compare f1 f2 = 0 && List.for_all2 nexp_identical args1 args2 | _, _ -> false let ord_identical (Ord_aux (ord1, _)) (Ord_aux (ord2, _)) = @@ -3401,16 +3406,16 @@ and propagate_lexp_effect_aux = function (* 6. Checking toplevel definitions *) (**************************************************************************) -let check_letdef env (LB_aux (letbind, (l, _))) = +let check_letdef orig_env (LB_aux (letbind, (l, _))) = begin match letbind with | LB_val (P_aux (P_typ (typ_annot, pat), _), bind) -> - let checked_bind = crule check_exp env (strip_exp bind) typ_annot in - let tpat, env = bind_pat_no_guard env (strip_pat pat) typ_annot in - [DEF_val (LB_aux (LB_val (P_aux (P_typ (typ_annot, tpat), (l, Some (env, typ_annot, no_effect))), checked_bind), (l, None)))], env + let checked_bind = crule check_exp orig_env (strip_exp bind) typ_annot in + let tpat, env = bind_pat_no_guard orig_env (strip_pat pat) typ_annot in + [DEF_val (LB_aux (LB_val (P_aux (P_typ (typ_annot, tpat), (l, Some (orig_env, typ_annot, no_effect))), checked_bind), (l, None)))], env | LB_val (pat, bind) -> - let inferred_bind = irule infer_exp env (strip_exp bind) in - let tpat, env = bind_pat_no_guard env (strip_pat pat) (typ_of inferred_bind) in + let inferred_bind = irule infer_exp orig_env (strip_exp bind) in + let tpat, env = bind_pat_no_guard orig_env (strip_pat pat) (typ_of inferred_bind) in [DEF_val (LB_aux (LB_val (tpat, inferred_bind), (l, None)))], env end diff --git a/src/util.ml b/src/util.ml index e2dc9b9f..b8670b84 100644 --- a/src/util.ml +++ b/src/util.ml @@ -389,12 +389,16 @@ let rec take n xs = match n, xs with | n, (x :: xs) -> x :: take (n - 1) xs let termcode n = "\x1B[" ^ string_of_int n ^ "m" + let bold str = termcode 1 ^ str + +let red str = termcode 91 ^ str let green str = termcode 92 ^ str let yellow str = termcode 93 ^ str -let red str = termcode 91 ^ str -let cyan str = termcode 96 ^ str let blue str = termcode 94 ^ str +let magenta str = termcode 95 ^ str +let cyan str = termcode 96 ^ str + let clear str = str ^ termcode 0 let zchar c = diff --git a/src/util.mli b/src/util.mli index 2b4d2e93..46d99002 100644 --- a/src/util.mli +++ b/src/util.mli @@ -240,6 +240,7 @@ val red : string -> string val yellow : string -> string val cyan : string -> string val blue : string -> string +val magenta : string -> string val clear : string -> string val warn : string -> unit |
