summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mips_new_tc/mips_extras_embed_sequential.lem27
-rw-r--r--src/ast_util.ml5
-rw-r--r--src/ast_util.mli4
-rw-r--r--src/gen_lib/sail_values.lem8
-rw-r--r--src/gen_lib/state.lem129
-rw-r--r--src/parser.mly2
-rw-r--r--src/pretty_print.mli2
-rw-r--r--src/pretty_print_lem.ml206
-rw-r--r--src/pretty_print_sail.ml1
-rw-r--r--src/process_file.ml65
-rw-r--r--src/rewriter.ml47
-rw-r--r--src/rewriter.mli2
-rw-r--r--src/sail.ml1
-rw-r--r--src/type_check.ml22
-rw-r--r--src/type_check.mli10
-rw-r--r--test/typecheck/pass/mips400.sail310
-rw-r--r--test/typecheck/pass/phantom_num.sail2
-rwxr-xr-xtest/typecheck/run_tests.sh3
18 files changed, 605 insertions, 241 deletions
diff --git a/mips_new_tc/mips_extras_embed_sequential.lem b/mips_new_tc/mips_extras_embed_sequential.lem
index ad567598..f9c6c92c 100644
--- a/mips_new_tc/mips_extras_embed_sequential.lem
+++ b/mips_new_tc/mips_extras_embed_sequential.lem
@@ -4,10 +4,10 @@ open import Sail_impl_base
open import Sail_values
open import State
-val MEMr : forall 'a 'b. Size 'b => (bitvector 'a * integer) -> M (bitvector 'b)
-val MEMr_reserve : forall 'a 'b. Size 'b => (bitvector 'a * integer) -> M (bitvector 'b)
-val MEMr_tag : forall 'a 'b. Size 'b => (bitvector 'a * integer) -> M (bitU * bitvector 'b)
-val MEMr_tag_reserve : forall 'a 'b. Size 'b => (bitvector 'a * integer) -> M (bitU * bitvector 'b)
+val MEMr : forall 'regs 'a 'b. Size 'b => (bitvector 'a * integer) -> M 'regs (bitvector 'b)
+val MEMr_reserve : forall 'regs 'a 'b. Size 'b => (bitvector 'a * integer) -> M 'regs (bitvector 'b)
+val MEMr_tag : forall 'regs 'a 'b. Size 'b => (bitvector 'a * integer) -> M 'regs (bitU * bitvector 'b)
+val MEMr_tag_reserve : forall 'regs 'a 'b. Size 'b => (bitvector 'a * integer) -> M 'regs (bitU * bitvector 'b)
let MEMr (addr,size) = read_mem false Read_plain addr size
let MEMr_reserve (addr,size) = read_mem false Read_reserve addr size
@@ -23,10 +23,10 @@ let MEMr_tag_reserve (addr,size) =
return (t, v)
-val MEMea : forall 'a. (bitvector 'a * integer) -> M unit
-val MEMea_conditional : forall 'a. (bitvector 'a * integer) -> M unit
-val MEMea_tag : forall 'a. (bitvector 'a * integer) -> M unit
-val MEMea_tag_conditional : forall 'a. (bitvector 'a * integer) -> M unit
+val MEMea : forall 'regs 'a. (bitvector 'a * integer) -> M 'regs unit
+val MEMea_conditional : forall 'regs 'a. (bitvector 'a * integer) -> M 'regs unit
+val MEMea_tag : forall 'regs 'a. (bitvector 'a * integer) -> M 'regs unit
+val MEMea_tag_conditional : forall 'regs 'a. (bitvector 'a * integer) -> M 'regs unit
let MEMea (addr,size) = write_mem_ea Write_plain addr size
let MEMea_conditional (addr,size) = write_mem_ea Write_conditional addr size
@@ -35,17 +35,16 @@ let MEMea_tag (addr,size) = write_mem_ea Write_plain addr size
let MEMea_tag_conditional (addr,size) = write_mem_ea Write_conditional addr size
-val MEMval : forall 'a 'b. (bitvector 'a * integer * bitvector 'b) -> M unit
-val MEMval_conditional : forall 'a 'b. (bitvector 'a * integer * bitvector 'b) -> M bool
-val MEMval_tag : forall 'a 'b. (bitvector 'a * integer * bitU * bitvector 'b) -> M unit
-val MEMval_tag_conditional : forall 'a 'b. (bitvector 'a * integer * bitU * bitvector 'b) -> M bool
+val MEMval : forall 'regs 'a 'b. (bitvector 'a * integer * bitvector 'b) -> M 'regs unit
+val MEMval_conditional : forall 'regs 'a 'b. (bitvector 'a * integer * bitvector 'b) -> M 'regs bool
+val MEMval_tag : forall 'regs 'a 'b. (bitvector 'a * integer * bitU * bitvector 'b) -> M 'regs unit
+val MEMval_tag_conditional : forall 'regs 'a 'b. (bitvector 'a * integer * bitU * bitvector 'b) -> M 'regs bool
let MEMval (_,_,v) = write_mem_val v >>= fun _ -> return ()
let MEMval_conditional (_,_,v) = write_mem_val v >>= fun b -> return (if b then true else false)
let MEMval_tag (_,_,t,v) = write_mem_val v >>= fun _ -> write_tag t >>= fun _ -> return ()
let MEMval_tag_conditional (_,_,t,v) = write_mem_val v >>= fun b -> write_tag t >>= fun _ -> return (if b then true else false)
-val MEM_sync : unit -> M unit
+val MEM_sync : forall 'regs. unit -> M 'regs unit
let MEM_sync () = barrier Barrier_MIPS_SYNC
-
diff --git a/src/ast_util.ml b/src/ast_util.ml
index ddd83429..ad14f0f1 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -50,6 +50,11 @@ let mk_nc nc_aux = NC_aux (nc_aux, Parse_ast.Unknown)
let mk_nexp nexp_aux = Nexp_aux (nexp_aux, Parse_ast.Unknown)
+let mk_exp exp_aux = E_aux (exp_aux, (Parse_ast.Unknown, ()))
+let unaux_exp (E_aux (exp_aux, _)) = exp_aux
+
+let mk_lit lit_aux = L_aux (lit_aux, Parse_ast.Unknown)
+
let rec map_exp_annot f (E_aux (exp, annot)) = E_aux (map_exp_annot_aux f exp, f annot)
and map_exp_annot_aux f = function
| E_block xs -> E_block (List.map (map_exp_annot f) xs)
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 94f3f0cc..b0ccb7b8 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -46,6 +46,10 @@ open Ast
val mk_nc : n_constraint_aux -> n_constraint
val mk_nexp : nexp_aux -> nexp
+val mk_exp : unit exp_aux -> unit exp
+val mk_lit : lit_aux -> lit
+
+val unaux_exp : 'a exp -> 'a exp_aux
(* Functions to map over the annotations in sub-expressions *)
val map_exp_annot : ('a annot -> 'b annot) -> 'a exp -> 'b exp
diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem
index 3616897a..f994ae22 100644
--- a/src/gen_lib/sail_values.lem
+++ b/src/gen_lib/sail_values.lem
@@ -887,6 +887,14 @@ type register =
| UndefinedRegister of integer (* length *)
| RegisterPair of register * register
+type register_ref 'regstate 'a =
+ <| read_from : 'regstate -> 'a;
+ write_to : 'regstate -> 'a -> 'regstate |>
+
+type field_ref 'regtype 'a =
+ <| get_field : 'regtype -> 'a;
+ set_field : 'regtype -> 'a -> 'regtype |>
+
let name_of_reg = function
| Register name _ _ _ _ -> name
| UndefinedRegister _ -> failwith "name_of_reg UndefinedRegister"
diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem
index 2e11e8a9..3cbcd4c8 100644
--- a/src/gen_lib/state.lem
+++ b/src/gen_lib/state.lem
@@ -6,53 +6,54 @@ open import Sail_values
type memstate = map integer memory_byte
type tagstate = map integer bitU
-type regstate = map string (vector bitU)
+(* type regstate = map string (vector bitU) *)
-type sequential_state = <| regstate : regstate;
- memstate : memstate;
- tagstate : tagstate;
- write_ea : maybe (write_kind * integer * integer);
- last_exclusive_operation_was_load : bool|>
+type sequential_state 'regs =
+ <| regstate : 'regs;
+ memstate : memstate;
+ tagstate : tagstate;
+ write_ea : maybe (write_kind * integer * integer);
+ last_exclusive_operation_was_load : bool|>
(* State, nondeterminism and exception monad with result type 'a
and exception type 'e. *)
-type ME 'a 'e = sequential_state -> list ((either 'a 'e) * sequential_state)
+type ME 'regs 'a 'e = sequential_state 'regs -> list ((either 'a 'e) * sequential_state 'regs)
(* Most of the time, we don't distinguish between different types of exceptions *)
-type M 'a = ME 'a unit
+type M 'regs 'a = ME 'regs 'a unit
(* For early return, we abuse exceptions by throwing and catching
the return value. The exception type is "maybe 'r", where "Nothing"
represents a proper exception and "Just r" an early return of value "r". *)
-type MR 'a 'r = ME 'a (maybe 'r)
+type MR 'regs 'a 'r = ME 'regs 'a (maybe 'r)
-val liftR : forall 'a 'r. M 'a -> MR 'a 'r
+val liftR : forall 'a 'r 'regs. M 'regs 'a -> MR 'regs 'a 'r
let liftR m s = List.map (function
| (Left a, s') -> (Left a, s')
| (Right (), s') -> (Right Nothing, s')
end) (m s)
-val return : forall 'a 'e. 'a -> ME 'a 'e
+val return : forall 'regs 'a 'e. 'a -> ME 'regs 'a 'e
let return a s = [(Left a,s)]
-val bind : forall 'a 'b 'e. ME 'a 'e -> ('a -> ME 'b 'e) -> ME 'b 'e
-let bind m f (s : sequential_state) =
+val bind : forall 'regs 'a 'b 'e. ME 'regs 'a 'e -> ('a -> ME 'regs 'b 'e) -> ME 'regs 'b 'e
+let bind m f (s : sequential_state 'regs) =
List.concatMap (function
| (Left a, s') -> f a s'
| (Right e, s') -> [(Right e, s')]
end) (m s)
let inline (>>=) = bind
-val (>>): forall 'b 'e. ME unit 'e -> ME 'b 'e -> ME 'b 'e
+val (>>): forall 'regs 'b 'e. ME 'regs unit 'e -> ME 'regs 'b 'e -> ME 'regs 'b 'e
let inline (>>) m n = m >>= fun _ -> n
-val exit : forall 'e 'a. 'e -> M 'a
+val exit : forall 'regs 'e 'a. 'e -> M 'regs 'a
let exit _ s = [(Right (), s)]
-val early_return : forall 'r. 'r -> MR unit 'r
+val early_return : forall 'regs 'r. 'r -> MR 'regs unit 'r
let early_return r s = [(Right (Just r), s)]
-val catch_early_return : forall 'a 'r. MR 'a 'a -> M 'a
+val catch_early_return : forall 'regs 'a 'r. MR 'regs 'a 'a -> M 'regs 'a
let catch_early_return m s =
List.map
(function
@@ -66,15 +67,15 @@ let rec range i j =
if i = j then [i]
else i :: range (i+1) j
-val get_reg : sequential_state -> string -> vector bitU
-let get_reg state reg = Map_extra.find reg state.regstate
+val get_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a
+let get_reg state reg = reg.read_from state.regstate
-val set_reg : sequential_state -> string -> vector bitU -> sequential_state
-let set_reg state reg bitv =
- <| state with regstate = Map.insert reg bitv state.regstate |>
+val set_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a -> sequential_state 'regs
+let set_reg state reg v =
+ <| state with regstate = reg.write_to state.regstate v |>
-val read_mem : forall 'a 'b. Size 'b => bool -> read_kind -> bitvector 'a -> integer -> M (bitvector 'b)
+val read_mem : forall 'regs 'a 'b. Size 'b => bool -> read_kind -> bitvector 'a -> integer -> M 'regs (bitvector 'b)
let read_mem dir read_kind addr sz state =
let addr = unsigned addr in
let addrs = range addr (addr+sz-1) in
@@ -96,7 +97,7 @@ let read_mem dir read_kind addr sz state =
(* caps are aligned at 32 bytes *)
let cap_alignment = (32 : integer)
-val read_tag : forall 'a. bool -> read_kind -> bitvector 'a -> M bitU
+val read_tag : forall 'regs 'a. bool -> read_kind -> bitvector 'a -> M 'regs bitU
let read_tag dir read_kind addr state =
let addr = (unsigned addr) / cap_alignment in
let tag = match (Map.lookup addr state.tagstate) with
@@ -117,18 +118,18 @@ let read_tag dir read_kind addr state =
then [(Left tag, <| state with last_exclusive_operation_was_load = true |>)]
else [(Left tag, state)]
-val excl_result : unit -> M bool
+val excl_result : forall 'regs. unit -> M 'regs bool
let excl_result () state =
let success =
(Left true, <| state with last_exclusive_operation_was_load = false |>) in
(Left false, state) :: if state.last_exclusive_operation_was_load then [success] else []
-val write_mem_ea : forall 'a. write_kind -> bitvector 'a -> integer -> M unit
+val write_mem_ea : forall 'regs 'a. write_kind -> bitvector 'a -> integer -> M 'regs unit
let write_mem_ea write_kind addr sz state =
let addr = unsigned addr in
[(Left (), <| state with write_ea = Just (write_kind,addr,sz) |>)]
-val write_mem_val : forall 'b. bitvector 'b -> M bool
+val write_mem_val : forall 'regs 'b. bitvector 'b -> M 'regs bool
let write_mem_val v state =
let (write_kind,addr,sz) = match state.write_ea with
| Nothing -> failwith "write ea has not been announced yet"
@@ -140,7 +141,7 @@ let write_mem_val v state =
state.memstate addresses_with_value in
[(Left true, <| state with memstate = memstate |>)]
-val write_tag : bitU -> M bool
+val write_tag : forall 'regs. bitU -> M 'regs bool
let write_tag t state =
let (write_kind,addr,sz) = match state.write_ea with
| Nothing -> failwith "write ea has not been announced yet"
@@ -149,10 +150,10 @@ let write_tag t state =
let tagstate = Map.insert taddr t state.tagstate in
[(Left true, <| state with tagstate = tagstate |>)]
-val read_reg : forall 'a. Size 'a => register -> M (bitvector 'a)
+val read_reg : forall 'regs 'a. register_ref 'regs 'a -> M 'regs 'a
let read_reg reg state =
- let v = get_reg state (name_of_reg reg) in
- [(Left (vec_to_bvec v),state)]
+ let v = reg.read_from state.regstate in
+ [(Left v,state)]
(*let read_reg_range reg i j state =
let v = slice (get_reg state (name_of_reg reg)) i j in
[(Left (vec_to_bvec v),state)]
@@ -168,41 +169,47 @@ let read_reg_bitfield reg regfield =
let reg_deref = read_reg
-val write_reg : forall 'a. Size 'a => register -> bitvector 'a -> M unit
+val write_reg : forall 'regs 'a. register_ref 'regs 'a -> 'a -> M 'regs unit
let write_reg reg v state =
- [(Left (), set_reg state (name_of_reg reg) (bvec_to_vec v))]
+ [(Left (), <| state with regstate = reg.write_to state.regstate v |>)]
+val update_reg : forall 'regs 'a 'b. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit
+let update_reg reg f v state =
+ let current_value = get_reg state reg in
+ let new_value = f current_value v in
+ [(Left (), set_reg state reg new_value)]
let write_reg_range reg i j v state =
- let current_value = get_reg state (name_of_reg reg) in
- let new_value = update current_value i j (bvec_to_vec v) in
- [(Left (), set_reg state (name_of_reg reg) new_value)]
+ let current_value = get_reg state reg in
+ let new_value = bvupdate current_value i j v in
+ [(Left (), set_reg state reg new_value)]
let write_reg_bit reg i bit state =
- let current_value = get_reg state (name_of_reg reg) in
- let new_value = update_pos current_value i bit in
- [(Left (), set_reg state (name_of_reg reg) new_value)]
+ let current_value = get_reg state reg in
+ let new_value = bvupdate_pos current_value i bit in
+ [(Left (), set_reg state reg new_value)]
let write_reg_field reg regfield =
- let (i,j) = register_field_indices reg regfield in
- write_reg_range reg i j
-let write_reg_bitfield reg regfield =
- let (i,_) = register_field_indices reg regfield in
- write_reg_bit reg i
-let write_reg_field_range reg regfield i j v state =
- let (i0,j0) = register_field_indices reg regfield in
- let current_value = get_reg state (name_of_reg reg) in
- let current_field_value = slice current_value i0 j0 in
- let new_field_value = update current_field_value i j (bvec_to_vec v) in
- let new_value = update current_value i j new_field_value in
- [(Left (), set_reg state (name_of_reg reg) new_value)]
-
-
-val barrier : barrier_kind -> M unit
+ update_reg reg regfield.set_field
+let write_reg_field_range reg regfield i j =
+ let upd regval v =
+ let current_field_value = regfield.get_field regval in
+ let new_field_value = bvupdate current_field_value i j v in
+ regfield.set_field regval new_field_value in
+ update_reg reg upd
+let write_reg_field_bit reg regfield i =
+ let upd regval v =
+ let current_field_value = regfield.get_field regval in
+ let new_field_value = bvupdate_pos current_field_value i v in
+ regfield.set_field regval new_field_value in
+ update_reg reg upd
+
+
+val barrier : forall 'regs. barrier_kind -> M 'regs unit
let barrier _ = return ()
-val footprint : M unit
+val footprint : forall 'regs. M 'regs unit
let footprint = return ()
-val foreachM_inc : forall 'vars 'e. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> ME 'vars 'e) -> ME 'vars 'e
+val foreachM_inc : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec foreachM_inc (i,stop,by) vars body =
if i <= stop
then
@@ -211,8 +218,8 @@ let rec foreachM_inc (i,stop,by) vars body =
else return vars
-val foreachM_dec : forall 'vars 'e. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> ME 'vars 'e) -> ME 'vars 'e
+val foreachM_dec : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec foreachM_dec (i,stop,by) vars body =
if i >= stop
then
@@ -220,7 +227,7 @@ let rec foreachM_dec (i,stop,by) vars body =
foreachM_dec (i - by,stop,by) vars body
else return vars
-let write_two_regs r1 r2 bvec state =
+(*let write_two_regs r1 r2 bvec state =
let vec = bvec_to_vec bvec in
let is_inc =
let is_inc_r1 = is_inc_of_reg r1 in
@@ -242,4 +249,4 @@ let write_two_regs r1 r2 bvec state =
else slice vec (start_vec - size_r1) (start_vec - size_vec) in
let state1 = set_reg state (name_of_reg r1) r1_v in
let state2 = set_reg state1 (name_of_reg r2) r2_v in
- [(Left (), state2)]
+ [(Left (), state2)]*)
diff --git a/src/parser.mly b/src/parser.mly
index 12e64141..10241137 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -501,6 +501,8 @@ atomic_pat:
{ ploc (P_typ($2,$4)) }
| id
{ ploc (P_app($1,[])) }
+ | tyvar
+ { ploc (P_var $1) }
| Lcurly fpats Rcurly
{ ploc (P_record((fst $2, snd $2))) }
| Lsquare comma_pats Rsquare
diff --git a/src/pretty_print.mli b/src/pretty_print.mli
index 24816206..37de5241 100644
--- a/src/pretty_print.mli
+++ b/src/pretty_print.mli
@@ -52,4 +52,4 @@ val pat_to_string : 'a pat -> string
val pp_lem_defs : Format.formatter -> tannot defs -> unit
val pp_defs_ocaml : out_channel -> tannot defs -> string -> string list -> unit
-val pp_defs_lem : (out_channel * string list) -> (out_channel * string list) -> (out_channel * string list) -> tannot defs -> string -> unit
+val pp_defs_lem : (out_channel * string list) -> (out_channel * string list) -> (out_channel * string list) -> (out_channel * string list) -> tannot defs -> string -> unit
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 9b66331b..2971081e 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -140,6 +140,12 @@ let is_regtyp (Typ_aux (typ, _)) env = match typ with
| Typ_id(id) when Env.is_regtyp id env -> true
| _ -> false
+let doc_nexp_lem (Nexp_aux (nexp, l) as full_nexp) = match nexp with
+ | Nexp_constant i -> string ("ty" ^ string_of_int i)
+ | Nexp_var v -> string (string_of_kid v)
+ | _ -> raise (Reporting_basic.err_unreachable l
+ ("cannot pretty-print non-atomic nexp \"" ^ string_of_nexp full_nexp ^ "\""))
+
let doc_typ_lem, doc_atomic_typ_lem =
(* following the structure of parser for precedence *)
let rec typ regtypes ty = fn_typ regtypes true ty
@@ -168,17 +174,19 @@ let doc_typ_lem, doc_atomic_typ_lem =
Typ_arg_aux (Typ_arg_typ elem_typ, _)]) ->
let tpp = match elem_typ with
| Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) ->
- (match simplify_nexp m with
+ string "bitvector " ^^ doc_nexp_lem (simplify_nexp m)
+ (* (match simplify_nexp m with
| (Nexp_aux(Nexp_constant i,_)) -> string "bitvector ty" ^^ doc_int i
| (Nexp_aux(Nexp_var _, _)) -> separate space [string "bitvector"; doc_nexp m]
| _ -> raise (Reporting_basic.err_unreachable l
- "cannot pretty-print bitvector type with non-constant length"))
+ "cannot pretty-print bitvector type with non-constant length")) *)
| _ -> string "vector" ^^ space ^^ typ regtypes elem_typ in
if atyp_needed then parens tpp else tpp
| Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) ->
(* TODO: Better distinguish register names and contents? *)
(* fn_typ regtypes atyp_needed etyp *)
- (string "register")
+ let tpp = (string "register_ref regstate " ^^ typ regtypes etyp) in
+ if atyp_needed then parens tpp else tpp
| Typ_app(Id_aux (Id "range", _),_) ->
(string "integer")
| Typ_app(Id_aux (Id "implicit", _),_) ->
@@ -206,7 +214,7 @@ let doc_typ_lem, doc_atomic_typ_lem =
if atyp_needed then parens tpp else tpp
and doc_typ_arg_lem regtypes (Typ_arg_aux(t,_)) = match t with
| Typ_arg_typ t -> app_typ regtypes true t
- | Typ_arg_nexp n -> empty
+ | Typ_arg_nexp n -> doc_nexp_lem (simplify_nexp n)
| Typ_arg_order o -> empty
in typ', atomic_typ
@@ -233,10 +241,10 @@ and contains_t_arg_pp_var (Typ_arg_aux (targ, _)) = match targ with
| _ -> false
let doc_tannot_lem regtypes eff typ =
- if contains_t_pp_var typ then empty
- else
+ (* if contains_t_pp_var typ then empty
+ else *)
let ta = doc_typ_lem regtypes typ in
- if eff then string " : M " ^^ parens ta
+ if eff then string " : _M " ^^ parens ta
else string " : " ^^ ta
(* doc_lit_lem gets as an additional parameter the type information from the
@@ -345,6 +353,13 @@ let contains_early_return exp =
{ (Rewriter.compute_exp_alg false (||))
with e_return = (fun (_, r) -> (true, E_return r)) } exp)
+let typ_id_of (Typ_aux (typ, l)) = match typ with
+ | Typ_id id -> id
+ | Typ_app (register, [Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id id, _)), _)])
+ when string_of_id register = "register" -> id
+ | Typ_app (id, _) -> id
+ | _ -> raise (Reporting_basic.err_unreachable l "failed to get type id")
+
let prefix_recordtype = true
let report = Reporting_basic.err_unreachable
let doc_exp_lem, doc_let_lem =
@@ -364,14 +379,18 @@ let doc_exp_lem, doc_let_lem =
(match le_act (*, t, tag*) with
| LEXP_vector_range (le,e2,e3) ->
(match le with
- | LEXP_aux (LEXP_field (le,id), lannot) ->
- if is_bit_typ (typ_of_annot lannot) then
+ | LEXP_aux (LEXP_field ((LEXP_aux (_, lannot) as le),id), fannot) ->
+ if is_bit_typ (typ_of_annot fannot) then
raise (report l "indexing a register's (single bit) bitfield not supported")
else
+ let field_ref =
+ doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^
+ underscore ^^
+ doc_id_lem id in
liftR ((prefix 2 1)
(string "write_reg_field_range")
(align (doc_lexp_deref_lem regtypes early_ret le ^^ space^^
- string_lit (doc_id_lem id) ^/^ expY e2 ^/^ expY e3 ^/^ expY e)))
+ field_ref ^/^ expY e2 ^/^ expY e3 ^/^ expY e)))
| _ ->
liftR ((prefix 2 1)
(string "write_reg_range")
@@ -379,27 +398,35 @@ let doc_exp_lem, doc_let_lem =
)
| LEXP_vector (le,e2) when is_bit_typ t ->
(match le with
- | LEXP_aux (LEXP_field (le,id), lannot) ->
- if is_bit_typ (typ_of_annot lannot) then
+ | LEXP_aux (LEXP_field ((LEXP_aux (_, lannot) as le),id), fannot) ->
+ if is_bit_typ (typ_of_annot fannot) then
raise (report l "indexing a register's (single bit) bitfield not supported")
else
+ let field_ref =
+ doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^
+ underscore ^^
+ doc_id_lem id in
liftR ((prefix 2 1)
(string "write_reg_field_bit")
- (align (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ doc_id_lem id ^/^ expY e2 ^/^ expY e)))
+ (align (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ field_ref ^/^ expY e2 ^/^ expY e)))
| _ ->
liftR ((prefix 2 1)
(string "write_reg_bit")
(doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ expY e2 ^/^ expY e))
)
- | LEXP_field (le,id) when is_bit_typ t ->
+ (* | LEXP_field (le,id) when is_bit_typ t ->
liftR ((prefix 2 1)
(string "write_reg_bitfield")
- (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ string_lit(doc_id_lem id) ^/^ expY e))
- | LEXP_field (le,id) ->
+ (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ string_lit(doc_id_lem id) ^/^ expY e)) *)
+ | LEXP_field ((LEXP_aux (_, lannot) as le),id) ->
+ let field_ref =
+ doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^
+ underscore ^^
+ doc_id_lem id in
liftR ((prefix 2 1)
(string "write_reg_field")
(doc_lexp_deref_lem regtypes early_ret le ^^ space ^^
- string_lit(doc_id_lem id) ^/^ expY e))
+ field_ref ^/^ expY e))
(* | (LEXP_id id | LEXP_cast (_,id)), t, Alias alias_info ->
(match alias_info with
| Alias_field(reg,field) ->
@@ -561,8 +588,7 @@ let doc_exp_lem, doc_let_lem =
when Env.is_regtyp tid env ->
let t = (* Env.base_typ_of (env_of full_exp) *) (typ_of full_exp) in
let eff = effect_of full_exp in
- let field_f = string "get" ^^ underscore ^^
- doc_id_lem tid ^^ underscore ^^ doc_id_lem id in
+ let field_f = doc_id_lem tid ^^ underscore ^^ doc_id_lem id ^^ dot ^^ string "get_field" in
let (ta,aexp_needed) =
if contains_bitvector_typ t && not (contains_t_pp_var t)
then (doc_tannot_lem regtypes (effectful eff) t, true)
@@ -669,7 +695,8 @@ let doc_exp_lem, doc_let_lem =
(doc_fexp regtypes early_ret recordtyp) fexps)) ^^ space) in
if aexp_needed then parens epp else epp
| E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) ->
- let recordtyp = match annot with
+ let (E_aux (_, (_, eannot))) = e in
+ let recordtyp = match eannot with
| Some (env, Typ_aux (Typ_id tid,_), _) when Env.is_record tid env ->
tid
| _ -> raise (report l "cannot get record type") in
@@ -984,15 +1011,41 @@ let doc_typdef_lem regtypes (TD_aux(td, (l, _))) = match td with
doc_op equals (concat [string "type"; space; doc_id_lem_type id])
(doc_typschm_lem regtypes false typschm)
| TD_record(id,nm,typq,fs,_) ->
- let f_pp (typ,fid) =
- let fname = if prefix_recordtype
- then concat [doc_id_lem id;string "_";doc_id_lem_type fid;]
- else doc_id_lem_type fid in
- concat [fname;space;colon;space;doc_typ_lem regtypes typ; semi] in
- let fs_doc = group (separate_map (break 1) f_pp fs) in
+ let fname fid = if prefix_recordtype
+ then concat [doc_id_lem id;string "_";doc_id_lem_type fid;]
+ else doc_id_lem_type fid in
+ let f_pp (typ,fid) =
+ concat [fname fid;space;colon;space;doc_typ_lem regtypes typ; semi] in
+ let rectyp = match typq with
+ | TypQ_aux (TypQ_tq qs, _) ->
+ let quant_item = function
+ | QI_aux (QI_id (KOpt_aux (KOpt_none kid, _)), l)
+ | QI_aux (QI_id (KOpt_aux (KOpt_kind (_, kid), _)), l) ->
+ [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid, l)), l)]
+ | _ -> [] in
+ let targs = List.concat (List.map quant_item qs) in
+ mk_typ (Typ_app (id, targs))
+ | TypQ_aux (TypQ_no_forall, _) -> mk_id_typ id in
+ let fs_doc = group (separate_map (break 1) f_pp fs) in
+ let doc_field (ftyp, fid) =
+ let reftyp =
+ mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown),
+ [mk_typ_arg (Typ_arg_typ rectyp);
+ mk_typ_arg (Typ_arg_typ ftyp)])) in
+ let rfannot = doc_tannot_lem regtypes false reftyp in
+ let get, set =
+ string "rec_val" ^^ dot ^^ fname fid,
+ anglebars (space ^^ string "rec_val with " ^^
+ (doc_op equals (fname fid) (string "v")) ^^ space) in
doc_op equals
- (concat [string "type"; space; doc_id_lem_type id;])
- ((*doc_typquant_lem typq*) (anglebars (space ^^ align fs_doc ^^ space)))
+ (concat [string "let "; parens (concat [doc_id_lem id; underscore; doc_id_lem fid; rfannot])])
+ (anglebars (concat [space;
+ doc_op equals (string "get_field") (parens (doc_op arrow (string "fun rec_val") get)); semi_sp;
+ doc_op equals (string "set_field") (parens (doc_op arrow (string "fun rec_val v") set)); space])) in
+ doc_op equals
+ (separate space [string "type"; doc_id_lem_type id; doc_typquant_items_lem typq])
+ ((*doc_typquant_lem typq*) (anglebars (space ^^ align fs_doc ^^ space))) ^^ hardline ^^
+ separate_map hardline doc_field fs
| TD_variant(id,nm,typq,ar,_) ->
(match id with
| Id_aux ((Id "read_kind"),_) -> empty
@@ -1196,27 +1249,38 @@ let doc_typdef_lem regtypes (TD_aux(td, (l, _))) = match td with
| BF_aux (BF_single i, _) -> (i, i)
| BF_aux (BF_range (i, j), _) -> (i, j)
| _ -> raise (Reporting_basic.err_unreachable l "unsupported field type") in
+ let fsize = if dir_b then j-i+1 else i-j+1 in
+ let ftyp = vector_typ (nconstant i) (nconstant fsize) ord bit_typ in
+ let reftyp =
+ mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown),
+ [mk_typ_arg (Typ_arg_typ (mk_id_typ id));
+ mk_typ_arg (Typ_arg_typ ftyp)])) in
+ let rfannot = doc_tannot_lem regtypes false reftyp in
let get, set =
"bitvector_subrange" ^ dir_suffix ^ " (reg, " ^ string_of_int i ^ ", " ^ string_of_int j ^ ")",
"bitvector_update" ^ dir_suffix ^ " (reg, " ^ string_of_int i ^ ", " ^ string_of_int j ^ ", v)" in
doc_op equals
- (concat [string "let get_"; doc_id_lem id; underscore; doc_id_lem fid;
- space; parens (string "reg" ^^ tannot)]) (string get) ^^
- hardline ^^
- doc_op equals
+ (concat [string "let "; parens (concat [doc_id_lem id; underscore; doc_id_lem fid; rfannot])])
+ (concat [
+ space; langlebar; string (" get_field = (fun reg -> " ^ get ^ ");"); hardline;
+ space; space; space; string (" set_field = (fun reg v -> " ^ set ^ ") "); ranglebar])
+ (* string " = <|" (*; parens (string "reg" ^^ tannot) *)]) ^^ hardline ^^
+ string (" get_field = (fun reg -> " ^ get ^ ");") ^^ hardline ^^
+ string (" set_field = (fun reg v -> " ^ set ^") |>") *)
+ (* doc_op equals
(concat [string "let set_"; doc_id_lem id; underscore; doc_id_lem fid;
- space; parens (separate comma_sp [parens (string "reg" ^^ tannot); string "v"])]) (string set)
+ space; parens (separate comma_sp [parens (string "reg" ^^ tannot); string "v"])]) (string set) *)
in
doc_op equals
(concat [string "type";space;doc_id_lem id])
(doc_typ_lem regtypes vtyp)
^^ hardline ^^
- doc_op equals
+ (* doc_op equals
(concat [string "let";space;string "build_";doc_id_lem id;space;string "regname"])
(string "Register" ^^ space ^^
align (separate space [string "regname"; doc_int size; doc_int i1; dir;
break 0 ^^ brackets (align doc_rids)]))
- ^^ hardline ^^
+ ^^ hardline ^^ *)
doc_op equals
(concat [string "let";space;string "cast_";doc_id_lem id;space;string "reg"])
(string "reg")
@@ -1331,7 +1395,8 @@ let rec doc_fundef_lem regtypes (FD_aux(FD_function(r, typa, efa, fcls),fannot))
let doc_dec_lem (DEC_aux (reg, ((l, _) as annot))) =
match reg with
| DEC_reg(typ,id) ->
- let env = env_of_annot annot in
+ empty
+ (* let env = env_of_annot annot in
(match typ with
| Typ_aux (Typ_id idt, _) when Env.is_regtyp idt env ->
separate space [string "let";doc_id_lem id;equals;
@@ -1354,7 +1419,7 @@ let doc_dec_lem (DEC_aux (reg, ((l, _) as annot))) =
else raise (Reporting_basic.err_unreachable l
("can't deal with register type " ^ string_of_typ typ))
else raise (Reporting_basic.err_unreachable l
- ("can't deal with register type " ^ string_of_typ typ)))
+ ("can't deal with register type " ^ string_of_typ typ))) *)
| DEC_alias(id,alspec) -> empty
| DEC_typ_alias(typ,id,alspec) -> empty
@@ -1397,9 +1462,50 @@ let find_regtypes (Defs defs) =
| _ -> acc
) [] defs
-let pp_defs_lem (types_file,types_modules) (prompt_file,prompt_modules) (state_file,state_modules) d top_line =
+let find_registers (Defs defs) =
+ List.fold_left
+ (fun acc def ->
+ match def with
+ | DEF_reg_dec (DEC_aux(DEC_reg (typ, id),_)) -> (typ, id) :: acc
+ | _ -> acc
+ ) [] defs
+
+let doc_regstate_lem regtypes registers =
+ let l = Parse_ast.Unknown in
+ let annot = (l, None) in
+ let regstate = match registers with
+ | [] ->
+ TD_abbrev (
+ Id_aux (Id "regstate", l),
+ Name_sect_aux (Name_sect_none, l),
+ TypSchm_aux (TypSchm_ts (TypQ_aux (TypQ_tq [], l), unit_typ), l))
+ | _ ->
+ TD_record (
+ Id_aux (Id "regstate", l),
+ Name_sect_aux (Name_sect_none, l),
+ TypQ_aux (TypQ_tq [], l),
+ registers,
+ false) in
+ concat [
+ doc_typdef_lem regtypes (TD_aux (regstate, annot)); hardline;
+ hardline;
+ string "type _M 'a = M regstate 'a"
+ ]
+
+let doc_register_refs_lem regtypes registers =
+ let doc_register_ref (typ, id) =
+ let idd = doc_id_lem id in
+ let field = if prefix_recordtype then string "regstate_" ^^ idd else idd in
+ concat [string "let "; idd; string " = <|"; hardline;
+ string " read_from = (fun s -> s."; field; string ");"; hardline;
+ string " write_to = (fun s v -> (<| s with "; field; string " = v |>)) |>"] in
+ separate_map hardline doc_register_ref registers
+
+let pp_defs_lem (types_file,types_modules) (types_seq_file,types_seq_modules) (prompt_file,prompt_modules) (state_file,state_modules) d top_line =
let regtypes = find_regtypes d in
let (typdefs,valdefs) = doc_defs_lem regtypes d in
+ let regstate_def = doc_regstate_lem regtypes (find_registers d) in
+ let register_refs = doc_register_refs_lem regtypes (find_registers d) in
(print types_file)
(concat
[string "(*" ^^ (string top_line) ^^ string "*)";hardline;
@@ -1418,6 +1524,30 @@ let pp_defs_lem (types_file,types_modules) (prompt_file,prompt_modules) (state_f
hardline]
else empty;
typdefs]);
+ (print types_seq_file)
+ (concat
+ [string "(*" ^^ (string top_line) ^^ string "*)";hardline;
+ (separate_map hardline)
+ (fun lib -> separate space [string "open import";string lib]) types_seq_modules;hardline;
+ if !print_to_from_interp_value
+ then
+ concat
+ [(separate_map hardline)
+ (fun lib -> separate space [string " import";string lib]) ["Interp";"Interp_ast"];
+ string "open import Deep_shallow_convert";
+ hardline;
+ hardline;
+ string "module SI = Interp"; hardline;
+ string "module SIA = Interp_ast"; hardline;
+ hardline]
+ else empty;
+ typdefs;
+ hardline;
+ hardline;
+ regstate_def;
+ hardline;
+ hardline;
+ register_refs]);
(print prompt_file)
(concat
[string "(*" ^^ (string top_line) ^^ string "*)";hardline;
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index bb1d7357..b6d33d6d 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -109,6 +109,7 @@ let doc_pat, doc_atomic_pat =
| P_lit lit -> doc_lit lit
| P_wild -> underscore
| P_id id -> doc_id id
+ | P_var kid -> doc_var kid
| P_as(p,id) -> parens (separate space [pat p; string "as"; doc_id id])
| P_typ(typ,p) -> separate space [parens (doc_typ typ); atomic_pat p]
| P_app(id,[]) -> doc_id id
diff --git a/src/process_file.ml b/src/process_file.ml
index 6d4bcea0..1c133ce8 100644
--- a/src/process_file.ml
+++ b/src/process_file.ml
@@ -134,6 +134,31 @@ let close_output_with_check (o, temp_file_name, file_name) =
let generated_line f =
Printf.sprintf "Generated by Sail from %s." f
+let output_lem filename libs libs_seq defs =
+ let generated_line = generated_line filename in
+ let types_module = (filename ^ "_embed_types") in
+ let types_module_sequential = (filename ^ "_embed_types_sequential") in
+ let ((ot,_, _) as ext_ot) =
+ open_output_with_check_unformatted (filename ^ "_embed_types.lem") in
+ let ((ots,_, _) as ext_ots) =
+ open_output_with_check_unformatted (filename ^ "_embed_types_sequential.lem") in
+ let ((o,_, _) as ext_o) =
+ open_output_with_check_unformatted (filename ^ "_embed.lem") in
+ let ((os,_, _) as ext_os) =
+ open_output_with_check_unformatted (filename ^ "_embed_sequential.lem") in
+ (Pretty_print.pp_defs_lem
+ (ot,["Pervasives_extra";"Sail_impl_base";"Sail_values";"Prompt"])
+ (ots,["Pervasives_extra";"Sail_impl_base";"Sail_values";"State"])
+ (o,["Pervasives_extra";"Sail_impl_base";"Sail_values";"Prompt";
+ String.capitalize types_module] @ libs)
+ (os,["Pervasives_extra";"Sail_impl_base";"Sail_values";"State";
+ String.capitalize types_module_sequential] @ libs_seq)
+ defs generated_line);
+ close_output_with_check ext_ot;
+ close_output_with_check ext_ots;
+ close_output_with_check ext_o;
+ close_output_with_check ext_os
+
let output1 libpath out_arg filename defs =
let f' = Filename.basename (Filename.chop_extension filename) in
match out_arg with
@@ -180,43 +205,9 @@ let output1 libpath out_arg filename defs =
close_output_with_check ext_o
end
| Lem_out None ->
- let generated_line = generated_line filename in
- let types_module = (f' ^ "_embed_types") in
- let ((o,_, _) as ext_o) =
- open_output_with_check_unformatted (f' ^ "_embed_types.lem") in
- let ((o',_, _) as ext_o') =
- open_output_with_check_unformatted (f' ^ "_embed.lem") in
- let ((o'',_, _) as ext_o'') =
- open_output_with_check_unformatted (f' ^ "_embed_sequential.lem") in
- (Pretty_print.pp_defs_lem
- (o,["Pervasives_extra";"Sail_impl_base";"Sail_values"])
- (o',["Pervasives_extra";"Sail_impl_base";"Prompt";"Sail_values";
- String.capitalize types_module])
- (o'',["Pervasives_extra";"Sail_impl_base";"State";"Sail_values";
- String.capitalize types_module])
- defs generated_line);
- close_output_with_check ext_o;
- close_output_with_check ext_o';
- close_output_with_check ext_o'';
+ output_lem f' [] [] defs
| Lem_out (Some lib) ->
- let generated_line = generated_line filename in
- let types_module = (f' ^ "_embed_types") in
- let ((o,_, _) as ext_o) =
- open_output_with_check_unformatted (f' ^ "_embed_types.lem") in
- let ((o',_, _) as ext_o') =
- open_output_with_check_unformatted (f' ^ "_embed.lem") in
- let ((o'',_, _) as ext_o'') =
- open_output_with_check_unformatted (f' ^ "_embed_sequential.lem") in
- (Pretty_print.pp_defs_lem
- (o,["Pervasives_extra";"Sail_impl_base";"Sail_values"])
- (o',["Pervasives_extra";"Sail_impl_base";"Prompt";
- "Sail_values";String.capitalize types_module;lib])
- (o'',["Pervasives_extra";"Sail_impl_base";"State";
- "Sail_values";String.capitalize types_module;lib ^ "_sequential"])
- defs generated_line);
- close_output_with_check ext_o;
- close_output_with_check ext_o';
- close_output_with_check ext_o''
+ output_lem f' [lib] [lib ^ "_sequential"] defs
| Ocaml_out None ->
let ((o,temp_file_name, _) as ext_o) = open_output_with_check_unformatted (f' ^ ".ml") in
begin Pretty_print.pp_defs_ocaml o defs (generated_line filename) ["Big_int_Z";"Sail_values"];
@@ -234,6 +225,8 @@ let output libpath out_arg files =
files
let rewrite_step defs rewriter =
+ (* print_endline "=============================== REWRITE STEP";
+ Pretty_print.pp_defs stdout defs; *)
let defs = rewriter defs in
let _ = match !(opt_ddump_rewrite_ast) with
| Some (f, i) ->
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 8da8aacf..ef4a209c 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -355,7 +355,7 @@ let rewrite_pat rewriters (P_aux (pat,(l,annot))) =
let ps = List.map (fun p -> P_aux (P_lit p, simple_annot l bit_typ))
(vector_string_to_bit_list l lit) in
rewrap (P_vector ps)
- | P_lit _ | P_wild | P_id _ -> rewrap pat
+ | P_lit _ | P_wild | P_id _ | P_var _ -> rewrap pat
| P_as(pat,id) -> rewrap (P_as(rewrite pat, id))
| P_typ(typ,pat) -> rewrap (P_typ(typ, rewrite pat))
| P_app(id ,pats) -> rewrap (P_app(id, List.map rewrite pats))
@@ -629,6 +629,7 @@ type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg =
; p_as : 'pat * id -> 'pat_aux
; p_typ : Ast.typ * 'pat -> 'pat_aux
; p_id : id -> 'pat_aux
+ ; p_var : kid -> 'pat_aux
; p_app : id * 'pat list -> 'pat_aux
; p_record : 'fpat list * bool -> 'pat_aux
; p_vector : 'pat list -> 'pat_aux
@@ -647,6 +648,7 @@ let rec fold_pat_aux (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a pat
| P_lit lit -> alg.p_lit lit
| P_wild -> alg.p_wild
| P_id id -> alg.p_id id
+ | P_var kid -> alg.p_var kid
| P_as (p,id) -> alg.p_as (fold_pat alg p,id)
| P_typ (typ,p) -> alg.p_typ (typ,fold_pat alg p)
| P_app (id,ps) -> alg.p_app (id,List.map (fold_pat alg) ps)
@@ -676,6 +678,7 @@ let id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg =
; p_as = (fun (pat,id) -> P_as (pat,id))
; p_typ = (fun (typ,pat) -> P_typ (typ,pat))
; p_id = (fun id -> P_id id)
+ ; p_var = (fun kid -> P_var kid)
; p_app = (fun (id,ps) -> P_app (id,ps))
; p_record = (fun (ps,b) -> P_record (ps,b))
; p_vector = (fun ps -> P_vector ps)
@@ -718,6 +721,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
; e_let : 'letbind * 'exp -> 'exp_aux
; e_assign : 'lexp * 'exp -> 'exp_aux
; e_sizeof : nexp -> 'exp_aux
+ ; e_constraint : n_constraint -> 'exp_aux
; e_exit : 'exp -> 'exp_aux
; e_return : 'exp -> 'exp_aux
; e_assert : 'exp * 'exp -> 'exp_aux
@@ -786,8 +790,7 @@ let rec fold_exp_aux alg = function
| E_let (letbind,e) -> alg.e_let (fold_letbind alg letbind, fold_exp alg e)
| E_assign (lexp,e) -> alg.e_assign (fold_lexp alg lexp, fold_exp alg e)
| E_sizeof nexp -> alg.e_sizeof nexp
- | E_constraint nc -> raise (Reporting_basic.err_unreachable (Parse_ast.Unknown)
- "E_constraint encountered during rewriting")
+ | E_constraint nc -> alg.e_constraint nc
| E_exit e -> alg.e_exit (fold_exp alg e)
| E_return e -> alg.e_return (fold_exp alg e)
| E_assert(e1,e2) -> alg.e_assert (fold_exp alg e1, fold_exp alg e2)
@@ -860,6 +863,7 @@ let id_exp_alg =
; e_let = (fun (lb,e2) -> E_let (lb,e2))
; e_assign = (fun (lexp,e2) -> E_assign (lexp,e2))
; e_sizeof = (fun nexp -> E_sizeof nexp)
+ ; e_constraint = (fun nc -> E_constraint nc)
; e_exit = (fun e1 -> E_exit (e1))
; e_return = (fun e1 -> E_return e1)
; e_assert = (fun (e1,e2) -> E_assert(e1,e2))
@@ -909,6 +913,7 @@ let compute_pat_alg bot join =
; p_as = (fun ((v,pat),id) -> (v, P_as (pat,id)))
; p_typ = (fun (typ,(v,pat)) -> (v, P_typ (typ,pat)))
; p_id = (fun id -> (bot, P_id id))
+ ; p_var = (fun kid -> (bot, P_var kid))
; p_app = (fun (id,ps) -> split_join (fun ps -> P_app (id,ps)) ps)
; p_record = (fun (ps,b) -> split_join (fun ps -> P_record (ps,b)) ps)
; p_vector = split_join (fun ps -> P_vector ps)
@@ -960,6 +965,7 @@ let compute_exp_alg bot join =
; e_let = (fun ((vl,lb),(v2,e2)) -> (join vl v2, E_let (lb,e2)))
; e_assign = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, E_assign (lexp,e2)))
; e_sizeof = (fun nexp -> (bot, E_sizeof nexp))
+ ; e_constraint = (fun nc -> (bot, E_constraint nc))
; e_exit = (fun (v1,e1) -> (v1, E_exit (e1)))
; e_return = (fun (v1,e1) -> (v1, E_return e1))
; e_assert = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_assert(e1,e2)) )
@@ -1074,8 +1080,9 @@ let rewrite_sizeof (Defs defs) =
(* Retrieve instantiation of the type variables of the called function
for the given parameters in the original environment *)
let inst = instantiation_of orig_exp in
+ let inst = KBindings.fold (fun kid uvar b -> KBindings.add (orig_kid kid) uvar b) inst KBindings.empty in
let kid_exp kid = begin
- match KBindings.find kid inst with
+ match KBindings.find (orig_kid kid) inst with
| U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp))
| _ ->
raise (Reporting_basic.err_unreachable l
@@ -1117,6 +1124,7 @@ let rewrite_sizeof (Defs defs) =
; e_let = (fun ((lb,lb'),(e2,e2')) -> (E_let (lb,e2), E_let (lb',e2')))
; e_assign = (fun ((lexp,lexp'),(e2,e2')) -> (E_assign (lexp,e2), E_assign (lexp',e2')))
; e_sizeof = (fun nexp -> (E_sizeof nexp, E_sizeof nexp))
+ ; e_constraint = (fun nc -> (E_constraint nc, E_constraint nc))
; e_exit = (fun (e1,e1') -> (E_exit (e1), E_exit (e1')))
; e_return = (fun (e1,e1') -> (E_return e1, E_return e1'))
; e_assert = (fun ((e1,e1'),(e2,e2')) -> (E_assert(e1,e2), E_assert(e1',e2')) )
@@ -1277,6 +1285,7 @@ let remove_vector_concat_pat pat =
; p_wild = P_wild
; p_as = (fun (pat,id) -> P_as (pat true,id))
; p_id = (fun id -> P_id id)
+ ; p_var = (fun kid -> P_var kid)
; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps))
; p_record = (fun (fpats,b) -> P_record (fpats, b))
; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps))
@@ -1403,6 +1412,7 @@ let remove_vector_concat_pat pat =
; p_as = (fun ((pat,decls),id) -> (P_as (pat,id),decls))
; p_typ = (fun (typ,(pat,decls)) -> (P_typ (typ,pat),decls))
; p_id = (fun id -> (P_id id,[]))
+ ; p_var = (fun kid -> (P_var kid, []))
; p_app = (fun (id,ps) -> let (ps,decls) = List.split ps in
(P_app (id,ps),List.flatten decls))
; p_record = (fun (ps,b) -> let (ps,decls) = List.split ps in
@@ -1775,6 +1785,7 @@ let remove_bitvector_pat pat =
; p_wild = P_wild
; p_as = (fun (pat,id) -> P_as (pat true,id))
; p_id = (fun id -> P_id id)
+ ; p_var = (fun kid -> P_var kid)
; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps))
; p_record = (fun (fpats,b) -> P_record (fpats, b))
; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps))
@@ -1923,6 +1934,7 @@ let remove_bitvector_pat pat =
; p_as = (fun ((pat,gdls),id) -> (P_as (pat,id), gdls))
; p_typ = (fun (typ,(pat,gdls)) -> (P_typ (typ,pat), gdls))
; p_id = (fun id -> (P_id id, (None, (fun b -> b), [])))
+ ; p_var = (fun kid -> (P_var kid, (None, (fun b -> b), [])))
; p_app = (fun (id,ps) -> let (ps,gdls) = List.split ps in
(P_app (id,ps), flatten_guards_decls gdls))
; p_record = (fun (ps,b) -> let (ps,gdls) = List.split ps in
@@ -2270,9 +2282,35 @@ let rewrite_defs_early_return =
rewrite_defs_base { rewriters_base with rewrite_fun = rewrite_fun_early_return }
+let rewrite_constraint =
+ let rec rewrite_nc (NC_aux (nc_aux, l)) = mk_exp (rewrite_nc_aux nc_aux)
+ and rewrite_nc_aux = function
+ | NC_bounded_ge (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id ">=", mk_exp (E_sizeof n2))
+ | NC_bounded_le (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id ">=", mk_exp (E_sizeof n2))
+ | NC_fixed (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "==", mk_exp (E_sizeof n2))
+ | NC_not_equal (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "!=", mk_exp (E_sizeof n2))
+ | NC_and (nc1, nc2) -> E_app_infix (rewrite_nc nc1, mk_id "&", rewrite_nc nc2)
+ | NC_or (nc1, nc2) -> E_app_infix (rewrite_nc nc1, mk_id "|", rewrite_nc nc2)
+ | NC_false -> E_lit (mk_lit L_true)
+ | NC_true -> E_lit (mk_lit L_false)
+ | NC_nat_set_bounded (kid, ints) ->
+ unaux_exp (rewrite_nc (List.fold_left (fun nc int -> nc_or nc (nc_eq (nvar kid) (nconstant int))) nc_true ints))
+ in
+ let rewrite_e_aux (E_aux (e_aux, _) as exp) =
+ match e_aux with
+ | E_constraint nc ->
+ check_exp (env_of exp) (rewrite_nc nc) bool_typ
+ | _ -> exp
+ in
+
+ let rewrite_e_constraint = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in
+
+ rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_e_constraint) }
+
let rewrite_defs_ocaml = [
top_sort_defs;
rewrite_defs_remove_vector_concat;
+ rewrite_constraint;
rewrite_sizeof;
rewrite_defs_exp_lift_assign (* ;
rewrite_defs_separate_numbs *)
@@ -2674,6 +2712,7 @@ let find_updated_vars (E_aux (_,(l,_)) as exp) =
; e_case = (fun (e1,pexps) -> e1 @@ lapp2 pexps)
; e_let = (fun (lb,e2) -> lb @@ e2)
; e_assign = (fun ((ids,acc),e2) -> ([],ids) @@ acc @@ e2)
+ ; e_constraint = (fun nc -> ([],[]))
; e_sizeof = (fun nexp -> ([],[]))
; e_exit = (fun e1 -> ([],[]))
; e_return = (fun e1 -> e1)
diff --git a/src/rewriter.mli b/src/rewriter.mli
index 9dbdee3d..010f1003 100644
--- a/src/rewriter.mli
+++ b/src/rewriter.mli
@@ -66,6 +66,7 @@ type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg =
; p_as : 'pat * id -> 'pat_aux
; p_typ : Ast.typ * 'pat -> 'pat_aux
; p_id : id -> 'pat_aux
+ ; p_var : kid -> 'pat_aux
; p_app : id * 'pat list -> 'pat_aux
; p_record : 'fpat list * bool -> 'pat_aux
; p_vector : 'pat list -> 'pat_aux
@@ -112,6 +113,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
; e_let : 'letbind * 'exp -> 'exp_aux
; e_assign : 'lexp * 'exp -> 'exp_aux
; e_sizeof : nexp -> 'exp_aux
+ ; e_constraint : n_constraint -> 'exp_aux
; e_exit : 'exp -> 'exp_aux
; e_return : 'exp -> 'exp_aux
; e_assert : 'exp * 'exp -> 'exp_aux
diff --git a/src/sail.ml b/src/sail.ml
index cf366d42..dbc0eff4 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -158,6 +158,7 @@ let main() =
else ());
(if !(opt_print_ocaml)
then let ast_ocaml = rewrite_ast_ocaml ast in
+ print_endline "Finished re-writing ocaml";
if !(opt_libs_ocaml) = []
then output "" (Ocaml_out None) [out_name,ast_ocaml]
else output "" (Ocaml_out (Some (List.hd !opt_libs_ocaml))) [out_name,ast_ocaml]
diff --git a/src/type_check.ml b/src/type_check.ml
index ec0b4a58..2fcba97d 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -84,6 +84,13 @@ let unaux_nexp (Nexp_aux (nexp, _)) = nexp
let unaux_order (Ord_aux (ord, _)) = ord
let unaux_typ (Typ_aux (typ, _)) = typ
+let orig_kid (Kid_aux (Var v, l) as kid) =
+ try
+ let i = String.rindex v '#' in
+ Kid_aux (Var ("'" ^ String.sub v (i + 1) (String.length v - i - 1)), l)
+ with
+ | Not_found -> kid
+
let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown)
let mk_typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown)
let mk_id str = Id_aux (Id str, Parse_ast.Unknown)
@@ -99,6 +106,10 @@ let mk_ord ord_aux = Ord_aux (ord_aux, Parse_ast.Unknown)
let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l)
and nexp_simp_aux = function
+ | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), Nexp_aux (Nexp_constant c1, _)), _), Nexp_aux (Nexp_constant c2, _)) when c1 = c2 ->
+ nexp_simp_aux n1
+ | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), Nexp_aux (Nexp_constant c1, _)), _), Nexp_aux (Nexp_constant c2, _)) when c1 = c2 ->
+ nexp_simp_aux n1
| Nexp_sum (n1, n2) ->
begin
let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in
@@ -442,7 +453,7 @@ module Env : sig
val enable_casts : t -> t
val add_cast : id -> t -> t
val lookup_id : id -> t -> lvar
- val fresh_kid : t -> kid
+ val fresh_kid : ?kid:kid -> t -> kid
val expand_synonyms : t -> typ -> typ
val base_typ_of : t -> typ -> typ
val empty : t
@@ -493,12 +504,13 @@ end = struct
let counter = ref 0
- let fresh_kid env =
- let fresh = Kid_aux (Var ("'fv" ^ string_of_int !counter), Parse_ast.Unknown) in
+ let fresh_kid ?kid:(kid=mk_kid "") env =
+ let suffix = if Kid.compare kid (mk_kid "") = 0 then "#" else "#" ^ string_of_id (id_of_kid kid) in
+ let fresh = Kid_aux (Var ("'fv" ^ string_of_int !counter ^ suffix), Parse_ast.Unknown) in
incr counter; fresh
let freshen_kid env kid (typq, typ) =
- let fresh = fresh_kid env in
+ let fresh = fresh_kid ~kid:kid env in
(typquant_subst_kid kid fresh typq, typ_subst_kid kid fresh typ)
let freshen_bind env bind =
@@ -1464,7 +1476,7 @@ let rec unify l env typ1 typ2 =
| Typ_arg_nexp n1, Typ_arg_nexp n2 ->
begin
match unify_nexps l env goals (nexp_simp n1) (nexp_simp n2) with
- | Some (kid, unifier) -> KBindings.singleton kid (U_nexp unifier)
+ | Some (kid, unifier) -> KBindings.singleton kid (U_nexp (nexp_simp unifier))
| None -> KBindings.empty
end
| Typ_arg_typ typ1, Typ_arg_typ typ2 -> unify_typ l typ1 typ2
diff --git a/src/type_check.mli b/src/type_check.mli
index d451e4d9..f22a6991 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -113,8 +113,10 @@ module Env : sig
val is_union_constructor : id -> t -> bool
- (* Return a fresh kind identifier that doesn't exist in the environment *)
- val fresh_kid : t -> kid
+ (* Return a fresh kind identifier that doesn't exist in the
+ environment. The optional argument bases the new identifer on the
+ old one. *)
+ val fresh_kid : ?kid:kid -> t -> kid
val expand_synonyms : t -> typ -> typ
@@ -138,6 +140,8 @@ end
(* Push all the type variables and constraints from a typquant into an environment *)
val add_typquant : typquant -> Env.t -> Env.t
+val orig_kid : kid -> kid
+
(* Some handy utility functions for constructing types. *)
val mk_typ : typ_aux -> typ
val mk_typ_arg : typ_arg_aux -> typ_arg
@@ -250,6 +254,8 @@ type uvar =
| U_effect of effect
| U_typ of typ
+val string_of_uvar : uvar -> string
+
(* Throws Invalid_argument if the argument is not a E_app expression *)
val instantiation_of : tannot exp -> uvar KBindings.t
diff --git a/test/typecheck/pass/mips400.sail b/test/typecheck/pass/mips400.sail
index 1e8691d9..951f5126 100644
--- a/test/typecheck/pass/mips400.sail
+++ b/test/typecheck/pass/mips400.sail
@@ -1,148 +1,300 @@
(* New typechecker prelude *)
-val cast forall Nat 'n, Nat 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0:2**'m - 1|] effect pure unsigned
+val cast forall Num 'n, Num 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0:2**'m - 1|] effect pure unsigned
+
+val forall Num 'n, Num 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0 - (2**('m - 1)):2**('m - 1) - 1|] effect pure signed
+
+val extern forall Num 'n, Num 'm. [|0:'n|] -> vector<'m - 1,'m,dec,bit> effect pure to_vec = "to_vec_dec"
+
+val extern forall Num 'm. int -> vector<'m - 1,'m,dec,bit> effect pure to_svec = "to_vec_dec"
(* Vector access can't actually be properly polymorphic on vector
direction because of the ranges being different for each type, so
we overload it instead *)
-val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec
-val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc
+val forall Num 'n, Num 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec
+val forall Num 'n, Num 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc
+val forall Num 'n, Num 'l, 'l >= 0. (vector<'n,'l,dec,bit>, [|'n - 'l + 1:'n|]) -> bit effect pure bitvector_access_dec
+val forall Num 'n, Num 'l, 'l >= 0. (vector<'n,'l,inc,bit>, [|'n:'n + 'l - 1|]) -> bit effect pure bitvector_access_inc
+
+overload vector_access [bitvector_access_inc; bitvector_access_dec; vector_access_inc; vector_access_dec]
(* Type safe vector subrange *)
-val forall Nat 'n, Nat 'l, Nat 'm, Nat 'o, Type 'a, 'l >= 0, 'm <= 'o, 'o <= 'l.
- (vector<'n,'l,inc,'a>, [:'m:], [:'o:]) -> vector<'m,'o - 'm,inc,'a> effect pure vector_subrange_inc
+(* vector_subrange(v, m, o) returns the subvector of v with elements with
+ indices from m up to and *including* o. *)
+val forall Num 'n, Num 'l, Num 'm, Num 'o, Type 'a, 'l >= 0, 'm <= 'o, 'o <= 'l.
+ (vector<'n,'l,inc,'a>, [:'m:], [:'o:]) -> vector<'m,('o - 'm) + 1,inc,'a> effect pure vector_subrange_inc
+
+val forall Num 'n, Num 'l, Num 'm, Num 'o, Type 'a, 'n >= 'm, 'm >= 'o, 'o >= 'n - 'l + 1.
+ (vector<'n,'l,dec,'a>, [:'m:], [:'o:]) -> vector<'m,('m - 'o) + 1,dec,'a> effect pure vector_subrange_dec
-val forall Nat 'n, Nat 'l, Nat 'm, Nat 'o, Type 'a, 'n >= 'm, 'm >= 'o, 'o >= 'n - 'l + 1.
- (vector<'n,'l,dec,'a>, [:'m:], [:'o:]) -> vector<'m,'m - ('o - 1),dec,'a> effect pure vector_subrange_dec
+val forall Num 'n, Num 'l, Order 'ord.
+ (vector<'n,'l,'ord,bit>, int, int) -> list<bit> effect pure vector_subrange_bl
-overload vector_subrange [vector_subrange_inc; vector_subrange_dec]
+val forall Num 'n, Num 'l, Num 'm, Num 'o, 'l >= 0, 'm <= 'o, 'o <= 'l.
+ (vector<'n,'l,inc,bit>, [:'m:], [:'o:]) -> vector<'m,('o - 'm) + 1,inc,bit> effect pure bitvector_subrange_inc
+
+val forall Num 'n, Num 'l, Num 'm, Num 'o, 'n >= 'm, 'm >= 'o, 'o >= 'n - 'l + 1.
+ (vector<'n,'l,dec,bit>, [:'m:], [:'o:]) -> vector<'m,('m - 'o) + 1,dec,bit> effect pure bitvector_subrange_dec
+
+overload vector_subrange [bitvector_subrange_inc; bitvector_subrange_dec; vector_subrange_inc; vector_subrange_dec; vector_subrange_bl]
(* Type safe vector append *)
-val forall Nat 'n1, Nat 'l1, Nat 'n2, Nat 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2 >= 0.
- (vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'n1,'l1 + 'l2,'o,'a> effect pure vector_append
+val extern forall Num 'n1, Num 'l1, Num 'n2, Num 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2 >= 0.
+ (vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'l1 + 'l2 - 1,'l1 + 'l2,'o,'a> effect pure vec_append = "vector_concat"
+
+val (list<bit>, list<bit>) -> list<bit> effect pure list_append
+
+val extern forall Num 'n1, Num 'l1, Num 'n2, Num 'l2, Order 'o, 'l1 >= 0, 'l2 >= 0.
+ (vector<'n1,'l1,'o,bit>, vector<'n2,'l2,'o,bit>) -> vector<'l1 + 'l2 - 1,'l1 + 'l2,'o,bit> effect pure bitvec_append = "bitvector_concat"
+
+overload vector_append [bitvec_append; vec_append; list_append]
(* Implicit register dereferencing *)
val cast forall Type 'a. register<'a> -> 'a effect {rreg} reg_deref
-overload vector_access [vector_access_inc; vector_access_dec]
-
(* Bitvector duplication *)
-val forall Nat 'n. (bit, [:'n:]) -> vector<'n - 1,'n,dec,bit> effect pure duplicate
+val forall Num 'n. (bit, [:'n:]) -> vector<'n - 1,'n,dec,bit> effect pure duplicate
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord.
+val (bit, int) -> list<bit> effect pure duplicate_to_list
+
+val forall Num 'n, Num 'm, Num 'o, Order 'ord.
(vector<'o,'n,'ord,bit>, [:'m:]) -> vector<'o,'m*'n,'ord,bit> effect pure duplicate_bits
-overload (deinfix ^^) [duplicate; duplicate_bits]
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o,'n,'ord,bit>, int) -> list<bit> effect pure duplicate_bits_to_list
+
+overload (deinfix ^^) [duplicate; duplicate_bits; duplicate_to_list; duplicate_bits_to_list]
(* Bitvector extension *)
-val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord.
+val forall Num 'n, Num 'm, Num 'o, Num 'p, Order 'ord.
vector<'o, 'n, 'ord, bit> -> vector<'p, 'm, 'ord, bit> effect pure extz
-val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord.
+val forall Num 'm, Num 'p, Order 'ord.
+ list<bit> -> vector<'p, 'm, 'ord, bit> effect pure extz_bl
+
+val forall Num 'n, Num 'm, Num 'o, Num 'p, Order 'ord.
vector<'o, 'n, 'ord, bit> -> vector<'p, 'm, 'ord, bit> effect pure exts
-overload EXTZ [extz]
-overload EXTS [exts]
+val forall Num 'm, Num 'p, Order 'ord.
+ list<bit> -> vector<'p, 'm, 'ord, bit> effect pure exts_bl
+
+(* If we want an automatic bitvector extension, then this is the function to
+ use, but I've disabled the cast because it hides signedness bugs. *)
+val (*cast*) forall Num 'n, Num 'm, Num 'o, Num 'p, Order 'ord, 'm >= 'n.
+ vector<'o, 'n, 'ord, bit> -> vector<'p, 'm, 'ord, bit> effect pure extzi
-val forall Type 'a, Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord, 'm >= 'o.
+overload EXTZ [extz; extz_bl]
+overload EXTS [exts; exts_bl]
+
+val forall Type 'a, Num 'n, Num 'm, Num 'o, Num 'p, Order 'ord, 'm >= 'o.
vector<'n, 'm, 'ord, 'a> -> vector<'p, 'o, 'ord, 'a> effect pure mask
(* Adjust the start index of a decreasing bitvector *)
-val cast forall Nat 'n, Nat 'm, Nat 'o, 'n >= 'm - 1, 'o >= 'm - 1.
+val cast forall Num 'n, Num 'm, 'n >= 'm - 1.
+ vector<'n,'m,dec,bit> -> vector<'m - 1,'m,dec,bit>
+ effect pure norm_dec
+
+val cast forall Num 'n, Num 'm, Num 'o, 'n >= 'm - 1, 'o >= 'm - 1.
vector<'n,'m,dec,bit> -> vector<'o,'m,dec,bit>
effect pure adjust_dec
(* Various casts from 0 and 1 to bitvectors *)
-val cast forall Nat 'n, Nat 'l, Order 'ord. [:0:] -> vector<'n,'l,'ord,bit> effect pure cast_0_vec
-val cast forall Nat 'n, Nat 'l, Order 'ord. [:1:] -> vector<'n,'l,'ord,bit> effect pure cast_1_vec
-val cast forall Nat 'n, Nat 'l, Order 'ord. [|0:1|] -> vector<'n,'l,'ord,bit> effect pure cast_01_vec
+val cast forall Num 'n, Num 'l, Order 'ord. [:0:] -> vector<'n,'l,'ord,bit> effect pure cast_0_vec
+val cast forall Num 'n, Num 'l, Order 'ord. [:1:] -> vector<'n,'l,'ord,bit> effect pure cast_1_vec
+val cast forall Num 'n, Num 'l, Order 'ord. [|0:1|] -> vector<'n,'l,'ord,bit> effect pure cast_01_vec
-val cast forall Nat 'n, Order 'ord. vector<'n,1,'ord,bit> -> bool effect pure cast_vec_bool
+val cast forall Num 'n, Order 'ord. vector<'n,1,'ord,bit> -> bool effect pure cast_vec_bool
val cast bit -> bool effect pure cast_bit_bool
+val cast forall Num 'n, Num 'm, 'n >= 'm - 1, 'm >= 1. bit -> vector<'n,'m,dec,bit> effect pure cast_bit_vec
+
(* MSB *)
-val forall Nat 'n, Nat 'm, Order 'ord. vector<'n, 'm, 'ord, bit> -> bit effect pure most_significant
+val forall Num 'n, Num 'm, Order 'ord. vector<'n, 'm, 'ord, bit> -> bit effect pure most_significant
(* Arithmetic *)
-val forall Nat 'n, Nat 'm.
- (atom<'n>, atom<'m>) -> atom<'n+'m> effect pure add
+val extern forall Num 'n, Num 'm, Num 'o, Num 'p.
+ ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add
-val forall Nat 'n, Nat 'o, Nat 'p, Order 'ord.
- (vector<'o, 'n, 'ord, bit>, vector<'p, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_vec
+val extern (nat, nat) -> nat effect pure add_nat = "add"
-val forall Nat 'n, Nat 'o, Nat 'p, Nat 'q, Order 'ord.
- (vector<'o, 'n, 'ord, bit>, vector<'p, 'n, 'ord, bit>) -> range<'q, 2**'n> effect pure add_vec_vec_range
+val extern (int, int) -> int effect pure add_int = "add"
-(* FIXME: the parser is broken for 2**... it's just been hacked to work for this common case *)
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord, 'o <= 2** 'm - 1.
- (vector<'n, 'm, 'ord, bit>, atom<'o>) -> vector<'n, 'm, 'ord, bit> effect pure add_vec_range
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_vec
-val forall Nat 'n, Nat 'o, Nat 'p, Order 'ord.
- (vector<'o, 'n, 'ord, bit>, vector<'p, 'n, 'ord, bit>) -> (vector<'o, 'n, 'ord, bit>, bit, bit) effect pure add_overflow_vec
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, int) -> vector<'o, 'n, 'ord, bit> effect pure add_vec_int
-(* but it doesn't parse this
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord, 'o <= 2** 'm - 1.
- (vector<'n, 'm, 'ord, bit>, atom<'o>) -> range<'o, 'o+2** 'm> effect pure add_vec_range_range
- *)
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> (vector<'o, 'n, 'ord, bit>, bit, bit) effect pure add_overflow_vec
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord, 'o <= 2** 'm - 1.
- (atom<'o>, vector<'n, 'm, 'ord, bit>) -> vector<'n, 'm, 'ord, bit> effect pure add_range_vec
+val extern forall Num 'n, Num 'm, Num 'o, Num 'p.
+ ([|'n:'m|], [|'o:'p|]) -> [|'n - 'p:'m - 'o|] effect pure sub
-(* or this
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord, 'o <= 2** 'm - 1.
- (atom<'o>, vector<'n, 'm, 'ord, bit>) -> range<'o, 'o+2**'m-1> effect pure add_range_vec_range
-*)
+val extern (int, int) -> int effect pure sub_int = "sub"
-val forall Nat 'o, Nat 'p, Order 'ord.
- (vector<'o, 'p, 'ord, bit>, bit) -> vector<'o, 'p, 'ord, bit> effect pure add_vec_bit
+val forall Num 'n, Num 'm, Order 'ord.
+ (vector<'n,'m,'ord,bit>, int) -> vector<'n,'m,'ord,bit> effect pure sub_vec_int
-val forall Nat 'o, Nat 'p, Order 'ord.
- (bit, vector<'o, 'p, 'ord, bit>) -> vector<'o, 'p, 'ord, bit> effect pure add_bit_vec
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure sub_vec
-val forall Nat 'n, Nat 'm. ([:'n:], [:'m:]) -> [:'n - 'm:] effect pure sub_exact
-val forall Nat 'n, Nat 'm, Nat 'o, 'o <= 'm - 'n. ([|'n:'m|], [:'o:]) -> [|'n:'m - 'o|] effect pure sub_range
-val forall Nat 'n, Nat 'm, Order 'ord. (vector<'n,'m,'ord,bit>, int) -> vector<'n,'m,'ord,bit> effect pure sub_bv
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> (vector<'o, 'n, 'ord, bit>, bit, bit) effect pure sub_underflow_vec
overload (deinfix +) [
- add;
add_vec;
- add_vec_vec_range;
- add_vec_range;
add_overflow_vec;
- add_vec_range_range;
- add_range_vec;
- add_range_vec_range;
- add_vec_bit;
- add_bit_vec;
+ add_vec_int;
+ add;
+ add_nat;
+ add_int
]
overload (deinfix -) [
- sub_exact;
- sub_bv;
- sub_range;
+ sub_vec_int;
+ sub_vec;
+ sub_underflow_vec;
+ sub;
+ sub_int
]
-(* Equality *)
+val extern bool -> bit effect pure bool_to_bit = "bool_to_bitU"
-(* Sail gives a bunch of overloads for equality, but apparantly also
-gives an equality and inequality for any type 'a, so why bother
-overloading? *)
+val (int, int) -> int effect pure mul_int
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<2 * 'n - 1, 2 * 'n, 'ord, bit> effect pure mul_vec
-val forall Type 'a. ('a, 'a) -> bool effect pure eq
-val forall Type 'a. ('a, 'a) -> bool effect pure neq
+overload (deinfix * ) [
+ mul_vec;
+ mul_int
+]
+
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<2 * 'n - 1, 2 * 'n, 'ord, bit> effect pure mul_svec
+
+overload (deinfix *_s) [
+ mul_svec
+]
+
+val extern (bool, bool) -> bool effect pure bool_xor
+
+val extern forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure xor_vec = "bitwise_xor"
+
+overload (deinfix ^) [
+ bool_xor;
+ xor_vec
+]
+
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, int) -> vector<'o, 'n, 'ord, bit> effect pure shiftl
+
+overload (deinfix <<) [
+ shiftl
+]
-overload (deinfix ==) [eq]
-overload (deinfix !=) [neq]
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, int) -> vector<'o, 'n, 'ord, bit> effect pure shiftr
+
+overload (deinfix >>) [
+ shiftr
+]
(* Boolean operators *)
-val bool -> bool effect pure bool_not
+val extern bool -> bool effect pure bool_not = "not"
val (bool, bool) -> bool effect pure bool_or
val (bool, bool) -> bool effect pure bool_and
-overload ~ [bool_not]
-overload (deinfix &) [bool_and]
-overload (deinfix |) [bool_or]
+val forall Num 'n, Num 'm, Order 'ord.
+ vector<'n,'m,'ord,bit> -> vector<'n,'m,'ord,bit> effect pure bitwise_not
+
+val forall Num 'n, Num 'm, Order 'ord.
+ (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> vector<'n,'m,'ord,bit> effect pure bitwise_and
+
+val forall Num 'n, Num 'm, Order 'ord.
+ (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> vector<'n,'m,'ord,bit> effect pure bitwise_or
+
+overload ~ [bool_not; bitwise_not]
+overload (deinfix &) [bool_and; bitwise_and]
+overload (deinfix |) [bool_or; bitwise_or]
+
+(* Equality *)
+
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure eq_vec
+
+val forall Type 'a. ('a, 'a) -> bool effect pure eq
+
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure neq_vec
+
+val forall Type 'a. ('a, 'a) -> bool effect pure neq
+
+function forall Num 'n, Num 'm, Order 'ord. bool neq_vec (v1, v2) = bool_not(eq_vec(v1, v2))
+
+overload (deinfix ==) [eq_vec; eq]
+overload (deinfix !=) [neq_vec; neq]
+
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gteq_vec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gt_vec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lteq_vec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lt_vec
+
+val extern (int, int) -> bool effect pure gteq_int = "gteq"
+val extern (int, int) -> bool effect pure gt_int = "gt"
+val extern (int, int) -> bool effect pure lteq_int = "lteq"
+val extern (int, int) -> bool effect pure lt_int = "lt"
+
+val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt"
+val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom = "lteq"
+val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom = "gt"
+val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom = "gteq"
+val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range = "lt"
+val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range = "lteq"
+val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range = "gt"
+val extern forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range = "gteq"
+
+val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure lteq_atom_atom = "lteq"
+val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure gteq_atom_atom = "gteq"
+val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure lt_atom_atom = "lt"
+val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> bool effect pure gt_atom_atom = "gt"
+
+overload (deinfix >=) [gteq_atom_atom; gteq_range_atom; gteq_atom_range; gteq_vec; gteq_int]
+overload (deinfix >) [gt_atom_atom; gt_vec; gt_int]
+overload (deinfix <=) [lteq_atom_atom; lteq_range_atom; lteq_atom_range; lteq_vec; lteq_int]
+overload (deinfix <) [lt_atom_atom; lt_vec; lt_int]
+
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gteq_svec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gt_svec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lteq_svec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lt_svec
+
+overload (deinfix <_s) [lt_svec]
+overload (deinfix <=_s) [lteq_svec]
+overload (deinfix >_s) [gt_svec]
+overload (deinfix >=_s) [gteq_svec]
+
+val (int, int) -> int effect pure quotient
+
+overload (deinfix quot) [quotient]
+
+val (int, int) -> int effect pure modulo
+
+overload (deinfix mod) [modulo]
+
+val extern forall Num 'n, Num 'm, Order 'ord, Type 'a. vector<'n,'m,'ord,'a> -> [:'m:] effect pure vec_length = "length"
+val forall Type 'a. list<'a> -> nat effect pure list_length
+
+val extern forall Num 'n, Num 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [:'m:] effect pure bitvector_length = "bvlength"
+
+overload length [bitvector_length; vector_length; list_length]
+
+val cast forall Num 'n. [:'n:] -> [|'n|] effect pure upper
+
+typedef option = const union forall Type 'a. {
+ None;
+ 'a Some
+}
(* Mips spec starts here *)
@@ -338,7 +490,7 @@ register (TLBEntry) TLBEntry61
register (TLBEntry) TLBEntry62
register (TLBEntry) TLBEntry63
-let (vector <0, 64, inc, (TLBEntry)>) TLBEntries = [
+let (vector <0, 64, inc, (register<TLBEntry>)>) TLBEntries = [
TLBEntry00,
TLBEntry01,
TLBEntry02,
@@ -629,3 +781,5 @@ function unit checkCP0Access () =
SignalException(CpU);
}
}
+
+function forall Type 'o. 'o SignalException ((Exception) ex) = SignalExceptionMIPS(ex, 0x0000000000000000)
diff --git a/test/typecheck/pass/phantom_num.sail b/test/typecheck/pass/phantom_num.sail
index c4ff8b13..c676807d 100644
--- a/test/typecheck/pass/phantom_num.sail
+++ b/test/typecheck/pass/phantom_num.sail
@@ -1,5 +1,5 @@
-val extern (int, int) -> bool effect pure gt_int
+val extern (int, int) -> bool effect pure gt_int = "gt"
(* val cast forall Num 'n, Num 'm. [|'n:'m|] -> int effect pure cast_range_int *)
diff --git a/test/typecheck/run_tests.sh b/test/typecheck/run_tests.sh
index 073a6251..12e6acc0 100755
--- a/test/typecheck/run_tests.sh
+++ b/test/typecheck/run_tests.sh
@@ -109,11 +109,12 @@ function test_lem {
cp $MIPS/mips_extras_embed_sequential.lem $DIR/lem/
mv $SAILDIR/${i%%.*}_embed_types.lem $DIR/lem/
+ mv $SAILDIR/${i%%.*}_embed_types_sequential.lem $DIR/lem/
mv $SAILDIR/${i%%.*}_embed.lem $DIR/lem/
mv $SAILDIR/${i%%.*}_embed_sequential.lem $DIR/lem/
# Test sequential embedding for now
# TODO: Add tests for the free monad
- if lem -lib $SAILDIR/src/lem_interp -lib $SAILDIR/src/gen_lib/ $DIR/lem/mips_extras_embed_sequential.lem $DIR/lem/${i%%.*}_embed_types.lem $DIR/lem/${i%%.*}_embed_sequential.lem 2> /dev/null
+ if lem -lib $SAILDIR/src/lem_interp -lib $SAILDIR/src/gen_lib/ $DIR/lem/mips_extras_embed_sequential.lem $DIR/lem/${i%%.*}_embed_types_sequential.lem $DIR/lem/${i%%.*}_embed_sequential.lem 2> /dev/null
then
green "typechecking lem for $1/$i" "pass"
else