diff options
52 files changed, 2252 insertions, 482 deletions
diff --git a/lib/prelude.sail b/lib/prelude.sail index bac9532c..795fe8fc 100644 --- a/lib/prelude.sail +++ b/lib/prelude.sail @@ -3,37 +3,50 @@ val cast forall Num 'n, Num 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0:2**'m val forall Num 'n, Num 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0 - (2**('m - 1)):2**('m - 1) - 1|] effect pure signed -val forall Num 'n, Num 'm. [|0:'n|] -> vector<'m - 1,'m,dec,bit> effect pure to_vec +val extern forall Num 'n, Num 'm. [|0:'n|] -> vector<'m - 1,'m,dec,bit> effect pure to_vec = "to_vec_dec" -val forall Num 'm. int -> vector<'m - 1,'m,dec,bit> effect pure to_svec +val extern forall Num 'm. int -> vector<'m - 1,'m,dec,bit> effect pure to_svec = "to_vec_dec" (* Vector access can't actually be properly polymorphic on vector direction because of the ranges being different for each type, so we overload it instead *) val forall Num 'n, Num 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec val forall Num 'n, Num 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc +val forall Num 'n, Num 'l, 'l >= 0. (vector<'n,'l,dec,bit>, [|'n - 'l + 1:'n|]) -> bit effect pure bitvector_access_dec +val forall Num 'n, Num 'l, 'l >= 0. (vector<'n,'l,inc,bit>, [|'n:'n + 'l - 1|]) -> bit effect pure bitvector_access_inc -overload vector_access [vector_access_inc; vector_access_dec] +overload vector_access [bitvector_access_inc; bitvector_access_dec; vector_access_inc; vector_access_dec] (* Type safe vector subrange *) +(* vector_subrange(v, m, o) returns the subvector of v with elements with + indices from m up to and *including* o. *) val forall Num 'n, Num 'l, Num 'm, Num 'o, Type 'a, 'l >= 0, 'm <= 'o, 'o <= 'l. - (vector<'n,'l,inc,'a>, [:'m:], [:'o:]) -> vector<'m,'o - 'm,inc,'a> effect pure vector_subrange_inc + (vector<'n,'l,inc,'a>, [:'m:], [:'o:]) -> vector<'m,('o - 'm) + 1,inc,'a> effect pure vector_subrange_inc val forall Num 'n, Num 'l, Num 'm, Num 'o, Type 'a, 'n >= 'm, 'm >= 'o, 'o >= 'n - 'l + 1. - (vector<'n,'l,dec,'a>, [:'m:], [:'o:]) -> vector<'m,'m - ('o - 1),dec,'a> effect pure vector_subrange_dec + (vector<'n,'l,dec,'a>, [:'m:], [:'o:]) -> vector<'m,('m - 'o) + 1,dec,'a> effect pure vector_subrange_dec val forall Num 'n, Num 'l, Order 'ord. (vector<'n,'l,'ord,bit>, int, int) -> list<bit> effect pure vector_subrange_bl -overload vector_subrange [vector_subrange_inc; vector_subrange_dec; vector_subrange_bl] +val forall Num 'n, Num 'l, Num 'm, Num 'o, 'l >= 0, 'm <= 'o, 'o <= 'l. + (vector<'n,'l,inc,bit>, [:'m:], [:'o:]) -> vector<'m,('o - 'm) + 1,inc,bit> effect pure bitvector_subrange_inc + +val forall Num 'n, Num 'l, Num 'm, Num 'o, 'n >= 'm, 'm >= 'o, 'o >= 'n - 'l + 1. + (vector<'n,'l,dec,bit>, [:'m:], [:'o:]) -> vector<'m,('m - 'o) + 1,dec,bit> effect pure bitvector_subrange_dec + +overload vector_subrange [bitvector_subrange_inc; bitvector_subrange_dec; vector_subrange_inc; vector_subrange_dec; vector_subrange_bl] (* Type safe vector append *) -val forall Num 'n1, Num 'l1, Num 'n2, Num 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2 >= 0. - (vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'l1 + 'l2 - 1,'l1 + 'l2,'o,'a> effect pure vec_append +val extern forall Num 'n1, Num 'l1, Num 'n2, Num 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2 >= 0. + (vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'l1 + 'l2 - 1,'l1 + 'l2,'o,'a> effect pure vec_append = "vector_concat" val (list<bit>, list<bit>) -> list<bit> effect pure list_append -overload vector_append [vec_append; list_append] +val extern forall Num 'n1, Num 'l1, Num 'n2, Num 'l2, Order 'o, 'l1 >= 0, 'l2 >= 0. + (vector<'n1,'l1,'o,bit>, vector<'n2,'l2,'o,bit>) -> vector<'l1 + 'l2 - 1,'l1 + 'l2,'o,bit> effect pure bitvec_append = "bitvector_concat" + +overload vector_append [bitvec_append; vec_append; list_append] (* Implicit register dereferencing *) val cast forall Type 'a. register<'a> -> 'a effect {rreg} reg_deref @@ -99,12 +112,12 @@ val forall Num 'n, Num 'm, Order 'ord. vector<'n, 'm, 'ord, bit> -> bit effect p (* Arithmetic *) -val forall Num 'n, Num 'm, Num 'o, Num 'p. +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add -val (nat, nat) -> nat effect pure add_nat +val extern (nat, nat) -> nat effect pure add_nat = "add" -val (int, int) -> int effect pure add_int +val extern (int, int) -> int effect pure add_int = "add" val forall Num 'n, Num 'o, Order 'ord. (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_vec @@ -115,10 +128,10 @@ val forall Num 'n, Num 'o, Order 'ord. val forall Num 'n, Num 'o, Order 'ord. (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> (vector<'o, 'n, 'ord, bit>, bit, bit) effect pure add_overflow_vec -val forall Num 'n, Num 'm, Num 'o, Num 'p. +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n - 'p:'m - 'o|] effect pure sub -val (int, int) -> int effect pure sub_int +val extern (int, int) -> int effect pure sub_int = "sub" val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, int) -> vector<'n,'m,'ord,bit> effect pure sub_vec_int @@ -146,7 +159,7 @@ overload (deinfix -) [ sub_int ] -val bool -> bit effect pure bool_to_bit +val extern bool -> bit effect pure bool_to_bit = "bool_to_bitU" val (int, int) -> int effect pure mul_int val forall Num 'n, Num 'o, Order 'ord. @@ -164,10 +177,10 @@ overload (deinfix *_s) [ mul_svec ] -val (bool, bool) -> bool effect pure bool_xor +val extern (bool, bool) -> bool effect pure bool_xor -val forall Num 'n, Num 'o, Order 'ord. - (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure xor_vec +val extern forall Num 'n, Num 'o, Order 'ord. + (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure xor_vec = "bitwise_xor" overload (deinfix ^) [ bool_xor; @@ -189,7 +202,7 @@ overload (deinfix >>) [ ] (* Boolean operators *) -val bool -> bool effect pure bool_not +val extern bool -> bool effect pure bool_not = "not" val (bool, bool) -> bool effect pure bool_or val (bool, bool) -> bool effect pure bool_and @@ -226,24 +239,24 @@ val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'or val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lteq_vec val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lt_vec -val (int, int) -> bool effect pure gteq_int -val (int, int) -> bool effect pure gt_int -val (int, int) -> bool effect pure lteq_int -val (int, int) -> bool effect pure lt_int - -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range - -val forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure lteq_atom_atom -val forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure gteq_atom_atom -val forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure lt_atom_atom -val forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure gt_atom_atom +val extern (int, int) -> bool effect pure gteq_int = "gteq" +val extern (int, int) -> bool effect pure gt_int = "gt" +val extern (int, int) -> bool effect pure lteq_int = "lteq" +val extern (int, int) -> bool effect pure lt_int = "lt" + +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom = "gteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range = "gteq" + +val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure lteq_atom_atom = "lteq" +val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure gteq_atom_atom = "gteq" +val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure lt_atom_atom = "lt" +val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure gt_atom_atom = "gt" overload (deinfix >=) [gteq_atom_atom; gteq_range_atom; gteq_atom_range; gteq_vec; gteq_int] overload (deinfix >) [gt_atom_atom; gt_vec; gt_int] @@ -264,11 +277,16 @@ val (int, int) -> int effect pure quotient overload (deinfix quot) [quotient] -val (int, int) -> int effect pure modulus +val (int, int) -> int effect pure modulo -overload (deinfix mod) [modulus] +overload (deinfix mod) [modulo] -val forall Num 'n, Num 'm, Order 'ord, Type 'a. vector<'n,'m,'ord,'a> -> [:'m:] effect pure length +val extern forall Num 'n, Num 'm, Order 'ord, Type 'a. vector<'n,'m,'ord,'a> -> [:'m:] effect pure vec_length = "length" +val forall Type 'a. list<'a> -> nat effect pure list_length + +val extern forall Num 'n, Num 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [:'m:] effect pure bitvector_length = "bvlength" + +overload length [bitvector_length; vector_length; list_length] val cast forall Num 'n. [:'n:] -> [|'n|] effect pure upper @@ -276,4 +294,3 @@ typedef option = const union forall Type 'a. { None; 'a Some } - diff --git a/mips_new_tc/mips_extras_embed_sequential.lem b/mips_new_tc/mips_extras_embed_sequential.lem new file mode 100644 index 00000000..ad567598 --- /dev/null +++ b/mips_new_tc/mips_extras_embed_sequential.lem @@ -0,0 +1,51 @@ +open import Pervasives +open import Pervasives_extra +open import Sail_impl_base +open import Sail_values +open import State + +val MEMr : forall 'a 'b. Size 'b => (bitvector 'a * integer) -> M (bitvector 'b) +val MEMr_reserve : forall 'a 'b. Size 'b => (bitvector 'a * integer) -> M (bitvector 'b) +val MEMr_tag : forall 'a 'b. Size 'b => (bitvector 'a * integer) -> M (bitU * bitvector 'b) +val MEMr_tag_reserve : forall 'a 'b. Size 'b => (bitvector 'a * integer) -> M (bitU * bitvector 'b) + +let MEMr (addr,size) = read_mem false Read_plain addr size +let MEMr_reserve (addr,size) = read_mem false Read_reserve addr size + +let MEMr_tag (addr,size) = + read_mem false Read_plain addr size >>= fun v -> + read_tag false Read_plain addr >>= fun t -> + return (t, v) + +let MEMr_tag_reserve (addr,size) = + read_mem false Read_plain addr size >>= fun v -> + read_tag false Read_plain addr >>= fun t -> + return (t, v) + + +val MEMea : forall 'a. (bitvector 'a * integer) -> M unit +val MEMea_conditional : forall 'a. (bitvector 'a * integer) -> M unit +val MEMea_tag : forall 'a. (bitvector 'a * integer) -> M unit +val MEMea_tag_conditional : forall 'a. (bitvector 'a * integer) -> M unit + +let MEMea (addr,size) = write_mem_ea Write_plain addr size +let MEMea_conditional (addr,size) = write_mem_ea Write_conditional addr size + +let MEMea_tag (addr,size) = write_mem_ea Write_plain addr size +let MEMea_tag_conditional (addr,size) = write_mem_ea Write_conditional addr size + + +val MEMval : forall 'a 'b. (bitvector 'a * integer * bitvector 'b) -> M unit +val MEMval_conditional : forall 'a 'b. (bitvector 'a * integer * bitvector 'b) -> M bool +val MEMval_tag : forall 'a 'b. (bitvector 'a * integer * bitU * bitvector 'b) -> M unit +val MEMval_tag_conditional : forall 'a 'b. (bitvector 'a * integer * bitU * bitvector 'b) -> M bool + +let MEMval (_,_,v) = write_mem_val v >>= fun _ -> return () +let MEMval_conditional (_,_,v) = write_mem_val v >>= fun b -> return (if b then true else false) +let MEMval_tag (_,_,t,v) = write_mem_val v >>= fun _ -> write_tag t >>= fun _ -> return () +let MEMval_tag_conditional (_,_,t,v) = write_mem_val v >>= fun b -> write_tag t >>= fun _ -> return (if b then true else false) + +val MEM_sync : unit -> M unit + +let MEM_sync () = barrier Barrier_MIPS_SYNC + diff --git a/mips_new_tc/mips_insts.sail b/mips_new_tc/mips_insts.sail index 1d3c5f4a..96826dae 100644 --- a/mips_new_tc/mips_insts.sail +++ b/mips_new_tc/mips_insts.sail @@ -1136,24 +1136,24 @@ function clause execute (Load(width, signed, linked, base, rt, offset)) = else let pAddr = (TLBTranslate(vAddr, LoadData)) in { - (bit[64]) memResult := if (linked) then + (bit[64]) memResult := if (linked) then { CP0LLBit := 0b1; - CP0LLAddr := pAddr; - switch wordWidthBytes(width) { - case ([:1:]) w -> extendLoad(MEMr_reserve_wrapper(pAddr, w), signed) - case ([:2:]) w -> extendLoad(MEMr_reserve_wrapper(pAddr, w), signed) - case ([:4:]) w -> extendLoad(MEMr_reserve_wrapper(pAddr, w), signed) - case ([:8:]) w -> extendLoad(MEMr_reserve_wrapper(pAddr, w), signed) - } + CP0LLAddr := pAddr; + w := wordWidthBytes(width); + if w == 1 then extendLoad(MEMr_reserve_wrapper(pAddr, 1), signed) + else if w == 2 then extendLoad(MEMr_reserve_wrapper(pAddr, 2), signed) + else if w == 4 then extendLoad(MEMr_reserve_wrapper(pAddr, 4), signed) + else extendLoad(MEMr_reserve_wrapper(pAddr, 8), signed) } else - switch wordWidthBytes(width) { - case ([:1:]) w -> extendLoad(MEMr_wrapper(pAddr, w), signed) - case ([:2:]) w -> extendLoad(MEMr_wrapper(pAddr, w), signed) - case ([:4:]) w -> extendLoad(MEMr_wrapper(pAddr, w), signed) - case ([:8:]) w -> extendLoad(MEMr_wrapper(pAddr, w), signed) - }; + { + w := wordWidthBytes(width); + if w == 1 then extendLoad(MEMr_reserve_wrapper(pAddr, 1), signed) + else if w == 2 then extendLoad(MEMr_reserve_wrapper(pAddr, 2), signed) + else if w == 4 then extendLoad(MEMr_reserve_wrapper(pAddr, 4), signed) + else extendLoad(MEMr_reserve_wrapper(pAddr, 8), signed); + }; wGPR(rt) := memResult } } diff --git a/mips_new_tc/mips_prelude.sail b/mips_new_tc/mips_prelude.sail index dca78b12..6792f546 100644 --- a/mips_new_tc/mips_prelude.sail +++ b/mips_new_tc/mips_prelude.sail @@ -190,7 +190,7 @@ register (TLBEntry) TLBEntry61 register (TLBEntry) TLBEntry62 register (TLBEntry) TLBEntry63 -let (vector <0, 64, inc, (TLBEntry)>) TLBEntries = [ +let (vector <0, 64, inc, (register<TLBEntry>)>) TLBEntries = [ TLBEntry00, TLBEntry01, TLBEntry02, diff --git a/src/ast_util.ml b/src/ast_util.ml index 955164a3..67381c52 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -297,6 +297,8 @@ let rec string_of_exp (E_aux (exp, _)) = ^ string_of_exp body | E_assert (test, msg) -> "assert(" ^ string_of_exp test ^ ", " ^ string_of_exp msg ^ ")" | E_exit exp -> "exit " ^ string_of_exp exp + | E_cons (x, xs) -> string_of_exp x ^ " :: " ^ string_of_exp xs + | E_list xs -> "[||" ^ string_of_list ", " string_of_exp xs ^ "||]" | _ -> "INTERNAL" and string_of_pexp (Pat_aux (pexp, _)) = match pexp with @@ -400,6 +402,39 @@ let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2 | _, _ -> false +let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with + | Nexp_id _ | Nexp_var _ -> false + | Nexp_constant _ -> true + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> + is_nexp_constant n1 && is_nexp_constant n2 + | Nexp_exp n | Nexp_neg n -> is_nexp_constant n + +let rec simplify_nexp (Nexp_aux (nexp, l)) = + let rewrap n = Nexp_aux (n, l) in + let try_binop op n1 n2 c = (match simplify_nexp n1, simplify_nexp n2 with + | Nexp_aux (Nexp_constant i1, _), Nexp_aux (Nexp_constant i2, _) -> + rewrap (Nexp_constant (op i1 i2)) + | n1, n2 -> rewrap (c n1 n2)) in + match nexp with + | Nexp_times (n1, n2) -> try_binop ( * ) n1 n2 (fun n1 n2 -> Nexp_times (n1, n2)) + | Nexp_sum (n1, n2) -> try_binop ( + ) n1 n2 (fun n1 n2 -> Nexp_sum (n1, n2)) + | Nexp_minus (n1, n2) -> try_binop ( - ) n1 n2 (fun n1 n2 -> Nexp_minus (n1, n2)) + | Nexp_exp n -> + (match simplify_nexp n with + | Nexp_aux (Nexp_constant i, _) -> + rewrap (Nexp_constant (power 2 i)) + | n -> rewrap (Nexp_exp n)) + | Nexp_neg n -> + (match simplify_nexp n with + | Nexp_aux (Nexp_constant i, _) -> + rewrap (Nexp_constant (-i)) + | n -> rewrap (Nexp_neg n)) + | _ -> rewrap nexp + (* | Nexp_sum of nexp * nexp (* sum *) + | Nexp_minus of nexp * nexp (* subtraction *) + | Nexp_exp of nexp (* exponential *) + | Nexp_neg of nexp (* For internal use *) *) + let rec is_number (Typ_aux (t,_)) = match t with | Typ_app (Id_aux (Id "range", _),_) diff --git a/src/ast_util.mli b/src/ast_util.mli index 6e22d173..ae340839 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -119,6 +119,8 @@ end val nexp_frees : nexp -> KidSet.t val nexp_identical : nexp -> nexp -> bool +val is_nexp_constant : nexp -> bool +val simplify_nexp : nexp -> nexp val is_number : typ -> bool val is_vector_typ : typ -> bool diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem index 70850dc1..0944f42b 100644 --- a/src/gen_lib/prompt.lem +++ b/src/gen_lib/prompt.lem @@ -78,6 +78,8 @@ let read_reg_bitfield reg regfield = read_reg_aux (external_reg_field_whole reg regfield) >>= fun v -> return (extract_only_element v) +let reg_deref = read_reg + val write_reg_aux : reg_name -> vector bitU -> M unit let write_reg_aux reg_name v = let regval = external_reg_value reg_name v in diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index f148c1ff..b4a15432 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -11,6 +11,12 @@ type nn = natural val pow : integer -> integer -> integer let pow m n = m ** (natFromInteger n) +let bool_or (l, r) = (l || r) +let bool_and (l, r) = (l && r) +let bool_xor (l, r) = xor l r + +let list_append (l, r) = l ++ r + let rec replace bs ((n : integer),b') = match bs with | [] -> [] | b :: bs -> @@ -18,6 +24,7 @@ let rec replace bs ((n : integer),b') = match bs with else b :: replace bs (n - 1,b') end +let upper n = n (*** Bits *) type bitU = B0 | B1 | BU @@ -39,6 +46,8 @@ let bitU_to_bool = function | BU -> failwith "to_bool applied to BU" end +let cast_bit_bool = bitU_to_bool + let bit_lifted_of_bitU = function | B0 -> Bitl_zero | B1 -> Bitl_one @@ -196,6 +205,9 @@ let access (Vector bs start is_inc) n = if is_inc then List_extra.nth bs (natFromInteger (n - start)) else List_extra.nth bs (natFromInteger (start - n)) +let vector_access_inc (v, i) = access v i +let vector_access_dec (v, i) = access v i + val update_pos : forall 'a. vector 'a -> integer -> 'a -> vector 'a let update_pos v n b = update_aux v n n [b] @@ -238,9 +250,12 @@ let reset_bitvector_start v = let set_bitvector_start_to_length v = set_bitvector_start (bvlength v - 1) v -let bitvector_concat (Bitvector bs start is_inc) (Bitvector bs' _ _) = +let bitvector_concat (Bitvector bs start is_inc, Bitvector bs' _ _) = Bitvector (word_concat bs bs') start is_inc +let norm_dec = reset_bitvector_start +let adjust_dec = reset_bitvector_start + let inline (^^^) = bitvector_concat val bvslice : forall 'a 'b. bitvector 'a -> integer -> integer -> bitvector 'b @@ -252,6 +267,13 @@ let bvslice (Bitvector bs start is_inc) i j = let subvector_bits = word_extract lo hi bs in Bitvector subvector_bits i is_inc +let bitvector_subrange_inc (v, i, j) = bvslice v i j +let bitvector_subrange_dec (v, i, j) = bvslice v i j + +let vector_subrange_bl (v, i, j) = + let v' = slice (bvec_to_vec v) i j in + get_elems v' + (* this is for the vector slicing introduced in vector-concat patterns: i and j index into the "raw data", the list of bits. Therefore getting the bit list is easy, but the start index has to be transformed to match the old vector start @@ -277,19 +299,20 @@ val bvupdate : forall 'a 'b. bitvector 'a -> integer -> integer -> bitvector 'b let bvupdate v i j (Bitvector bs' _ _) = bvupdate_aux v i j bs' -(* TODO: decide between nat/natural, change either here or in machine_word *) -val getBit' : forall 'a. mword 'a -> nat -> bool -let getBit' w n = getBit w (naturalFromNat n) - val bvaccess : forall 'a. bitvector 'a -> integer -> bitU let bvaccess (Bitvector bs start is_inc) n = bool_to_bitU ( - if is_inc then getBit' bs (natFromInteger (n - start)) - else getBit' bs (natFromInteger (start - n))) + if is_inc then getBit bs (natFromInteger (n - start)) + else getBit bs (natFromInteger (start - n))) val bvupdate_pos : forall 'a. Size 'a => bitvector 'a -> integer -> bitU -> bitvector 'a let bvupdate_pos v n b = bvupdate_aux v n n ((wordFromNatural (if bitU_to_bool b then 1 else 0)) : mword ty1) +let bitvector_access_inc (v, i) = bvaccess v i +let bitvector_access_dec (v, i) = bvaccess v i +let bitvector_update_pos_dec (v, i, b) = bvupdate_pos v i b +let bitvector_update_dec (v, i, j, v') = bvupdate v i j v' + (*** Bit vector operations *) let extract_only_element (Vector elems _ _) = match elems with @@ -308,6 +331,9 @@ let extract_only_bit (Bitvector elems _ _) = else failwith "extract_single_bit called for vector with more bits"*) +let cast_vec_bool v = bitU_to_bool (extract_only_bit v) +let cast_bit_vec b = vec_to_bvec (Vector [b] 0 false) + let pp_bitu_vector (Vector elems start inc) = let elems_pp = List.foldl (fun acc elem -> acc ^ showBitU elem) "" elems in "Vector [" ^ elems_pp ^ "] " ^ show start ^ " " ^ show inc @@ -409,19 +435,25 @@ end let add_one_bit_ignore_overflow bits = List.reverse (add_one_bit_ignore_overflow_aux (List.reverse bits)) -let to_vec is_inc ((len : integer),(n : integer)) = - let start = if is_inc then 0 else len - 1 in +let to_vec is_inc ((n : integer)) = + (* Bitvector length is determined by return type *) let bits = wordFromInteger n in - if integerFromNat (word_length bits) = len then + let len = integerFromNat (word_length bits) in + let start = if is_inc then 0 else len - 1 in + (*if integerFromNat (word_length bits) = len then*) Bitvector bits start is_inc - else - failwith "Vector length mismatch in to_vec" + (*else + failwith "Vector length mismatch in to_vec"*) let to_vec_big = to_vec let to_vec_inc = to_vec true let to_vec_dec = to_vec false +let cast_0_vec = to_vec_dec +let cast_1_vec = to_vec_dec +let cast_01_vec = to_vec_dec + (* TODO: Think about undefined bit(vector)s *) let to_vec_undef is_inc (len : integer) = Bitvector (failwith "undefined bitvector") (if is_inc then 0 else len-1) is_inc @@ -429,19 +461,25 @@ let to_vec_undef is_inc (len : integer) = let to_vec_inc_undef = to_vec_undef true let to_vec_dec_undef = to_vec_undef false -let exts (len, vec) = to_vec (bvget_dir vec) (len,signed vec) -let extz (len, vec) = to_vec (bvget_dir vec) (len,unsigned vec) +let exts (vec) = to_vec (bvget_dir vec) (signed vec) +let extz (vec) = to_vec (bvget_dir vec) (unsigned vec) + +let exts_big (vec) = to_vec_big (bvget_dir vec) (signed_big vec) +let extz_big (vec) = to_vec_big (bvget_dir vec) (unsigned_big vec) -let exts_big (len, vec) = to_vec_big (bvget_dir vec) (len, signed_big vec) -let extz_big (len, vec) = to_vec_big (bvget_dir vec) (len, unsigned_big vec) +let extz_bl (bits) = vec_to_bvec (Vector bits (integerFromNat (List.length bits - 1)) false) -let add = integerAdd -let add_signed = integerAdd -let minus = integerMinus -let multiply = integerMult -let modulo = hardware_mod +let add (l,r) = integerAdd l r +let add_signed (l,r) = integerAdd l r +let sub (l,r) = integerMinus l r +let multiply (l,r) = integerMult l r +let quotient (l,r) = integerDiv l r +let modulo (l,r) = hardware_mod l r let quot = hardware_quot -let power = integerPow +let power (l,r) = integerPow l r + +let sub_int = sub +let mul_int = multiply (* TODO: this, and the definitions that use it, currently require Size for to_vec, which I'd rather avoid in favour of library versions; the @@ -449,7 +487,7 @@ let power = integerPow let arith_op_vec op sign (size : integer) (Bitvector _ _ is_inc as l) r = let (l',r') = (to_num sign l, to_num sign r) in let n = op l' r' in - to_vec is_inc (size * (bvlength l),n) + to_vec is_inc (n) (* add_vec @@ -464,9 +502,15 @@ let minus_VVV = arith_op_vec integerMinus false 1 let mult_VVV = arith_op_vec integerMult false 2 let multS_VVV = arith_op_vec integerMult true 2 +let mul_vec (l, r) = mult_VVV l r +let mul_svec (l, r) = multS_VVV l r + +let add_vec (l, r) = add_VVV l r +let sub_vec (l, r) = minus_VVV l r + val arith_op_vec_range : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> bitvector 'a -> integer -> bitvector 'b let arith_op_vec_range op sign size (Bitvector _ _ is_inc as l) r = - arith_op_vec op sign size l ((to_vec is_inc (bvlength l,r)) : bitvector 'a) + arith_op_vec op sign size l ((to_vec is_inc (r)) : bitvector 'a) (* add_vec_range * add_vec_range_signed @@ -480,9 +524,12 @@ let minus_VIV = arith_op_vec_range integerMinus false 1 let mult_VIV = arith_op_vec_range integerMult false 2 let multS_VIV = arith_op_vec_range integerMult true 2 +let add_vec_int (l, r) = add_VIV l r +let sub_vec_int (l, r) = minus_VIV l r + val arith_op_range_vec : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> integer -> bitvector 'a -> bitvector 'b let arith_op_range_vec op sign size l (Bitvector _ _ is_inc as r) = - arith_op_vec op sign size ((to_vec is_inc (bvlength r, l)) : bitvector 'a) r + arith_op_vec op sign size ((to_vec is_inc (l)) : bitvector 'a) r (* add_range_vec * add_range_vec_signed @@ -528,10 +575,10 @@ let arith_op_vec_vec_range op sign l r = let add_VVI = arith_op_vec_vec_range integerAdd false let addS_VVI = arith_op_vec_vec_range integerAdd true -let arith_op_vec_bit op sign (size : integer) (Bitvector _ _ is_inc as l)r = +let arith_op_vec_bit op sign (size : integer) (Bitvector _ _ is_inc as l) r = let l' = to_num sign l in let n = op l' (match r with | B1 -> (1 : integer) | _ -> 0 end) in - to_vec is_inc (bvlength l * size,n) + to_vec is_inc (n) (* add_vec_bit * add_vec_bit_signed @@ -610,17 +657,19 @@ let shift_op_vec op (Bitvector bs start is_inc,(n : integer)) = let n = natFromInteger n in match op with | LL_shift (*"<<"*) -> - Bitvector (shiftLeft bs (naturalFromNat n)) start is_inc + Bitvector (shiftLeft bs n) start is_inc | RR_shift (*">>"*) -> - Bitvector (shiftRight bs (naturalFromNat n)) start is_inc + Bitvector (shiftRight bs n) start is_inc | LLL_shift (*"<<<"*) -> - Bitvector (rotateLeft (naturalFromNat n) bs) start is_inc + Bitvector (rotateLeft n bs) start is_inc end let bitwise_leftshift = shift_op_vec LL_shift (*"<<"*) let bitwise_rightshift = shift_op_vec RR_shift (*">>"*) let bitwise_rotate = shift_op_vec LLL_shift (*"<<<"*) +let shiftl = bitwise_leftshift + let rec arith_op_no0 (op : integer -> integer -> integer) l r = if r = 0 then Nothing @@ -681,19 +730,19 @@ let rec repeat xs n = if n = 0 then [] else xs ++ repeat xs (n-1) -(* -let duplicate bit length = - Vector (repeat [bit] length) (if dir then 0 else length - 1) dir - *) -let compare_op op (l,r) = bool_to_bitU (op l r) +let duplicate (bit, length) = + vec_to_bvec (Vector (repeat [bit] length) (length - 1) false) + +let duplicate_to_list (bit, length) = repeat [bit] length + +let compare_op op (l,r) = (op l r) let lt = compare_op (<) let gt = compare_op (>) let lteq = compare_op (<=) let gteq = compare_op (>=) - let compare_op_vec op sign (l,r) = let (l',r') = (to_num sign l, to_num sign r) in compare_op op (l',r') @@ -712,6 +761,8 @@ let gt_vec_unsigned = compare_op_vec (>) false let lteq_vec_unsigned = compare_op_vec (<=) false let gteq_vec_unsigned = compare_op_vec (>=) false +let lt_svec = lt_vec_signed + let compare_op_vec_range op sign (l,r) = compare_op op ((to_num sign l),r) @@ -728,20 +779,23 @@ let gt_range_vec = compare_op_range_vec (>) true let lteq_range_vec = compare_op_range_vec (<=) true let gteq_range_vec = compare_op_range_vec (>=) true -let eq (l,r) = bool_to_bitU (l = r) -let eq_range (l,r) = bool_to_bitU (l = r) -let eq_vec (l,r) = bool_to_bitU (l = r) -let eq_bit (l,r) = bool_to_bitU (l = r) +val eq : forall 'a. Eq 'a => 'a * 'a -> bool +let eq (l,r) = (l = r) +let eq_range (l,r) = (l = r) + +val eq_vec : forall 'a. bitvector 'a * bitvector 'a -> bool +let eq_vec (l,r) = (l = r) +let eq_bit (l,r) = (l = r) let eq_vec_range (l,r) = eq (to_num false l,r) let eq_range_vec (l,r) = eq (l, to_num false r) let eq_vec_vec (l,r) = eq (to_num true l, to_num true r) -let neq (l,r) = bitwise_not_bit (eq (l,r)) -let neq_bit (l,r) = bitwise_not_bit (eq_bit (l,r)) -let neq_range (l,r) = bitwise_not_bit (eq_range (l,r)) -let neq_vec (l,r) = bitwise_not_bit (eq_vec_vec (l,r)) -let neq_vec_range (l,r) = bitwise_not_bit (eq_vec_range (l,r)) -let neq_range_vec (l,r) = bitwise_not_bit (eq_range_vec (l,r)) +let neq (l,r) = not (eq (l,r)) +let neq_bit (l,r) = not (eq_bit (l,r)) +let neq_range (l,r) = not (eq_range (l,r)) +let neq_vec (l,r) = not (eq_vec_vec (l,r)) +let neq_vec_range (l,r) = not (eq_vec_range (l,r)) +let neq_range_vec (l,r) = not (eq_range_vec (l,r)) val make_indexed_vector : forall 'a. list (integer * 'a) -> 'a -> integer -> integer -> bool -> vector 'a @@ -757,9 +811,9 @@ let make_bitvector_undef length = (* let bitwise_not_range_bit n = bitwise_not (to_vec defaultDir n) *) -let mask (n,bv) = - let len = bvlength bv in - bvslice_raw bv (len - n) (len - 1) +(* TODO *) +val mask : forall 'a 'b. Size 'b => bitvector 'a -> bitvector 'b +let mask (Bitvector w i dir) = (Bitvector (zeroExtend w) i dir) val byte_chunks : forall 'a. nat -> list 'a -> list (list 'a) @@ -974,7 +1028,7 @@ let assert' b msg_opt = | Just msg -> msg | Nothing -> "unspecified error" end in - if bitU_to_bool b then () else failwith msg + if b then () else failwith msg (* convert numbers unsafely to naturals *) diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index 709052fe..2e11e8a9 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -14,12 +14,28 @@ type sequential_state = <| regstate : regstate; write_ea : maybe (write_kind * integer * integer); last_exclusive_operation_was_load : bool|> -type M 'a = sequential_state -> list ((either 'a string) * sequential_state) +(* State, nondeterminism and exception monad with result type 'a + and exception type 'e. *) +type ME 'a 'e = sequential_state -> list ((either 'a 'e) * sequential_state) -val return : forall 'a. 'a -> M 'a +(* Most of the time, we don't distinguish between different types of exceptions *) +type M 'a = ME 'a unit + +(* For early return, we abuse exceptions by throwing and catching + the return value. The exception type is "maybe 'r", where "Nothing" + represents a proper exception and "Just r" an early return of value "r". *) +type MR 'a 'r = ME 'a (maybe 'r) + +val liftR : forall 'a 'r. M 'a -> MR 'a 'r +let liftR m s = List.map (function + | (Left a, s') -> (Left a, s') + | (Right (), s') -> (Right Nothing, s') + end) (m s) + +val return : forall 'a 'e. 'a -> ME 'a 'e let return a s = [(Left a,s)] -val bind : forall 'a 'b. M 'a -> ('a -> M 'b) -> M 'b +val bind : forall 'a 'b 'e. ME 'a 'e -> ('a -> ME 'b 'e) -> ME 'b 'e let bind m f (s : sequential_state) = List.concatMap (function | (Left a, s') -> f a s' @@ -27,12 +43,23 @@ let bind m f (s : sequential_state) = end) (m s) let inline (>>=) = bind -val (>>): forall 'b. M unit -> M 'b -> M 'b +val (>>): forall 'b 'e. ME unit 'e -> ME 'b 'e -> ME 'b 'e let inline (>>) m n = m >>= fun _ -> n val exit : forall 'e 'a. 'e -> M 'a -let exit _ s = [(Right "exit",s)] +let exit _ s = [(Right (), s)] +val early_return : forall 'r. 'r -> MR unit 'r +let early_return r s = [(Right (Just r), s)] + +val catch_early_return : forall 'a 'r. MR 'a 'a -> M 'a +let catch_early_return m s = + List.map + (function + | (Right (Just a), s') -> (Left a, s') + | (Right Nothing, s') -> (Right (), s') + | (Left a, s') -> (Left a, s') + end) (m s) val range : integer -> integer -> list integer let rec range i j = @@ -126,7 +153,7 @@ val read_reg : forall 'a. Size 'a => register -> M (bitvector 'a) let read_reg reg state = let v = get_reg state (name_of_reg reg) in [(Left (vec_to_bvec v),state)] -let read_reg_range reg i j state = +(*let read_reg_range reg i j state = let v = slice (get_reg state (name_of_reg reg)) i j in [(Left (vec_to_bvec v),state)] let read_reg_bit reg i state = @@ -137,7 +164,9 @@ let read_reg_field reg regfield = read_reg_range reg i j let read_reg_bitfield reg regfield = let (i,_) = register_field_indices reg regfield in - read_reg_bit reg i + read_reg_bit reg i *) + +let reg_deref = read_reg val write_reg : forall 'a. Size 'a => register -> bitvector 'a -> M unit let write_reg reg v state = @@ -172,8 +201,8 @@ val footprint : M unit let footprint = return () -val foreachM_inc : forall 'vars. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> M 'vars) -> M 'vars +val foreachM_inc : forall 'vars 'e. (integer * integer * integer) -> 'vars -> + (integer -> 'vars -> ME 'vars 'e) -> ME 'vars 'e let rec foreachM_inc (i,stop,by) vars body = if i <= stop then @@ -182,8 +211,8 @@ let rec foreachM_inc (i,stop,by) vars body = else return vars -val foreachM_dec : forall 'vars. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> M 'vars) -> M 'vars +val foreachM_dec : forall 'vars 'e. (integer * integer * integer) -> 'vars -> + (integer -> 'vars -> ME 'vars 'e) -> ME 'vars 'e let rec foreachM_dec (i,stop,by) vars body = if i >= stop then diff --git a/src/pretty_print_common.ml b/src/pretty_print_common.ml index 57a06e74..bd43c1a7 100644 --- a/src/pretty_print_common.ml +++ b/src/pretty_print_common.ml @@ -46,6 +46,7 @@ open PPrint let pipe = string "|" let arrow = string "->" let dotdot = string ".." +let coloncolon = string "::" let coloneq = string ":=" let lsquarebar = string "[|" let rsquarebar = string "|]" diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 2619cc51..586773ca 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -160,7 +160,7 @@ let doc_typ_lem, doc_atomic_typ_lem = let tpp = separate_map (space ^^ star ^^ space) (app_typ regtypes false) typs in if atyp_needed then parens tpp else tpp | _ -> app_typ regtypes atyp_needed ty - and app_typ regtypes atyp_needed ((Typ_aux (t, _)) as ty) = match t with + and app_typ regtypes atyp_needed ((Typ_aux (t, l)) as ty) = match t with | Typ_app(Id_aux (Id "vector", _), [ Typ_arg_aux (Typ_arg_nexp n, _); Typ_arg_aux (Typ_arg_nexp m, _); @@ -168,19 +168,17 @@ let doc_typ_lem, doc_atomic_typ_lem = Typ_arg_aux (Typ_arg_typ elem_typ, _)]) -> let tpp = match elem_typ with | Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) -> - let len = match m with - | (Nexp_aux(Nexp_constant i,_)) -> string "ty" ^^ doc_int i - | _ -> doc_nexp m in - string "bitvector" ^^ space ^^ len + (match simplify_nexp m with + | (Nexp_aux(Nexp_constant i,_)) -> string "bitvector ty" ^^ doc_int i + | (Nexp_aux(Nexp_var _, _)) -> separate space [string "bitvector"; doc_nexp m] + | _ -> raise (Reporting_basic.err_unreachable l + "cannot pretty-print bitvector type with non-constant length")) | _ -> string "vector" ^^ space ^^ typ regtypes elem_typ in if atyp_needed then parens tpp else tpp | Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) -> - (* TODO: Better distinguish register names and contents? - The former are represented in the Lem library using a type - "register" (without parameters), the latter just using the content - type (e.g. "bitvector ty64"). We assume the latter is meant here - and drop the "register" keyword. *) - fn_typ regtypes atyp_needed etyp + (* TODO: Better distinguish register names and contents? *) + (* fn_typ regtypes atyp_needed etyp *) + (string "register") | Typ_app(Id_aux (Id "range", _),_) -> (string "integer") | Typ_app(Id_aux (Id "implicit", _),_) -> @@ -192,13 +190,13 @@ let doc_typ_lem, doc_atomic_typ_lem = if atyp_needed then parens tpp else tpp | _ -> atomic_typ regtypes atyp_needed ty and atomic_typ regtypes atyp_needed ((Typ_aux (t, _)) as ty) = match t with - | Typ_id (Id_aux (Id "bool",_)) -> string "bitU" - | Typ_id (Id_aux (Id "boolean",_)) -> string "bitU" + | Typ_id (Id_aux (Id "bool",_)) -> string "bool" + | Typ_id (Id_aux (Id "boolean",_)) -> string "bool" | Typ_id (Id_aux (Id "bit",_)) -> string "bitU" | Typ_id (id) -> - if List.exists ((=) (string_of_id id)) regtypes + (*if List.exists ((=) (string_of_id id)) regtypes then string "register" - else doc_id_lem_type id + else*) doc_id_lem_type id | Typ_var v -> doc_var v | Typ_wild -> underscore | Typ_app _ | Typ_tup _ | Typ_fn _ -> @@ -213,46 +211,86 @@ let doc_typ_lem, doc_atomic_typ_lem = | Typ_arg_effect e -> empty in typ', atomic_typ +(* Check for variables in types that would be pretty-printed. + In particular, in case of vector types, only the element type and the + length argument are checked for variables, and the latter only if it is + a bitvector; for other types of vectors, the length is not pretty-printed + in the type, and the start index is never pretty-printed in vector types. *) +let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with + | Typ_wild -> true + | Typ_id _ -> false + | Typ_var _ -> true + | Typ_fn (t1,t2,_) -> contains_t_pp_var t1 || contains_t_pp_var t2 + | Typ_tup ts -> List.exists contains_t_pp_var ts + | Typ_app (c,targs) -> + if Ast_util.is_number typ then false + else if is_bitvector_typ typ then + let (_,length,_,_) = vector_typ_args_of typ in + not (is_nexp_constant (simplify_nexp length)) + else List.exists contains_t_arg_pp_var targs +and contains_t_arg_pp_var (Typ_arg_aux (targ, _)) = match targ with + | Typ_arg_typ t -> contains_t_pp_var t + | Typ_arg_nexp nexp -> not (is_nexp_constant (simplify_nexp nexp)) + | _ -> false + let doc_tannot_lem regtypes eff typ = - let ta = doc_typ_lem regtypes typ in - if eff then string " : M " ^^ parens ta - else string " : " ^^ ta + if contains_t_pp_var typ then empty + else + let ta = doc_typ_lem regtypes typ in + if eff then string " : M " ^^ parens ta + else string " : " ^^ ta (* doc_lit_lem gets as an additional parameter the type information from the * expression around it: that's a hack, but how else can we distinguish between * undefined values of different types ? *) -let doc_lit_lem in_pat (L_aux(lit,l)) a = - utf8string (match lit with - | L_unit -> "()" - | L_zero -> "B0" - | L_one -> "B1" - | L_false -> "B0" - | L_true -> "B1" +let doc_lit_lem regtypes in_pat (L_aux(lit,l)) a = + match lit with + | L_unit -> utf8string "()" + | L_zero -> utf8string "B0" + | L_one -> utf8string "B1" + | L_false -> utf8string "false" + | L_true -> utf8string "true" | L_num i -> let ipp = string_of_int i in - if in_pat then "("^ipp^":nn)" - else if i < 0 then "((0"^ipp^"):ii)" - else "("^ipp^":ii)" + utf8string ( + if in_pat then "("^ipp^":nn)" + else if i < 0 then "((0"^ipp^"):ii)" + else "("^ipp^":ii)") | L_hex n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)*) | L_bin n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)*) | L_undef -> (match a with - | Some (_, Typ_aux (t,_), _) -> + | Some (_, (Typ_aux (t,_) as typ), _) -> (match t with | Typ_id (Id_aux (Id "bit", _)) - | Typ_app (Id_aux (Id "register", _),_) -> "UndefinedRegister 0" - | Typ_id (Id_aux (Id "string", _)) -> "\"\"" - | _ -> "(failwith \"undefined value of unsupported type\")") - | _ -> "(failwith \"undefined value of unsupported type\")") - | L_string s -> "\"" ^ s ^ "\"" - | L_real s -> s (* TODO What's the Lem syntax for reals? *)) + | Typ_app (Id_aux (Id "register", _),_) -> utf8string "UndefinedRegister 0" + | Typ_id (Id_aux (Id "string", _)) -> utf8string "\"\"" + | _ -> + parens + ((utf8string "(failwith \"undefined value of unsupported type\")") ^^ + (doc_tannot_lem regtypes false typ))) + | _ -> utf8string "(failwith \"undefined value of unsupported type\")") + | L_string s -> utf8string ("\"" ^ s ^ "\"") + | L_real s -> utf8string s (* TODO What's the Lem syntax for reals? *) (* typ_doc is the doc for the type being quantified *) +let doc_quant_item (QI_aux (qi, _)) = match qi with +| QI_id (KOpt_aux (KOpt_none kid, _)) +| QI_id (KOpt_aux (KOpt_kind (_, kid), _)) -> doc_var kid +| _ -> empty -let doc_typquant_lem (TypQ_aux(tq,_)) typ_doc = typ_doc +let doc_typquant_items_lem (TypQ_aux(tq,_)) = match tq with +| TypQ_tq qs -> separate_map space doc_quant_item qs +| _ -> empty -let doc_typschm_lem regtypes (TypSchm_aux(TypSchm_ts(tq,t),_)) = - (doc_typquant_lem tq (doc_typ_lem regtypes t)) +let doc_typquant_lem (TypQ_aux(tq,_)) typ = match tq with +| TypQ_tq ((_ :: _) as qs) -> + string "forall " ^^ separate_map space doc_quant_item qs ^^ string ". " ^^ typ +| _ -> empty + +let doc_typschm_lem regtypes quants (TypSchm_aux(TypSchm_ts(tq,t),_)) = + if quants then (doc_typquant_lem tq (doc_typ_lem regtypes t)) + else doc_typ_lem regtypes t let is_ctor env id = match Env.lookup_id id env with | Enum _ | Union _ -> true @@ -267,14 +305,17 @@ let rec doc_pat_lem regtypes apat_needed (P_aux (p,(l,annot)) as pa) = match p w (parens (separate_map comma (doc_pat_lem regtypes true) pats)) in if apat_needed then parens ppp else ppp | P_app(id,[]) -> doc_id_lem_ctor id - | P_lit lit -> doc_lit_lem true lit annot + | P_lit lit -> doc_lit_lem regtypes true lit annot | P_wild -> underscore | P_id id -> begin match id with | Id_aux (Id "None",_) -> string "Nothing" (* workaround temporary issue *) | _ -> doc_id_lem id end | P_as(p,id) -> parens (separate space [doc_pat_lem regtypes true p; string "as"; doc_id_lem id]) - | P_typ(typ,p) -> parens (doc_op colon (doc_pat_lem regtypes true p) (doc_typ_lem regtypes typ)) + | P_typ(typ,p) -> + let doc_p = doc_pat_lem regtypes true p in + if contains_t_pp_var typ then doc_p + else parens (doc_op colon doc_p (doc_typ_lem regtypes typ)) | P_vector pats -> let ppp = (separate space) @@ -300,38 +341,23 @@ and contains_bitvector_typ_arg (Typ_arg_aux (targ, _)) = match targ with | Typ_arg_typ t -> contains_bitvector_typ t | _ -> false -let const_nexp (Nexp_aux (nexp,_)) = match nexp with - | Nexp_constant _ -> true - | _ -> false - -(* Check for variables in types that would be pretty-printed. - In particular, in case of vector types, only the element type and the - length argument are checked for variables, and the latter only if it is - a bitvector; for other types of vectors, the length is not pretty-printed - in the type, and the start index is never pretty-printed in vector types. *) -let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with - | Typ_wild -> true - | Typ_id _ -> false - | Typ_var _ -> true - | Typ_fn (t1,t2,_) -> contains_t_pp_var t1 || contains_t_pp_var t2 - | Typ_tup ts -> List.exists contains_t_pp_var ts - | Typ_app (c,targs) -> - if is_bitvector_typ typ then - let (_,length,_,_) = vector_typ_args_of typ in - not (const_nexp ((*normalize_nexp*) length)) - else List.exists contains_t_arg_pp_var targs -and contains_t_arg_pp_var (Typ_arg_aux (targ, _)) = match targ with - | Typ_arg_typ t -> contains_t_pp_var t - | Typ_arg_nexp nexp -> not (const_nexp ((*normalize_nexp*) nexp)) - | _ -> false +let contains_early_return exp = + fst (fold_exp + { (Rewriter.compute_exp_alg false (||)) + with e_return = (fun (_, r) -> (true, E_return r)) } exp) let prefix_recordtype = true let report = Reporting_basic.err_unreachable let doc_exp_lem, doc_let_lem = - let rec top_exp regtypes (aexp_needed : bool) (E_aux (e, (l,annot)) as full_exp) = - let expY = top_exp regtypes true in - let expN = top_exp regtypes false in - let expV = top_exp regtypes in + let rec top_exp regtypes (early_ret : bool) (aexp_needed : bool) + (E_aux (e, (l,annot)) as full_exp) = + let expY = top_exp regtypes early_ret true in + let expN = top_exp regtypes early_ret false in + let expV = top_exp regtypes early_ret in + let liftR doc = + if early_ret && effectful (effect_of full_exp) + then separate space [string "liftR"; parens (doc)] + else doc in match e with | E_assign((LEXP_aux(le_act,tannot) as le), e) -> (* can only be register writes *) @@ -343,14 +369,14 @@ let doc_exp_lem, doc_let_lem = if is_bit_typ (typ_of_annot lannot) then raise (report l "indexing a register's (single bit) bitfield not supported") else - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_field_range") - (align (doc_lexp_deref_lem regtypes le ^^ space^^ - string_lit (doc_id_lem id) ^/^ expY e2 ^/^ expY e3 ^/^ expY e)) + (align (doc_lexp_deref_lem regtypes early_ret le ^^ space^^ + string_lit (doc_id_lem id) ^/^ expY e2 ^/^ expY e3 ^/^ expY e))) | _ -> - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_range") - (align (doc_lexp_deref_lem regtypes le ^^ space ^^ expY e2 ^/^ expY e3 ^/^ expY e)) + (align (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ expY e2 ^/^ expY e3 ^/^ expY e))) ) | LEXP_vector (le,e2) when is_bit_typ t -> (match le with @@ -358,23 +384,23 @@ let doc_exp_lem, doc_let_lem = if is_bit_typ (typ_of_annot lannot) then raise (report l "indexing a register's (single bit) bitfield not supported") else - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_field_bit") - (align (doc_lexp_deref_lem regtypes le ^^ space ^^ doc_id_lem id ^/^ expY e2 ^/^ expY e)) + (align (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ doc_id_lem id ^/^ expY e2 ^/^ expY e))) | _ -> - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_bit") - (doc_lexp_deref_lem regtypes le ^^ space ^^ expY e2 ^/^ expY e) + (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ expY e2 ^/^ expY e)) ) | LEXP_field (le,id) when is_bit_typ t -> - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_bitfield") - (doc_lexp_deref_lem regtypes le ^^ space ^^ string_lit(doc_id_lem id) ^/^ expY e) + (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ string_lit(doc_id_lem id) ^/^ expY e)) | LEXP_field (le,id) -> - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_field") - (doc_lexp_deref_lem regtypes le ^^ space ^^ - string_lit(doc_id_lem id) ^/^ expY e) + (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ + string_lit(doc_id_lem id) ^/^ expY e)) (* | (LEXP_id id | LEXP_cast (_,id)), t, Alias alias_info -> (match alias_info with | Alias_field(reg,field) -> @@ -389,9 +415,11 @@ let doc_exp_lem, doc_let_lem = string "write_two_regs" ^^ space ^^ string reg1 ^^ space ^^ string reg2 ^^ space ^^ expY e) *) | _ -> - (prefix 2 1) (string "write_reg") (doc_lexp_deref_lem regtypes le ^/^ expY e)) + liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem regtypes early_ret le ^/^ expY e))) | E_vector_append(le,re) -> - let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in + raise (Reporting_basic.err_unreachable l + "E_vector_access should have been rewritten before pretty-printing") + (* let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (call,ta,aexp_needed) = if is_bitvector_typ t then if not (contains_t_pp_var t) @@ -400,12 +428,12 @@ let doc_exp_lem, doc_let_lem = else ("vector_concat",empty,aexp_needed) in let epp = align (group (separate space [string call;expY le;expY re])) ^^ ta in - if aexp_needed then parens epp else epp + if aexp_needed then parens epp else epp *) | E_cons(le,re) -> doc_op (group (colon^^colon)) (expY le) (expY re) | E_if(c,t,e) -> let (E_aux (_,(_,cannot))) = c in let epp = - separate space [string "if";group (align (string "bitU_to_bool" ^//^ group (expY c)))] ^^ + separate space [string "if";group (expY c)] ^^ break 1 ^^ (prefix 2 1 (string "then") (expN t)) ^^ (break 1) ^^ (prefix 2 1 (string "else") (expN e)) in @@ -413,7 +441,7 @@ let doc_exp_lem, doc_let_lem = | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> raise (report l "E_for should have been removed till now") | E_let(leb,e) -> - let epp = let_exp regtypes leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in + let epp = let_exp regtypes early_ret leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in if aexp_needed then parens epp else epp | E_app(f,args) -> begin match f with @@ -438,7 +466,7 @@ let doc_exp_lem, doc_let_lem = (prefix 1 1 (separate space [string "fun";expY id;varspp;arrow]) (expN body)) ) ) - | Id_aux (Id "append",_) -> + (* | Id_aux (Id "append",_) -> let [e1;e2] = args in let epp = align (expY e1 ^^ space ^^ string "++" ^//^ expY e2) in if aexp_needed then parens (align epp) else epp @@ -464,7 +492,7 @@ let doc_exp_lem, doc_let_lem = | Id_aux (Id "bool_not", _) -> let [a] = args in let epp = align (string "~" ^^ expY a) in - if aexp_needed then parens (align epp) else epp + if aexp_needed then parens (align epp) else epp *) | _ -> begin match annot with | Some (env, _, _) when (is_ctor env f) -> @@ -477,24 +505,28 @@ let doc_exp_lem, doc_let_lem = parens (separate_map comma (expV false) args) in if aexp_needed then parens (align epp) else epp | _ -> - let call = (*match annot with - | Base(_,External (Some n),_,_,_,_) -> string n - | _ ->*) doc_id_lem f in + let call = match annot with + | Some (env, _, _) when Env.is_extern f env -> + string (Env.get_extern f env) + | _ -> doc_id_lem f in let argspp = match args with | [arg] -> expV true arg | args -> parens (align (separate_map (comma ^^ break 0) (expV false) args)) in let epp = align (call ^//^ argspp) in let (taepp,aexp_needed) = - let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in + let t = (*Env.base_typ_of (env_of full_exp)*) (typ_of full_exp) in let eff = effect_of full_exp in - if contains_bitvector_typ t && not (contains_t_pp_var t) + if contains_bitvector_typ (Env.base_typ_of (env_of full_exp) t) && + not (contains_t_pp_var t) then (align epp ^^ (doc_tannot_lem regtypes (effectful eff) t), true) else (epp, aexp_needed) in - if aexp_needed then parens (align taepp) else taepp + liftR (if aexp_needed then parens (align taepp) else taepp) end end | E_vector_access (v,e) -> - let eff = effect_of full_exp in + raise (Reporting_basic.err_unreachable l + "E_vector_access should have been rewritten before pretty-printing") + (* let eff = effect_of full_exp in let epp = if has_effect eff BE_rreg then separate space [string "read_reg_bit";expY v;expY e] @@ -502,9 +534,11 @@ let doc_exp_lem, doc_let_lem = let tv = typ_of v in let call = if is_bitvector_typ tv then "bvaccess" else "access" in separate space [string call;expY v;expY e] in - if aexp_needed then parens (align epp) else epp + if aexp_needed then parens (align epp) else epp*) | E_vector_subrange (v,e1,e2) -> - let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in + raise (Reporting_basic.err_unreachable l + "E_vector_access should have been rewritten before pretty-printing") + (* let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let eff = effect_of full_exp in let (epp,aexp_needed) = if has_effect eff BE_rreg then @@ -519,22 +553,22 @@ let doc_exp_lem, doc_let_lem = then (bepp ^^ doc_tannot_lem regtypes false t, true) else (bepp, aexp_needed) else (string "slice" ^^ space ^^ expY v ^//^ expY e1 ^//^ expY e2, aexp_needed) in - if aexp_needed then parens (align epp) else epp + if aexp_needed then parens (align epp) else epp *) | E_field((E_aux(_,(l,fannot)) as fexp),id) -> let ft = typ_of_annot (l,fannot) in (match fannot with - | Some(env, ftyp, _) when is_regtyp ftyp env -> - let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in + | Some(env, (Typ_aux (Typ_id tid, _)), _) + | Some(env, (Typ_aux (Typ_app (Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id tid, _)), _)]), _)), _) + when Env.is_regtyp tid env -> + let t = (* Env.base_typ_of (env_of full_exp) *) (typ_of full_exp) in let eff = effect_of full_exp in - let field_f = string - (if is_bit_typ t - then "read_reg_bitfield" - else "read_reg_field") in + let field_f = string "get" ^^ underscore ^^ + doc_id_lem tid ^^ underscore ^^ doc_id_lem id in let (ta,aexp_needed) = if contains_bitvector_typ t && not (contains_t_pp_var t) then (doc_tannot_lem regtypes (effectful eff) t, true) else (empty, aexp_needed) in - let epp = field_f ^^ space ^^ (expY fexp) ^^ space ^^ string_lit (doc_id_lem id) in + let epp = field_f ^^ space ^^ (expY fexp) in if aexp_needed then parens (align epp ^^ ta) else (epp ^^ ta) | Some(env, (Typ_aux (Typ_id tid, _)), _) when Env.is_record tid env -> let fname = @@ -554,9 +588,9 @@ let doc_exp_lem, doc_let_lem = let base_typ = Env.base_typ_of env typ in if has_effect eff BE_rreg then let epp = separate space [string "read_reg";doc_id_lem id] in - if contains_bitvector_typ base_typ && not (contains_t_pp_var base_typ) - then parens (epp ^^ doc_tannot_lem regtypes true base_typ) - else epp + if is_bitvector_typ base_typ && not (contains_t_pp_var base_typ) + then liftR (parens (epp ^^ doc_tannot_lem regtypes true base_typ)) + else liftR epp else if is_ctor env id then doc_id_lem_ctor id else doc_id_lem id (*| Base((_,t),Alias alias_info,_,eff,_,_) -> @@ -592,7 +626,7 @@ let doc_exp_lem, doc_let_lem = separate space [string "read_reg_range";string reg;doc_int start;doc_int stop] ^^ ta in if aexp_needed then parens (align epp) else epp )*) - | E_lit lit -> doc_lit_lem false lit annot + | E_lit lit -> doc_lit_lem regtypes false lit annot | E_cast(typ,e) -> expV aexp_needed e (* (match annot with @@ -633,14 +667,14 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (report l "cannot get record type") in let epp = anglebars (space ^^ (align (separate_map (semi_sp ^^ break 1) - (doc_fexp regtypes recordtyp) fexps)) ^^ space) in + (doc_fexp regtypes early_ret recordtyp) fexps)) ^^ space) in if aexp_needed then parens epp else epp | E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) -> let recordtyp = match annot with | Some (env, Typ_aux (Typ_id tid,_), _) when Env.is_record tid env -> tid | _ -> raise (report l "cannot get record type") in - anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp regtypes recordtyp) fexps)) + anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp regtypes early_ret recordtyp) fexps)) | E_vector exps -> let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (start, len, order, etyp) = @@ -653,7 +687,7 @@ let doc_exp_lem, doc_let_lem = | Tapp("vector", [TA_nexp start; TA_nexp len; TA_ord order; TA_typ etyp]) | Tabbrev(_,{t= Tapp("vector", [TA_nexp start; TA_nexp len; TA_ord order; TA_typ etyp])}) ->*) let dir,dir_out = if is_order_inc order then (true,"true") else (false, "false") in - let start = match start with + let start = match simplify_nexp start with | Nexp_aux (Nexp_constant i, _) -> string_of_int i | _ -> if dir then "0" else string_of_int (List.length exps) in let expspp = @@ -685,10 +719,10 @@ let doc_exp_lem, doc_let_lem = if is_vector_typ t then vector_typ_args_of t else raise (Reporting_basic.err_unreachable l "E_vector_indexed of non-vector type") in let dir,dir_out = if is_order_inc order then (true,"true") else (false, "false") in - let start = match start with + let start = match simplify_nexp start with | Nexp_aux (Nexp_constant i, _) -> string_of_int i | _ -> if dir then "0" else string_of_int (List.length iexps) in - let size = match len with + let size = match simplify_nexp len with | Nexp_aux (Nexp_constant i, _)-> string_of_int i | Nexp_aux (Nexp_exp (Nexp_aux (Nexp_constant i, _)), _) -> string_of_int (Util.power 2 i) @@ -769,10 +803,10 @@ let doc_exp_lem, doc_let_lem = pattern-matching on integers *) let epp = group ((separate space [string "match"; only_integers e; string "with"]) ^/^ - (separate_map (break 1) (doc_case regtypes) pexps) ^/^ + (separate_map (break 1) (doc_case regtypes early_ret) pexps) ^/^ (string "end")) in if aexp_needed then parens (align epp) else align epp - | E_exit e -> separate space [string "exit"; expY e;] + | E_exit e -> liftR (separate space [string "exit"; expY e;]) | E_assert (e1,e2) -> let epp = separate space [string "assert'"; expY e1; expY e2] in if aexp_needed then parens (align epp) else align epp @@ -889,46 +923,45 @@ let doc_exp_lem, doc_let_lem = | E_internal_return (e1) -> separate space [string "return"; expY e1;] | E_sizeof nexp -> - (match nexp with - | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem false (L_aux (L_num i, l)) annot + (match simplify_nexp nexp with + | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem regtypes false (L_aux (L_num i, l)) annot | _ -> raise (Reporting_basic.err_unreachable l "pretty-printing non-constant sizeof expressions to Lem not supported")) - | E_return _ -> - raise (Reporting_basic.err_todo l - "pretty-printing early return statements to Lem not yet supported") + | E_return r -> + align (string "early_return" ^//^ expV true r) | E_constraint _ | E_comment _ | E_comment_struc _ -> empty | E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ -> raise (Reporting_basic.err_unreachable l "unsupported internal expression encountered while pretty-printing") - and let_exp regtypes (LB_aux(lb,_)) = match lb with + and let_exp regtypes early_ret (LB_aux(lb,_)) = match lb with | LB_val_explicit(_,pat,e) | LB_val_implicit(pat,e) -> prefix 2 1 (separate space [string "let"; doc_pat_lem regtypes true pat; equals]) - (top_exp regtypes false e) + (top_exp regtypes early_ret false e) - and doc_fexp regtypes recordtyp (FE_aux(FE_Fexp(id,e),_)) = + and doc_fexp regtypes early_ret recordtyp (FE_aux(FE_Fexp(id,e),_)) = let fname = if prefix_recordtype then (string (string_of_id recordtyp ^ "_")) ^^ doc_id_lem id else doc_id_lem id in - group (doc_op equals fname (top_exp regtypes true e)) + group (doc_op equals fname (top_exp regtypes early_ret true e)) - and doc_case regtypes = function + and doc_case regtypes early_ret = function | Pat_aux(Pat_exp(pat,e),_) -> group (prefix 3 1 (separate space [pipe; doc_pat_lem regtypes false pat;arrow]) - (group (top_exp regtypes false e))) + (group (top_exp regtypes early_ret false e))) | Pat_aux(Pat_when(_,_,_),(l,_)) -> raise (Reporting_basic.err_unreachable l "guarded pattern expression should have been rewritten before pretty-printing") - and doc_lexp_deref_lem regtypes ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with + and doc_lexp_deref_lem regtypes early_ret ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with | LEXP_field (le,id) -> - parens (separate empty [doc_lexp_deref_lem regtypes le;dot;doc_id_lem id]) + parens (separate empty [doc_lexp_deref_lem regtypes early_ret le;dot;doc_id_lem id]) | LEXP_vector(le,e) -> - parens ((separate space) [string "access";doc_lexp_deref_lem regtypes le; - top_exp regtypes true e]) + parens ((separate space) [string "access";doc_lexp_deref_lem regtypes early_ret le; + top_exp regtypes early_ret true e]) | LEXP_id id -> doc_id_lem id | LEXP_cast (typ,id) -> doc_id_lem id | _ -> @@ -947,10 +980,10 @@ let rec doc_range_lem (BF_aux(r,_)) = match r with | BF_range(i1,i2) -> parens (doc_op comma (doc_int i1) (doc_int i2)) | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2) -let doc_typdef_lem regtypes (TD_aux(td,_)) = match td with +let doc_typdef_lem regtypes (TD_aux(td, (l, _))) = match td with | TD_abbrev(id,nm,typschm) -> doc_op equals (concat [string "type"; space; doc_id_lem_type id]) - (doc_typschm_lem regtypes typschm) + (doc_typschm_lem regtypes false typschm) | TD_record(id,nm,typq,fs,_) -> let f_pp (typ,fid) = let fname = if prefix_recordtype @@ -960,7 +993,7 @@ let doc_typdef_lem regtypes (TD_aux(td,_)) = match td with let fs_doc = group (separate_map (break 1) f_pp fs) in doc_op equals (concat [string "type"; space; doc_id_lem_type id;]) - (doc_typquant_lem typq (anglebars (space ^^ align fs_doc ^^ space))) + ((*doc_typquant_lem typq*) (anglebars (space ^^ align fs_doc ^^ space))) | TD_variant(id,nm,typq,ar,_) -> (match id with | Id_aux ((Id "read_kind"),_) -> empty @@ -971,13 +1004,14 @@ let doc_typdef_lem regtypes (TD_aux(td,_)) = match td with | Id_aux ((Id "regfp"),_) -> empty | Id_aux ((Id "niafp"),_) -> empty | Id_aux ((Id "diafp"),_) -> empty + | Id_aux ((Id "option"),_) -> empty | _ -> let ar_doc = group (separate_map (break 1) (doc_type_union_lem regtypes) ar) in let typ_pp = (doc_op equals) - (concat [string "type"; space; doc_id_lem_type id;]) - (doc_typquant_lem typq ar_doc) in + (concat [string "type"; space; doc_id_lem_type id; space; doc_typquant_items_lem typq]) + ((*doc_typquant_lem typq*) ar_doc) in let make_id pat id = separate space [string "SIA.Id_aux"; parens (string "SIA.Id " ^^ string_lit (doc_id id)); @@ -1142,7 +1176,7 @@ let doc_typdef_lem regtypes (TD_aux(td,_)) = match td with fromToInterpValuePP ^^ hardline else empty) | TD_register(id,n1,n2,rs) -> - match n1,n2 with + match n1, n2 with | Nexp_aux(Nexp_constant i1,_),Nexp_aux(Nexp_constant i2,_) -> let doc_rid (r,id) = parens (separate comma_sp [string_lit (doc_id_lem id); doc_range_lem r;]) in @@ -1153,25 +1187,64 @@ let doc_typdef_lem regtypes (TD_aux(td,_)) = match td with (string "Register_field" ^^ space ^^ string_lit(doc_id_lem id)) in*) let dir_b = i1 < i2 in let dir = string (if dir_b then "true" else "false") in + let dir_suffix = (if dir_b then "_inc" else "_dec") in + let ord = Ord_aux ((if dir_b then Ord_inc else Ord_dec), Parse_ast.Unknown) in let size = if dir_b then i2-i1 +1 else i1-i2 + 1 in - (doc_op equals) + let vtyp = vector_typ (nconstant i1) (nconstant size) ord bit_typ in + let tannot = doc_tannot_lem regtypes false vtyp in + let doc_field (fr, fid) = + let i, j = match fr with + | BF_aux (BF_single i, _) -> (i, i) + | BF_aux (BF_range (i, j), _) -> (i, j) + | _ -> raise (Reporting_basic.err_unreachable l "unsupported field type") in + let get, set = + "bitvector_subrange" ^ dir_suffix ^ " (reg, " ^ string_of_int i ^ ", " ^ string_of_int j ^ ")", + "bitvector_update" ^ dir_suffix ^ " (reg, " ^ string_of_int i ^ ", " ^ string_of_int j ^ ", v)" in + doc_op equals + (concat [string "let get_"; doc_id_lem id; underscore; doc_id_lem fid; + space; parens (string "reg" ^^ tannot)]) (string get) ^^ + hardline ^^ + doc_op equals + (concat [string "let set_"; doc_id_lem id; underscore; doc_id_lem fid; + space; parens (separate comma_sp [parens (string "reg" ^^ tannot); string "v"])]) (string set) + in + doc_op equals + (concat [string "type";space;doc_id_lem id]) + (doc_typ_lem regtypes vtyp) + ^^ hardline ^^ + doc_op equals (concat [string "let";space;string "build_";doc_id_lem id;space;string "regname"]) (string "Register" ^^ space ^^ align (separate space [string "regname"; doc_int size; doc_int i1; dir; break 0 ^^ brackets (align doc_rids)])) - (*^^ hardline ^^ - separate_map hardline doc_rfield rs *) + ^^ hardline ^^ + doc_op equals + (concat [string "let";space;string "cast_";doc_id_lem id;space;string "reg"]) + (string "reg") + ^^ hardline ^^ + doc_op equals + (concat [string "let";space;string "cast_to_";doc_id_lem id;space;string "reg"]) + (string "reg") + ^^ hardline ^^ + separate_map hardline doc_field rs let doc_rec_lem (Rec_aux(r,_)) = match r with | Rec_nonrec -> space | Rec_rec -> space ^^ string "rec" ^^ space let doc_tannot_opt_lem regtypes (Typ_annot_opt_aux(t,_)) = match t with - | Typ_annot_opt_some(tq,typ) -> doc_typquant_lem tq (doc_typ_lem regtypes typ) + | Typ_annot_opt_some(tq,typ) -> (*doc_typquant_lem tq*) (doc_typ_lem regtypes typ) + +let doc_fun_body_lem regtypes exp = + let early_ret = contains_early_return exp in + let doc_exp = doc_exp_lem regtypes early_ret false exp in + if early_ret + then align (string "catch_early_return" ^//^ parens (doc_exp)) + else doc_exp let doc_funcl_lem regtypes (FCL_aux(FCL_Funcl(id,pat,exp),_)) = group (prefix 3 1 ((doc_pat_lem regtypes false pat) ^^ space ^^ arrow) - (doc_exp_lem regtypes false exp)) + (doc_fun_body_lem regtypes exp)) let get_id = function | [] -> failwith "FD_function with empty list" @@ -1188,7 +1261,7 @@ let rec doc_fundef_lem regtypes (FD_aux(FD_function(r, typa, efa, fcls),fannot)) [(string "let") ^^ (doc_rec_lem r) ^^ (doc_id_lem id); (doc_pat_lem regtypes true pat); equals]) - (doc_exp_lem regtypes false exp) + (doc_fun_body_lem regtypes exp) | _ -> let id = get_id fcls in (* let sep = hardline ^^ pipe ^^ space in *) @@ -1230,7 +1303,7 @@ let rec doc_fundef_lem regtypes (FD_aux(FD_function(r, typa, efa, fcls),fannot)) let named_pat = P_aux (P_app (Id_aux (Id ctor,l),named_argspat),pannot) in let doc_arg idx (P_aux (p,(l,a))) = match p with | P_as (pat,id) -> doc_id_lem id - | P_lit lit -> doc_lit_lem false lit a + | P_lit lit -> doc_lit_lem regtypes false lit a | P_id id -> doc_id_lem id | _ -> string ("arg" ^ string_of_int idx) in let clauses = @@ -1256,37 +1329,33 @@ let rec doc_fundef_lem regtypes (FD_aux(FD_function(r, typa, efa, fcls),fannot)) -let doc_dec_lem (DEC_aux (reg,(l,annot))) = +let doc_dec_lem (DEC_aux (reg, ((l, _) as annot))) = match reg with | DEC_reg(typ,id) -> + let env = env_of_annot annot in (match typ with - | Typ_aux (Typ_app (r, [Typ_arg_aux (Typ_arg_typ rt, _)]), _) - when string_of_id r = "register" && is_vector_typ rt -> - let env = env_of_annot (l,annot) in - let (start, size, order, etyp) = vector_typ_args_of (Env.base_typ_of env rt) in - (match is_bit_typ (Env.base_typ_of env etyp), start, size with - | true, Nexp_aux (Nexp_constant start, _), Nexp_aux (Nexp_constant size, _) -> - let o = if is_order_inc order then "true" else "false" in - (doc_op equals) - (string "let" ^^ space ^^ doc_id_lem id) - (string "Register" ^^ space ^^ - align (separate space [string_lit(doc_id_lem id); - doc_int (size); - doc_int (start); - string o; - string "[]"])) - ^/^ hardline - | _ -> - let (Id_aux (Id name,_)) = id in - failwith ("can't deal with register " ^ name)) - | Typ_aux (Typ_app(r, [Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id idt, _)), _)]), _) - when string_of_id r = "register" -> - separate space [string "let";doc_id_lem id;equals; - string "build_" ^^ string (string_of_id idt);string_lit (doc_id_lem id)] ^/^ hardline - | Typ_aux (Typ_id idt, _) -> - separate space [string "let";doc_id_lem id;equals; - string "build_" ^^ string (string_of_id idt);string_lit (doc_id_lem id)] ^/^ hardline - |_-> empty) + | Typ_aux (Typ_id idt, _) when Env.is_regtyp idt env -> + separate space [string "let";doc_id_lem id;equals; + string "build_" ^^ string (string_of_id idt);string_lit (doc_id_lem id)] ^/^ hardline + | _ -> + let rt = Env.base_typ_of env typ in + if is_vector_typ rt then + let (start, size, order, etyp) = vector_typ_args_of rt in + if is_bit_typ etyp && is_nexp_constant start && is_nexp_constant size then + let o = if is_order_inc order then "true" else "false" in + (doc_op equals) + (string "let" ^^ space ^^ doc_id_lem id) + (string "Register" ^^ space ^^ + align (separate space [string_lit(doc_id_lem id); + doc_nexp (size); + doc_nexp (start); + string o; + string "[]"])) + ^/^ hardline + else raise (Reporting_basic.err_unreachable l + ("can't deal with register type " ^ string_of_typ typ)) + else raise (Reporting_basic.err_unreachable l + ("can't deal with register type " ^ string_of_typ typ))) | DEC_alias(id,alspec) -> empty | DEC_typ_alias(typ,id,alspec) -> empty @@ -1306,7 +1375,7 @@ let rec doc_def_lem regtypes def = match def with | DEF_default df -> (empty,empty) | DEF_fundef f_def -> (empty,group (doc_fundef_lem regtypes f_def) ^/^ hardline) - | DEF_val lbind -> (empty,group (doc_let_lem regtypes lbind) ^/^ hardline) + | DEF_val lbind -> (empty,group (doc_let_lem regtypes false lbind) ^/^ hardline) | DEF_scattered sdef -> failwith "doc_def_lem: shoulnd't have DEF_scattered at this point" | DEF_kind _ -> (empty,empty) diff --git a/src/pretty_print_ocaml.ml b/src/pretty_print_ocaml.ml index 66252d94..fc02f568 100644 --- a/src/pretty_print_ocaml.ml +++ b/src/pretty_print_ocaml.ml @@ -741,28 +741,23 @@ let doc_dec_ocaml (DEC_aux (reg,(l,annot))) = | DEC_reg(typ,id) -> if is_vector_typ typ then let (start, size, order, itemt) = vector_typ_args_of typ in - (* (match annot with - | Base((_,t),_,_,_,_,_) -> - (match t.t with - | Tapp("register", [TA_typ {t= Tapp("vector", [TA_nexp start; TA_nexp size; TA_ord order; TA_typ itemt])}]) - | Tapp("register", [TA_typ {t= Tabbrev(_,{t=Tapp("vector", [TA_nexp start; TA_nexp size; TA_ord order; TA_typ itemt])})}]) -> *) - (match is_bit_typ itemt, start, size with - | true, Nexp_aux (Nexp_constant start, _), Nexp_aux (Nexp_constant size, _) -> - let o = if is_order_inc order then string "true" else string "false" in - separate space [string "let"; - doc_id_ocaml id; - equals; - string "Vregister"; - parens (separate comma [separate space [string "ref"; - parens (separate space - [string "Array.make"; - doc_int size; - string "Vzero";])]; - doc_int start; - o; - string_lit (doc_id id); - brackets empty])] - | _ -> empty) + if is_bit_typ itemt && is_nexp_constant start && is_nexp_constant size then + let o = if is_order_inc order then string "true" else string "false" in + separate space [string "let"; + doc_id_ocaml id; + equals; + string "Vregister"; + parens (separate comma [separate space [string "ref"; + parens (separate space + [string "Array.make"; + doc_nexp size; + string "Vzero";])]; + doc_nexp start; + o; + string_lit (doc_id id); + brackets empty])] + else raise (Reporting_basic.err_unreachable l + ("can't deal with register type " ^ string_of_typ typ)) else (match typ with | Typ_aux (Typ_id idt, _) -> @@ -773,7 +768,8 @@ let doc_dec_ocaml (DEC_aux (reg,(l,annot))) = equals; doc_id_ocaml idt; string "None"] - |_-> failwith "type was not handled in register declaration") + |_-> raise (Reporting_basic.err_unreachable l + ("can't deal with register type " ^ string_of_typ typ))) (* | _ -> failwith "annot was not Base") *) | DEC_alias(id,alspec) -> empty (* doc_op equals (string "register alias" ^^ space ^^ doc_id id) (doc_alias alspec) *) diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index aff3a976..a63df7ea 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -117,7 +117,7 @@ let doc_pat, doc_atomic_pat = | P_vector_indexed ipats -> brackets (separate_map comma_sp npat ipats) | P_tup pats -> parens (separate_map comma_sp atomic_pat pats) | P_list pats -> squarebarbars (separate_map semi_sp atomic_pat pats) - | P_cons (pat1, pat2) -> separate space [atomic_pat pat1; string "::"; pat pat2] + | P_cons (pat1, pat2) -> separate space [atomic_pat pat1; coloncolon; pat pat2] | P_app(_, _ :: _) | P_vector_concat _ -> group (parens (pat pa)) and fpat (FP_aux(FP_Fpat(id,fpat),_)) = doc_op equals (doc_id id) (pat fpat) @@ -155,7 +155,7 @@ let doc_exp, doc_let = | E_vector_append(l,r) -> doc_op colon (shift_exp l) (cons_exp r) | E_cons(l,r) -> - doc_op colon (shift_exp l) (cons_exp r) + doc_op coloncolon (shift_exp l) (cons_exp r) | _ -> shift_exp expr and shift_exp ((E_aux(e,_)) as expr) = match e with | E_app_infix(l,(Id_aux(Id (">>" | ">>>" | "<<" | "<<<"),_) as op),r) -> diff --git a/src/process_file.ml b/src/process_file.ml index 91262ce9..6d4bcea0 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -98,6 +98,7 @@ let load_file order env f = let opt_just_check = ref false let opt_ddump_tc_ast = ref false +let opt_ddump_rewrite_ast = ref None let opt_dno_cast = ref false let check_ast (defs : unit Ast.defs) : Type_check.tannot Ast.defs * Type_check.Env.t = @@ -112,10 +113,6 @@ let monomorphise_ast locs ast = let ienv = Type_check.Env.no_casts Type_check.initial_env in Type_check.check ienv ast -let rewrite_ast (defs: Type_check.tannot Ast.defs) = Rewriter.rewrite_defs defs -let rewrite_ast_lem (defs: Type_check.tannot Ast.defs) = Rewriter.rewrite_defs_lem defs -let rewrite_ast_ocaml (defs: Type_check.tannot Ast.defs) = Rewriter.rewrite_defs_ocaml defs - let open_output_with_check file_name = let (temp_file_name, o) = Filename.open_temp_file "ll_temp" "" in let o' = Format.formatter_of_out_channel o in @@ -235,3 +232,20 @@ let output libpath out_arg files = (fun (f, defs) -> output1 libpath out_arg f defs) files + +let rewrite_step defs rewriter = + let defs = rewriter defs in + let _ = match !(opt_ddump_rewrite_ast) with + | Some (f, i) -> + begin + output "" Lem_ast_out [f ^ "_rewrite_" ^ string_of_int i ^ ".sail",defs]; + opt_ddump_rewrite_ast := Some (f, i + 1) + end + | _ -> () in + defs + +let rewrite rewriters defs = List.fold_left rewrite_step defs rewriters + +let rewrite_ast = rewrite [Rewriter.rewrite_defs] +let rewrite_ast_lem = rewrite Rewriter.rewrite_defs_lem +let rewrite_ast_ocaml = rewrite Rewriter.rewrite_defs_ocaml diff --git a/src/process_file.mli b/src/process_file.mli index 7972c689..cd867b0d 100644 --- a/src/process_file.mli +++ b/src/process_file.mli @@ -54,6 +54,7 @@ val load_file : Ast.order -> Type_check.Env.t -> string -> Type_check.tannot Ast val opt_new_parser : bool ref val opt_just_check : bool ref val opt_ddump_tc_ast : bool ref +val opt_ddump_rewrite_ast : ((string * int) option) ref val opt_dno_cast : bool ref type out_type = diff --git a/src/rewriter.ml b/src/rewriter.ml index 8cf682bf..8da8aacf 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -66,7 +66,13 @@ let effect_of_fexp (FE_aux (_,(_,a))) = effect_of_annot a let effect_of_fexps (FES_aux (FES_Fexps (fexps,_),_)) = List.fold_left union_effects no_effect (List.map effect_of_fexp fexps) let effect_of_opt_default (Def_val_aux (_,(_,a))) = effect_of_annot a -let effect_of_pexp (Pat_aux (_,(_,a))) = effect_of_annot a +(* The typechecker does not seem to annotate pexps themselves *) +let effect_of_pexp (Pat_aux (pexp,(_,a))) = match a with + | Some (_, _, eff) -> eff + | None -> + (match pexp with + | Pat_exp (_, e) -> effect_of e + | Pat_when (_, g, e) -> union_effects (effect_of g) (effect_of e)) let effect_of_lb (LB_aux (_,(_,a))) = effect_of_annot a let get_loc_exp (E_aux (_,(l,_))) = l @@ -143,8 +149,8 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with List.fold_left union_effects (effect_of e) (List.map effect_of_pexp pexps) | E_let (lb,e) -> union_effects (effect_of_lb lb) (effect_of e) | E_assign (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e) - | E_exit e -> effect_of e - | E_return e -> effect_of e + | E_exit e -> union_effects eff (effect_of e) + | E_return e -> union_effects eff (effect_of e) | E_sizeof _ | E_sizeof_internal _ | E_constraint _ -> no_effect | E_assert (c,m) -> eff | E_comment _ | E_comment_struc _ -> no_effect @@ -1048,6 +1054,7 @@ let rewrite_sizeof (Defs defs) = Id_aux (Id op, Parse_ast.Unknown), E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) ) in + let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in (match nexp with | Nexp_constant i -> E_lit (L_aux (L_num i, l)) | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2 @@ -1058,14 +1065,15 @@ let rewrite_sizeof (Defs defs) = (* Rewrite calls to functions which have had parameters added to pass values of type-level variables; these are added as sizeof expressions first, and then further rewritten as above. *) - let e_app_aux param_map (exp, ((l,_) as annot)) = + let e_app_aux param_map ((exp, exp_orig), ((l,_) as annot)) = let full_exp = E_aux (exp, annot) in + let orig_exp = E_aux (exp_orig, annot) in match exp with | E_app (f, args) -> if Bindings.mem f param_map then (* Retrieve instantiation of the type variables of the called function - for the given parameters in the current environment *) - let inst = instantiation_of full_exp in + for the given parameters in the original environment *) + let inst = instantiation_of orig_exp in let kid_exp kid = begin match KBindings.find kid inst with | U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp)) @@ -1075,9 +1083,75 @@ let rewrite_sizeof (Defs defs) = " of function " ^ string_of_id f)) end in let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in - E_aux (E_app (f, kid_exps @ args), annot) - else full_exp - | _ -> full_exp in + (E_aux (E_app (f, kid_exps @ args), annot), orig_exp) + else (full_exp, orig_exp) + | _ -> (full_exp, orig_exp) in + + (* Plug this into a folding algorithm that also keeps around a copy of the + original expressions, which we use to infer instantiations of type variables + in the original environments *) + let copy_exp_alg = + { e_block = (fun es -> let (es, es') = List.split es in (E_block es, E_block es')) + ; e_nondet = (fun es -> let (es, es') = List.split es in (E_nondet es, E_nondet es')) + ; e_id = (fun id -> (E_id id, E_id id)) + ; e_lit = (fun lit -> (E_lit lit, E_lit lit)) + ; e_cast = (fun (typ,(e,e')) -> (E_cast (typ,e), E_cast (typ,e'))) + ; e_app = (fun (id,es) -> let (es, es') = List.split es in (E_app (id,es), E_app (id,es'))) + ; e_app_infix = (fun ((e1,e1'),id,(e2,e2')) -> (E_app_infix (e1,id,e2), E_app_infix (e1',id,e2'))) + ; e_tuple = (fun es -> let (es, es') = List.split es in (E_tuple es, E_tuple es')) + ; e_if = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_if (e1,e2,e3), E_if (e1',e2',e3'))) + ; e_for = (fun (id,(e1,e1'),(e2,e2'),(e3,e3'),order,(e4,e4')) -> (E_for (id,e1,e2,e3,order,e4), E_for (id,e1',e2',e3',order,e4'))) + ; e_vector = (fun es -> let (es, es') = List.split es in (E_vector es, E_vector es')) + ; e_vector_indexed = (fun (es,(opt2,opt2')) -> let (is, es) = List.split es in let (es, es') = List.split es in let (es, es') = (List.combine is es, List.combine is es') in (E_vector_indexed (es,opt2), E_vector_indexed (es',opt2'))) + ; e_vector_access = (fun ((e1,e1'),(e2,e2')) -> (E_vector_access (e1,e2), E_vector_access (e1',e2'))) + ; e_vector_subrange = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_vector_subrange (e1,e2,e3), E_vector_subrange (e1',e2',e3'))) + ; e_vector_update = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_vector_update (e1,e2,e3), E_vector_update (e1',e2',e3'))) + ; e_vector_update_subrange = (fun ((e1,e1'),(e2,e2'),(e3,e3'),(e4,e4')) -> (E_vector_update_subrange (e1,e2,e3,e4), E_vector_update_subrange (e1',e2',e3',e4'))) + ; e_vector_append = (fun ((e1,e1'),(e2,e2')) -> (E_vector_append (e1,e2), E_vector_append (e1',e2'))) + ; e_list = (fun es -> let (es, es') = List.split es in (E_list es, E_list es')) + ; e_cons = (fun ((e1,e1'),(e2,e2')) -> (E_cons (e1,e2), E_cons (e1',e2'))) + ; e_record = (fun (fexps, fexps') -> (E_record fexps, E_record fexps')) + ; e_record_update = (fun ((e1,e1'),(fexp,fexp')) -> (E_record_update (e1,fexp), E_record_update (e1',fexp'))) + ; e_field = (fun ((e1,e1'),id) -> (E_field (e1,id), E_field (e1',id))) + ; e_case = (fun ((e1,e1'),pexps) -> let (pexps, pexps') = List.split pexps in (E_case (e1,pexps), E_case (e1',pexps'))) + ; e_let = (fun ((lb,lb'),(e2,e2')) -> (E_let (lb,e2), E_let (lb',e2'))) + ; e_assign = (fun ((lexp,lexp'),(e2,e2')) -> (E_assign (lexp,e2), E_assign (lexp',e2'))) + ; e_sizeof = (fun nexp -> (E_sizeof nexp, E_sizeof nexp)) + ; e_exit = (fun (e1,e1') -> (E_exit (e1), E_exit (e1'))) + ; e_return = (fun (e1,e1') -> (E_return e1, E_return e1')) + ; e_assert = (fun ((e1,e1'),(e2,e2')) -> (E_assert(e1,e2), E_assert(e1',e2')) ) + ; e_internal_cast = (fun (a,(e1,e1')) -> (E_internal_cast (a,e1), E_internal_cast (a,e1'))) + ; e_internal_exp = (fun a -> (E_internal_exp a, E_internal_exp a)) + ; e_internal_exp_user = (fun (a1,a2) -> (E_internal_exp_user (a1,a2), E_internal_exp_user (a1,a2))) + ; e_comment = (fun c -> (E_comment c, E_comment c)) + ; e_comment_struc = (fun (e,e') -> (E_comment_struc e, E_comment_struc e')) + ; e_internal_let = (fun ((lexp,lexp'), (e2,e2'), (e3,e3')) -> (E_internal_let (lexp,e2,e3), E_internal_let (lexp',e2',e3'))) + ; e_internal_plet = (fun (pat, (e1,e1'), (e2,e2')) -> (E_internal_plet (pat,e1,e2), E_internal_plet (pat,e1',e2'))) + ; e_internal_return = (fun (e,e') -> (E_internal_return e, E_internal_return e')) + ; e_aux = (fun ((e,e'),annot) -> (E_aux (e,annot), E_aux (e',annot))) + ; lEXP_id = (fun id -> (LEXP_id id, LEXP_id id)) + ; lEXP_memory = (fun (id,es) -> let (es, es') = List.split es in (LEXP_memory (id,es), LEXP_memory (id,es'))) + ; lEXP_cast = (fun (typ,id) -> (LEXP_cast (typ,id), LEXP_cast (typ,id))) + ; lEXP_tup = (fun tups -> let (tups,tups') = List.split tups in (LEXP_tup tups, LEXP_tup tups')) + ; lEXP_vector = (fun ((lexp,lexp'),(e2,e2')) -> (LEXP_vector (lexp,e2), LEXP_vector (lexp',e2'))) + ; lEXP_vector_range = (fun ((lexp,lexp'),(e2,e2'),(e3,e3')) -> (LEXP_vector_range (lexp,e2,e3), LEXP_vector_range (lexp',e2',e3'))) + ; lEXP_field = (fun ((lexp,lexp'),id) -> (LEXP_field (lexp,id), LEXP_field (lexp',id))) + ; lEXP_aux = (fun ((lexp,lexp'),annot) -> (LEXP_aux (lexp,annot), LEXP_aux (lexp',annot))) + ; fE_Fexp = (fun (id,(e,e')) -> (FE_Fexp (id,e), FE_Fexp (id,e'))) + ; fE_aux = (fun ((fexp,fexp'),annot) -> (FE_aux (fexp,annot), FE_aux (fexp',annot))) + ; fES_Fexps = (fun (fexps,b) -> let (fexps, fexps') = List.split fexps in (FES_Fexps (fexps,b), FES_Fexps (fexps',b))) + ; fES_aux = (fun ((fexp,fexp'),annot) -> (FES_aux (fexp,annot), FES_aux (fexp',annot))) + ; def_val_empty = (Def_val_empty, Def_val_empty) + ; def_val_dec = (fun (e,e') -> (Def_val_dec e, Def_val_dec e')) + ; def_val_aux = (fun ((defval,defval'),aux) -> (Def_val_aux (defval,aux), Def_val_aux (defval',aux))) + ; pat_exp = (fun (pat,(e,e')) -> (Pat_exp (pat,e), Pat_exp (pat,e'))) + ; pat_when = (fun (pat,(e1,e1'),(e2,e2')) -> (Pat_when (pat,e1,e2), Pat_when (pat,e1',e2'))) + ; pat_aux = (fun ((pexp,pexp'),a) -> (Pat_aux (pexp,a), Pat_aux (pexp',a))) + ; lB_val_explicit = (fun (typ,pat,(e,e')) -> (LB_val_explicit (typ,pat,e), LB_val_explicit (typ,pat,e'))) + ; lB_val_implicit = (fun (pat,(e,e')) -> (LB_val_implicit (pat,e), LB_val_implicit (pat,e'))) + ; lB_aux = (fun ((lb,lb'),annot) -> (LB_aux (lb,annot), LB_aux (lb',annot))) + ; pat_alg = id_pat_alg + } in let rewrite_sizeof_fun params_map (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) = @@ -1086,7 +1160,7 @@ let rewrite_sizeof (Defs defs) = let body_typ = typ_of exp in let nmap = nexps_from_params pat in (* first rewrite calls to other functions... *) - let exp' = fold_exp { id_exp_alg with e_aux = e_app_aux params_map } exp in + let exp' = fst (fold_exp { copy_exp_alg with e_aux = e_app_aux params_map } exp) in (* ... then rewrite sizeof expressions in current function body *) let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in (FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls, @@ -1133,9 +1207,10 @@ let rewrite_sizeof (Defs defs) = let rewrite_sizeof_fundef (params_map, defs) = function | DEF_fundef fd -> let (nvars, fd') = rewrite_sizeof_fun params_map fd in + let id = id_of_fundef fd in let params_map' = if KidSet.is_empty nvars then params_map - else Bindings.add (id_of_fundef fd) nvars params_map in + else Bindings.add id nvars params_map in (params_map', defs @ [DEF_fundef fd']) | def -> (params_map, defs @ [def]) in @@ -1947,7 +2022,7 @@ let rewrite_defs_remove_bitvector_pats (Defs defs) = let defvals = List.map (fun lb -> DEF_val lb) letbinds in [DEF_val (LB_aux (LB_val_implicit (pat',exp),a))] @ defvals | d -> [d] in - Defs (List.flatten (List.map rewrite_def defs)) + fst (check initial_env (Defs (List.flatten (List.map rewrite_def defs)))) (* Remove pattern guards by rewriting them to if-expressions within the @@ -1999,7 +2074,7 @@ let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as f | (E_aux(E_assign((LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_)) as le,e), ((l, Some (env,typ,eff)) as annot)) as exp)::exps -> (match Env.lookup_id id env with - | Unbound -> + | Unbound | Local _ -> let le' = rewriters.rewrite_lexp rewriters le in let e' = rewrite_base e in let exps' = walker exps in @@ -2136,12 +2211,72 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base rewrite_def = rewrite_def; rewrite_defs = rewrite_defs_base} defs*) -let rewrite_defs_ocaml = - top_sort_defs >> - rewrite_defs_remove_vector_concat >> - rewrite_sizeof >> - rewrite_defs_exp_lift_assign (* >> +let rewrite_defs_early_return = + let is_return (E_aux (exp, _)) = match exp with + | E_return _ -> true + | _ -> false in + + let get_return (E_aux (e, (l, _)) as exp) = match e with + | E_return e -> e + | _ -> exp in + + let e_block es = + (* let rec walker = function + | e :: es -> if is_return e then [e] else e :: walker es + | [] -> [] in + let es = walker es in *) + match es with + | [E_aux (e, _)] -> e + | _ -> E_block es in + + let e_if (e1, e2, e3) = + if is_return e2 && is_return e3 then E_if (e1, get_return e2, get_return e3) + else E_if (e1, e2, e3) in + + let e_case (e, pes) = + let is_return_pexp (Pat_aux (pexp, _)) = match pexp with + | Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in + let get_return_pexp (Pat_aux (pexp, a)) = match pexp with + | Pat_exp (p, e) -> Pat_aux (Pat_exp (p, get_return e), a) + | Pat_when (p, g, e) -> Pat_aux (Pat_when (p, g, get_return e), a) in + if List.for_all is_return_pexp pes + then E_return (E_aux (E_case (e, List.map get_return_pexp pes), (Parse_ast.Unknown, None))) + else E_case (e, pes) in + + let e_aux (exp, (l, annot)) = + let full_exp = fix_eff_exp (E_aux (exp, (l, annot))) in + match annot with + | Some (env, typ, eff) when is_return full_exp -> + (* Add escape effect annotation, since we use the exception mechanism + of the state monad to implement early return in the Lem backend *) + let annot' = Some (env, typ, union_effects eff (mk_effect [BE_escape])) in + E_aux (exp, (l, annot')) + | _ -> full_exp in + + let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pat, exp), a)) = + let exp = fold_exp + { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; + e_aux = e_aux } exp in + let a = match a with + | (l, Some (env, typ, eff)) -> + (l, Some (env, typ, union_effects eff (effect_of exp))) + | _ -> a in + FCL_aux (FCL_Funcl (id, pat, get_return exp), a) in + + let rewrite_fun_early_return rewriters + (FD_aux (FD_function (rec_opt, tannot_opt, effect_opt, funcls), a)) = + FD_aux (FD_function (rec_opt, tannot_opt, effect_opt, + List.map (rewrite_funcl_early_return rewriters) funcls), a) in + + rewrite_defs_base { rewriters_base with rewrite_fun = rewrite_fun_early_return } + +let rewrite_defs_ocaml = [ + top_sort_defs; + rewrite_defs_remove_vector_concat; + rewrite_sizeof; + rewrite_defs_exp_lift_assign (* ; rewrite_defs_separate_numbs *) + ] let rewrite_defs_remove_blocks = let letbind_wild v body = @@ -2398,7 +2533,7 @@ let rewrite_defs_letbind_effects = | E_case (exp1,pexps) -> let newreturn = List.fold_left - (fun b (Pat_aux (_,(_,annot))) -> b || effectful_effs (effect_of_annot annot)) + (fun b pexp -> b || effectful_effs (effect_of_pexp pexp)) false pexps in n_exp_name exp1 (fun exp1 -> n_pexpL newreturn pexps (fun pexps -> @@ -2766,7 +2901,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | _ -> raise (Reporting_basic.err_unreachable l "assignment without effects annotation") in - if not (List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs) then + if effectful exp then Same_vars (E_aux (E_assign (lexp,vexp),annot)) else (match lexp with @@ -2968,18 +3103,22 @@ let rewrite_defs_remove_e_assign = ; rewrite_defs = rewrite_defs_base } - -let rewrite_defs_lem = - top_sort_defs >> - rewrite_sizeof >> - rewrite_defs_remove_vector_concat >> - rewrite_defs_remove_bitvector_pats >> - rewrite_defs_guarded_pats >> - rewrite_defs_exp_lift_assign >> - rewrite_defs_remove_blocks >> - rewrite_defs_letbind_effects >> - rewrite_defs_remove_e_assign >> - rewrite_defs_effectful_let_expressions >> - rewrite_defs_remove_superfluous_letbinds >> +let recheck_defs defs = fst (check initial_env defs) + +let rewrite_defs_lem =[ + top_sort_defs; + rewrite_sizeof; + rewrite_defs_remove_vector_concat; + rewrite_defs_remove_bitvector_pats; + rewrite_defs_guarded_pats; + (* recheck_defs; *) + rewrite_defs_early_return; + rewrite_defs_exp_lift_assign; + rewrite_defs_remove_blocks; + rewrite_defs_letbind_effects; + rewrite_defs_remove_e_assign; + rewrite_defs_effectful_let_expressions; + rewrite_defs_remove_superfluous_letbinds; rewrite_defs_remove_superfluous_returns + ] diff --git a/src/rewriter.mli b/src/rewriter.mli index 473456f6..9dbdee3d 100644 --- a/src/rewriter.mli +++ b/src/rewriter.mli @@ -56,8 +56,8 @@ type 'a rewriters = { rewrite_exp : 'a rewriters -> 'a exp -> 'a exp; val rewrite_exp : tannot rewriters -> tannot exp -> tannot exp val rewrite_defs : tannot defs -> tannot defs -val rewrite_defs_ocaml : tannot defs -> tannot defs (*Perform rewrites to exclude AST nodes not supported for ocaml out*) -val rewrite_defs_lem : tannot defs -> tannot defs (*Perform rewrites to exclude AST nodes not supported for lem out*) +val rewrite_defs_ocaml : (tannot defs -> tannot defs) list (*Perform rewrites to exclude AST nodes not supported for ocaml out*) +val rewrite_defs_lem : (tannot defs -> tannot defs) list (*Perform rewrites to exclude AST nodes not supported for lem out*) (* the type of interpretations of pattern-matching expressions *) type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg = @@ -154,3 +154,10 @@ val fold_exp : ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_a 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg -> 'a exp -> 'exp val id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg + +val compute_exp_alg : 'b -> ('b -> 'b -> 'b) -> + ('a,('b * 'a exp),('b * 'a exp_aux),('b * 'a lexp),('b * 'a lexp_aux),('b * 'a fexp), + ('b * 'a fexp_aux),('b * 'a fexps),('b * 'a fexps_aux), + ('b * 'a opt_default_aux),('b * 'a opt_default),('b * 'a pexp),('b * 'a pexp_aux), + ('b * 'a letbind_aux),('b * 'a letbind), + ('b * 'a pat),('b * 'a pat_aux),('b * 'a fpat),('b * 'a fpat_aux)) exp_alg diff --git a/src/sail.ml b/src/sail.ml index d03060dc..cf366d42 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -97,6 +97,9 @@ let options = Arg.align ([ ( "-ddump_tc_ast", Arg.Set opt_ddump_tc_ast, " (debug) dump the typechecked ast to stdout"); + ( "-ddump_rewrite_ast", + Arg.String (fun l -> opt_ddump_rewrite_ast := Some (l, 0)), + " <prefix> (debug) dump the ast after each rewriting step to <prefix>_<i>.lem"); ( "-dtc_verbose", Arg.Int (fun verbosity -> Type_check.opt_tc_debug := verbosity), " (debug) verbose typechecker output: 0 is silent"); diff --git a/src/type_check.ml b/src/type_check.ml index bd2db570..6186a431 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -157,6 +157,12 @@ let is_range (Typ_aux (typ_aux, _)) = when string_of_id f = "range" -> Some (n1, n2) | _ -> None +let is_list (Typ_aux (typ_aux, _)) = + match typ_aux with + | Typ_app (f, [Typ_arg_aux (Typ_arg_typ typ, _)]) + when string_of_id f = "list" -> Some typ + | _ -> None + let nconstant c = Nexp_aux (Nexp_constant c, Parse_ast.Unknown) let nminus n1 n2 = Nexp_aux (Nexp_minus (n1, n2), Parse_ast.Unknown) let nsum n1 n2 = Nexp_aux (Nexp_sum (n1, n2), Parse_ast.Unknown) @@ -406,6 +412,7 @@ module Env : sig val add_union_id : id -> typquant * typ -> t -> t val add_flow : id -> (typ -> typ) -> t -> t val get_flow : id -> t -> typ -> typ + val is_register : id -> t -> bool val get_register : id -> t -> typ val add_register : id -> typ -> t -> t val add_regtyp : id -> int -> int -> (index_range * id) list -> t -> t @@ -423,6 +430,9 @@ module Env : sig val get_typ_synonym : id -> t -> t -> typ_arg list -> typ val add_overloads : id -> id list -> t -> t val get_overloads : id -> t -> id list + val is_extern : id -> t -> bool + val add_extern : id -> string -> t -> t + val get_extern : id -> t -> string val get_default_order : t -> order val set_default_order_inc : t -> t val set_default_order_dec : t -> t @@ -453,6 +463,7 @@ end = struct enums : IdSet.t Bindings.t; records : (typquant * (typ * id) list) Bindings.t; accessors : (typquant * typ) Bindings.t; + externs : string Bindings.t; casts : id list; allow_casts : bool; constraints : n_constraint list; @@ -474,6 +485,7 @@ end = struct enums = Bindings.empty; records = Bindings.empty; accessors = Bindings.empty; + externs = Bindings.empty; casts = []; allow_casts = true; constraints = []; @@ -694,6 +706,9 @@ end = struct { env with flow = Bindings.add id (fun typ -> f (get_flow id env typ)) env.flow } end + let is_register id env = + Bindings.mem id env.registers + let get_register id env = try Bindings.find id env.registers with | Not_found -> typ_error (id_loc id) ("No register binding found for " ^ string_of_id id) @@ -706,6 +721,16 @@ end = struct typ_print ("Adding overloads for " ^ string_of_id id ^ " [" ^ string_of_list ", " string_of_id ids ^ "]"); { env with overloads = Bindings.add id ids env.overloads } + let is_extern id env = + Bindings.mem id env.externs + + let add_extern id ext env = + { env with externs = Bindings.add id ext env.externs } + + let get_extern id env = + try Bindings.find id env.externs with + | Not_found -> typ_error (id_loc id) ("No extern binding found for " ^ string_of_id id) + let get_casts env = env.casts let check_index_range cmp f t (BF_aux (ir, l)) = @@ -863,6 +888,19 @@ end = struct | Typ_arg_typ typ -> Typ_arg_aux (Typ_arg_typ (expand_synonyms env typ), l) | arg -> Typ_arg_aux (arg, l) + let get_default_order env = + match env.default_order with + | None -> typ_error Parse_ast.Unknown ("No default order has been set") + | Some ord -> ord + + let set_default_order o env = + match env.default_order with + | None -> { env with default_order = Some (Ord_aux (o, Parse_ast.Unknown)) } + | Some _ -> typ_error Parse_ast.Unknown ("Cannot change default order once already set") + + let set_default_order_inc = set_default_order Ord_inc + let set_default_order_dec = set_default_order Ord_dec + let base_typ_of env typ = let rec aux (Typ_aux (t,a)) = let rewrap t = Typ_aux (t,a) in @@ -876,6 +914,11 @@ end = struct aux rtyp | Typ_app (id, targs) -> rewrap (Typ_app (id, List.map aux_arg targs)) + | Typ_id id when is_regtyp id env -> + let base, top, ranges = get_regtyp id env in + let len = abs(top - base) + 1 in + vector_typ (nconstant base) (nconstant len) (get_default_order env) bit_typ + (* TODO registers with non-default order? non-bitvector registers? *) | t -> rewrap t and aux_arg (Typ_arg_aux (targ,a)) = let rewrap targ = Typ_arg_aux (targ,a) in @@ -884,19 +927,6 @@ end = struct | targ -> rewrap targ in aux (expand_synonyms env typ) - let get_default_order env = - match env.default_order with - | None -> typ_error Parse_ast.Unknown ("No default order has been set") - | Some ord -> ord - - let set_default_order o env = - match env.default_order with - | None -> { env with default_order = Some (Ord_aux (o, Parse_ast.Unknown)) } - | Some _ -> typ_error Parse_ast.Unknown ("Cannot change default order once already set") - - let set_default_order_inc = set_default_order Ord_inc - let set_default_order_dec = set_default_order Ord_dec - end @@ -1837,13 +1867,30 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ Pat_aux (Pat_when (tpat, checked_guard, crule check_exp env case typ), (l, None)) in annot_exp (E_case (inferred_exp, List.map (fun case -> check_case case typ) cases)) typ + | E_cons (x, xs), _ -> + begin + match is_list (Env.expand_synonyms env typ) with + | Some elem_typ -> + let checked_xs = crule check_exp env xs typ in + let checked_x = crule check_exp env x elem_typ in + annot_exp (E_cons (checked_x, checked_xs)) typ + | None -> typ_error l ("Cons " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) + end + | E_list xs, _ -> + begin + match is_list (Env.expand_synonyms env typ) with + | Some elem_typ -> + let checked_xs = List.map (fun x -> crule check_exp env x elem_typ) xs in + annot_exp (E_list checked_xs) typ + | None -> typ_error l ("List " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) + end | E_let (LB_aux (letbind, (let_loc, _)), exp), _ -> begin match letbind with | LB_val_explicit (typschm, pat, bind) -> assert false | LB_val_implicit (P_aux (P_typ (ptyp, _), _) as pat, bind) -> let checked_bind = crule check_exp env bind ptyp in - let tpat, env = bind_pat env pat (typ_of checked_bind) in + let tpat, env = bind_pat env pat ptyp in annot_exp (E_let (LB_aux (LB_val_implicit (tpat, checked_bind), (let_loc, None)), crule check_exp env exp typ)) typ | LB_val_implicit (pat, bind) -> let inferred_bind = irule infer_exp env bind in @@ -1897,7 +1944,8 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | E_lit (L_aux (L_undef, _) as lit), _ -> annot_exp_effect (E_lit lit) typ (mk_effect [BE_undef]) (* This rule allows registers of type t to be passed by name with type register<t>*) - | E_id reg, Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "register" -> + | E_id reg, Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) + when string_of_id id = "register" && Env.is_register reg env -> let rtyp = Env.get_register reg env in subtyp l env rtyp typ; annot_exp (E_id reg) typ (* CHECK: is this subtyp the correct way around? *) | E_id id, _ when is_union_id id env -> @@ -1975,7 +2023,7 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = end and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) = - typ_print ("Binding " ^ string_of_typ typ); + typ_print ("Binding " ^ string_of_pat pat ^ " to " ^ string_of_typ typ); let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in let switch_typ (P_aux (pat_aux, (l, Some (env, _, eff)))) typ = P_aux (pat_aux, (l, Some (env, typ, eff))) in let bind_tuple_pat (tpats, env) pat typ = @@ -2161,10 +2209,11 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as end in let regtyp, inferred_flexp, is_register = infer_flexp flexp in - let eff = if is_register then mk_effect [BE_wreg] else no_effect in typ_debug ("REGTYP: " ^ string_of_typ regtyp ^ " / " ^ string_of_typ (Env.expand_synonyms env regtyp)); match Env.expand_synonyms env regtyp with + | Typ_aux (Typ_app (Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id regtyp_id, _)), _)]), _) | Typ_aux (Typ_id regtyp_id, _) when Env.is_regtyp regtyp_id env -> + let eff = mk_effect [BE_wreg] in let base, top, ranges = Env.get_regtyp regtyp_id env in let range, _ = try List.find (fun (_, id) -> Id.compare id field = 0) ranges with @@ -2180,6 +2229,7 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as let checked_exp = crule check_exp env exp vec_typ in annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp eff, field)) vec_typ) checked_exp, env | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> + let eff = if is_register then mk_effect [BE_wreg] else no_effect in let (typq, Typ_aux (Typ_fn (rectyp_q, field_typ, _), _)) = Env.get_accessor rectyp_id field env in let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q regtyp with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in let field_typ' = subst_unifiers unifiers field_typ in @@ -2282,11 +2332,12 @@ and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ = (* Not sure about this case... can the left lexp be anything other than an identifier? *) | LEXP_vector (LEXP_aux (LEXP_id v, _), exp) -> begin - let is_immutable, vtyp = match Env.lookup_id v env with + let is_immutable, is_register, vtyp = match Env.lookup_id v env with | Unbound -> typ_error l "Cannot assign to element of unbound vector" | Enum _ -> typ_error l "Cannot vector assign to enumeration element" - | Local (Immutable, vtyp) -> true, vtyp - | Local (Mutable, vtyp) | Register vtyp -> false, vtyp + | Local (Immutable, vtyp) -> true, false, vtyp + | Local (Mutable, vtyp) -> false, false, vtyp + | Register vtyp -> false, true, vtyp in let access = infer_exp (Env.enable_casts env) (E_aux (E_app (mk_id "vector_access", [E_aux (E_id v, (l, ())); exp]), (l, ()))) in let E_aux (E_app (_, [_; inferred_exp]), _) = access in @@ -2294,6 +2345,9 @@ and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ = | Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_typ deref_typ, _)]), _) when string_of_id id = "register" -> subtyp l env typ deref_typ; annot_lexp (LEXP_vector (annot_lexp_effect (LEXP_id v) vtyp (mk_effect [BE_wreg]), inferred_exp)) typ, env + | _ when not is_immutable && is_register -> + subtyp l env typ (typ_of access); + annot_lexp (LEXP_vector (annot_lexp_effect (LEXP_id v) vtyp (mk_effect [BE_wreg]), inferred_exp)) typ, env | _ when not is_immutable -> subtyp l env typ (typ_of access); annot_lexp (LEXP_vector (annot_lexp (LEXP_id v) vtyp, inferred_exp)) typ, env @@ -2336,26 +2390,28 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = let inferred_exp = irule infer_exp env exp in match Env.expand_synonyms env (typ_of inferred_exp) with (* Accessing a (bit) field of a register *) - | Typ_aux (Typ_id regtyp, _) when Env.is_regtyp regtyp env -> + | Typ_aux (Typ_app (Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ ((Typ_aux (Typ_id regtyp, _) as regtyp_aux)), _)]), _) + | (Typ_aux (Typ_id regtyp, _) as regtyp_aux) when Env.is_regtyp regtyp env -> let base, top, ranges = Env.get_regtyp regtyp env in let range, _ = try List.find (fun (_, id) -> Id.compare id field = 0) ranges with | Not_found -> typ_error l ("Field " ^ string_of_id field ^ " doesn't exist for register type " ^ string_of_id regtyp) in + let checked_exp = crule check_exp env (strip_exp inferred_exp) regtyp_aux in begin match range, Env.get_default_order env with | BF_aux (BF_single n, _), Ord_aux (Ord_dec, _) -> let vec_typ = dvector_typ env (nconstant n) (nconstant 1) bit_typ in - annot_exp (E_field (inferred_exp, field)) vec_typ + annot_exp (E_field (checked_exp, field)) vec_typ | BF_aux (BF_range (n, m), _), Ord_aux (Ord_dec, _) -> let vec_typ = dvector_typ env (nconstant n) (nconstant (n - m + 1)) bit_typ in - annot_exp (E_field (inferred_exp, field)) vec_typ + annot_exp (E_field (checked_exp, field)) vec_typ | BF_aux (BF_single n, _), Ord_aux (Ord_inc, _) -> let vec_typ = dvector_typ env (nconstant n) (nconstant 1) bit_typ in - annot_exp (E_field (inferred_exp, field)) vec_typ + annot_exp (E_field (checked_exp, field)) vec_typ | BF_aux (BF_range (n, m), _), Ord_aux (Ord_inc, _) -> let vec_typ = dvector_typ env (nconstant n) (nconstant (m - n + 1)) bit_typ in - annot_exp (E_field (inferred_exp, field)) vec_typ + annot_exp (E_field (checked_exp, field)) vec_typ | _, _ -> typ_error l "Invalid register field type" end (* Accessing a field of a record *) @@ -2674,6 +2730,13 @@ and propagate_exp_effect_aux = function let p_lb, eff = propagate_letbind_effect letbind in let p_exp = propagate_exp_effect exp in E_let (p_lb, p_exp), union_effects (effect_of p_exp) eff + | E_cons (x, xs) -> + let p_x = propagate_exp_effect x in + let p_xs = propagate_exp_effect xs in + E_cons (p_x, p_xs), union_effects (effect_of p_x) (effect_of p_xs) + | E_list xs -> + let p_xs = List.map propagate_exp_effect xs in + E_list p_xs, collect_effects p_xs | E_assign (lexp, exp) -> let p_lexp = propagate_lexp_effect lexp in let p_exp = propagate_exp_effect exp in @@ -2908,8 +2971,12 @@ let check_val_spec env (VS_aux (vs, (l, _))) = let (id, quants, typ, env) = match vs with | VS_val_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id) -> (id, quants, typ, env) | VS_cast_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id) -> (id, quants, typ, Env.add_cast id env) - | VS_extern_no_rename (TypSchm_aux (TypSchm_ts (quants, typ), _), id) -> (id, quants, typ, env) - | VS_extern_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id, _) -> (id, quants, typ, env) in + | VS_extern_no_rename (TypSchm_aux (TypSchm_ts (quants, typ), _), id) -> + let env = Env.add_extern id (string_of_id id) env in + (id, quants, typ, env) + | VS_extern_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id, ext) -> + let env = Env.add_extern id ext env in + (id, quants, typ, env) in [DEF_spec (VS_aux (vs, (l, None)))], Env.add_val_spec id (quants, typ) env let check_default env (DT_aux (ds, l)) = @@ -2957,7 +3024,11 @@ let check_type_union env variant typq (Tu_aux (tu, l)) = let ret_typ = app_typ variant (List.fold_left fold_union_quant [] (quant_items typq)) in match tu with | Tu_id v -> Env.add_union_id v (typq, ret_typ) env - | Tu_ty_id (typ, v) -> Env.add_val_spec v (typq, mk_typ (Typ_fn (typ, ret_typ, no_effect))) env + | Tu_ty_id (typ, v) -> + let typ' = mk_typ (Typ_fn (typ, ret_typ, no_effect)) in + env + |> Env.add_union_id v (typq, typ') + |> Env.add_val_spec v (typq, typ') let mk_synonym typq typ = let kopts, ncs = quant_split typq in @@ -3012,7 +3083,8 @@ let rec check_def env def = | DEF_default default -> check_default env default | DEF_overload (id, ids) -> [DEF_overload (id, ids)], Env.add_overloads id ids env | DEF_reg_dec (DEC_aux (DEC_reg (typ, id), (l, _))) -> - [DEF_reg_dec (DEC_aux (DEC_reg (typ, id), (l, None)))], Env.add_register id typ env + let env = Env.add_register id typ env in + [DEF_reg_dec (DEC_aux (DEC_reg (typ, id), (l, Some (env, typ, no_effect))))], env | DEF_reg_dec (DEC_aux (DEC_alias (id, aspec), (l, annot))) -> cd_err () | DEF_reg_dec (DEC_aux (DEC_typ_alias (typ, id, aspec), (l, tannot))) -> cd_err () | DEF_scattered _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "Scattered given to type checker") diff --git a/src/type_check.mli b/src/type_check.mli index de825d06..0a85624c 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -101,6 +101,10 @@ module Env : sig val get_overloads : id -> t -> id list + val is_extern : id -> t -> bool + + val get_extern : id -> t -> string + (* Lookup id searchs for a specified id in the environment, and returns it's type and what kind of identifier it is, using the lvar type. Returns Unbound if the identifier is unbound, and diff --git a/test/typecheck/pass/add_vec_lit.sail b/test/typecheck/pass/add_vec_lit.sail index be897021..4d662a8d 100644 --- a/test/typecheck/pass/add_vec_lit.sail +++ b/test/typecheck/pass/add_vec_lit.sail @@ -3,7 +3,7 @@ default Order inc val extern forall Num 'n. (bit['n], bit['n]) -> bit['n] effect pure add_vec = "add_vec" val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range = "add_range" -val cast forall Num 'n. bit['n] -> [|0: 2** 'n - 1|] effect pure cast_vec_range +val cast forall Num 'n. bit['n] -> [|0: 2** 'n - 1|] effect pure unsigned overload (deinfix +) [add_vec; add_range] diff --git a/test/typecheck/pass/arm_FPEXC1.sail b/test/typecheck/pass/arm_FPEXC1.sail index cfae86a1..f711a5ad 100644 --- a/test/typecheck/pass/arm_FPEXC1.sail +++ b/test/typecheck/pass/arm_FPEXC1.sail @@ -1,9 +1,9 @@ default Order dec -val forall Num 'n. (bit['n], int) -> bit effect pure vector_access +val extern forall Num 'n. (bit['n], int) -> bit effect pure vector_access = "bitvector_access_dec" -val forall Num 'n, Num 'm, Num 'o, 'm >= 'o, 'o >= 0, 'n >= 'm + 1. - (bit['n], [:'m:], [:'o:]) -> bit['m - ('o - 1)] effect pure vector_subrange +val extern forall Num 'n, Num 'm, Num 'o, 'm >= 'o, 'o >= 0, 'n >= 'm + 1. + (bit['n], [:'m:], [:'o:]) -> bit['m - ('o - 1)] effect pure vector_subrange = "bitvector_subrange_dec" register vector<32 - 1, 32, dec, bit> _FPEXC32_EL2 diff --git a/test/typecheck/pass/bv_simple_index.sail b/test/typecheck/pass/bv_simple_index.sail index 72e1b094..811b3a5b 100644 --- a/test/typecheck/pass/bv_simple_index.sail +++ b/test/typecheck/pass/bv_simple_index.sail @@ -1,7 +1,7 @@ -val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec -val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc +val forall Nat 'n, Nat 'l, 'l >= 0. (vector<'n,'l,dec,bit>, [|'n - 'l + 1:'n|]) -> bit effect pure bitvector_access_dec +val forall Nat 'n, Nat 'l, 'l >= 0. (vector<'n,'l,inc,bit>, [|'n:'n + 'l - 1|]) -> bit effect pure bitvector_access_inc -overload vector_access [vector_access_inc; vector_access_dec] +overload vector_access [bitvector_access_inc; bitvector_access_dec] val cast bit -> bool effect pure cast_bit_bool diff --git a/test/typecheck/pass/bv_simple_index_bit.sail b/test/typecheck/pass/bv_simple_index_bit.sail index 2ba5b928..46bc19d6 100644 --- a/test/typecheck/pass/bv_simple_index_bit.sail +++ b/test/typecheck/pass/bv_simple_index_bit.sail @@ -1,7 +1,7 @@ -val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec -val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc +val forall Nat 'n, Nat 'l, 'l >= 0. (vector<'n,'l,dec,bit>, [|'n - 'l + 1:'n|]) -> bit effect pure bitvector_access_dec +val forall Nat 'n, Nat 'l, 'l >= 0. (vector<'n,'l,inc,bit>, [|'n:'n + 'l - 1|]) -> bit effect pure bitvector_access_inc -overload vector_access [vector_access_inc; vector_access_dec] +overload vector_access [bitvector_access_inc; bitvector_access_dec] function bit bv ((bit[64]) x) = { diff --git a/test/typecheck/pass/case_simple_constraints.sail b/test/typecheck/pass/case_simple_constraints.sail index f1b87235..335e10ee 100644 --- a/test/typecheck/pass/case_simple_constraints.sail +++ b/test/typecheck/pass/case_simple_constraints.sail @@ -1,9 +1,9 @@ -val forall Nat 'n, Nat 'm. ([:'n + 20:], [:'m:]) -> [:'n + 20 + 'm:] effect pure plus +val extern forall Nat 'n, Nat 'm. ([:'n + 20:], [:'m:]) -> [:'n + 20 + 'm:] effect pure plus = "add" -val forall Nat 'n, 'n <= -10. [:'n:] -> [:'n:] effect pure minus_ten_id +val extern forall Nat 'n, 'n <= -10. [:'n:] -> [:'n:] effect pure minus_ten_id = "id" -val forall Nat 'n, 'n >= 10. [:'n:] -> [:'n:] effect pure ten_id +val extern forall Nat 'n, 'n >= 10. [:'n:] -> [:'n:] effect pure ten_id = "id" val forall Nat 'N, 'N >= 63. [|10:'N|] -> [|10:'N|] effect pure branch diff --git a/test/typecheck/pass/deinfix_plus.sail b/test/typecheck/pass/deinfix_plus.sail index c5a0f1ee..8fc7c00e 100644 --- a/test/typecheck/pass/deinfix_plus.sail +++ b/test/typecheck/pass/deinfix_plus.sail @@ -1,6 +1,6 @@ default Order inc -val extern forall Num 'n. (bit['n], bit['n]) -> bit['n] effect pure bv_add = "bv_add_inc" +val extern forall Num 'n. (bit['n], bit['n]) -> bit['n] effect pure bv_add = "add_vec" overload (deinfix +) [bv_add] diff --git a/test/typecheck/pass/flow_gt1.sail b/test/typecheck/pass/flow_gt1.sail index acfbab68..ddeefd53 100644 --- a/test/typecheck/pass/flow_gt1.sail +++ b/test/typecheck/pass/flow_gt1.sail @@ -1,17 +1,17 @@ default Order inc -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range = "add" -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range = "sub" -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom = "gteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range = "gteq" overload (deinfix +) [add_range] overload (deinfix -) [sub_range] diff --git a/test/typecheck/pass/flow_gteq1.sail b/test/typecheck/pass/flow_gteq1.sail index 8918438c..47f7aa0f 100644 --- a/test/typecheck/pass/flow_gteq1.sail +++ b/test/typecheck/pass/flow_gteq1.sail @@ -1,17 +1,17 @@ default Order inc -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range = "add" -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range = "sub" -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom = "gteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range = "gteq" overload (deinfix +) [add_range] overload (deinfix -) [sub_range] diff --git a/test/typecheck/pass/flow_lt1.sail b/test/typecheck/pass/flow_lt1.sail index 0f3c1bbc..c210ed7a 100644 --- a/test/typecheck/pass/flow_lt1.sail +++ b/test/typecheck/pass/flow_lt1.sail @@ -1,11 +1,17 @@ default Order inc -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range = "add" -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range = "sub" -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom = "gteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range = "gteq" overload (deinfix +) [add_range] overload (deinfix -) [sub_range] diff --git a/test/typecheck/pass/flow_lt2.sail b/test/typecheck/pass/flow_lt2.sail index effe0bc4..cccebaa3 100644 --- a/test/typecheck/pass/flow_lt2.sail +++ b/test/typecheck/pass/flow_lt2.sail @@ -1,11 +1,17 @@ default Order inc -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range = "add" -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range = "sub" -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom = "gteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range = "gteq" overload (deinfix +) [add_range] overload (deinfix -) [sub_range] diff --git a/test/typecheck/pass/flow_lt_assign.sail b/test/typecheck/pass/flow_lt_assign.sail index 4e787741..9601f48f 100644 --- a/test/typecheck/pass/flow_lt_assign.sail +++ b/test/typecheck/pass/flow_lt_assign.sail @@ -1,11 +1,17 @@ default Order inc -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range = "add" -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range = "sub" -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom = "gteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range = "gteq" overload (deinfix +) [add_range] overload (deinfix -) [sub_range] diff --git a/test/typecheck/pass/flow_lteq1.sail b/test/typecheck/pass/flow_lteq1.sail index d32831a2..ffa4dd8b 100644 --- a/test/typecheck/pass/flow_lteq1.sail +++ b/test/typecheck/pass/flow_lteq1.sail @@ -1,17 +1,17 @@ default Order inc -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n+'o:'m+'p|] effect pure add_range = "add" -val forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. ([|'n:'m|], [|'o:'p|]) -> [|'n-'p:'m-'o|] effect pure sub_range = "sub" -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom -val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range -val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom = "gteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range = "lt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range = "lteq" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range = "gt" +val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range = "gteq" overload (deinfix +) [add_range] overload (deinfix -) [sub_range] diff --git a/test/typecheck/pass/let_subtyp_bug.sail b/test/typecheck/pass/let_subtyp_bug.sail new file mode 100644 index 00000000..e2abde2d --- /dev/null +++ b/test/typecheck/pass/let_subtyp_bug.sail @@ -0,0 +1,9 @@ +let ([|5|]) y = 2 + +val unit -> nat effect pure test + +function test() = { + let ([|5|]) x = 2 in + x +} +
\ No newline at end of file diff --git a/test/typecheck/pass/list_cons.sail b/test/typecheck/pass/list_cons.sail new file mode 100644 index 00000000..6f103bf6 --- /dev/null +++ b/test/typecheck/pass/list_cons.sail @@ -0,0 +1 @@ +function list<int> foo ((int) i, (list<int>) l) = i :: l diff --git a/test/typecheck/pass/list_cons2.sail b/test/typecheck/pass/list_cons2.sail new file mode 100644 index 00000000..8c34282b --- /dev/null +++ b/test/typecheck/pass/list_cons2.sail @@ -0,0 +1,7 @@ +function list<int> foo ((int) i, (list<int>) l) = i :: l + +function list<int> bar () = [||||] + +function list<int> baz ((list<int>) l) = l + +function list<int> quux () = baz ([||||]) diff --git a/test/typecheck/pass/list_lit.sail b/test/typecheck/pass/list_lit.sail new file mode 100644 index 00000000..d4febadf --- /dev/null +++ b/test/typecheck/pass/list_lit.sail @@ -0,0 +1,2 @@ + +let (list<int>) xs = [||1,2,3,4,5,6||] diff --git a/test/typecheck/pass/mips_CP0Cause_BD_assign1.sail b/test/typecheck/pass/mips_CP0Cause_BD_assign1.sail index 4dc63e71..7808b2c0 100644 --- a/test/typecheck/pass/mips_CP0Cause_BD_assign1.sail +++ b/test/typecheck/pass/mips_CP0Cause_BD_assign1.sail @@ -1,5 +1,5 @@ -val cast forall Nat 'n, Order 'ord. [:1:] -> vector<'n,1,'ord,bit> effect pure cast_one_bv -val cast forall Nat 'n, Order 'ord. [:0:] -> vector<'n,1,'ord,bit> effect pure cast_zero_bv +val cast forall Nat 'n, Order 'ord. [:1:] -> vector<'n,1,'ord,bit> effect pure cast_1_vec +val cast forall Nat 'n, Order 'ord. [:0:] -> vector<'n,1,'ord,bit> effect pure cast_0_vec default Order dec diff --git a/test/typecheck/pass/mips_CP0Cause_BD_assign2.sail b/test/typecheck/pass/mips_CP0Cause_BD_assign2.sail index b35a0767..26f161e2 100644 --- a/test/typecheck/pass/mips_CP0Cause_BD_assign2.sail +++ b/test/typecheck/pass/mips_CP0Cause_BD_assign2.sail @@ -1,5 +1,5 @@ -val cast forall Nat 'n, Order 'ord. [:1:] -> vector<'n,1,'ord,bit> effect pure cast_one_bv -val cast forall Nat 'n, Order 'ord. [:0:] -> vector<'n,1,'ord,bit> effect pure cast_zero_bv +val cast forall Nat 'n, Order 'ord. [:1:] -> vector<'n,1,'ord,bit> effect pure cast_1_vec +val cast forall Nat 'n, Order 'ord. [:0:] -> vector<'n,1,'ord,bit> effect pure cast_0_vec default Order dec diff --git a/test/typecheck/pass/mips_CP0Cause_access.sail b/test/typecheck/pass/mips_CP0Cause_access.sail index c0e318c4..eb3b9389 100644 --- a/test/typecheck/pass/mips_CP0Cause_access.sail +++ b/test/typecheck/pass/mips_CP0Cause_access.sail @@ -3,10 +3,10 @@ effect pure ADJUST *) -val forall Num 'n, Num 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec -val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc +val extern forall Num 'n, Num 'l, 'l >= 0. (vector<'n,'l,dec,bit>, [|'n - 'l + 1:'n|]) -> bit effect pure bitvector_access_dec +val extern forall Nat 'n, Nat 'l, 'l >= 0. (vector<'n,'l,inc,bit>, [|'n:'n + 'l - 1|]) -> bit effect pure bitvector_access_inc -overload vector_access [vector_access_inc; vector_access_dec] +overload vector_access [bitvector_access_inc; bitvector_access_dec] default Order dec diff --git a/test/typecheck/pass/mips_reg_field_bit.sail b/test/typecheck/pass/mips_reg_field_bit.sail index 33560bde..4c37a6e9 100644 --- a/test/typecheck/pass/mips_reg_field_bit.sail +++ b/test/typecheck/pass/mips_reg_field_bit.sail @@ -1,8 +1,8 @@ +default Order dec + val cast forall Nat 'n, Nat 'm, Nat 'o, 'o >= 'm - 1. vector<'n,'m,dec,bit> -> vector<'o,'m,dec,bit> - effect pure ADJUST - -default Order dec + effect pure adjust_dec typedef CauseReg = register bits [ 31 : 0 ] { 31 : BD; (* branch delay *) diff --git a/test/typecheck/pass/mips_reg_field_bv.sail b/test/typecheck/pass/mips_reg_field_bv.sail index 4b82d4de..0ce19b4f 100644 --- a/test/typecheck/pass/mips_reg_field_bv.sail +++ b/test/typecheck/pass/mips_reg_field_bv.sail @@ -1,8 +1,8 @@ +default Order dec + val cast forall Nat 'n, Nat 'm, Nat 'o, 'o >= 'm - 1. vector<'n,'m,dec,bit> -> vector<'o,'m,dec,bit> - effect pure ADJUST - -default Order dec + effect pure adjust_dec typedef CauseReg = register bits [ 31 : 0 ] { 31 : BD; (* branch delay *) diff --git a/test/typecheck/pass/overload_plus.sail b/test/typecheck/pass/overload_plus.sail index 5390a5a4..2aa8ecc5 100644 --- a/test/typecheck/pass/overload_plus.sail +++ b/test/typecheck/pass/overload_plus.sail @@ -1,6 +1,6 @@ default Order inc -val extern forall Nat 'n. (bit['n], bit['n]) -> bit['n] effect pure bv_add = "bv_add_inc" +val extern forall Nat 'n. (bit['n], bit['n]) -> bit['n] effect pure bv_add = "add_vec" overload (deinfix +) [bv_add] diff --git a/test/typecheck/pass/regtyp_vec.sail b/test/typecheck/pass/regtyp_vec.sail index c939cce8..28978882 100644 --- a/test/typecheck/pass/regtyp_vec.sail +++ b/test/typecheck/pass/regtyp_vec.sail @@ -3,12 +3,12 @@ effect pure ADJUST *) -val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec +val forall Nat 'n, Nat 'l, 'l >= 0. (vector<'n,'l,dec,bit>, [|'n - 'l + 1:'n|]) -> bit effect pure bitvector_access_dec (* val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc *) -overload vector_access [vector_access_dec] +overload vector_access [bitvector_access_dec] default Order dec diff --git a/test/typecheck/pass/set_mark.sail b/test/typecheck/pass/set_mark.sail index 59710c46..7bc7370b 100644 --- a/test/typecheck/pass/set_mark.sail +++ b/test/typecheck/pass/set_mark.sail @@ -1,5 +1,5 @@ -val cast forall Num 'n, Num 'm, Order 'ord. [:0:] -> vector<'n,'m,'ord,bit> effect pure cast_zero_bv +val cast forall Num 'n, Num 'm, Order 'ord. [:0:] -> vector<'n,'m,'ord,bit> effect pure cast_0_vec function forall Num 'N, 'N IN {32}. bit['N] Foo32( (bit['N]) x) = x diff --git a/test/typecheck/pass/set_mark2.sail b/test/typecheck/pass/set_mark2.sail index c1433058..cabfb1af 100644 --- a/test/typecheck/pass/set_mark2.sail +++ b/test/typecheck/pass/set_mark2.sail @@ -1,4 +1,4 @@ -val cast forall Num 'n, Num 'm, Order 'ord. [:0:] -> vector<'n,'m,'ord,bit> effect pure cast_zero_bv +val cast forall Num 'n, Num 'm, Order 'ord. [:0:] -> vector<'n,'m,'ord,bit> effect pure cast_0_vec function forall Nat 'N, 'N IN {32, 64}. bit['N] Foo32( (bit['N]) x) = x diff --git a/test/typecheck/pass/vec_pat1.sail b/test/typecheck/pass/vec_pat1.sail index 0a79d701..fe9b4a0a 100644 --- a/test/typecheck/pass/vec_pat1.sail +++ b/test/typecheck/pass/vec_pat1.sail @@ -1,13 +1,16 @@ default Order inc -val extern forall Num 'n. (bit['n], bit['n]) -> bit['n] effect pure bv_add = "bv_add_inc" +val extern forall Num 'n. (bit['n], bit['n]) -> bit['n] effect pure bv_add = "add_vec" -val forall Num 'n, Num 'm, Num 'o, Num 'p, Type 'a. - (vector<'n,'m,inc,'a>, vector<'o,'p,inc,'a>) -> vector<'n,'m + 'p,inc,'a> - effect pure vector_append_inc +val extern forall Num 'n, Num 'l, Num 'm, Num 'o, 'l >= 0, 'm <= 'o, 'o <= 'l. + (vector<'n,'l,inc,bit>, [:'m:], [:'o:]) -> vector<'m,'o + 1 - 'm,inc,bit> effect pure vector_subrange = "bitvector_subrange_inc" + +val forall Num 'n, Num 'm, Num 'o, Num 'p. + (vector<'n,'m,inc,bit>, vector<'o,'p,inc,bit>) -> vector<'n,'m + 'p,inc,bit> + effect pure bitvector_concat overload (deinfix +) [bv_add] -overload vector_append [vector_append_inc] +overload vector_append [bitvector_concat] val (bit[3], bit[3]) -> bit[3] effect pure test diff --git a/test/typecheck/pass/vector_append.sail b/test/typecheck/pass/vector_append.sail index af83c44d..17db3fbd 100644 --- a/test/typecheck/pass/vector_append.sail +++ b/test/typecheck/pass/vector_append.sail @@ -1,7 +1,6 @@ - -val forall Nat 'n1, Nat 'l1, Nat 'n2, Nat 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2 >= 0. - (vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'n1,'l1 + 'l2,'o,'a> effect pure vector_append +val extern forall Nat 'n1, Nat 'l1, Nat 'n2, Nat 'l2, Order 'o, 'l1 >= 0, 'l2 >= 0. + (vector<'n1,'l1,'o,bit>, vector<'n2,'l2,'o,bit>) -> vector<'n1,'l1 + 'l2,'o,bit> effect pure vector_append = "bitvector_concat" default Order inc @@ -12,4 +11,4 @@ function bit[8] test (v1, v2) = zv := vector_append(v1, v2); zv := v1 : v2; zv -}
\ No newline at end of file +} diff --git a/test/typecheck/pass/vector_append_gen.sail b/test/typecheck/pass/vector_append_gen.sail index ddb027ee..ce63ed87 100644 --- a/test/typecheck/pass/vector_append_gen.sail +++ b/test/typecheck/pass/vector_append_gen.sail @@ -1,8 +1,6 @@ -val forall Nat 'n, Nat 'l, Order 'o, Type 'a, 'l >= 0. (vector<'n,'l,'o,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access - -val forall Nat 'n1, Nat 'l1, Nat 'n2, Nat 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2 >= 0. - (vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'n1,'l1 + 'l2,'o,'a> effect pure vector_append +val extern forall Nat 'n1, Nat 'l1, Nat 'n2, Nat 'l2, Order 'o, 'l1 >= 0, 'l2 >= 0. + (vector<'n1,'l1,'o,bit>, vector<'n2,'l2,'o,bit>) -> vector<'n1,'l1 + 'l2,'o,bit> effect pure vector_append = "bitvector_concat" default Order inc @@ -11,4 +9,4 @@ val forall 'n, 'm, 'n >= 0, 'm >= 0. (bit['n], bit['m]) -> bit['n + 'm] effect p function forall 'n, 'm. bit['n + 'm] test (v1, v2) = { vector_append(v1, v2); -}
\ No newline at end of file +} diff --git a/test/typecheck/pass/vector_subrange_gen.sail b/test/typecheck/pass/vector_subrange_gen.sail index 8857bd18..4ec067de 100644 --- a/test/typecheck/pass/vector_subrange_gen.sail +++ b/test/typecheck/pass/vector_subrange_gen.sail @@ -4,17 +4,18 @@ val forall Nat 'n, Nat 'l, Order 'o, Type 'a, 'l >= 0. (vector<'n,'l,'o,'a>, [|' val forall Nat 'n1, Nat 'l1, Nat 'n2, Nat 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2 >= 0. (vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'n1,'l1 + 'l2,'o,'a> effect pure vector_append -val forall Nat 'n, Nat 'l, Nat 'm, Nat 'u, Order 'o, Type 'a, 'l >= 0, 'm <= 'u, 'u <= 'l. (vector<'n,'l,'o,'a>, [:'m:], [:'u:]) -> vector<'m,'u - 'm,'o,'a> effect pure vector_subrange +val extern forall Num 'n, Num 'l, Num 'm, Num 'o, 'l >= 0, 'm <= 'o, 'o <= 'l. + (vector<'n,'l,inc,bit>, [:'m:], [:'o:]) -> vector<'m,('o - 'm) + 1,inc,bit> effect pure vector_subrange = "bitvector_subrange_inc" -val forall Nat 'n, Nat 'm. ([:'n:], [:'m:]) -> [:'n - 'm:] effect pure minus +val forall Nat 'n, Nat 'm. ([:'n:], [:'m:]) -> [:'n - 'm:] effect pure sub default Order inc -val forall 'n, 'm, 'n >= 5. bit['n] -> bit['n - 2] effect pure test +val forall 'n, 'm, 'n >= 5. bit['n] -> bit['n - 1] effect pure test -function forall 'n, 'n >= 5. bit['n - 2] test v = +function forall 'n, 'n >= 5. bit['n - 1] test v = { - z := vector_subrange(v, 0, minus(sizeof 'n, 2)); - z := v[0 .. minus(sizeof 'n, 2)]; + z := vector_subrange(v, 0, sub(sizeof 'n, 2)); + z := v[0 .. sub(sizeof 'n, 2)]; z -}
\ No newline at end of file +} diff --git a/test/typecheck/pass/vector_synonym_cast.sail b/test/typecheck/pass/vector_synonym_cast.sail index 72a7e9d0..f1de42e9 100644 --- a/test/typecheck/pass/vector_synonym_cast.sail +++ b/test/typecheck/pass/vector_synonym_cast.sail @@ -1,7 +1,7 @@ typedef vecsyn = vector<0,1,dec,bit> -val cast vector<1,1,dec,bit> -> vector<0,1,dec,bit> effect pure test_cast +val cast vector<1,1,dec,bit> -> vector<0,1,dec,bit> effect pure adjust_dec val vector<1,1,dec,bit> -> vecsyn effect pure test diff --git a/test/typecheck/run_tests.sh b/test/typecheck/run_tests.sh index 8659e60e..073a6251 100755 --- a/test/typecheck/run_tests.sh +++ b/test/typecheck/run_tests.sh @@ -22,6 +22,7 @@ cat $SAILDIR/lib/prelude.sail $MIPS/mips_prelude.sail > $DIR/pass/mips_prelude.s cat $SAILDIR/lib/prelude.sail $MIPS/mips_prelude.sail $MIPS/mips_tlb.sail > $DIR/pass/mips_tlb.sail cat $SAILDIR/lib/prelude.sail $MIPS/mips_prelude.sail $MIPS/mips_tlb.sail $MIPS/mips_wrappers.sail > $DIR/pass/mips_wrappers.sail cat $SAILDIR/lib/prelude.sail $MIPS/mips_prelude.sail $MIPS/mips_tlb.sail $MIPS/mips_wrappers.sail $MIPS/mips_insts.sail $MIPS/mips_epilogue.sail > $DIR/pass/mips_insts.sail +cat $SAILDIR/lib/prelude.sail $MIPS/mips_prelude.sail $MIPS/mips_tlb_stub.sail $MIPS/mips_wrappers.sail $MIPS/mips_insts.sail $MIPS/mips_epilogue.sail > $DIR/pass/mips_notlb.sail pass=0 fail=0 @@ -99,14 +100,20 @@ finish_suite "Expecting fail" function test_lem { for i in `ls $DIR/pass/`; do - if $SAILDIR/sail -lem $DIR/$1/$i 2> /dev/null + # MIPS requires an additional library, Mips_extras_embed. + # It might be useful to allow adding options for specific test cases. + # For now, include the library for all test cases, which doesn't seem to hurt. + if $SAILDIR/sail -lem -lem_lib Mips_extras_embed $DIR/$1/$i 2> /dev/null then green "generated lem for $1/$i" "pass" + cp $MIPS/mips_extras_embed_sequential.lem $DIR/lem/ mv $SAILDIR/${i%%.*}_embed_types.lem $DIR/lem/ mv $SAILDIR/${i%%.*}_embed.lem $DIR/lem/ mv $SAILDIR/${i%%.*}_embed_sequential.lem $DIR/lem/ - if lem -lib $SAILDIR/src/lem_interp -lib $SAILDIR/src/gen_lib/ $DIR/lem/${i%%.*}_embed_types.lem $DIR/lem/${i%%.*}_embed.lem 2> /dev/null + # Test sequential embedding for now + # TODO: Add tests for the free monad + if lem -lib $SAILDIR/src/lem_interp -lib $SAILDIR/src/gen_lib/ $DIR/lem/mips_extras_embed_sequential.lem $DIR/lem/${i%%.*}_embed_types.lem $DIR/lem/${i%%.*}_embed_sequential.lem 2> /dev/null then green "typechecking lem for $1/$i" "pass" else @@ -133,6 +140,8 @@ function test_ocaml { if $SAILDIR/sail -ocaml $DIR/$1/$i 2> /dev/null then green "generated ocaml for $1/$i" "pass" + + rm $SAILDIR/${i%%.*}.ml else red "generated ocaml for $1/$i" "fail" fi diff --git a/x86/x64.sail b/x86/x64.sail new file mode 100644 index 00000000..a5e0710c --- /dev/null +++ b/x86/x64.sail @@ -0,0 +1,1227 @@ +(*========================================================================*) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(*========================================================================*) + +default Order dec + +val extern forall Type 'a. ('a, list<'a>) -> bool effect pure ismember +val extern forall Type 'a. list<'a> -> nat effect pure listlength +val extern forall Nat 'n. (bit['n],[|'n|]) -> bit['n] effect pure ASR +val extern forall Nat 'n. (bit['n],[|'n|]) -> bit['n] effect pure LSR +val extern forall Nat 'n. (bit['n],[|'n|]) -> bit['n] effect pure ROR +val extern forall Nat 'n. (bit['n],[|'n|]) -> bit['n] effect pure ROL +val cast bool -> bit effect pure cast_bool_bit +val cast bit -> int effect pure cast_bit_int +val extern forall Num 'n. int -> bit['n] effect pure cast_int_vec +val extern forall 'n, 'm, 'o, 'n <= 0, 'm <= 'o. [|'n:'m|] -> [|0:'o|] effect pure negative_to_zero + +typedef byte = bit[8] +typedef qword = bit[64] +typedef regn = [|15|] +typedef byte_stream = list<byte> +typedef ostream = option<byte_stream> + +(* -------------------------------------------------------------------------- + Registers + -------------------------------------------------------------------------- *) + +(* Program Counter *) + +register qword RIP + +(* General purpose registers *) + +register qword RAX (* 0 *) +register qword RCX (* 1 *) +register qword RDX (* 2 *) +register qword RBX (* 3 *) +register qword RSP (* 4 *) +register qword RBP (* 5 *) +register qword RSI (* 6 *) +register qword RDI (* 7 *) +register qword R8 +register qword R9 +register qword R10 +register qword R11 +register qword R12 +register qword R13 +register qword R14 +register qword R15 + +let (vector<0,16,inc,(register<qword>)>) REG = + [RAX,RCX,RDX,RBX,RSP,RBP,RSI,RDI,R8,R9,R10,R11,R12,R13,R14,R15] + +(* Flags *) + +register bit CF +register bit PF +register bit AF +register bit ZF +register bit SF +register bit OF + +(* -------------------------------------------------------------------------- + Memory + -------------------------------------------------------------------------- *) + +val extern forall Nat 'n. (qword, [|'n|]) -> (bit[8 * 'n]) effect { rmem } MEM + +val extern forall Nat 'n. (qword, [|'n|], bit[8 * 'n]) -> unit effect { wmem } wMEM + +(* -------------------------------------------------------------------------- + Helper functions + -------------------------------------------------------------------------- *) + +(* Instruction addressing modes *) + +typedef size = const union { + bool Sz8; + unit Sz16; + unit Sz32; + unit Sz64; +} + +typedef base = const union { + unit NoBase; + unit RipBase; + regn RegBase; +} + +typedef scale_index = (bit[2],regn) + +typedef rm = const union { + regn Reg; + (option<scale_index>,base,qword) Mem; +} + +typedef dest_src = const union { + (rm,qword) Rm_i; + (rm,regn) Rm_r; + (regn,rm) R_rm; +} + +typedef imm_rm = const union { + rm Rm; + qword Imm; +} + +typedef monop_name = enumerate { Dec; Inc; Not; Neg } + +typedef binop_name = enumerate { + Add; Or; Adc; Sbb; And; Sub; Xor; Cmp; Rol; Ror; Rcl; Rcr; Shl; Shr; Test; Sar +} + +function binop_name opc_to_binop_name ((bit[4]) opc) = + switch opc + { + case 0x0 -> Add + case 0x1 -> Or + case 0x2 -> Adc + case 0x3 -> Sbb + case 0x4 -> And + case 0x5 -> Sub + case 0x6 -> Xor + case 0x7 -> Cmp + case 0x8 -> Rol + case 0x9 -> Ror + case 0xa -> Rcl + case 0xb -> Rcr + case 0xc -> Shl + case 0xd -> Shr + case 0xe -> Test + case 0xf -> Sar + } + +typedef cond = enumerate { + O; NO; B; NB; E; NE; NA; A; S; NS; P; NP; L; NL; NG; G; ALWAYS +} + +function cond bv_to_cond ((bit[4]) v) = + switch v + { + case 0x0 -> O + case 0x1 -> NO + case 0x2 -> B + case 0x3 -> NB + case 0x4 -> E + case 0x5 -> NE + case 0x6 -> NA + case 0x7 -> A + case 0x8 -> S + case 0x9 -> NS + case 0xa -> P + case 0xb -> NP + case 0xc -> L + case 0xd -> NL + case 0xe -> NG + case 0xf -> G + } + +(* Effective addresses *) + +typedef ea = const union { + (size,qword) Ea_i; + (size,regn) Ea_r; + (size,qword) Ea_m; +} + +function qword ea_index ((option<scale_index>) index) = + switch (index) { + case None -> 0x0000000000000000 + case (Some(scale, idx)) -> + let x = (qword) (0x0000000000000001 << scale) in + let y = (qword) (REG[idx]) in + let z = (bit[128]) (x * y) in + z[63 .. 0] + } + +function qword ea_base ((base) b) = + switch b { + case NoBase -> 0x0000000000000000 + case RipBase -> RIP + case (RegBase(b)) -> REG[b] + } + +function ea ea_rm ((size) sz, (rm) r) = + switch r { + case (Reg(n)) -> Ea_r (sz, n) + case (Mem(idx, b, d)) -> Ea_m (sz, ea_index(idx) + (qword) (ea_base(b) + d)) + } + +function ea ea_dest ((size) sz, (dest_src) ds) = + switch ds { + case (Rm_i (v, _)) -> ea_rm (sz, v) + case (Rm_r (v, _)) -> ea_rm (sz, v) + case (R_rm (v, _)) -> Ea_r (sz, v) + } + +function ea ea_src ((size) sz, (dest_src) ds) = + switch ds { + case (Rm_i (_, v)) -> Ea_i (sz, v) + case (Rm_r (_, v)) -> Ea_r (sz, v) + case (R_rm (_, v)) -> ea_rm (sz, v) + } + +function ea ea_imm_rm ((size) sz, (imm_rm) i_rm) = + switch i_rm { + case (Rm (v)) -> ea_rm (sz, v) + case (Imm (v)) -> Ea_i (sz, v) + } + +function qword restrict_size ((size) sz, (qword) imm) = + switch sz { + case (Sz8(_)) -> imm & 0x00000000000000FF + case Sz16 -> imm & 0x000000000000FFFF + case Sz32 -> imm & 0x00000000FFFFFFFF + case Sz64 -> imm + } + +function regn sub4 ((regn) r) = negative_to_zero (r - 4) + +function qword effect { rreg, rmem } EA ((ea) e) = + switch e { + case (Ea_i(sz,i)) -> restrict_size(sz,i) + case (Ea_r((Sz8(have_rex)),r)) -> + if have_rex | r < 4 (* RSP *) | r > 7 (* RDI *) then + REG[r] + else + (REG[sub4 (r)] >> 8) & 0x00000000000000FF + case (Ea_r(sz,r)) -> restrict_size(sz, REG[r]) + case (Ea_m((Sz8(_)),a)) -> EXTZ (MEM(a, 1)) + case (Ea_m(Sz16,a)) -> EXTZ (MEM(a, 2)) + case (Ea_m(Sz32,a)) -> EXTZ (MEM(a, 4)) + case (Ea_m(Sz64,a)) -> MEM(a, 8) + } + +function unit effect { wmem, wreg, escape } wEA ((ea) e, (qword) w) = + switch e { + case (Ea_i(_,_)) -> exit () + case (Ea_r((Sz8(have_rex)),r)) -> + if have_rex | r < 4 (* RSP *) | r > 7 (* RDI *) then + { + (qword) regr := REG[r]; + regr[7 .. 0] := w[7 .. 0]; + REG[r] := regr + } + else + { + (qword) regr := REG[sub4(r)]; + regr[15 .. 8] := (vector<15,8,dec,bit>) (adjust_dec(w[7 .. 0])); + REG[sub4(r)] := regr + } + case (Ea_r(Sz16,r)) -> + { + (qword) regr := REG[r]; + regr[15 .. 8] := w[15 .. 8]; + REG[r] := regr + } + case (Ea_r(Sz32,r)) -> REG[r] := (qword) (EXTZ (w[31 .. 0])) + case (Ea_r(Sz64,r)) -> REG[r] := w + case (Ea_m((Sz8(_)),a)) -> wMEM(a, 1, w[7 .. 0]) + case (Ea_m(Sz16,a)) -> wMEM(a, 2, w[15 .. 0]) + case (Ea_m(Sz32,a)) -> wMEM(a, 4, w[31 .. 0]) + case (Ea_m(Sz64,a)) -> wMEM(a, 8, w) + } + +function (ea, qword, qword) read_dest_src_ea ((size) sz, (dest_src) ds) = + let e = ea_dest (sz, ds) in + (e, EA(e), EA(ea_src(sz, ds))) + +function qword call_dest_from_ea ((ea) e) = + switch e { + case (Ea_i(_, i)) -> RIP + i + case (Ea_r(_, r)) -> REG[r] + case (Ea_m(_, a)) -> MEM(a, 8) + } + +function qword get_ea_address ((ea) e) = + switch e { + case (Ea_i(_, i)) -> 0x0000000000000000 + case (Ea_r(_, r)) -> 0x0000000000000000 + case (Ea_m(_, a)) -> 0x0000000000000000 + } + +function unit jump_to_ea ((ea) e) = RIP := call_dest_from_ea(e) + +(* EFLAG updates *) + +function bit byte_parity ((byte) b) = +{ + (int) acc := 0; + foreach (i from 0 to 7) acc := acc + (int) (b[i]); + (bit) (acc mod 2 == 0) +} + +function [|64|] size_width ((size) sz) = + switch sz { + case (Sz8(_)) -> 8 + case Sz16 -> 16 + case Sz32 -> 32 + case Sz64 -> 64 + } + +function [|63|] size_width_sub1 ((size) sz) = + switch sz { + case (Sz8(_)) -> 7 + case Sz16 -> 15 + case Sz32 -> 31 + case Sz64 -> 63 + } + +(* XXXXX +function bit word_size_msb ((size) sz, (qword) w) = w[size_width(sz) - 1] +*) + +function bit word_size_msb ((size) sz, (qword) w) = w[size_width_sub1(sz)] + +function unit write_PF ((qword) w) = PF := byte_parity (w[7 .. 0]) + +function unit write_SF ((size) sz, (qword) w) = SF := word_size_msb (sz, w) + +function unit write_ZF ((size) sz, (qword) w) = + ZF := (bit) + (switch sz { + case (Sz8(_)) -> w[7 .. 0] == 0x00 + case Sz16 -> w[15 .. 0] == 0x0000 + case Sz32 -> w[31 .. 0] == 0x00000000 + case Sz64 -> w == 0x0000000000000000 + }) + +function unit write_arith_eflags_except_CF_OF ((size) sz, (qword) w) = +{ + AF := undefined; + write_PF(w); + write_SF(sz, w); + write_ZF(sz, w); +} + +function unit write_arith_eflags ((size) sz, (qword) w, (bit) c, (bit) x) = +{ + CF := c; + OF := x; + write_arith_eflags_except_CF_OF (sz, w) +} + +function unit write_logical_eflags ((size) sz, (qword) w) = + write_arith_eflags (sz, w, bitzero, bitzero) + +function unit erase_eflags () = +{ + AF := undefined; + CF := undefined; + OF := undefined; + PF := undefined; + SF := undefined; + ZF := undefined; +} + +(* XXXXX *) +function nat power ((nat) x, ([|64|]) y) = undefined + +function nat value_width ((size) sz) = power (2, size_width(sz)) + +function bit word_signed_overflow_add ((size) sz, (qword) a, (qword) b) = + (bit) (word_size_msb (sz, a) == word_size_msb (sz, b) & + word_size_msb (sz, a + b) != word_size_msb (sz, a)) + +function bit word_signed_overflow_sub ((size) sz, (qword) a, (qword) b) = + (bit) (word_size_msb (sz, a) != word_size_msb (sz, b) & + word_size_msb (sz, a - b) != word_size_msb (sz, a)) + +function (qword, bit, bit) add_with_carry_out ((size) sz, (qword) a, (qword) b) = + (a + b, (bit) ((int) (value_width (sz)) <= unsigned(a) + unsigned(b)), + word_signed_overflow_add (sz, a, b)) + +function (qword, bit, bit) sub_with_borrow ((size) sz, (qword) a, (qword) b) = + (a - b, (bit) (a < b), word_signed_overflow_sub (sz, a, b)) + +function unit write_arith_result ((size) sz, (qword) w, (bit) c, (bit) x, (ea) e) = +{ + write_arith_eflags (sz, w, c, x); + wEA (e) := w; +} + +function unit write_arith_result_no_CF_OF ((size) sz, (qword) w, (ea) e) = +{ + write_arith_eflags_except_CF_OF (sz, w); + wEA (e) := w; +} + +function unit write_logical_result ((size) sz, (qword) w, (ea) e) = +{ + write_arith_eflags_except_CF_OF (sz, w); + wEA (e) := w; +} + +function unit write_result_erase_eflags ((qword) w, (ea) e) = +{ + erase_eflags (); + wEA (e) := w; +} + +function qword effect { escape } sign_extension ((qword) w, (size) size1, (size) size2) = +{ + (qword) x := w; + switch (size1, size2) { + case ((Sz8(_)), Sz16) -> x[15 .. 0] := (bit[16]) (EXTS (w[7 .. 0])) + case ((Sz8(_)), Sz32) -> x[31 .. 0] := (bit[32]) (EXTS (w[7 .. 0])) + case ((Sz8(_)), Sz64) -> x := (qword) (EXTS (w[7 .. 0])) + case (Sz16, Sz32) -> x[31 .. 0] := (bit[32]) (EXTS (w[15 .. 0])) + case (Sz16, Sz64) -> x := (qword) (EXTS (w[15 .. 0])) + case (Sz32, Sz64) -> x := (qword) (EXTS (w[31 .. 0])) + case _ -> undefined + }; + x; +} + +function [|64|] mask_shift ((size) sz, (qword) w) = + if sz == Sz64 then w[5 .. 0] else w[4 .. 0] + +function qword rol ((size) sz, (qword) a, (qword) b) = + switch sz { + case (Sz8(_)) -> EXTZ (ROL (a[7 .. 0], b[2 .. 0])) + case Sz16 -> EXTZ (ROL (a[15 .. 0], b[3 .. 0])) + case Sz32 -> EXTZ (ROL (a[31 .. 0], b[4 .. 0])) + case Sz64 -> ROL (a, b[5 .. 0]) + } + +function qword ror ((size) sz, (qword) a, (qword) b) = + switch sz { + case (Sz8(_)) -> EXTZ (ROR (a[7 .. 0], b[2 .. 0])) + case Sz16 -> EXTZ (ROR (a[15 .. 0], b[3 .. 0])) + case Sz32 -> EXTZ (ROR (a[31 .. 0], b[4 .. 0])) + case Sz64 -> ROR (a, b[5 .. 0]) + } + +function qword sar ((size) sz, (qword) a, (qword) b) = + switch sz { + case (Sz8(_)) -> EXTZ (ASR (a[7 .. 0], b[2 .. 0])) + case Sz16 -> EXTZ (ASR (a[15 .. 0], b[3 .. 0])) + case Sz32 -> EXTZ (ASR (a[31 .. 0], b[4 .. 0])) + case Sz64 -> ASR (a, b[5 .. 0]) + } + +function unit write_binop ((size) sz, (binop_name) bop, (qword) a, (qword) b, (ea) e) = + switch bop { + case Add -> let (w,c,x) = add_with_carry_out (sz, a, b) in + write_arith_result (sz, w, c, x, e) + case Sub -> let (w,c,x) = sub_with_borrow (sz, a, b) in + write_arith_result (sz, w, c, x, e) + case Cmp -> let (w,c,x) = sub_with_borrow (sz, a, b) in + write_arith_eflags (sz, w, c, x) + case Test -> write_logical_eflags (sz, a & b) + case And -> write_logical_result (sz, a & b, e) + case Xor -> write_logical_result (sz, a ^ b, e) + case Or -> write_logical_result (sz, a | b, e) + case Rol -> write_result_erase_eflags (rol (sz, a, b), e) + case Ror -> write_result_erase_eflags (ror (sz, a, b), e) + case Sar -> write_result_erase_eflags (sar (sz, a, b), e) + case Shl -> write_result_erase_eflags (a << mask_shift (sz, b), e) + case Shr -> write_result_erase_eflags (a >> mask_shift (sz, b), e) + case Adc -> + { + let carry = CF in + let (qword) result = a + (qword) (b + carry) in + { + CF := (bit) ((int) (value_width (sz)) <= unsigned(a) + unsigned(b)); + OF := undefined; + write_arith_result_no_CF_OF (sz, result, e); + } + } + case Sbb -> + { + let carry = CF in + let (qword) result = a - (qword) (b + carry) in + { + CF := (bit) (unsigned(a) < unsigned(b) + (int) carry); + OF := undefined; + write_arith_result_no_CF_OF (sz, result, e); + } + } + case _ -> exit () + } + +function unit write_monop ((size) sz, (monop_name) mop, (qword) a, (ea) e) = + switch mop { + case Not -> wEA(e) := ~(a) + case Dec -> write_arith_result_no_CF_OF (sz, a - 1, e) + case Inc -> write_arith_result_no_CF_OF (sz, a + 1, e) + case Neg -> { write_arith_result_no_CF_OF (sz, 0 - a, e); + CF := undefined; + } + } + +function bool read_cond ((cond) c) = + switch c { + case A -> ~(CF) & ~(ZF) + case NB -> ~(CF) + case B -> CF + case NA -> CF | (bit) ZF + case E -> ZF + case G -> ~(ZF) & (SF == OF) + case NL -> SF == OF + case L -> SF != OF + case NG -> ZF | SF != OF + case NE -> ~(ZF) + case NO -> ~(OF) + case NP -> ~(PF) + case NS -> ~(SF) + case O -> OF + case P -> PF + case S -> SF + case ALWAYS -> true + } + +function qword pop_aux () = + let top = MEM(RSP, 8) in + { + RSP := RSP + 8; + top; + } + +function unit push_aux ((qword) w) = +{ + RSP := RSP - 8; + wMEM(RSP, 8) := w; +} + +function unit pop ((rm) r) = wEA (ea_rm (Sz64,r)) := pop_aux() +function unit pop_rip () = RIP := pop_aux() +function unit push ((imm_rm) i) = push_aux (EA (ea_imm_rm (Sz64, i))) +function unit push_rip () = push_aux (RIP) + +function unit drop ((qword) i) = if i[7 ..0] != 0 then () else RSP := RSP + i + +(* -------------------------------------------------------------------------- + Instructions + -------------------------------------------------------------------------- *) + +scattered function unit execute +scattered typedef ast = const union + +val ast -> unit effect {escape, rmem, rreg, undef, wmem, wreg} execute + +(* ========================================================================== + Binop + ========================================================================== *) + +union ast member (binop_name,size,dest_src) Binop + +function clause execute (Binop (bop,sz,ds)) = + let (e, val_dst, val_src) = read_dest_src_ea (sz, ds) in + write_binop (sz, bop, val_dst, val_src, e) + +(* ========================================================================== + CALL + ========================================================================== *) + +union ast member imm_rm CALL + +function clause execute (CALL (i)) = +{ + push_rip(); + jump_to_ea (ea_imm_rm (Sz64, i)) +} + +(* ========================================================================== + CLC + ========================================================================== *) + +union ast member unit CLC + +function clause execute CLC = CF := false + +(* ========================================================================== + CMC + ========================================================================== *) + +union ast member unit CMC + +function clause execute CMC = CF := ~(CF) + +(* ========================================================================== + CMPXCHG + ========================================================================== *) + +union ast member (size,rm,regn) CMPXCHG + +function clause execute (CMPXCHG (sz,r,n)) = + let src = Ea_r(sz, n) in + let acc = Ea_r(sz, 0) in (* RAX *) + let dst = ea_rm(sz, r) in + let val_dst = EA(dst) in + let val_acc = EA(src) in + { + write_binop (sz, Cmp, val_acc, val_dst, src); + if val_acc == val_dst then + wEA(dst) := EA (src) + else + wEA(acc) := val_dst; + } + +(* ========================================================================== + DIV + ========================================================================== *) + +union ast member (size,rm) DIV + +function clause execute (DIV (sz,r)) = + let w = (int) (value_width(sz)) in + let eax = Ea_r(sz, 0) in (* RAX *) + let edx = Ea_r(sz, 2) in (* RDX *) + let n = unsigned(EA(edx)) * w + unsigned(EA(eax)) in + let d = unsigned(EA(ea_rm(sz, r))) in + let q = n quot d in + let m = n mod d in + if d == 0 | w < q then exit () + else + { + wEA(eax) := cast_int_vec(q); + wEA(edx) := cast_int_vec(m); + erase_eflags(); + } + +(* ========================================================================== + Jcc + ========================================================================== *) + +union ast member (cond,qword) Jcc + +function clause execute (Jcc (c,i)) = + if read_cond (c) then RIP := RIP + i else () + +(* ========================================================================== + JMP + ========================================================================== *) + +union ast member rm JMP + +function clause execute (JMP (r)) = RIP := EA (ea_rm (Sz64, r)) + +(* ========================================================================== + LEA + ========================================================================== *) + +union ast member (size,dest_src) LEA + +function clause execute (LEA (sz,ds)) = + let src = ea_src (sz, ds) in + let dst = ea_dest (sz, ds) in + wEA(dst) := get_ea_address (src) + +(* ========================================================================== + LEAVE + ========================================================================== *) + +union ast member unit LEAVE + +function clause execute LEAVE = +{ + RSP := RBP; + pop (Reg (5)); (* RBP *) +} + +(* ========================================================================== + LOOP + ========================================================================== *) + +union ast member (cond,qword) LOOP + +function clause execute (LOOP (c,i)) = +{ + RCX := RCX - 1; + if RCX != 0 & read_cond (c) then RIP := RIP + i else (); +} + +(* ========================================================================== + Monop + ========================================================================== *) + +union ast member (monop_name,size,rm) Monop + +function clause execute (Monop (mop,sz,r)) = + let e = ea_rm (sz, r) in write_monop (sz, mop, EA(e), e) + +(* ========================================================================== + MOV + ========================================================================== *) + +union ast member (cond,size,dest_src) MOV + +function clause execute (MOV (c,sz,ds)) = + if read_cond (c) then + let src = ea_src (sz, ds) in + let dst = ea_dest (sz, ds) in + wEA(dst) := EA(src) + else () + +(* ========================================================================== + MOVSX + ========================================================================== *) + +union ast member (size,dest_src,size) MOVSX + +function clause execute (MOVSX (sz1,ds,sz2)) = + let src = ea_src (sz1, ds) in + let dst = ea_dest (sz2, ds) in + wEA(dst) := sign_extension (EA(src), sz1, sz2) + +(* ========================================================================== + MOVZX + ========================================================================== *) + +union ast member (size,dest_src,size) MOVZX + +function clause execute (MOVZX (sz1,ds,sz2)) = + let src = ea_src (sz1, ds) in + let dst = ea_dest (sz2, ds) in + wEA(dst) := EA(src) + +(* ========================================================================== + MUL + ========================================================================== *) + +union ast member (size,rm) MUL + +function clause execute (MUL (sz,r)) = + let eax = Ea_r (sz, 0) in (* RAX *) + let val_eax = EA(eax) in + let val_src = EA(ea_rm (sz, r)) in + switch sz { + case (Sz8(_)) -> wEA(Ea_r(Sz16,0)) := (val_eax * val_src)[63 .. 0] + case _ -> + let m = val_eax * val_src in + let edx = Ea_r (sz, 2) in (* RDX *) + { + wEA(eax) := m[63 .. 0]; + wEA(edx) := (LSR (m, size_width(sz)))[63 .. 0] + } + } + +(* ========================================================================== + NOP + ========================================================================== *) + +union ast member nat NOP + +function clause execute (NOP (_)) = () + +(* ========================================================================== + POP + ========================================================================== *) + +union ast member rm POP + +function clause execute (POP (r)) = pop(r) + +(* ========================================================================== + PUSH + ========================================================================== *) + +union ast member imm_rm PUSH + +function clause execute (PUSH (i)) = push(i) + +(* ========================================================================== + RET + ========================================================================== *) + +union ast member qword RET + +function clause execute (RET (i)) = +{ + pop_rip(); + drop(i); +} + +(* ========================================================================== + SET + ========================================================================== *) + +union ast member (cond,bool,rm) SET + +function clause execute (SET (c,b,r)) = + wEA(ea_rm(Sz8(b),r)) := if read_cond (c) then 1 else 0 + +(* ========================================================================== + STC + ========================================================================== *) + +union ast member unit STC + +function clause execute STC = CF := true + +(* ========================================================================== + XADD + ========================================================================== *) + +union ast member (size,rm,regn) XADD + +function clause execute (XADD (sz,r,n)) = + let src = Ea_r (sz, n) in + let dst = ea_rm (sz, r) in + let val_src = EA(src) in + let val_dst = EA(dst) in + { + wEA(src) := val_dst; + write_binop (sz, Add, val_src, val_dst, dst); + } + +(* ========================================================================== + XCHG + ========================================================================== *) + +union ast member (size,rm,regn) XCHG + +function clause execute (XCHG (sz,r,n)) = + let src = Ea_r (sz, n) in + let dst = ea_rm (sz, r) in + let val_src = EA(src) in + let val_dst = EA(dst) in + { + wEA(src) := val_dst; + wEA(dst) := val_src; + } + +end ast +end execute + +(* -------------------------------------------------------------------------- + Decoding + -------------------------------------------------------------------------- *) + +function (qword,ostream) oimmediate8 ((ostream) strm) = + switch strm { + case (Some (b :: t)) -> ((qword) (EXTS(b)), Some (t)) + case _ -> ((qword) undefined, (ostream) None) + } + +function (qword,ostream) immediate8 ((byte_stream) strm) = + oimmediate8 (Some (strm)) + +function (qword,ostream) immediate16 ((byte_stream) strm) = + switch strm { + case b1 :: b2 :: t -> ((qword) (EXTS(b2 : b1)), Some (t)) + case _ -> ((qword) undefined, (ostream) None) + } + +function (qword,ostream) immediate32 ((byte_stream) strm) = + switch strm { + case b1 :: b2 :: b3 :: b4 :: t -> + ((qword) (EXTS(b4 : b3 : b2 : b1)), Some (t)) + case _ -> ((qword) undefined, (ostream) None) + } + +function (qword,ostream) immediate64 ((byte_stream) strm) = + switch strm { + case b1 :: b2 :: b3 :: b4 :: b5 :: b6 :: b7 :: b8 :: t -> + ((qword) (EXTS(b8 : b7 : b6 : b5 : b4 : b3 : b2 : b1)), Some (t)) + case _ -> ((qword) undefined, (ostream) None) + } + +function (qword, ostream) immediate ((size) sz, (byte_stream) strm) = + switch sz { + case (Sz8 (_)) -> immediate8 (strm) + case Sz16 -> immediate16 (strm) + case _ -> immediate32 (strm) + } + +function (qword, ostream) oimmediate ((size) sz, (ostream) strm) = + switch strm { + case (Some (s)) -> immediate (sz, s) + case None -> ((qword) undefined, (ostream) None) + } + +function (qword, ostream) full_immediate ((size) sz, (byte_stream) strm) = + if sz == Sz64 then immediate64 (strm) else immediate (sz, strm) + +(* - Parse ModR/M and SIB bytes --------------------------------------------- *) + +typedef REX = register bits [3 : 0] { + 3 : W; + 2 : R; + 1 : X; + 0 : B +} + +function regn rex_reg ((bit[1]) b, (bit[3]) r) = unsigned(b : r) + +function (qword, ostream) read_displacement ((bit[2]) Mod, (byte_stream) strm) = + if Mod == 0b01 + then immediate8 (strm) + else if Mod == 0b10 + then immediate32 (strm) + else (0x0000000000000000, (Some (strm))) + +function (qword, ostream) + read_sib_displacement ((bit[2]) Mod, (byte_stream) strm) = + if Mod == 0b01 then immediate8 (strm) else immediate32 (strm) + +function (rm, ostream) + read_SIB ((REX) rex, (bit[2]) Mod, (byte_stream) strm) = + switch strm { + case ((bit[2]) SS : (bit[3]) Index : (bit[3]) Base) :: strm1 -> + (let bbase = rex_reg (rex.B, Base) in + let index = rex_reg (rex.X, Index) in + let scaled_index = if index == 4 (* RSP *) then + (option<scale_index>) None + else let x = (scale_index) (SS, index) in + Some (x) in + (if bbase == 5 (* RBP *) + then let (displacement, strm2) = + read_sib_displacement (Mod, strm1) in + let bbase = if Mod == 0b00 then NoBase else RegBase (bbase) + in + (Mem (scaled_index, bbase, displacement), strm2) + else let (displacement, strm2) = read_displacement (Mod, strm1) in + (Mem (scaled_index, RegBase (bbase), displacement), strm2))) + case _ -> ((rm) undefined, (ostream) None) + } + +function (regn, rm, ostream) read_ModRM ((REX) rex, (byte_stream) strm) = + switch strm { + case (0b00 : (bit[3]) RegOpc : 0b101) :: strm1 -> + let (displacement, strm2) = immediate32 (strm1) in + (rex_reg (rex.R, RegOpc), Mem (None, RipBase, displacement), strm2) + case (0b11 : (bit[3]) REG : (bit[3]) RM) :: strm1 -> + (rex_reg (rex.R, REG), Reg (rex_reg (rex.B, RM)), Some (strm1)) + case ((bit[2]) Mod : (bit[3]) RegOpc : 0b100) :: strm1 -> + let (sib, strm2) = read_SIB (rex, Mod, strm1) in + (rex_reg (rex.R, RegOpc), sib, strm2) + case ((bit[2]) Mod : (bit[3]) RegOpc : (bit[3]) RM) :: strm1 -> + let (displacement, strm2) = read_displacement (Mod, strm1) in + (rex_reg (rex.R, RegOpc), + Mem (None, RegBase (rex_reg (rex.B, RM)), displacement), + strm2) + case _ -> ((regn) undefined, (rm) undefined, (ostream) None) + } + +function (bit[3], rm, ostream) + read_opcode_ModRM ((REX) rex, (byte_stream) strm) = + let (opcode, r, strm1) = read_ModRM (rex, strm) in + ((bit[3]) (cast_int_vec((int) opcode mod 8)), r, strm1) + +(* - Prefixes --------------------------------------------------------------- *) + +typedef prefix = [|5|] + +function prefix prefix_group ((byte) b) = + switch b { + case 0xf0 -> 1 + case 0xf2 -> 1 + case 0xf3 -> 1 + case 0x26 -> 2 + case 0x2e -> 2 + case 0x36 -> 2 + case 0x3e -> 2 + case 0x64 -> 2 + case 0x65 -> 2 + case 0x66 -> 3 + case 0x67 -> 4 + case _ -> if b[7 .. 4] == 0b0100 then 5 else 0 + } + +typedef atuple = (byte_stream, bool, REX, byte_stream) + +val (list<prefix>, byte_stream, byte_stream) -> option<atuple> effect {undef} read_prefix + +function rec option<atuple> read_prefix + ((list<prefix>) s, (byte_stream) p, (byte_stream) strm) = + switch strm { + case h :: strm1 -> + let group = prefix_group (h) in + if group == 0 then + let x = (p, false, (REX) 0b0000, strm) in Some (x) + else if group == 5 then + let x = (p, true, (REX) (h[3 .. 0]), strm1) in Some (x) + else if ismember (group, s) then + None + else + read_prefix (group :: s, h :: p, strm1) + case _ -> let x = (p, false, (REX) undefined, strm) in Some (x) + } + +function option<atuple> read_prefixes ((byte_stream) strm) = + read_prefix ([||||], [||||], strm) + +function size op_size ((bool) have_rex, (bit[1]) w, (bit[1]) v, (bool) override) = + if v == 1 then + Sz8 (have_rex) + else if w == 1 then + Sz64 + else if override then + Sz16 + else + Sz32 + +function bool is_mem ((rm) r) = + switch r {case (Mem (_, _, _)) -> true case _ -> false} + +(* - Decoder ---------------------------------------------------------------- *) + +function (ast, ostream) decode_aux + ((byte_stream) strm, (bool) have_rex, (REX) rex, (bool) op_size_override) = + switch strm + { + case (0b00 : (bit[3]) opc : 0b0 : (bit[1]) x : (bit[1]) v) :: strm2 -> + let (reg, r, strm3) = read_ModRM (rex, strm2) in + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let binop = opc_to_binop_name (EXTZ (opc)) in + let src_dst = if x == 0 then Rm_r (r, reg) else R_rm (reg, r) in + (Binop (binop, sz, src_dst), strm3) + case (0b00 : (bit[3]) opc : 0b10 : (bit[1]) v) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let binop = opc_to_binop_name (EXTZ (opc)) in + let (imm, strm3) = immediate (sz, strm2) in + (Binop (binop, sz, Rm_i (Reg (0), imm)), strm3) + case (0x5 : (bit[1]) b : (bit[3]) r) :: strm2 -> + let reg = Reg (([|15|]) (rex.B : r)) in + (if b == 0b0 then PUSH (Rm (reg)) else POP (reg), Some (strm2)) + case 0x63 :: strm2 -> + let (reg, r, strm3) = read_ModRM (rex, strm2) in + (MOVSX (Sz32, R_rm (reg, r), Sz64), strm3) + case (0x6 : 0b10 : (bit[1]) b : 0b0) :: strm2 -> + let (imm, strm3) = if b == 1 then immediate8 (strm2) + else immediate32 (strm2) in + (PUSH (Imm (imm)), strm3) + case (0x7 : (bit[4]) c) :: strm2 -> + let (imm, strm3) = immediate8 (strm2) in + (Jcc (bv_to_cond (c), imm), strm3) + case (0x8 : 0b000 : (bit[1]) v) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + let (imm, strm4) = oimmediate (sz, strm3) in + let binop = opc_to_binop_name (EXTZ (opc)) in + (Binop (binop, sz, Rm_i (r, imm)), strm4) + case 0x83 :: strm2 -> + let sz = op_size (have_rex, rex.W, 1, op_size_override) in + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + let (imm, strm4) = oimmediate (sz, strm3) in + let binop = opc_to_binop_name (EXTZ (opc)) in + (Binop (binop, sz, Rm_i (r, imm)), strm4) + case (0x8 : 0b010 : (bit[1]) v) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (reg, r, strm3) = read_ModRM (rex, strm2) in + (Binop (Test, sz, Rm_r (r, reg)), strm3) + case (0x8 : 0b011 : (bit[1]) v) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (reg, r, strm3) = read_ModRM (rex, strm2) in + (XCHG (sz, r, reg), strm3) + case (0x8 : 0b10 : (bit[1]) x : (bit[1]) v) :: strm2 -> + let (reg, r, strm3) = read_ModRM (rex, strm2) in + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let src_dst = if x == 0 then Rm_r (r, reg) else R_rm (reg, r) in + (MOV (ALWAYS, sz, src_dst), strm3) + case 0x8d :: strm2 -> + let sz = op_size (true, rex.W, 1, op_size_override) in + let (reg, r, strm3) = read_ModRM (rex, strm2) in + if is_mem (r) then (LEA (sz, R_rm (reg, r)), strm3) else exit () + case 0x8f :: strm2 -> + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + if opc == 0 then (POP (r), strm3) else exit () + case (0x9 : 0b0 : (bit[3]) r) :: strm2 -> + let sz = op_size (true, rex.W, 1, op_size_override) in + let reg = rex_reg (rex.B, r) in + if reg == 0 then + (NOP (listlength (strm)), Some (strm2)) + else + (XCHG (sz, Reg (0), reg), Some (strm2)) + case (0xa : 0b100 : (bit[1]) v) :: strm2 -> + let sz = op_size (true, rex.W, v, op_size_override) in + let (imm, strm3) = immediate (sz, strm2) in + (Binop (Test, sz, Rm_i (Reg (0), imm)), strm3) + case (0xb : (bit[1]) v : (bit[3]) r) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (imm, strm3) = full_immediate (sz, strm2) in + let reg = rex_reg (rex.B, r) in + (MOV (ALWAYS, sz, Rm_i (Reg (reg), imm)), strm3) + case (0xc : 0b000 : (bit[1]) v) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + let (imm, strm4) = oimmediate8 (strm3) in + let binop = opc_to_binop_name (0b1 : opc) in + if opc == 0b110 then exit () + else (Binop (binop, sz, Rm_i (r, imm)), strm4) + case (0xc : 0b001 : (bit[1]) v) :: strm2 -> + if v == 0 then + let (imm, strm3) = immediate16 (strm2) in (RET (imm), strm3) + else + (RET (0), Some (strm2)) + case (0xc : 0b011 : (bit[1]) v) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + let (imm, strm4) = oimmediate (sz, strm3) in + if opc == 0 then (MOV (ALWAYS, sz, Rm_i (r, imm)), strm4) + else exit () + case 0xc9 :: strm2 -> + (LEAVE, Some (strm2)) + case (0xd : 0b00 : (bit[1]) b : (bit[1]) v) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + let shift = if b == 0 then Rm_i (r, 1) else Rm_r (r, 1) in + let binop = opc_to_binop_name (0b1 : opc) in + if opc == 0b110 then exit () + else (Binop (binop, sz, shift), strm3) + case (0xe : 0b000 : (bit[1]) b) :: strm2 -> + let (imm, strm3) = immediate8 (strm2) in + let cnd = if b == 0 then NE else E in + (LOOP (cnd, imm), strm3) + case 0xe2 :: strm2 -> + let (imm, strm3) = immediate8 (strm2) in + (LOOP (ALWAYS, imm), strm3) + case 0xe8 :: strm2 -> + let (imm, strm3) = immediate32 (strm2) in + (CALL (Imm (imm)), strm3) + case (0xe : 0b10 : (bit[1]) b : 0b1) :: strm2 -> + let (imm, strm3) = if b == 0 then immediate32 (strm2) + else immediate8 (strm2) in + (Jcc (ALWAYS, imm), strm3) + case 0xf5 :: strm2 -> (CMC, Some (strm2)) + case 0xf8 :: strm2 -> (CLC, Some (strm2)) + case 0xf9 :: strm2 -> (STC, Some (strm2)) + case (0xf : 0b011 : (bit[1]) v) :: strm2 -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + switch opc { + case 0b000 -> let (imm, strm4) = oimmediate (sz, strm3) in + (Binop (Test, sz, Rm_i (r, imm)), strm4) + case 0b010 -> (Monop (Not, sz, r), strm3) + case 0b011 -> (Monop (Neg, sz, r), strm3) + case 0b100 -> (MUL (sz, r), strm3) + case 0b110 -> (DIV (sz, r), strm3) + case _ -> exit () + } + case 0xfe :: strm2 -> + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + switch opc { + case 0b000 -> (Monop (Inc, Sz8 (have_rex), r), strm3) + case 0b001 -> (Monop (Dec, Sz8 (have_rex), r), strm3) + case _ -> exit () + } + case 0xff :: strm2 -> + let sz = op_size (have_rex, rex.W, 1, op_size_override) in + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + switch opc { + case 0b000 -> (Monop (Inc, sz, r), strm3) + case 0b001 -> (Monop (Dec, sz, r), strm3) + case 0b010 -> (CALL (Rm (r)), strm3) + case 0b100 -> (JMP (r), strm3) + case 0b110 -> (PUSH (Rm (r)), strm3) + case _ -> exit () + } + case 0x0f :: opc :: strm2 -> + switch opc { + case 0x1f -> + let (opc, r, strm3) = read_opcode_ModRM (rex, strm2) in + (NOP (listlength (strm)), strm3) + case (0x4 : (bit[4]) c) -> + let sz = op_size (true, rex.W, 1, op_size_override) in + let (reg, r, strm3) = read_ModRM (rex, strm2) in + (MOV (bv_to_cond (c), sz, R_rm (reg, r)), strm3) + case (0x8 : (bit[4]) c) -> + let (imm, strm3) = immediate32 (strm2) in + (Jcc (bv_to_cond (c), imm), strm3) + case (0x9 : (bit[4]) c) -> + let (reg, r, strm3) = read_ModRM (rex, strm2) in + (SET (bv_to_cond (c), have_rex, r), strm3) + case (0xb : 0b000 : (bit[1]) v) -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (reg, r, strm3) = read_ModRM (rex, strm2) in + (CMPXCHG (sz, r, reg), strm3) + case (0xc : 0b000 : (bit[1]) v) -> + let sz = op_size (have_rex, rex.W, v, op_size_override) in + let (reg, r, strm3) = read_ModRM (rex, strm2) in + (XADD (sz, r, reg), strm3) + case (0xb : (bit[1]) s : 0b11 : (bit[1]) v) -> + let sz2 = op_size (have_rex, rex.W, 1, op_size_override) in + let sz = if v == 1 then Sz16 else Sz8 (have_rex) in + let (reg, r, strm3) = read_ModRM (rex, strm2) in + if s == 1 then + (MOVSX (sz, R_rm (reg, r), sz2), strm3) + else + (MOVZX (sz, R_rm (reg, r), sz2), strm3) + case _ -> exit () + } + case _ -> exit () + } + +function (byte_stream, ast, nat) decode ((byte_stream) strm) = + switch read_prefixes (strm) + { + case None -> exit () + case (Some (prefixes, have_rex, rex, strm1)) -> + let op_size_override = ismember (0x66, prefixes) in + if rex.W == 1 & op_size_override | ismember (0x67, prefixes) then + exit () + else + switch decode_aux (strm1, have_rex, rex, op_size_override) { + case (instr, (Some (strm2))) -> (prefixes, instr, listlength (strm2)) + case _ -> exit () + } + } |
