summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJon French2019-04-15 16:18:18 +0100
committerJon French2019-04-15 16:18:18 +0100
commita9f0b829507e9882efdb59cce4d83ea7e87f5f71 (patch)
tree11cde6c1918bc15f4dda9a8e40afd4a1fe912a0a
parent0f6fd188ca232cb539592801fcbb873d59611d81 (diff)
parent57443173923e87f33713c99dbab9eba7e3db0660 (diff)
Merge branch 'sail2' into rmem_interpreter
-rw-r--r--aarch64/aarch64_extras.lem2
-rwxr-xr-xaarch64/prelude.sail12
-rw-r--r--aarch64_small/armV8.sail2
-rw-r--r--aarch64_small/prelude.sail14
-rw-r--r--language/jib.ott48
-rw-r--r--language/sail.ott64
-rw-r--r--lib/arith.sail30
-rw-r--r--lib/coq/Makefile2
-rw-r--r--lib/coq/Sail2_instr_kinds.v15
-rw-r--r--lib/coq/Sail2_operators_mwords.v1
-rw-r--r--lib/coq/Sail2_prompt.v8
-rw-r--r--lib/coq/Sail2_prompt_monad.v298
-rw-r--r--lib/coq/Sail2_state.v129
-rw-r--r--lib/coq/Sail2_state_lifting.v61
-rw-r--r--lib/coq/Sail2_state_monad.v422
-rw-r--r--lib/coq/Sail2_values.v128
-rw-r--r--lib/mono_rewrites.sail108
-rw-r--r--lib/sail.c21
-rw-r--r--lib/sail.h2
-rw-r--r--lib/smt.sail13
-rw-r--r--lib/vector_dec.sail63
-rw-r--r--opam6
-rw-r--r--src/ast_util.ml32
-rw-r--r--src/ast_util.mli7
-rw-r--r--src/constant_fold.ml39
-rw-r--r--src/constant_propagation.ml241
-rw-r--r--src/constant_propagation_mutrec.ml232
-rw-r--r--src/constraint.ml121
-rw-r--r--src/constraint.mli8
-rw-r--r--src/gen_lib/sail2_operators_bitlists.lem3
-rw-r--r--src/gen_lib/sail2_operators_mwords.lem7
-rw-r--r--src/initial_check.ml203
-rw-r--r--src/initial_check.mli6
-rw-r--r--src/interpreter.ml5
-rw-r--r--src/isail.ml292
-rw-r--r--src/jib/anf.ml14
-rw-r--r--src/jib/anf.mli1
-rw-r--r--src/jib/c_backend.ml239
-rw-r--r--src/jib/c_backend.mli1
-rw-r--r--src/jib/jib_compile.ml242
-rw-r--r--src/jib/jib_compile.mli2
-rw-r--r--src/jib/jib_optimize.ml162
-rw-r--r--src/jib/jib_optimize.mli3
-rw-r--r--src/jib/jib_ssa.ml209
-rw-r--r--src/jib/jib_ssa.mli18
-rw-r--r--src/jib/jib_util.ml253
-rw-r--r--src/latex.ml2
-rw-r--r--src/monomorphise.ml161
-rw-r--r--src/ocaml_backend.ml3
-rw-r--r--src/parse_ast.ml2
-rw-r--r--src/parser.mly98
-rw-r--r--src/pretty_print_common.ml16
-rw-r--r--src/pretty_print_coq.ml455
-rw-r--r--src/pretty_print_lem.ml35
-rw-r--r--src/pretty_print_sail.ml12
-rw-r--r--src/process_file.ml53
-rw-r--r--src/process_file.mli10
-rw-r--r--src/reporting.ml14
-rw-r--r--src/reporting.mli6
-rw-r--r--src/rewrites.ml466
-rw-r--r--src/rewrites.mli29
-rw-r--r--src/sail.ml329
-rw-r--r--src/sail_lib.ml34
-rw-r--r--src/slice.ml56
-rw-r--r--src/slice.mli3
-rw-r--r--src/specialize.ml20
-rw-r--r--src/specialize.mli7
-rw-r--r--src/toFromInterp_backend.ml4
-rw-r--r--src/type_check.ml108
-rw-r--r--src/type_check.mli3
-rw-r--r--src/value.ml55
-rw-r--r--src/value2.lem20
-rwxr-xr-xtest/arm/run_tests.sh2
-rw-r--r--test/arm/test.isail1
-rw-r--r--test/builtins/div_int.sail2
-rw-r--r--test/builtins/div_int2.sail2
-rw-r--r--test/builtins/divmod.sail43
-rw-r--r--test/c/anf_as_pattern.expect1
-rw-r--r--test/c/anf_as_pattern.sail19
-rw-r--r--test/c/anon_rec.expect1
-rw-r--r--test/c/anon_rec.sail12
-rw-r--r--test/c/execute.isail1
-rw-r--r--test/c/flow_restrict.expect1
-rw-r--r--test/c/flow_restrict.sail23
-rw-r--r--test/c/poly_int_record.expect3
-rw-r--r--test/c/poly_int_record.sail21
-rw-r--r--test/c/poly_record.expect1
-rw-r--r--test/c/poly_record.sail18
-rwxr-xr-xtest/c/run_tests.py2
-rw-r--r--test/c/tuple_union.expect42
-rw-r--r--test/c/tuple_union.sail48
-rw-r--r--test/c/unused_poly_ctor.expect1
-rw-r--r--test/c/unused_poly_ctor.sail18
-rw-r--r--test/coq/_CoqProject2
-rw-r--r--test/coq/pass/foreach_using_tyvar.sail11
-rw-r--r--test/coq/pass/rebind.sail10
-rw-r--r--test/coq/pass/unbound_ex_tyvars.sail16
-rw-r--r--test/coq/pass/unpacking.sail16
-rw-r--r--test/coq/skip33
-rw-r--r--test/mono/exint.sail4
-rw-r--r--test/ocaml/bitfield/test.isail1
-rw-r--r--test/ocaml/hello_world/test.isail1
-rw-r--r--test/ocaml/loop/test.isail1
-rw-r--r--test/ocaml/lsl/test.isail1
-rw-r--r--test/ocaml/pattern1/test.isail1
-rw-r--r--test/ocaml/reg_alias/test.isail1
-rw-r--r--test/ocaml/reg_passing/test.isail1
-rw-r--r--test/ocaml/reg_ref/test.isail1
-rwxr-xr-xtest/ocaml/run_tests.sh2
-rw-r--r--test/ocaml/short_circuit/test.isail1
-rw-r--r--test/ocaml/string_equality/test.isail1
-rw-r--r--test/ocaml/string_of_struct/test.isail1
-rw-r--r--test/ocaml/trycatch/test.isail1
-rw-r--r--test/ocaml/types/test.isail1
-rw-r--r--test/ocaml/vec_32_64/test.isail1
-rw-r--r--test/ocaml/void/test.isail1
-rw-r--r--test/typecheck/pass/Replicate.sail3
-rw-r--r--test/typecheck/pass/Replicate/v1.expect6
-rw-r--r--test/typecheck/pass/Replicate/v1.sail3
-rw-r--r--test/typecheck/pass/Replicate/v2.expect6
-rw-r--r--test/typecheck/pass/Replicate/v2.sail3
-rw-r--r--test/typecheck/pass/anon_rec.sail12
-rw-r--r--test/typecheck/pass/existential_ast/v3.expect2
-rw-r--r--test/typecheck/pass/existential_ast3/v1.expect8
-rw-r--r--test/typecheck/pass/existential_ast3/v2.expect8
-rw-r--r--test/typecheck/pass/existential_ast3/v3.expect2
-rw-r--r--test/typecheck/pass/guards.sail3
-rw-r--r--test/typecheck/pass/if_infer/v1.expect4
-rw-r--r--test/typecheck/pass/if_infer/v2.expect4
-rw-r--r--test/typecheck/pass/recursion.sail2
-rw-r--r--test/typecheck/pass/shadow_let.sail14
-rw-r--r--test/typecheck/pass/shadow_let/v1.expect12
-rw-r--r--test/typecheck/pass/shadow_let/v1.sail14
133 files changed, 4107 insertions, 2160 deletions
diff --git a/aarch64/aarch64_extras.lem b/aarch64/aarch64_extras.lem
index d22ece00..b662e230 100644
--- a/aarch64/aarch64_extras.lem
+++ b/aarch64/aarch64_extras.lem
@@ -78,7 +78,7 @@ val write_ram : forall 'rv 'e.
integer -> integer -> list bitU -> list bitU -> list bitU -> monad 'rv unit 'e
let write_ram addrsize size hexRAM address value =
write_mem_ea Write_plain address size >>
- write_mem_val value >>= fun _ ->
+ write_mem Write_plain address size value >>= fun _ ->
return ()
val read_ram : forall 'rv 'e.
diff --git a/aarch64/prelude.sail b/aarch64/prelude.sail
index f4f7dc75..431ad1f7 100755
--- a/aarch64/prelude.sail
+++ b/aarch64/prelude.sail
@@ -284,17 +284,11 @@ val abs_real = {coq: "Rabs", _: "abs_real"} : real -> real
overload abs = {abs_atom, abs_real}
-val quotient_nat = {ocaml: "quotient", lem: "integerDiv", c: "tdiv_int"} : (nat, nat) -> nat
-
val quotient_real = {ocaml: "quotient_real", lem: "realDiv", c: "div_real", coq: "Rdiv"} : (real, real) -> real
-val quotient = {ocaml: "quotient", lem: "integerDiv", c: "tdiv_int", coq: "Z.quot"} : (int, int) -> int
-
-overload operator / = {quotient_nat, quotient, quotient_real}
-
-val modulus = {ocaml: "modulus", lem: "hardware_mod", c: "tmod_int", coq: "Z.rem"} : (int, int) -> int
-
-overload operator % = {modulus}
+overload operator / = {ediv_int, quotient_real}
+overload div = {ediv_int}
+overload operator % = {emod_int}
val Real = {ocaml: "to_real", lem: "realFromInteger", c: "to_real", coq: "IZR"} : int -> real
diff --git a/aarch64_small/armV8.sail b/aarch64_small/armV8.sail
index f125ec72..a9a78900 100644
--- a/aarch64_small/armV8.sail
+++ b/aarch64_small/armV8.sail
@@ -2201,7 +2201,7 @@ function clause execute ( Division((d,n,m,datasize as int('R),_unsigned)) ) = {
if IsZero(operand2) then
result = 0
else
- result = /* ARM: RoundTowardsZero*/ quot (_Int(operand1, _unsigned), _Int(operand2, _unsigned)); /* FIXME: does quot round towards zero? */
+ result = /* ARM: RoundTowardsZero*/ tdiv_int (_Int(operand1, _unsigned), _Int(operand2, _unsigned));
wX(d) = to_bits(result) : (bits('R)) ; /* ARM: result[(datasize-1)..0] */
}
diff --git a/aarch64_small/prelude.sail b/aarch64_small/prelude.sail
index 2dbd2bf4..f97c84a6 100644
--- a/aarch64_small/prelude.sail
+++ b/aarch64_small/prelude.sail
@@ -150,17 +150,9 @@ overload operator ^ = {xor_vec, int_power, concat_str}
val mask : forall 'l 'm, 'l >= 0 & 'm >= 0. (implicit('l), bits('m)) -> bits('l)
-/* put this val spec into Sail lib for "%" */
-
-val mod = {
- smt: "mod",
- ocaml: "modulus",
- lem: "integerMod",
- c: "tmod_int",
- coq: "Z.rem"
-} : forall 'M 'N. (int('M), int('N)) -> {'O, 0 <= 'O & 'O < N . int('O)}
-
-/* overload operator % = {mod_int} */
+overload operator % = {emod_int}
+overload operator / = {ediv_int}
+overload mod = {emod_int}
val print = "print_endline" : string -> unit
diff --git a/language/jib.ott b/language/jib.ott
index e54e2ea5..5f800fcd 100644
--- a/language/jib.ott
+++ b/language/jib.ott
@@ -47,21 +47,27 @@ open import Value2
grammar
+name :: '' ::=
+ | id nat :: :: name
+ | have_exception nat :: :: have_exception
+ | current_exception nat :: :: current_exception
+ | return nat :: :: return
+
% Fragments are small pure snippets of (abstract) C code, mostly
-% expressions, used by the aval and cval types.
+% expressions, used by the aval (ANF) and cval (Jib) types.
fragment :: 'F_' ::=
- | id :: :: id
- | '&' id :: :: ref
- | value :: :: lit
- | have_exception :: :: have_exception
- | current_exception :: :: current_exception
- | fragment op fragment' :: :: op
- | op fragment :: :: unary
- | string ( fragment0 , ... , fragmentn ) :: :: call
- | fragment . string :: :: field
- | string :: :: raw
- | poly fragment :: :: poly
+ | name :: :: id
+ | '&' name :: :: ref
+ | value :: :: lit
+ | fragment != kind id ( ctyp0 , ... , ctypn ) ctyp :: :: ctor_kind
+ | unwrap id ( ctyp0 , ... , ctypn ) fragment :: :: ctor_unwrap
+ | fragment op fragment' :: :: op
+ | op fragment :: :: unary
+ | string ( fragment0 , ... , fragmentn ) :: :: call
+ | fragment . string :: :: field
+ | string :: :: raw
+ | poly fragment :: :: poly
% Note that init / clear are sometimes refered to as create / kill
@@ -129,13 +135,10 @@ cval :: 'CV_' ::=
{{ lem fragment * ctyp }}
clexp :: 'CL_' ::=
- | id : ctyp :: :: id
+ | name : ctyp :: :: id
| clexp . string :: :: field
| * clexp :: :: addr
| clexp . nat :: :: tuple
- | current_exception : ctyp :: :: current_exception
- | have_exception :: :: have_exception
- | return : ctyp :: :: return
| void :: :: void
ctype_def :: 'CTD_' ::=
@@ -152,17 +155,17 @@ instr :: 'I_' ::=
{{ aux _ iannot }}
% The following are the minimal set of instructions output by
% Jib_compile.ml.
- | ctyp id :: :: decl
- | ctyp id = cval :: :: init
+ | ctyp name :: :: decl
+ | ctyp name = cval :: :: init
| jump ( cval ) string :: :: jump
| goto string :: :: goto
| string : :: :: label
| clexp = bool id ( cval0 , ... , cvaln ) :: :: funcall
| clexp = cval :: :: copy
- | clear ctyp id :: :: clear
+ | clear ctyp name :: :: clear
| undefined ctyp :: :: undefined
| match_failure :: :: match_failure
- | end :: :: end
+ | end name :: :: end
% All instructions containing nested instructions can be flattened
% away. try and throw only exist for internal use within
@@ -187,9 +190,8 @@ instr :: 'I_' ::=
| return cval :: :: return
% For optimising away allocations and copying.
- | reset ctyp id :: :: reset
- | ctyp id = cval :: :: reinit
- | alias clexp = cval :: :: alias
+ | reset ctyp name :: :: reset
+ | ctyp name = cval :: :: reinit
cdef :: 'CDEF_' ::=
| register id : ctyp = {
diff --git a/language/sail.ott b/language/sail.ott
index b3df66bb..00a62fe3 100644
--- a/language/sail.ott
+++ b/language/sail.ott
@@ -119,44 +119,16 @@ l :: '' ::= {{ phantom }}
{{ hol () }}
annot :: '' ::=
- {{ phantom }}
- {{ ocaml 'a annot }}
- {{ lem annot 'a }}
- {{ hol unit }}
+ {{ phantom }}
+ {{ ocaml 'a annot }}
+ {{ lem annot 'a }}
+ {{ hol unit }}
id :: '' ::=
{{ com Identifier }}
{{ aux _ l }}
- | x :: :: id
- | ( deinfix x ) :: D :: deIid {{ com remove infix status }}
- | bool :: M :: bool {{ com built in type identifiers }} {{ ichlo (Id "bool") }}
- | bit :: M :: bit {{ ichlo (Id "bit") }}
- | unit :: M :: unit {{ ichlo (Id "unit") }}
- | nat :: M :: nat {{ ichlo (Id "nat") }}
- | int :: M :: int {{ ichlo (Id "int") }}
- | string :: M :: string {{ tex \ottkw{string} }} {{ ichlo (Id "string") }}
- | range :: M :: range {{ ichlo (Id "range") }}
- | atom :: M :: atom {{ ichlo (Id "atom") }}
- | vector :: M :: vector {{ ichlo (Id "vector") }}
- | list :: M :: list {{ ichlo (Id "list") }}
-% | set :: M :: set {{ ichlo (Id "set") }}
- | reg :: M :: reg {{ ichlo (Id "reg") }}
- | to_num :: M :: tonum {{ com built-in function identifiers }} {{ ichlo (Id "to_num") }}
- | to_vec :: M :: tovec {{ ichlo (Id "to_vec") }}
- | msb :: M :: msb {{ ichlo (Id "msb") }}
-% Note: we have just a single namespace. We don't want the same
-% identifier to be reused as a type name or variable, expression
-% variable, and field name. We don't enforce any lexical convention
-% on type variables (or variables of other kinds)
-% We don't enforce a lexical convention on infix operators, as some of the
-% targets use alphabetical infix operators.
-
-% Vector builtins
- | vector_access :: M :: vector_access {{ ichlo (Id "vector_access") }}
- | vector_update :: M :: vector_update {{ ichlo (Id "vector_update") }}
- | vector_update_subrange :: M :: vector_update_subrange {{ ichlo (Id "vector_update_subrange") }}
- | vector_subrange :: M :: vector_subrange {{ ichlo (Id "vector_subrange") }}
- | vector_append :: M :: vector_append {{ ichlo (Id "vector_append") }}
+ | x :: :: id
+ | ( operator x ) :: D :: operator {{ com remove infix status }}
kid :: '' ::=
{{ com kinded IDs: Type, Int, and Order variables }}
@@ -180,23 +152,23 @@ kind :: 'K_' ::=
nexp :: 'Nexp_' ::=
{{ com numeric expression, of kind Int }}
{{ aux _ l }}
- | id :: :: id {{ com abbreviation identifier }}
- | kid :: :: var {{ com variable }}
- | num :: :: constant {{ com constant }}
- | id ( nexp1 , ... , nexpn ) :: :: app {{ com app }}
- | nexp1 * nexp2 :: :: times {{ com product }}
- | nexp1 + nexp2 :: :: sum {{ com sum }}
- | nexp1 - nexp2 :: :: minus {{ com subtraction }}
- | 2** nexp :: :: exp {{ com exponential }}
- | neg nexp :: I :: neg {{ com for internal use only}}
- | ( nexp ) :: S :: paren {{ ichlo [[nexp]] }}
+ | id :: :: id {{ com abbreviation identifier }}
+ | kid :: :: var {{ com variable }}
+ | num :: :: constant {{ com constant }}
+ | id ( nexp1 , ... , nexpn ) :: :: app {{ com app }}
+ | nexp1 * nexp2 :: :: times {{ com product }}
+ | nexp1 + nexp2 :: :: sum {{ com sum }}
+ | nexp1 - nexp2 :: :: minus {{ com subtraction }}
+ | 2 ^ nexp :: :: exp {{ com exponential }}
+ | - nexp :: :: neg {{ com unary negation}}
+ | ( nexp ) :: S :: paren {{ ichlo [[nexp]] }}
order :: 'Ord_' ::=
{{ com vector order specifications, of kind Order }}
{{ aux _ l }}
| kid :: :: var {{ com variable }}
- | inc :: :: inc {{ com increasing }}
- | dec :: :: dec {{ com decreasing }}
+ | inc :: :: inc {{ com increasing }}
+ | dec :: :: dec {{ com decreasing }}
| ( order ) :: S :: paren {{ ichlo [[order]] }}
base_effect :: 'BE_' ::=
diff --git a/lib/arith.sail b/lib/arith.sail
index a1eef9f0..6ddc58aa 100644
--- a/lib/arith.sail
+++ b/lib/arith.sail
@@ -78,28 +78,22 @@ overload shr_int = {_shr32, _shr_int}
// ***** div and mod *****
-val div_int = {
- smt: "div",
- ocaml: "quotient",
- interpreter: "quotient",
- lem: "integerDiv",
- c: "tdiv_int",
- coq: "Z.quot"
+/*! Truncating division (rounds towards zero) */
+val tdiv_int = {
+ ocaml: "tdiv_int",
+ interpreter: "tdiv_int",
+ lem: "integerDiv_t",
+ c: "tdiv_int"
} : (int, int) -> int
-overload operator / = {div_int}
-
-val mod_int = {
- smt: "mod",
- ocaml: "modulus",
- interpreter: "modulus",
- lem: "integerMod",
- c: "tmod_int",
- coq: "Z.rem"
+/*! Remainder for truncating division (has sign of dividend) */
+val tmod_int = {
+ ocaml: "tmod_int",
+ interpreter: "tmod_int",
+ lem: "integerMod_t",
+ c: "tmod_int"
} : (int, int) -> nat
-overload operator % = {mod_int}
-
val abs_int = {
smt : "abs",
ocaml: "abs_int",
diff --git a/lib/coq/Makefile b/lib/coq/Makefile
index 6dd962d1..f763db6f 100644
--- a/lib/coq/Makefile
+++ b/lib/coq/Makefile
@@ -1,6 +1,6 @@
BBV_DIR?=../../../bbv
-SRC=Sail2_prompt_monad.v Sail2_prompt.v Sail2_impl_base.v Sail2_instr_kinds.v Sail2_operators_bitlists.v Sail2_operators_mwords.v Sail2_operators.v Sail2_values.v Sail2_state_monad.v Sail2_state.v Sail2_string.v Sail2_real.v
+SRC=Sail2_prompt_monad.v Sail2_prompt.v Sail2_impl_base.v Sail2_instr_kinds.v Sail2_operators_bitlists.v Sail2_operators_mwords.v Sail2_operators.v Sail2_values.v Sail2_state_monad.v Sail2_state.v Sail2_state_lifting.v Sail2_string.v Sail2_real.v
COQ_LIBS = -R . Sail -R "$(BBV_DIR)/theories" bbv
diff --git a/lib/coq/Sail2_instr_kinds.v b/lib/coq/Sail2_instr_kinds.v
index c6fb866b..338bf10b 100644
--- a/lib/coq/Sail2_instr_kinds.v
+++ b/lib/coq/Sail2_instr_kinds.v
@@ -48,14 +48,13 @@
(* SUCH DAMAGE. *)
(*========================================================================*)
+Require Import DecidableClass.
-(*
-
-class ( EnumerationType 'a )
- val toNat : 'a -> nat
-end
-
+Class EnumerationType (A : Type) := {
+ toNat : A -> nat
+}.
+(*
val enumeration_typeCompare : forall 'a. EnumerationType 'a => 'a -> 'a -> ordering
let ~{ocaml} enumeration_typeCompare e1 e2 :=
compare (toNat e1) (toNat e2)
@@ -89,6 +88,7 @@ Inductive read_kind :=
(* x86 reads *)
| Read_X86_locked (* the read part of a lock'd instruction (rmw) *)
.
+Scheme Equality for read_kind.
(*
instance (Show read_kind)
let show := function
@@ -121,6 +121,7 @@ Inductive write_kind :=
(* x86 writes *)
| Write_X86_locked (* the write part of a lock'd instruction (rmw) *)
.
+Scheme Equality for write_kind.
(*
instance (Show write_kind)
let show := function
@@ -161,6 +162,7 @@ Inductive barrier_kind :=
| Barrier_RISCV_i
(* X86 *)
| Barrier_x86_MFENCE.
+Scheme Equality for barrier_kind.
(*
instance (Show barrier_kind)
@@ -196,6 +198,7 @@ end*)
Inductive trans_kind :=
(* AArch64 *)
| Transaction_start | Transaction_commit | Transaction_abort.
+Scheme Equality for trans_kind.
(*
instance (Show trans_kind)
let show := function
diff --git a/lib/coq/Sail2_operators_mwords.v b/lib/coq/Sail2_operators_mwords.v
index 7e4abe29..ebab269f 100644
--- a/lib/coq/Sail2_operators_mwords.v
+++ b/lib/coq/Sail2_operators_mwords.v
@@ -497,3 +497,4 @@ Definition set_slice_int len n lo (v : mword len) : Z :=
else n.
Definition prerr_bits {a} (s : string) (bs : mword a) : unit := tt.
+Definition print_bits {a} (s : string) (bs : mword a) : unit := tt.
diff --git a/lib/coq/Sail2_prompt.v b/lib/coq/Sail2_prompt.v
index bae8381e..8efd66f0 100644
--- a/lib/coq/Sail2_prompt.v
+++ b/lib/coq/Sail2_prompt.v
@@ -129,11 +129,11 @@ wfR) y)
end.
Definition Zwf_guarded (z:Z) : Acc (Zwf 0) z :=
- match z with
+ Acc_intro _ (fun y H => match z with
| Zpos p => pos_guard_wf p (Zwf_well_founded _) _
- | _ => Zwf_well_founded _ _
- end.
-
+ | Zneg p => pos_guard_wf p (Zwf_well_founded _) _
+ | Z0 => Zwf_well_founded _ _
+ end).
(*val whileM : forall 'rv 'vars 'e. 'vars -> ('vars -> monad 'rv bool 'e) ->
('vars -> monad 'rv 'vars 'e) -> monad 'rv 'vars 'e
diff --git a/lib/coq/Sail2_prompt_monad.v b/lib/coq/Sail2_prompt_monad.v
index 2715b5e7..39567520 100644
--- a/lib/coq/Sail2_prompt_monad.v
+++ b/lib/coq/Sail2_prompt_monad.v
@@ -2,27 +2,28 @@ Require Import String.
(*Require Import Sail_impl_base*)
Require Import Sail2_instr_kinds.
Require Import Sail2_values.
-
-
+Require bbv.Word.
+Import ListNotations.
Definition register_name := string.
Definition address := list bitU.
Inductive monad regval a e :=
| Done : a -> monad regval a e
- (* Read a number : bytes from memory, returned in little endian order *)
- | Read_mem : read_kind -> address -> nat -> (list memory_byte -> monad regval a e) -> monad regval a e
- (* Read the tag : a memory address *)
- | Read_tag : address -> (bitU -> monad regval a e) -> monad regval a e
- (* Tell the system a write is imminent, at address lifted, : size nat *)
- | Write_ea : write_kind -> address -> nat -> monad regval a e -> monad regval a e
+ (* Read a number of bytes from memory, returned in little endian order,
+ with or without a tag. The first nat specifies the address, the second
+ the number of bytes. *)
+ | Read_mem : read_kind -> nat -> nat -> (list memory_byte -> monad regval a e) -> monad regval a e
+ | Read_memt : read_kind -> nat -> nat -> ((list memory_byte * bitU) -> monad regval a e) -> monad regval a e
+ (* Tell the system a write is imminent, at the given address and with the
+ given size. *)
+ | Write_ea : write_kind -> nat -> nat -> monad regval a e -> monad regval a e
(* Request the result : store-exclusive *)
| Excl_res : (bool -> monad regval a e) -> monad regval a e
- (* Request to write memory at last signalled address. Memory value should be 8
- times the size given in ea signal, given in little endian order *)
- | Write_memv : list memory_byte -> (bool -> monad regval a e) -> monad regval a e
- (* Request to write the tag at last signalled address. *)
- | Write_tag : address -> bitU -> (bool -> monad regval a e) -> monad regval a e
+ (* Request to write a memory value of the given size at the given address,
+ with or without a tag. *)
+ | Write_mem : write_kind -> nat -> nat -> list memory_byte -> (bool -> monad regval a e) -> monad regval a e
+ | Write_memt : write_kind -> nat -> nat -> list memory_byte -> bitU -> (bool -> monad regval a e) -> monad regval a e
(* Tell the system to dynamically recalculate dependency footprint *)
| Footprint : monad regval a e -> monad regval a e
(* Request a memory barrier *)
@@ -31,50 +32,70 @@ Inductive monad regval a e :=
| Read_reg : register_name -> (regval -> monad regval a e) -> monad regval a e
(* Request to write register *)
| Write_reg : register_name -> regval -> monad regval a e -> monad regval a e
- | Undefined : (bool -> monad regval a e) -> monad regval a e
- (*Result : a failed assert with possible error message to report*)
+ (* Request to choose a Boolean, e.g. to resolve an undefined bit. The string
+ argument may be used to provide information to the system about what the
+ Boolean is going to be used for. *)
+ | Choose : string -> (bool -> monad regval a e) -> monad regval a e
+ (* Print debugging or tracing information *)
+ | Print : string -> monad regval a e -> monad regval a e
+ (*Result of a failed assert with possible error message to report*)
| Fail : string -> monad regval a e
- | Error : string -> monad regval a e
- (* Exception : type e *)
+ (* Exception of type e *)
| Exception : e -> monad regval a e.
- (* TODO: Reading/writing tags *)
Arguments Done [_ _ _].
Arguments Read_mem [_ _ _].
-Arguments Read_tag [_ _ _].
+Arguments Read_memt [_ _ _].
Arguments Write_ea [_ _ _].
Arguments Excl_res [_ _ _].
-Arguments Write_memv [_ _ _].
-Arguments Write_tag [_ _ _].
+Arguments Write_mem [_ _ _].
+Arguments Write_memt [_ _ _].
Arguments Footprint [_ _ _].
Arguments Barrier [_ _ _].
Arguments Read_reg [_ _ _].
Arguments Write_reg [_ _ _].
-Arguments Undefined [_ _ _].
+Arguments Choose [_ _ _].
+Arguments Print [_ _ _].
Arguments Fail [_ _ _].
-Arguments Error [_ _ _].
Arguments Exception [_ _ _].
+Inductive event {regval} :=
+ | E_read_mem : read_kind -> nat -> nat -> list memory_byte -> event
+ | E_read_memt : read_kind -> nat -> nat -> (list memory_byte * bitU) -> event
+ | E_write_mem : write_kind -> nat -> nat -> list memory_byte -> bool -> event
+ | E_write_memt : write_kind -> nat -> nat -> list memory_byte -> bitU -> bool -> event
+ | E_write_ea : write_kind -> nat -> nat -> event
+ | E_excl_res : bool -> event
+ | E_barrier : barrier_kind -> event
+ | E_footprint : event
+ | E_read_reg : register_name -> regval -> event
+ | E_write_reg : register_name -> regval -> event
+ | E_choose : string -> bool -> event
+ | E_print : string -> event.
+Arguments event : clear implicits.
+
+Definition trace regval := list (event regval).
+
(*val return : forall rv a e. a -> monad rv a e*)
Definition returnm {rv A E} (a : A) : monad rv A E := Done a.
(*val bind : forall rv a b e. monad rv a e -> (a -> monad rv b e) -> monad rv b e*)
Fixpoint bind {rv A B E} (m : monad rv A E) (f : A -> monad rv B E) := match m with
| Done a => f a
- | Read_mem rk a sz k => Read_mem rk a sz (fun v => bind (k v) f)
- | Read_tag a k => Read_tag a (fun v => bind (k v) f)
- | Write_memv descr k => Write_memv descr (fun v => bind (k v) f)
- | Write_tag a t k => Write_tag a t (fun v => bind (k v) f)
- | Read_reg descr k => Read_reg descr (fun v => bind (k v) f)
- | Excl_res k => Excl_res (fun v => bind (k v) f)
- | Undefined k => Undefined (fun v => bind (k v) f)
- | Write_ea wk a sz k => Write_ea wk a sz (bind k f)
- | Footprint k => Footprint (bind k f)
- | Barrier bk k => Barrier bk (bind k f)
- | Write_reg r v k => Write_reg r v (bind k f)
- | Fail descr => Fail descr
- | Error descr => Error descr
- | Exception e => Exception e
+ | Read_mem rk a sz k => Read_mem rk a sz (fun v => bind (k v) f)
+ | Read_memt rk a sz k => Read_memt rk a sz (fun v => bind (k v) f)
+ | Write_mem wk a sz v k => Write_mem wk a sz v (fun v => bind (k v) f)
+ | Write_memt wk a sz v t k => Write_memt wk a sz v t (fun v => bind (k v) f)
+ | Read_reg descr k => Read_reg descr (fun v => bind (k v) f)
+ | Excl_res k => Excl_res (fun v => bind (k v) f)
+ | Choose descr k => Choose descr (fun v => bind (k v) f)
+ | Write_ea wk a sz k => Write_ea wk a sz (bind k f)
+ | Footprint k => Footprint (bind k f)
+ | Barrier bk k => Barrier bk (bind k f)
+ | Write_reg r v k => Write_reg r v (bind k f)
+ | Print msg k => Print msg (bind k f)
+ | Fail descr => Fail descr
+ | Exception e => Exception e
end.
Notation "m >>= f" := (bind m f) (at level 50, left associativity).
@@ -86,8 +107,11 @@ Notation "m >> n" := (bind0 m n) (at level 50, left associativity).
(*val exit : forall rv a e. unit -> monad rv a e*)
Definition exit {rv A E} (_ : unit) : monad rv A E := Fail "exit".
+(*val choose_bool : forall 'rv 'e. string -> monad 'rv bool 'e*)
+Definition choose_bool {rv E} descr : monad rv bool E := Choose descr returnm.
+
(*val undefined_bool : forall 'rv 'e. unit -> monad 'rv bool 'e*)
-Definition undefined_bool {rv e} (_:unit) : monad rv bool e := Undefined returnm.
+Definition undefined_bool {rv e} (_:unit) : monad rv bool e := choose_bool "undefined_bool".
(*val assert_exp : forall rv e. bool -> string -> monad rv unit e*)
Definition assert_exp {rv E} (exp :bool) msg : monad rv unit E :=
@@ -104,21 +128,21 @@ Definition throw {rv A E} e : monad rv A E := Exception e.
(*val try_catch : forall rv a e1 e2. monad rv a e1 -> (e1 -> monad rv a e2) -> monad rv a e2*)
Fixpoint try_catch {rv A E1 E2} (m : monad rv A E1) (h : E1 -> monad rv A E2) := match m with
- | Done a => Done a
- | Read_mem rk a sz k => Read_mem rk a sz (fun v => try_catch (k v) h)
- | Read_tag a k => Read_tag a (fun v => try_catch (k v) h)
- | Write_memv descr k => Write_memv descr (fun v => try_catch (k v) h)
- | Write_tag a t k => Write_tag a t (fun v => try_catch (k v) h)
- | Read_reg descr k => Read_reg descr (fun v => try_catch (k v) h)
- | Excl_res k => Excl_res (fun v => try_catch (k v) h)
- | Undefined k => Undefined (fun v => try_catch (k v) h)
- | Write_ea wk a sz k => Write_ea wk a sz (try_catch k h)
- | Footprint k => Footprint (try_catch k h)
- | Barrier bk k => Barrier bk (try_catch k h)
- | Write_reg r v k => Write_reg r v (try_catch k h)
- | Fail descr => Fail descr
- | Error descr => Error descr
- | Exception e => h e
+ | Done a => Done a
+ | Read_mem rk a sz k => Read_mem rk a sz (fun v => try_catch (k v) h)
+ | Read_memt rk a sz k => Read_memt rk a sz (fun v => try_catch (k v) h)
+ | Write_mem wk a sz v k => Write_mem wk a sz v (fun v => try_catch (k v) h)
+ | Write_memt wk a sz v t k => Write_memt wk a sz v t (fun v => try_catch (k v) h)
+ | Read_reg descr k => Read_reg descr (fun v => try_catch (k v) h)
+ | Excl_res k => Excl_res (fun v => try_catch (k v) h)
+ | Choose descr k => Choose descr (fun v => try_catch (k v) h)
+ | Write_ea wk a sz k => Write_ea wk a sz (try_catch k h)
+ | Footprint k => Footprint (try_catch k h)
+ | Barrier bk k => Barrier bk (try_catch k h)
+ | Write_reg r v k => Write_reg r v (try_catch k h)
+ | Print msg k => Print msg (try_catch k h)
+ | Fail descr => Fail descr
+ | Exception e => h e
end.
(* For early return, we abuse exceptions by throwing and catching
@@ -158,9 +182,23 @@ match x with
| None => Fail msg
end.
+(*val read_memt_bytes : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv (list memory_byte * bitU) 'e*)
+Definition read_memt_bytes {rv A E} rk (addr : mword A) sz : monad rv (list memory_byte * bitU) E :=
+ Read_memt rk (Word.wordToNat (get_word addr)) (Z.to_nat sz) returnm.
+
+(*val read_memt : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv ('b * bitU) 'e*)
+Definition read_memt {rv A B E} `{ArithFact (B >= 0)} rk (addr : mword A) sz : monad rv (mword B * bitU) E :=
+ bind
+ (read_memt_bytes rk addr sz)
+ (fun '(bytes, tag) =>
+ match of_bits (bits_of_mem_bytes bytes) with
+ | Some v => returnm (v, tag)
+ | None => Fail "bits_of_mem_bytes"
+ end).
+
(*val read_mem_bytes : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv (list memory_byte) 'e*)
Definition read_mem_bytes {rv A E} rk (addr : mword A) sz : monad rv (list memory_byte) E :=
- Read_mem rk (bits_of addr) (Z.to_nat sz) returnm.
+ Read_mem rk (Word.wordToNat (get_word addr)) (Z.to_nat sz) returnm.
(*val read_mem : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv 'b 'e*)
Definition read_mem {rv A B E} `{ArithFact (B >= 0)} rk (addr : mword A) sz : monad rv (mword B) E :=
@@ -169,50 +207,56 @@ Definition read_mem {rv A B E} `{ArithFact (B >= 0)} rk (addr : mword A) sz : mo
(fun bytes =>
maybe_fail "bits_of_mem_bytes" (of_bits (bits_of_mem_bytes bytes))).
-(*val read_tag : forall rv a e. Bitvector a => a -> monad rv bitU e*)
-Definition read_tag {rv a e} `{Bitvector a} (addr : a) : monad rv bitU e :=
- Read_tag (bits_of addr) returnm.
-
(*val excl_result : forall rv e. unit -> monad rv bool e*)
Definition excl_result {rv e} (_:unit) : monad rv bool e :=
let k successful := (returnm successful) in
Excl_res k.
-Definition write_mem_ea {rv a E} `{Bitvector a} wk (addr: a) sz : monad rv unit E :=
- Write_ea wk (bits_of addr) (Z.to_nat sz) (Done tt).
-
-Definition write_mem_val {rv a e} `{Bitvector a} (v : a) : monad rv bool e := match mem_bytes_of_bits v with
- | Some v => Write_memv v returnm
- | None => Fail "write_mem_val"
-end.
-
-(*val write_tag : forall rv a e. Bitvector 'a => 'a -> bitU -> monad rv bool e*)
-Definition write_tag {rv a e} (addr : mword a) (b : bitU) : monad rv bool e := Write_tag (bits_of addr) b returnm.
+Definition write_mem_ea {rv a E} wk (addr: mword a) sz : monad rv unit E :=
+ Write_ea wk (Word.wordToNat (get_word addr)) (Z.to_nat sz) (Done tt).
+
+(*val write_mem : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b =>
+ write_kind -> 'a -> integer -> 'b -> monad 'rv bool 'e*)
+Definition write_mem {rv a b E} wk (addr : mword a) sz (v : mword b) : monad rv bool E :=
+ match (mem_bytes_of_bits v, Word.wordToNat (get_word addr)) with
+ | (Some v, addr) =>
+ Write_mem wk addr (Z.to_nat sz) v returnm
+ | _ => Fail "write_mem"
+ end.
+
+(*val write_memt : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b =>
+ write_kind -> 'a -> integer -> 'b -> bitU -> monad 'rv bool 'e*)
+Definition write_memt {rv a b E} wk (addr : mword a) sz (v : mword b) tag : monad rv bool E :=
+ match (mem_bytes_of_bits v, Word.wordToNat (get_word addr)) with
+ | (Some v, addr) =>
+ Write_memt wk addr (Z.to_nat sz) v tag returnm
+ | _ => Fail "write_mem"
+ end.
Definition read_reg {s rv a e} (reg : register_ref s rv a) : monad rv a e :=
let k v :=
match reg.(of_regval) v with
| Some v => Done v
- | None => Error "read_reg: unrecognised value"
+ | None => Fail "read_reg: unrecognised value"
end
in
Read_reg reg.(name) k.
(* TODO
-val read_reg_range : forall s r rv a e. Bitvector a => register_ref s rv r -> integer -> integer -> monad rv a e
-Definition read_reg_range reg i j :=
- read_reg_aux of_bits (external_reg_slice reg (natFromInteger i,natFromInteger j))
+val read_reg_range : forall 's 'r 'rv 'a 'e. Bitvector 'a => register_ref 's 'rv 'r -> integer -> integer -> monad 'rv 'a 'e
+let read_reg_range reg i j =
+ read_reg_aux of_bits (external_reg_slice reg (nat_of_int i,nat_of_int j))
-Definition read_reg_bit reg i :=
- read_reg_aux (fun v -> v) (external_reg_slice reg (natFromInteger i,natFromInteger i)) >>= fun v ->
- returnm (extract_only_element v)
+let read_reg_bit reg i =
+ read_reg_aux (fun v -> v) (external_reg_slice reg (nat_of_int i,nat_of_int i)) >>= fun v ->
+ return (extract_only_element v)
-Definition read_reg_field reg regfield :=
+let read_reg_field reg regfield =
read_reg_aux (external_reg_field_whole reg regfield)
-Definition read_reg_bitfield reg regfield :=
+let read_reg_bitfield reg regfield =
read_reg_aux (external_reg_field_whole reg regfield) >>= fun v ->
- returnm (extract_only_element v)*)
+ return (extract_only_element v)*)
Definition reg_deref {s rv a e} := @read_reg s rv a e.
@@ -221,27 +265,101 @@ Definition write_reg {s rv a e} (reg : register_ref s rv a) (v : a) : monad rv u
Write_reg reg.(name) (reg.(regval_of) v) (Done tt).
(* TODO
-Definition write_reg reg v :=
+let write_reg reg v =
write_reg_aux (external_reg_whole reg) v
-Definition write_reg_range reg i j v :=
- write_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger j)) v
-Definition write_reg_pos reg i v :=
- let iN := natFromInteger i in
+let write_reg_range reg i j v =
+ write_reg_aux (external_reg_slice reg (nat_of_int i,nat_of_int j)) v
+let write_reg_pos reg i v =
+ let iN = nat_of_int i in
write_reg_aux (external_reg_slice reg (iN,iN)) [v]
-Definition write_reg_bit := write_reg_pos
-Definition write_reg_field reg regfield v :=
+let write_reg_bit = write_reg_pos
+let write_reg_field reg regfield v =
write_reg_aux (external_reg_field_whole reg regfield.field_name) v
-Definition write_reg_field_bit reg regfield bit :=
+let write_reg_field_bit reg regfield bit =
write_reg_aux (external_reg_field_whole reg regfield.field_name)
(Vector [bit] 0 (is_inc_of_reg reg))
-Definition write_reg_field_range reg regfield i j v :=
- write_reg_aux (external_reg_field_slice reg regfield.field_name (natFromInteger i,natFromInteger j)) v
-Definition write_reg_field_pos reg regfield i v :=
+let write_reg_field_range reg regfield i j v =
+ write_reg_aux (external_reg_field_slice reg regfield.field_name (nat_of_int i,nat_of_int j)) v
+let write_reg_field_pos reg regfield i v =
write_reg_field_range reg regfield i i [v]
-Definition write_reg_field_bit := write_reg_field_pos*)
+let write_reg_field_bit = write_reg_field_pos*)
(*val barrier : forall rv e. barrier_kind -> monad rv unit e*)
Definition barrier {rv e} bk : monad rv unit e := Barrier bk (Done tt).
(*val footprint : forall rv e. unit -> monad rv unit e*)
Definition footprint {rv e} (_ : unit) : monad rv unit e := Footprint (Done tt).
+
+(* Event traces *)
+
+Local Open Scope bool_scope.
+
+(*val emitEvent : forall 'regval 'a 'e. Eq 'regval => monad 'regval 'a 'e -> event 'regval -> maybe (monad 'regval 'a 'e)*)
+Definition emitEvent {Regval A E} `{forall (x y : Regval), Decidable (x = y)} (m : monad Regval A E) (e : event Regval) : option (monad Regval A E) :=
+ match (e, m) with
+ | (E_read_mem rk a sz v, Read_mem rk' a' sz' k) =>
+ if read_kind_beq rk' rk && Nat.eqb a' a && Nat.eqb sz' sz then Some (k v) else None
+ | (E_read_memt rk a sz vt, Read_memt rk' a' sz' k) =>
+ if read_kind_beq rk' rk && Nat.eqb a' a && Nat.eqb sz' sz then Some (k vt) else None
+ | (E_write_mem wk a sz v r, Write_mem wk' a' sz' v' k) =>
+ if write_kind_beq wk' wk && Nat.eqb a' a && Nat.eqb sz' sz && generic_eq v' v then Some (k r) else None
+ | (E_write_memt wk a sz v tag r, Write_memt wk' a' sz' v' tag' k) =>
+ if write_kind_beq wk' wk && Nat.eqb a' a && Nat.eqb sz' sz && generic_eq v' v && generic_eq tag' tag then Some (k r) else None
+ | (E_read_reg r v, Read_reg r' k) =>
+ if generic_eq r' r then Some (k v) else None
+ | (E_write_reg r v, Write_reg r' v' k) =>
+ if generic_eq r' r && generic_eq v' v then Some k else None
+ | (E_write_ea wk a sz, Write_ea wk' a' sz' k) =>
+ if write_kind_beq wk' wk && Nat.eqb a' a && Nat.eqb sz' sz then Some k else None
+ | (E_barrier bk, Barrier bk' k) =>
+ if barrier_kind_beq bk' bk then Some k else None
+ | (E_print m, Print m' k) =>
+ if generic_eq m' m then Some k else None
+ | (E_excl_res v, Excl_res k) => Some (k v)
+ | (E_choose descr v, Choose descr' k) => if generic_eq descr' descr then Some (k v) else None
+ | (E_footprint, Footprint k) => Some k
+ | _ => None
+end.
+
+Definition option_bind {A B : Type} (a : option A) (f : A -> option B) : option B :=
+match a with
+| Some x => f x
+| None => None
+end.
+
+(*val runTrace : forall 'regval 'a 'e. Eq 'regval => trace 'regval -> monad 'regval 'a 'e -> maybe (monad 'regval 'a 'e)*)
+Fixpoint runTrace {Regval A E} `{forall (x y : Regval), Decidable (x = y)} (t : trace Regval) (m : monad Regval A E) : option (monad Regval A E) :=
+match t with
+ | [] => Some m
+ | e :: t' => option_bind (emitEvent m e) (runTrace t')
+end.
+
+(*val final : forall 'regval 'a 'e. monad 'regval 'a 'e -> bool*)
+Definition final {Regval A E} (m : monad Regval A E) : bool :=
+match m with
+ | Done _ => true
+ | Fail _ => true
+ | Exception _ => true
+ | _ => false
+end.
+
+(*val hasTrace : forall 'regval 'a 'e. Eq 'regval => trace 'regval -> monad 'regval 'a 'e -> bool*)
+Definition hasTrace {Regval A E} `{forall (x y : Regval), Decidable (x = y)} (t : trace Regval) (m : monad Regval A E) : bool :=
+match runTrace t m with
+ | Some m => final m
+ | None => false
+end.
+
+(*val hasException : forall 'regval 'a 'e. Eq 'regval => trace 'regval -> monad 'regval 'a 'e -> bool*)
+Definition hasException {Regval A E} `{forall (x y : Regval), Decidable (x = y)} (t : trace Regval) (m : monad Regval A E) :=
+match runTrace t m with
+ | Some (Exception _) => true
+ | _ => false
+end.
+
+(*val hasFailure : forall 'regval 'a 'e. Eq 'regval => trace 'regval -> monad 'regval 'a 'e -> bool*)
+Definition hasFailure {Regval A E} `{forall (x y : Regval), Decidable (x = y)} (t : trace Regval) (m : monad Regval A E) :=
+match runTrace t m with
+ | Some (Fail _) => true
+ | _ => false
+end.
diff --git a/lib/coq/Sail2_state.v b/lib/coq/Sail2_state.v
index 1d5cb342..b73d5013 100644
--- a/lib/coq/Sail2_state.v
+++ b/lib/coq/Sail2_state.v
@@ -3,53 +3,82 @@ Require Import Sail2_values.
Require Import Sail2_prompt_monad.
Require Import Sail2_prompt.
Require Import Sail2_state_monad.
-(*
-(* State monad wrapper around prompt monad *)
-
-val liftState : forall 'regval 'regs 'a 'e. register_accessors 'regs 'regval -> monad 'regval 'a 'e -> monadS 'regs 'a 'e
-let rec liftState ra s = match s with
- | (Done a) -> returnS a
- | (Read_mem rk a sz k) -> bindS (read_mem_bytesS rk a sz) (fun v -> liftState ra (k v))
- | (Read_tag t k) -> bindS (read_tagS t) (fun v -> liftState ra (k v))
- | (Write_memv a k) -> bindS (write_mem_bytesS a) (fun v -> liftState ra (k v))
- | (Write_tagv t k) -> bindS (write_tagS t) (fun v -> liftState ra (k v))
- | (Read_reg r k) -> bindS (read_regvalS ra r) (fun v -> liftState ra (k v))
- | (Excl_res k) -> bindS (excl_resultS ()) (fun v -> liftState ra (k v))
- | (Undefined k) -> bindS (undefined_boolS ()) (fun v -> liftState ra (k v))
- | (Write_ea wk a sz k) -> seqS (write_mem_eaS wk a sz) (liftState ra k)
- | (Write_reg r v k) -> seqS (write_regvalS ra r v) (liftState ra k)
- | (Footprint k) -> liftState ra k
- | (Barrier _ k) -> liftState ra k
- | (Fail descr) -> failS descr
- | (Error descr) -> failS descr
- | (Exception e) -> throwS e
-end
-
-
-val iterS_aux : forall 'rv 'a 'e. integer -> (integer -> 'a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e
-let rec iterS_aux i f xs = match xs with
- | x :: xs -> f i x >>$ iterS_aux (i + 1) f xs
- | [] -> returnS ()
- end
+Import ListNotations.
-declare {isabelle} termination_argument iterS_aux = automatic
+(*val iterS_aux : forall 'rv 'a 'e. integer -> (integer -> 'a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e*)
+Fixpoint iterS_aux {RV A E} i (f : Z -> A -> monadS RV unit E) (xs : list A) :=
+ match xs with
+ | x :: xs => f i x >>$ iterS_aux (i + 1) f xs
+ | [] => returnS tt
+ end.
-val iteriS : forall 'rv 'a 'e. (integer -> 'a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e
-let iteriS f xs = iterS_aux 0 f xs
+(*val iteriS : forall 'rv 'a 'e. (integer -> 'a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e*)
+Definition iteriS {RV A E} (f : Z -> A -> monadS RV unit E) (xs : list A) : monadS RV unit E :=
+ iterS_aux 0 f xs.
-val iterS : forall 'rv 'a 'e. ('a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e
-let iterS f xs = iteriS (fun _ x -> f x) xs
+(*val iterS : forall 'rv 'a 'e. ('a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e*)
+Definition iterS {RV A E} (f : A -> monadS RV unit E) (xs : list A) : monadS RV unit E :=
+ iteriS (fun _ x => f x) xs.
-val foreachS : forall 'a 'rv 'vars 'e.
- list 'a -> 'vars -> ('a -> 'vars -> monadS 'rv 'vars 'e) -> monadS 'rv 'vars 'e
-let rec foreachS xs vars body = match xs with
- | [] -> returnS vars
- | x :: xs ->
- body x vars >>$= fun vars ->
+(*val foreachS : forall 'a 'rv 'vars 'e.
+ list 'a -> 'vars -> ('a -> 'vars -> monadS 'rv 'vars 'e) -> monadS 'rv 'vars 'e*)
+Fixpoint foreachS {A RV Vars E} (xs : list A) (vars : Vars) (body : A -> Vars -> monadS RV Vars E) : monadS RV Vars E :=
+ match xs with
+ | [] => returnS vars
+ | x :: xs =>
+ body x vars >>$= fun vars =>
foreachS xs vars body
-end
+end.
+
+(*val genlistS : forall 'a 'rv 'e. (nat -> monadS 'rv 'a 'e) -> nat -> monadS 'rv (list 'a) 'e*)
+Definition genlistS {A RV E} (f : nat -> monadS RV A E) n : monadS RV (list A) E :=
+ let indices := genlist (fun n => n) n in
+ foreachS indices [] (fun n xs => (f n >>$= (fun x => returnS (xs ++ [x])))).
+
+(*val and_boolS : forall 'rv 'e. monadS 'rv bool 'e -> monadS 'rv bool 'e -> monadS 'rv bool 'e*)
+Definition and_boolS {RV E} (l r : monadS RV bool E) : monadS RV bool E :=
+ l >>$= (fun l => if l then r else returnS false).
+
+(*val or_boolS : forall 'rv 'e. monadS 'rv bool 'e -> monadS 'rv bool 'e -> monadS 'rv bool 'e*)
+Definition or_boolS {RV E} (l r : monadS RV bool E) : monadS RV bool E :=
+ l >>$= (fun l => if l then returnS true else r).
+
+(*val bool_of_bitU_fail : forall 'rv 'e. bitU -> monadS 'rv bool 'e*)
+Definition bool_of_bitU_fail {RV E} (b : bitU) : monadS RV bool E :=
+match b with
+ | B0 => returnS false
+ | B1 => returnS true
+ | BU => failS "bool_of_bitU"
+end.
+
+(*val bool_of_bitU_nondetS : forall 'rv 'e. bitU -> monadS 'rv bool 'e*)
+Definition bool_of_bitU_nondetS {RV E} (b : bitU) : monadS RV bool E :=
+match b with
+ | B0 => returnS false
+ | B1 => returnS true
+ | BU => undefined_boolS tt
+end.
+
+(*val bools_of_bits_nondetS : forall 'rv 'e. list bitU -> monadS 'rv (list bool) 'e*)
+Definition bools_of_bits_nondetS {RV E} bits : monadS RV (list bool) E :=
+ foreachS bits []
+ (fun b bools =>
+ bool_of_bitU_nondetS b >>$= (fun b =>
+ returnS (bools ++ [b]))).
-declare {isabelle} termination_argument foreachS = automatic
+(*val of_bits_nondetS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e*)
+Definition of_bits_nondetS {RV A E} bits `{ArithFact (A >= 0)} : monadS RV (mword A) E :=
+ bools_of_bits_nondetS bits >>$= (fun bs =>
+ returnS (of_bools bs)).
+
+(*val of_bits_failS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e*)
+Definition of_bits_failS {RV A E} bits `{ArithFact (A >= 0)} : monadS RV (mword A) E :=
+ maybe_failS "of_bits" (of_bits bits).
+
+(*val mword_nondetS : forall 'rv 'a 'e. Size 'a => unit -> monadS 'rv (mword 'a) 'e
+let mword_nondetS () =
+ bools_of_bits_nondetS (repeat [BU] (integerFromNat size)) >>$= (fun bs ->
+ returnS (wordFromBitlist bs))
val whileS : forall 'rv 'vars 'e. 'vars -> ('vars -> monadS 'rv bool 'e) ->
@@ -67,3 +96,21 @@ let rec untilS vars cond body s =
(cond vars >>$= (fun cond_val s'' ->
if cond_val then returnS vars s'' else untilS vars cond body s'')) s')) s
*)
+(*val choose_boolsS : forall 'rv 'e. nat -> monadS 'rv (list bool) 'e*)
+Definition choose_boolsS {RV E} n : monadS RV (list bool) E :=
+ genlistS (fun _ => choose_boolS tt) n.
+
+(* TODO: Replace by chooseS and prove equivalence to prompt monad version *)
+(*val internal_pickS : forall 'rv 'a 'e. list 'a -> monadS 'rv 'a 'e
+let internal_pickS xs =
+ (* Use sufficiently many nondeterministically chosen bits and convert into an
+ index into the list *)
+ choose_boolsS (List.length xs) >>$= fun bs ->
+ let idx = (natFromNatural (nat_of_bools bs)) mod List.length xs in
+ match index xs idx with
+ | Just x -> returnS x
+ | Nothing -> failS "choose internal_pick"
+ end
+
+
+*)
diff --git a/lib/coq/Sail2_state_lifting.v b/lib/coq/Sail2_state_lifting.v
new file mode 100644
index 00000000..633c0ef7
--- /dev/null
+++ b/lib/coq/Sail2_state_lifting.v
@@ -0,0 +1,61 @@
+Require Import Sail2_values.
+Require Import Sail2_prompt_monad.
+Require Import Sail2_prompt.
+Require Import Sail2_state_monad.
+Import ListNotations.
+
+(* Lifting from prompt monad to state monad *)
+(*val liftState : forall 'regval 'regs 'a 'e. register_accessors 'regs 'regval -> monad 'regval 'a 'e -> monadS 'regs 'a 'e*)
+Fixpoint liftState {Regval Regs A E} (ra : register_accessors Regs Regval) (m : monad Regval A E) : monadS Regs A E :=
+ match m with
+ | (Done a) => returnS a
+ | (Read_mem rk a sz k) => bindS (read_mem_bytesS rk a sz) (fun v => liftState ra (k v))
+ | (Read_memt rk a sz k) => bindS (read_memt_bytesS rk a sz) (fun v => liftState ra (k v))
+ | (Write_mem wk a sz v k) => bindS (write_mem_bytesS wk a sz v) (fun v => liftState ra (k v))
+ | (Write_memt wk a sz v t k) => bindS (write_memt_bytesS wk a sz v t) (fun v => liftState ra (k v))
+ | (Read_reg r k) => bindS (read_regvalS ra r) (fun v => liftState ra (k v))
+ | (Excl_res k) => bindS (excl_resultS tt) (fun v => liftState ra (k v))
+ | (Choose _ k) => bindS (choose_boolS tt) (fun v => liftState ra (k v))
+ | (Write_reg r v k) => seqS (write_regvalS ra r v) (liftState ra k)
+ | (Write_ea _ _ _ k) => liftState ra k
+ | (Footprint k) => liftState ra k
+ | (Barrier _ k) => liftState ra k
+ | (Print _ k) => liftState ra k (* TODO *)
+ | (Fail descr) => failS descr
+ | (Exception e) => throwS e
+end.
+
+Local Open Scope bool_scope.
+
+(*val emitEventS : forall 'regval 'regs 'a 'e. Eq 'regval => register_accessors 'regs 'regval -> event 'regval -> sequential_state 'regs -> maybe (sequential_state 'regs)*)
+Definition emitEventS {Regval Regs} `{forall (x y : Regval), Decidable (x = y)} (ra : register_accessors Regs Regval) (e : event Regval) (s : sequential_state Regs) : option (sequential_state Regs) :=
+match e with
+ | E_read_mem _ addr sz v =>
+ option_bind (get_mem_bytes addr sz s) (fun '(v', _) =>
+ if generic_eq v' v then Some s else None)
+ | E_read_memt _ addr sz (v, tag) =>
+ option_bind (get_mem_bytes addr sz s) (fun '(v', tag') =>
+ if generic_eq v' v && generic_eq tag' tag then Some s else None)
+ | E_write_mem _ addr sz v success =>
+ if success then Some (put_mem_bytes addr sz v B0 s) else None
+ | E_write_memt _ addr sz v tag success =>
+ if success then Some (put_mem_bytes addr sz v tag s) else None
+ | E_read_reg r v =>
+ let (read_reg, _) := ra in
+ option_bind (read_reg r s.(regstate)) (fun v' =>
+ if generic_eq v' v then Some s else None)
+ | E_write_reg r v =>
+ let (_, write_reg) := ra in
+ option_bind (write_reg r v s.(regstate)) (fun rs' =>
+ Some {| regstate := rs'; memstate := s.(memstate); tagstate := s.(tagstate) |})
+ | _ => Some s
+end.
+
+Local Close Scope bool_scope.
+
+(*val runTraceS : forall 'regval 'regs 'a 'e. Eq 'regval => register_accessors 'regs 'regval -> trace 'regval -> sequential_state 'regs -> maybe (sequential_state 'regs)*)
+Fixpoint runTraceS {Regval Regs} `{forall (x y : Regval), Decidable (x = y)} (ra : register_accessors Regs Regval) (t : trace Regval) (s : sequential_state Regs) : option (sequential_state Regs) :=
+match t with
+ | [] => Some s
+ | e :: t' => option_bind (emitEventS ra e s) (runTraceS ra t')
+end.
diff --git a/lib/coq/Sail2_state_monad.v b/lib/coq/Sail2_state_monad.v
index c48db31b..235e4b9e 100644
--- a/lib/coq/Sail2_state_monad.v
+++ b/lib/coq/Sail2_state_monad.v
@@ -1,184 +1,237 @@
Require Import Sail2_instr_kinds.
Require Import Sail2_values.
-(*
-(* 'a is result type *)
-
-type memstate = map integer memory_byte
-type tagstate = map integer bitU
+Require FMapList.
+Require Import OrderedType.
+Require OrderedTypeEx.
+Require Import List.
+Require bbv.Word.
+Import ListNotations.
+
+(* TODO: revisit choice of FMapList *)
+Module NatMap := FMapList.Make(OrderedTypeEx.Nat_as_OT).
+
+Definition Memstate : Type := NatMap.t memory_byte.
+Definition Tagstate : Type := NatMap.t bitU.
(* type regstate = map string (vector bitU) *)
-type sequential_state 'regs =
- <| regstate : 'regs;
- memstate : memstate;
- tagstate : tagstate;
- write_ea : maybe (write_kind * integer * integer);
- last_exclusive_operation_was_load : bool|>
-
-val init_state : forall 'regs. 'regs -> sequential_state 'regs
-let init_state regs =
- <| regstate = regs;
- memstate = Map.empty;
- tagstate = Map.empty;
- write_ea = Nothing;
- last_exclusive_operation_was_load = false |>
-
-type ex 'e =
- | Failure of string
- | Throw of 'e
-
-type result 'a 'e =
- | Value of 'a
- | Ex of (ex 'e)
+Record sequential_state {Regs} :=
+ { regstate : Regs;
+ memstate : Memstate;
+ tagstate : Tagstate }.
+Arguments sequential_state : clear implicits.
+
+(*val init_state : forall 'regs. 'regs -> sequential_state 'regs*)
+Definition init_state {Regs} regs : sequential_state Regs :=
+ {| regstate := regs;
+ memstate := NatMap.empty _;
+ tagstate := NatMap.empty _ |}.
+
+Inductive ex E :=
+ | Failure : string -> ex E
+ | Throw : E -> ex E.
+Arguments Failure {E} _.
+Arguments Throw {E} _.
+
+Inductive result A E :=
+ | Value : A -> result A E
+ | Ex : ex E -> result A E.
+Arguments Value {A} {E} _.
+Arguments Ex {A} {E} _.
(* State, nondeterminism and exception monad with result value type 'a
and exception type 'e. *)
-type monadS 'regs 'a 'e = sequential_state 'regs -> list (result 'a 'e * sequential_state 'regs)
-
-val returnS : forall 'regs 'a 'e. 'a -> monadS 'regs 'a 'e
-let returnS a s = [(Value a,s)]
-
-val bindS : forall 'regs 'a 'b 'e. monadS 'regs 'a 'e -> ('a -> monadS 'regs 'b 'e) -> monadS 'regs 'b 'e
-let bindS m f (s : sequential_state 'regs) =
- List.concatMap (function
- | (Value a, s') -> f a s'
- | (Ex e, s') -> [(Ex e, s')]
- end) (m s)
-
-val seqS: forall 'regs 'b 'e. monadS 'regs unit 'e -> monadS 'regs 'b 'e -> monadS 'regs 'b 'e
-let seqS m n = bindS m (fun (_ : unit) -> n)
-
+(* TODO: the list was originally a set, can we reasonably go back to a set? *)
+Definition monadS Regs a e : Type :=
+ sequential_state Regs -> list (result a e * sequential_state Regs).
+
+(*val returnS : forall 'regs 'a 'e. 'a -> monadS 'regs 'a 'e*)
+Definition returnS {Regs A E} (a:A) : monadS Regs A E := fun s => [(Value a,s)].
+
+(*val bindS : forall 'regs 'a 'b 'e. monadS 'regs 'a 'e -> ('a -> monadS 'regs 'b 'e) -> monadS 'regs 'b 'e*)
+Definition bindS {Regs A B E} (m : monadS Regs A E) (f : A -> monadS Regs B E) : monadS Regs B E :=
+ fun (s : sequential_state Regs) =>
+ List.concat (List.map (fun v => match v with
+ | (Value a, s') => f a s'
+ | (Ex e, s') => [(Ex e, s')]
+ end) (m s)).
+
+(*val seqS: forall 'regs 'b 'e. monadS 'regs unit 'e -> monadS 'regs 'b 'e -> monadS 'regs 'b 'e*)
+Definition seqS {Regs B E} (m : monadS Regs unit E) (n : monadS Regs B E) : monadS Regs B E :=
+ bindS m (fun (_ : unit) => n).
+(*
let inline (>>$=) = bindS
let inline (>>$) = seqS
-
-val chooseS : forall 'regs 'a 'e. list 'a -> monadS 'regs 'a 'e
-let chooseS xs s = List.map (fun x -> (Value x, s)) xs
-
-val readS : forall 'regs 'a 'e. (sequential_state 'regs -> 'a) -> monadS 'regs 'a 'e
-let readS f = (fun s -> returnS (f s) s)
-
-val updateS : forall 'regs 'e. (sequential_state 'regs -> sequential_state 'regs) -> monadS 'regs unit 'e
-let updateS f = (fun s -> returnS () (f s))
-
-val failS : forall 'regs 'a 'e. string -> monadS 'regs 'a 'e
-let failS msg s = [(Ex (Failure msg), s)]
-
-val exitS : forall 'regs 'e 'a. unit -> monadS 'regs 'a 'e
-let exitS () = failS "exit"
-
-val throwS : forall 'regs 'a 'e. 'e -> monadS 'regs 'a 'e
-let throwS e s = [(Ex (Throw e), s)]
-
-val try_catchS : forall 'regs 'a 'e1 'e2. monadS 'regs 'a 'e1 -> ('e1 -> monadS 'regs 'a 'e2) -> monadS 'regs 'a 'e2
-let try_catchS m h s =
- List.concatMap (function
- | (Value a, s') -> returnS a s'
- | (Ex (Throw e), s') -> h e s'
- | (Ex (Failure msg), s') -> [(Ex (Failure msg), s')]
- end) (m s)
-
-val assert_expS : forall 'regs 'e. bool -> string -> monadS 'regs unit 'e
-let assert_expS exp msg = if exp then returnS () else failS msg
+*)
+Notation "m >>$= f" := (bindS m f) (at level 50, left associativity).
+Notation "m >>$ n" := (seqS m n) (at level 50, left associativity).
+
+(*val chooseS : forall 'regs 'a 'e. SetType 'a => list 'a -> monadS 'regs 'a 'e*)
+Definition chooseS {Regs A E} (xs : list A) : monadS Regs A E :=
+ fun s => (List.map (fun x => (Value x, s)) xs).
+
+(*val readS : forall 'regs 'a 'e. (sequential_state 'regs -> 'a) -> monadS 'regs 'a 'e*)
+Definition readS {Regs A E} (f : sequential_state Regs -> A) : monadS Regs A E :=
+ (fun s => returnS (f s) s).
+
+(*val updateS : forall 'regs 'e. (sequential_state 'regs -> sequential_state 'regs) -> monadS 'regs unit 'e*)
+Definition updateS {Regs E} (f : sequential_state Regs -> sequential_state Regs) : monadS Regs unit E :=
+ (fun s => returnS tt (f s)).
+
+(*val failS : forall 'regs 'a 'e. string -> monadS 'regs 'a 'e*)
+Definition failS {Regs A E} msg : monadS Regs A E :=
+ fun s => [(Ex (Failure msg), s)].
+
+(*val choose_boolS : forall 'regval 'regs 'a 'e. unit -> monadS 'regs bool 'e*)
+Definition choose_boolS {Regs E} (_:unit) : monadS Regs bool E :=
+ chooseS [false; true].
+Definition undefined_boolS {Regs E} := @choose_boolS Regs E.
+
+(*val exitS : forall 'regs 'e 'a. unit -> monadS 'regs 'a 'e*)
+Definition exitS {Regs A E} (_:unit) : monadS Regs A E := failS "exit".
+
+(*val throwS : forall 'regs 'a 'e. 'e -> monadS 'regs 'a 'e*)
+Definition throwS {Regs A E} (e : E) :monadS Regs A E :=
+ fun s => [(Ex (Throw e), s)].
+
+(*val try_catchS : forall 'regs 'a 'e1 'e2. monadS 'regs 'a 'e1 -> ('e1 -> monadS 'regs 'a 'e2) -> monadS 'regs 'a 'e2*)
+Definition try_catchS {Regs A E1 E2} (m : monadS Regs A E1) (h : E1 -> monadS Regs A E2) : monadS Regs A E2 :=
+fun s =>
+ List.concat (List.map (fun v => match v with
+ | (Value a, s') => returnS a s'
+ | (Ex (Throw e), s') => h e s'
+ | (Ex (Failure msg), s') => [(Ex (Failure msg), s')]
+ end) (m s)).
+
+(*val assert_expS : forall 'regs 'e. bool -> string -> monadS 'regs unit 'e*)
+Definition assert_expS {Regs E} (exp : bool) (msg : string) : monadS Regs unit E :=
+ if exp then returnS tt else failS msg.
(* For early return, we abuse exceptions by throwing and catching
the return value. The exception type is "either 'r 'e", where "Right e"
represents a proper exception and "Left r" an early return of value "r". *)
-type monadSR 'regs 'a 'r 'e = monadS 'regs 'a (either 'r 'e)
+Definition monadRS Regs A R E := monadS Regs A (sum R E).
-val early_returnS : forall 'regs 'a 'r 'e. 'r -> monadSR 'regs 'a 'r 'e
-let early_returnS r = throwS (Left r)
+(*val early_returnS : forall 'regs 'a 'r 'e. 'r -> monadRS 'regs 'a 'r 'e*)
+Definition early_returnS {Regs A R E} (r : R) : monadRS Regs A R E := throwS (inl r).
-val catch_early_returnS : forall 'regs 'a 'e. monadSR 'regs 'a 'a 'e -> monadS 'regs 'a 'e
-let catch_early_returnS m =
+(*val catch_early_returnS : forall 'regs 'a 'e. monadRS 'regs 'a 'a 'e -> monadS 'regs 'a 'e*)
+Definition catch_early_returnS {Regs A E} (m : monadRS Regs A A E) : monadS Regs A E :=
try_catchS m
- (function
- | Left a -> returnS a
- | Right e -> throwS e
- end)
+ (fun v => match v with
+ | inl a => returnS a
+ | inr e => throwS e
+ end).
(* Lift to monad with early return by wrapping exceptions *)
-val liftSR : forall 'a 'r 'regs 'e. monadS 'regs 'a 'e -> monadSR 'regs 'a 'r 'e
-let liftSR m = try_catchS m (fun e -> throwS (Right e))
+(*val liftRS : forall 'a 'r 'regs 'e. monadS 'regs 'a 'e -> monadRS 'regs 'a 'r 'e*)
+Definition liftRS {A R Regs E} (m : monadS Regs A E) : monadRS Regs A R E :=
+ try_catchS m (fun e => throwS (inr e)).
(* Catch exceptions in the presence of early returns *)
-val try_catchSR : forall 'regs 'a 'r 'e1 'e2. monadSR 'regs 'a 'r 'e1 -> ('e1 -> monadSR 'regs 'a 'r 'e2) -> monadSR 'regs 'a 'r 'e2
-let try_catchSR m h =
+(*val try_catchRS : forall 'regs 'a 'r 'e1 'e2. monadRS 'regs 'a 'r 'e1 -> ('e1 -> monadRS 'regs 'a 'r 'e2) -> monadRS 'regs 'a 'r 'e2*)
+Definition try_catchRS {Regs A R E1 E2} (m : monadRS Regs A R E1) (h : E1 -> monadRS Regs A R E2) : monadRS Regs A R E2 :=
try_catchS m
- (function
- | Left r -> throwS (Left r)
- | Right e -> h e
- end)
+ (fun v => match v with
+ | inl r => throwS (inl r)
+ | inr e => h e
+ end).
+
+(*val maybe_failS : forall 'regs 'a 'e. string -> maybe 'a -> monadS 'regs 'a 'e*)
+Definition maybe_failS {Regs A E} msg (v : option A) : monadS Regs A E :=
+match v with
+ | Some a => returnS a
+ | None => failS msg
+end.
+
+(*val read_tagS : forall 'regs 'a 'e. Bitvector 'a => 'a -> monadS 'regs bitU 'e*)
+Definition read_tagS {Regs A E} (addr : mword A) : monadS Regs bitU E :=
+ let addr := Word.wordToNat (get_word addr) in
+ readS (fun s => opt_def B0 (NatMap.find addr s.(tagstate))).
+
+Fixpoint genlist_acc {A:Type} (f : nat -> A) n acc : list A :=
+ match n with
+ | O => acc
+ | S n' => genlist_acc f n' (f n' :: acc)
+ end.
+Definition genlist {A} f n := @genlist_acc A f n [].
-val read_tagS : forall 'regs 'a 'e. Bitvector 'a => 'a -> monadS 'regs bitU 'e
-let read_tagS addr =
- readS (fun s -> fromMaybe B0 (Map.lookup (unsigned addr) s.tagstate))
(* Read bytes from memory and return in little endian order *)
-val read_mem_bytesS : forall 'regs 'e 'a. Bitvector 'a => read_kind -> 'a -> nat -> monadS 'regs (list memory_byte) 'e
-let read_mem_bytesS read_kind addr sz =
- let addr = unsigned addr in
- let sz = integerFromNat sz in
- let addrs = index_list addr (addr+sz-1) 1 in
- let read_byte s addr = Map.lookup addr s.memstate in
- readS (fun s -> just_list (List.map (read_byte s) addrs)) >>$= (function
- | Just mem_val ->
- updateS (fun s ->
- if read_is_exclusive read_kind
- then <| s with last_exclusive_operation_was_load = true |>
- else s) >>$
- returnS mem_val
- | Nothing -> failS "read_memS"
- end)
-
-val read_memS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monadS 'regs 'b 'e
-let read_memS rk a sz =
- read_mem_bytesS rk a (natFromInteger sz) >>$= (fun bytes ->
- returnS (bits_of_mem_bytes bytes))
-
-val excl_resultS : forall 'regs 'e. unit -> monadS 'regs bool 'e
-let excl_resultS () =
- readS (fun s -> s.last_exclusive_operation_was_load) >>$= (fun excl_load ->
- updateS (fun s -> <| s with last_exclusive_operation_was_load = false |>) >>$
- chooseS (if excl_load then [false; true] else [false]))
-
-val write_mem_eaS : forall 'regs 'e 'a. Bitvector 'a => write_kind -> 'a -> nat -> monadS 'regs unit 'e
-let write_mem_eaS write_kind addr sz =
- let addr = unsigned addr in
- let sz = integerFromNat sz in
- updateS (fun s -> <| s with write_ea = Just (write_kind, addr, sz) |>)
-
-(* Write little-endian list of bytes to previously announced address *)
-val write_mem_bytesS : forall 'regs 'e. list memory_byte -> monadS 'regs bool 'e
-let write_mem_bytesS v =
- readS (fun s -> s.write_ea) >>$= (function
- | Nothing -> failS "write ea has not been announced yet"
- | Just (_, addr, sz) ->
- let addrs = index_list addr (addr+sz-1) 1 in
- (*let v = external_mem_value (bits_of v) in*)
- let a_v = List.zip addrs v in
- let write_byte mem (addr, v) = Map.insert addr v mem in
- updateS (fun s ->
- <| s with memstate = List.foldl write_byte s.memstate a_v |>) >>$
- returnS true
- end)
-
-val write_mem_valS : forall 'regs 'e 'a. Bitvector 'a => 'a -> monadS 'regs bool 'e
-let write_mem_valS v = match mem_bytes_of_bits v with
- | Just v -> write_mem_bytesS v
- | Nothing -> failS "write_mem_val"
-end
-
-val write_tagS : forall 'regs 'e. bitU -> monadS 'regs bool 'e
-let write_tagS t =
- readS (fun s -> s.write_ea) >>$= (function
- | Nothing -> failS "write ea has not been announced yet"
- | Just (_, addr, _) ->
- (*let taddr = addr / cap_alignment in*)
- updateS (fun s -> <| s with tagstate = Map.insert addr t s.tagstate |>) >>$
- returnS true
- end)
-
-val read_regS : forall 'regs 'rv 'a 'e. register_ref 'regs 'rv 'a -> monadS 'regs 'a 'e
-let read_regS reg = readS (fun s -> reg.read_from s.regstate)
+(*val get_mem_bytes : forall 'regs. nat -> nat -> sequential_state 'regs -> maybe (list memory_byte * bitU)*)
+Definition get_mem_bytes {Regs} addr sz (s : sequential_state Regs) : option (list memory_byte * bitU) :=
+ let addrs := genlist (fun n => addr + n)%nat sz in
+ let read_byte s addr := NatMap.find addr s.(memstate) in
+ let read_tag s addr := opt_def B0 (NatMap.find addr s.(tagstate)) in
+ option_map
+ (fun mem_val => (mem_val, List.fold_left and_bit (List.map (read_tag s) addrs) B1))
+ (just_list (List.map (read_byte s) addrs)).
+
+(*val read_memt_bytesS : forall 'regs 'e. read_kind -> nat -> nat -> monadS 'regs (list memory_byte * bitU) 'e*)
+Definition read_memt_bytesS {Regs E} (_ : read_kind) addr sz : monadS Regs (list memory_byte * bitU) E :=
+ readS (get_mem_bytes addr sz) >>$=
+ maybe_failS "read_memS".
+
+(*val read_mem_bytesS : forall 'regs 'e. read_kind -> nat -> nat -> monadS 'regs (list memory_byte) 'e*)
+Definition read_mem_bytesS {Regs E} (rk : read_kind) addr sz : monadS Regs (list memory_byte) E :=
+ read_memt_bytesS rk addr sz >>$= (fun '(bytes, _) =>
+ returnS bytes).
+
+(*val read_memtS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monadS 'regs ('b * bitU) 'e*)
+Definition read_memtS {Regs E A B} (rk : read_kind) (a : mword A) sz `{ArithFact (B >= 0)} : monadS Regs (mword B * bitU) E :=
+ let a := Word.wordToNat (get_word a) in
+ read_memt_bytesS rk a (Z.to_nat sz) >>$= (fun '(bytes, tag) =>
+ maybe_failS "bits_of_mem_bytes" (of_bits (bits_of_mem_bytes bytes)) >>$= (fun mem_val =>
+ returnS (mem_val, tag))).
+
+(*val read_memS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monadS 'regs 'b 'e*)
+Definition read_memS {Regs E A B} rk (a : mword A) sz `{ArithFact (B >= 0)} : monadS Regs (mword B) E :=
+ read_memtS rk a sz >>$= (fun '(bytes, _) =>
+ returnS bytes).
+
+(*val excl_resultS : forall 'regs 'e. unit -> monadS 'regs bool 'e*)
+Definition excl_resultS {Regs E} : unit -> monadS Regs bool E :=
+ (* TODO: This used to be more deterministic, checking a flag in the state
+ whether an exclusive load has occurred before. However, this does not
+ seem very precise; it might be safer to overapproximate the possible
+ behaviours by always making a nondeterministic choice. *)
+ @undefined_boolS Regs E.
+
+(* Write little-endian list of bytes to given address *)
+(*val put_mem_bytes : forall 'regs. nat -> nat -> list memory_byte -> bitU -> sequential_state 'regs -> sequential_state 'regs*)
+Definition put_mem_bytes {Regs} addr sz (v : list memory_byte) (tag : bitU) (s : sequential_state Regs) : sequential_state Regs :=
+ let addrs := genlist (fun n => addr + n)%nat sz in
+ let a_v := List.combine addrs v in
+ let write_byte mem '(addr, v) := NatMap.add addr v mem in
+ let write_tag mem addr := NatMap.add addr tag mem in
+ {| regstate := s.(regstate);
+ memstate := List.fold_left write_byte a_v s.(memstate);
+ tagstate := List.fold_left write_tag addrs s.(tagstate) |}.
+
+(*val write_memt_bytesS : forall 'regs 'e. write_kind -> nat -> nat -> list memory_byte -> bitU -> monadS 'regs bool 'e*)
+Definition write_memt_bytesS {Regs E} (_ : write_kind) addr sz (v : list memory_byte) (t : bitU) : monadS Regs bool E :=
+ updateS (put_mem_bytes addr sz v t) >>$
+ returnS true.
+
+(*val write_mem_bytesS : forall 'regs 'e. write_kind -> nat -> nat -> list memory_byte -> monadS 'regs bool 'e*)
+Definition write_mem_bytesS {Regs E} wk addr sz (v : list memory_byte) : monadS Regs bool E :=
+ write_memt_bytesS wk addr sz v B0.
+
+(*val write_memtS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b =>
+ write_kind -> 'a -> integer -> 'b -> bitU -> monadS 'regs bool 'e*)
+Definition write_memtS {Regs E A B} wk (addr : mword A) sz (v : mword B) (t : bitU) : monadS Regs bool E :=
+ match (Word.wordToNat (get_word addr), mem_bytes_of_bits v) with
+ | (addr, Some v) => write_memt_bytesS wk addr (Z.to_nat sz) v t
+ | _ => failS "write_mem"
+ end.
+
+(*val write_memS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b =>
+ write_kind -> 'a -> integer -> 'b -> monadS 'regs bool 'e*)
+Definition write_memS {Regs E A B} wk (addr : mword A) sz (v : mword B) : monadS Regs bool E :=
+ write_memtS wk addr sz v B0.
+
+(*val read_regS : forall 'regs 'rv 'a 'e. register_ref 'regs 'rv 'a -> monadS 'regs 'a 'e*)
+Definition read_regS {Regs RV A E} (reg : register_ref Regs RV A) : monadS Regs A E :=
+ readS (fun s => reg.(read_from) s.(regstate)).
(* TODO
let read_reg_range reg i j state =
@@ -194,25 +247,27 @@ let read_reg_bitfield reg regfield =
let (i,_) = register_field_indices reg regfield in
read_reg_bit reg i *)
-val read_regvalS : forall 'regs 'rv 'e.
- register_accessors 'regs 'rv -> string -> monadS 'regs 'rv 'e
-let read_regvalS (read, _) reg =
- readS (fun s -> read reg s.regstate) >>$= (function
- | Just v -> returnS v
- | Nothing -> failS ("read_regvalS " ^ reg)
- end)
-
-val write_regvalS : forall 'regs 'rv 'e.
- register_accessors 'regs 'rv -> string -> 'rv -> monadS 'regs unit 'e
-let write_regvalS (_, write) reg v =
- readS (fun s -> write reg v s.regstate) >>$= (function
- | Just rs' -> updateS (fun s -> <| s with regstate = rs' |>)
- | Nothing -> failS ("write_regvalS " ^ reg)
- end)
-
-val write_regS : forall 'regs 'rv 'a 'e. register_ref 'regs 'rv 'a -> 'a -> monadS 'regs unit 'e
-let write_regS reg v =
- updateS (fun s -> <| s with regstate = reg.write_to v s.regstate |>)
+(*val read_regvalS : forall 'regs 'rv 'e.
+ register_accessors 'regs 'rv -> string -> monadS 'regs 'rv 'e*)
+Definition read_regvalS {Regs RV E} (acc : register_accessors Regs RV) reg : monadS Regs RV E :=
+ let '(read, _) := acc in
+ readS (fun s => read reg s.(regstate)) >>$= (fun v => match v with
+ | Some v => returnS v
+ | None => failS ("read_regvalS " ++ reg)
+ end).
+
+(*val write_regvalS : forall 'regs 'rv 'e.
+ register_accessors 'regs 'rv -> string -> 'rv -> monadS 'regs unit 'e*)
+Definition write_regvalS {Regs RV E} (acc : register_accessors Regs RV) reg (v : RV) : monadS Regs unit E :=
+ let '(_, write) := acc in
+ readS (fun s => write reg v s.(regstate)) >>$= (fun x => match x with
+ | Some rs' => updateS (fun s => {| regstate := rs'; memstate := s.(memstate); tagstate := s.(tagstate) |})
+ | None => failS ("write_regvalS " ++ reg)
+ end).
+
+(*val write_regS : forall 'regs 'rv 'a 'e. register_ref 'regs 'rv 'a -> 'a -> monadS 'regs unit 'e*)
+Definition write_regS {Regs RV A E} (reg : register_ref Regs RV A) (v:A) : monadS Regs unit E :=
+ updateS (fun s => {| regstate := reg.(write_to) v s.(regstate); memstate := s.(memstate); tagstate := s.(tagstate) |}).
(* TODO
val update_reg : forall 'regs 'rv 'a 'b 'e. register_ref 'regs 'rv 'a -> ('a -> 'b -> 'a) -> 'b -> monadS 'regs unit 'e
@@ -250,4 +305,17 @@ let update_reg_field_bit regfield i reg_val bit =
let new_field_value = set_bit (regfield.field_is_inc) current_field_value i (to_bitU bit) in
regfield.set_field reg_val new_field_value
let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i)*)
-*)
+
+(* TODO Add Show typeclass for value and exception type *)
+(*val show_result : forall 'a 'e. result 'a 'e -> string*)
+Definition show_result {A E} (x : result A E) : string := match x with
+ | Value _ => "Value ()"
+ | Ex (Failure msg) => "Failure " ++ msg
+ | Ex (Throw _) => "Throw"
+end.
+
+(*val prerr_results : forall 'a 'e 's. SetType 's => set (result 'a 'e * 's) -> unit*)
+Definition prerr_results {A E S} (rs : list (result A E * S)) : unit := tt.
+(* let _ = Set.map (fun (r, _) -> let _ = prerr_endline (show_result r) in ()) rs in
+ ()*)
+
diff --git a/lib/coq/Sail2_values.v b/lib/coq/Sail2_values.v
index f11e057a..d1f1a768 100644
--- a/lib/coq/Sail2_values.v
+++ b/lib/coq/Sail2_values.v
@@ -110,6 +110,9 @@ refine ((if Decidable_witness as b return (b = true <-> x = y -> _) then fun H'
* right. intuition.
Defined.
+Instance Decidable_eq_list {A : Type} `(D : forall x y : A, Decidable (x = y)) : forall (x y : list A), Decidable (x = y) :=
+ Decidable_eq_from_dec (list_eq_dec (fun x y => generic_dec x y)).
+
(* Used by generated code that builds Decidable equality instances for records. *)
Ltac cmp_record_field x y :=
let H := fresh "H" in
@@ -144,6 +147,47 @@ unfold pow.
auto using Z.le_refl.
Qed.
+Lemma ZEuclid_div_pos : forall x y, y > 0 -> x >= 0 -> ZEuclid.div x y >= 0.
+intros.
+unfold ZEuclid.div.
+change 0 with (0 * 0).
+apply Zmult_ge_compat; auto with zarith.
+* apply Z.le_ge. apply Z.sgn_nonneg. apply Z.ge_le. auto with zarith.
+* apply Z_div_ge0; auto. apply Z.lt_gt. apply Z.abs_pos. auto with zarith.
+Qed.
+
+Lemma ZEuclid_div_ge : forall x y, y > 0 -> x >= 0 -> x - ZEuclid.div x y >= 0.
+intros.
+unfold ZEuclid.div.
+rewrite Z.sgn_pos; auto with zarith.
+rewrite Z.mul_1_l.
+apply Z.le_ge.
+apply Zle_minus_le_0.
+apply Z.div_le_upper_bound.
+* apply Z.abs_pos. auto with zarith.
+* rewrite Z.mul_comm.
+ assert (0 < Z.abs y). {
+ apply Z.abs_pos.
+ omega.
+ }
+ revert H1.
+ generalize (Z.abs y). intros. nia.
+Qed.
+
+Lemma ZEuclid_div_mod0 : forall x y, y <> 0 ->
+ ZEuclid.modulo x y = 0 ->
+ y * ZEuclid.div x y = x.
+intros x y H1 H2.
+rewrite Zplus_0_r_reverse at 1.
+rewrite <- H2.
+symmetry.
+apply ZEuclid.div_mod.
+assumption.
+Qed.
+
+Hint Resolve ZEuclid_div_pos ZEuclid_div_ge ZEuclid_div_mod0 : sail.
+
+
(*
Definition inline lt := (<)
Definition inline gt := (>)
@@ -416,19 +460,23 @@ Definition binop_bit op x y :=
match (x, y) with
| (BU,_) => BU (*Do we want to do this or to respect | of I and & of B0 rules?*)
| (_,BU) => BU (*Do we want to do this or to respect | of I and & of B0 rules?*)
- | (x,y) => bitU_of_bool (op (bool_of_bitU x) (bool_of_bitU y))
+(* | (x,y) => bitU_of_bool (op (bool_of_bitU x) (bool_of_bitU y))*)
+ | (B0,B0) => bitU_of_bool (op false false)
+ | (B0,B1) => bitU_of_bool (op false true)
+ | (B1,B0) => bitU_of_bool (op true false)
+ | (B1,B1) => bitU_of_bool (op true true)
end.
-(*val and_bit : bitU -> bitU -> bitU
-Definition and_bit := binop_bit (&&)
+(*val and_bit : bitU -> bitU -> bitU*)
+Definition and_bit := binop_bit andb.
-val or_bit : bitU -> bitU -> bitU
-Definition or_bit := binop_bit (||)
+(*val or_bit : bitU -> bitU -> bitU*)
+Definition or_bit := binop_bit orb.
-val xor_bit : bitU -> bitU -> bitU
-Definition xor_bit := binop_bit xor
+(*val xor_bit : bitU -> bitU -> bitU*)
+Definition xor_bit := binop_bit xorb.
-val (&.) : bitU -> bitU -> bitU
+(*val (&.) : bitU -> bitU -> bitU
Definition inline (&.) x y := and_bit x y
val (|.) : bitU -> bitU -> bitU
@@ -1061,8 +1109,8 @@ Ltac unbool_comparisons_goal :=
| |- context [generic_eq _ _ = false] => apply generic_eq_false
| |- context [generic_neq _ _ = true] => apply generic_neq_true
| |- context [generic_neq _ _ = false] => apply generic_neq_false
- | |- context [_ <> true] => rewrite Bool.not_true_iff_false
- | |- context [_ <> false] => rewrite Bool.not_false_iff_true
+ | |- context [_ <> true] => setoid_rewrite Bool.not_true_iff_false
+ | |- context [_ <> false] => setoid_rewrite Bool.not_false_iff_true
end.
(* Split up dependent pairs to get at proofs of properties *)
@@ -1135,7 +1183,7 @@ Qed.
the variable is unused. This is used so that we can use eauto with a low
search bound that doesn't include the exists. (Not terribly happy with
how this works...) *)
-Ltac drop_exists :=
+Ltac drop_Z_exists :=
repeat
match goal with |- @ex Z ?p =>
let a := eval hnf in (p 0) in
@@ -1152,10 +1200,14 @@ repeat
clear xx
end.
*)
+(* For boolean solving we just use plain metavariables *)
+Ltac drop_bool_exists :=
+repeat match goal with |- @ex bool _ => eexists end.
(* The linear solver doesn't like existentials. *)
Ltac destruct_exists :=
- repeat match goal with H:@ex Z _ |- _ => destruct H end.
+ repeat match goal with H:@ex Z _ |- _ => destruct H end;
+ repeat match goal with H:@ex bool _ |- _ => destruct H end.
Ltac prepare_for_solver :=
(*dump_context;*)
@@ -1169,6 +1221,7 @@ Ltac prepare_for_solver :=
destruct_exists;
unbool_comparisons;
unbool_comparisons_goal;
+ repeat match goal with H:and _ _ |- _ => destruct H end;
unfold_In; (* after unbool_comparisons to deal with && and || *)
reduce_list_lengths;
reduce_pow;
@@ -1202,6 +1255,27 @@ match goal with
| _ => tauto
end.
+Lemma or_iff_cong : forall A B C D, A <-> B -> C <-> D -> A \/ C <-> B \/ D.
+intros.
+tauto.
+Qed.
+
+Lemma and_iff_cong : forall A B C D, A <-> B -> C <-> D -> A /\ C <-> B /\ D.
+intros.
+tauto.
+Qed.
+
+Ltac solve_euclid :=
+repeat match goal with |- context [ZEuclid.modulo ?x ?y] =>
+ specialize (ZEuclid.div_mod x y);
+ specialize (ZEuclid.mod_always_pos x y);
+ generalize (ZEuclid.modulo x y);
+ generalize (ZEuclid.div x y);
+ intros
+end;
+nia.
+
+
Ltac solve_arithfact :=
(* Attempt a simple proof first to avoid lengthy preparation steps (especially
as the large proof terms can upset subsequent proofs). *)
@@ -1209,30 +1283,42 @@ intros; (* To solve implications for derive_m *)
try (exact trivial_range);
try fill_in_evar_eq;
try match goal with |- context [projT1 ?X] => apply (ArithFact_self_proof X) end;
+(* Trying reflexivity will fill in more complex metavariable examples than
+ fill_in_evar_eq above, e.g., 8 * n = 8 * ?Goal3 *)
+try (constructor; reflexivity);
try (constructor; omega);
prepare_for_solver;
(*dump_context;*)
+constructor;
+repeat match goal with |- and _ _ => split end;
solve
- [ match goal with |- ArithFact (?x _) => is_evar x; idtac "Warning: unknown constraint"; constructor; exact (I : (fun _ => True) _) end
+ [ match goal with |- (?x _) => is_evar x; idtac "Warning: unknown constraint"; exact (I : (fun _ => True) _) end
| apply ArithFact_mword; assumption
- | constructor; omega with Z
+ | omega with Z
(* Try sail hints before dropping the existential *)
- | constructor; eauto 3 with zarith sail
+ | subst; eauto 3 with zarith sail
(* The datatypes hints give us some list handling, esp In *)
- | constructor; drop_exists; eauto 3 with datatypes zarith sail
- | match goal with |- context [Z.mul] => constructor; nia end
+ | subst; drop_Z_exists; eauto 3 with datatypes zarith sail
+ | subst; match goal with |- context [ZEuclid.div] => solve_euclid
+ | |- context [ZEuclid.modulo] => solve_euclid
+ end
+ | match goal with |- context [Z.mul] => nia end
(* Booleans - and_boolMP *)
- | match goal with |- ArithFact (forall l r:bool, _ -> _ -> exists _ : bool, _) =>
- constructor; intros [|] [|] H1 H2;
+ | drop_bool_exists; solve [eauto using iff_refl, or_iff_cong, and_iff_cong | intuition]
+ | match goal with |- (forall l r:bool, _ -> _ -> exists _ : bool, _) =>
+ intros [|] [|] H1 H2;
repeat match goal with H:?X = ?X -> _ |- _ => specialize (H eq_refl) end;
repeat match goal with H:@ex _ _ |- _ => destruct H end;
bruteforce_bool_exists
end
+(* While firstorder was quite effective at dealing with existentially quantified
+ goals from boolean expressions, it attempts lazy normalization of terms,
+ which blows up on integer comparisons with large constants.
| match goal with |- context [@eq bool _ _] =>
(* Don't use auto for the fallback to keep runtime down *)
firstorder fail
- end
- | constructor; idtac "Unable to solve constraint"; dump_context; fail
+ end*)
+ | idtac "Unable to solve constraint"; dump_context; fail
].
(* Add an indirection so that you can redefine run_solver to fail to get
slow running constraints into proof mode. *)
diff --git a/lib/mono_rewrites.sail b/lib/mono_rewrites.sail
index 9e4010a0..5e20fc71 100644
--- a/lib/mono_rewrites.sail
+++ b/lib/mono_rewrites.sail
@@ -1,23 +1,12 @@
-/* Definitions for use with the -mono_rewrites option */
-
-/* External definitions not in the usual asl prelude */
-
-infix 6 <<
-
-val shiftleft = "shiftl" : forall 'n ('ord : Order).
- (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
-
-overload operator << = {shiftleft}
-
-infix 6 >>
+$ifndef _MONO_REWRITES
+$define _MONO_REWRITES
-val shiftright = "shiftr" : forall 'n ('ord : Order).
- (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+/* Definitions for use with the -mono_rewrites option */
-overload operator >> = {shiftright}
+$include <arith.sail>
+$include <vector_dec.sail>
-val arith_shiftright = "arith_shiftr" : forall 'n ('ord : Order).
- (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+/* External definitions not in the usual asl prelude */
val extzv = "extz_vec" : forall 'n 'm. (implicit('m), vector('n, dec, bit)) -> vector('m, dec, bit) effect pure
@@ -30,23 +19,18 @@ val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect p
/* Definitions for the rewrites */
-val slice_mask : forall 'n, 'n >= 0. (implicit('n), int, int) -> bits('n) effect pure
-function slice_mask(n,i,l) =
- let one : bits('n) = extzv(n, 0b1) in
- ((one << l) - one) << i
-
val is_zero_subrange : forall 'n, 'n >= 0.
(bits('n), int, int) -> bool effect pure
function is_zero_subrange (xs, i, j) = {
- (xs & slice_mask(j, i-j+1)) == extzv(0b0)
+ (xs & slice_mask(j, i-j+1)) == extzv([bitzero] : bits(1))
}
val is_zeros_slice : forall 'n, 'n >= 0.
(bits('n), int, int) -> bool effect pure
function is_zeros_slice (xs, i, l) = {
- (xs & slice_mask(i, l)) == extzv(0b0)
+ (xs & slice_mask(i, l)) == extzv([bitzero] : bits(1))
}
val is_ones_subrange : forall 'n, 'n >= 0.
@@ -69,17 +53,17 @@ val slice_slice_concat : forall 'n 'm 'r, 'n >= 0 & 'm >= 0 & 'r >= 0.
(implicit('r), bits('n), int, int, bits('m), int, int) -> bits('r) effect pure
function slice_slice_concat (r, xs, i, l, ys, i', l') = {
- let xs = (xs & slice_mask(i,l)) >> i in
- let ys = (ys & slice_mask(i',l')) >> i' in
- extzv(r, xs) << l' | extzv(r, ys)
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
+ let ys = sail_shiftright(ys & slice_mask(i',l'), i') in
+ sail_shiftleft(extzv(r, xs), l') | extzv(r, ys)
}
val slice_zeros_concat : forall 'n 'p 'q, 'n >= 0 & 'p + 'q >= 0.
(bits('n), int, atom('p), atom('q)) -> bits('p + 'q) effect pure
function slice_zeros_concat (xs, i, l, l') = {
- let xs = (xs & slice_mask(i,l)) >> i in
- extzv(l + l', xs) << l'
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
+ sail_shiftleft(extzv(l + l', xs), l')
}
/* Assumes initial vectors are of equal size */
@@ -88,8 +72,8 @@ val subrange_subrange_eq : forall 'n, 'n >= 0.
(bits('n), int, int, bits('n), int, int) -> bool effect pure
function subrange_subrange_eq (xs, i, j, ys, i', j') = {
- let xs = (xs & slice_mask(j,i-j+1)) >> j in
- let ys = (ys & slice_mask(j',i'-j'+1)) >> j' in
+ let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
+ let ys = sail_shiftright(ys & slice_mask(j',i'-j'+1), j') in
xs == ys
}
@@ -97,25 +81,25 @@ val subrange_subrange_concat : forall 'n 'o 'p 'm 'q 'r 's, 's >= 0 & 'n >= 0 &
(implicit('s), bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s) effect pure
function subrange_subrange_concat (s, xs, i, j, ys, i', j') = {
- let xs = (xs & slice_mask(j,i-j+1)) >> j in
- let ys = (ys & slice_mask(j',i'-j'+1)) >> j' in
- extzv(s, xs) << (i' - j' + 1) | extzv(s, ys)
+ let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
+ let ys = sail_shiftright(ys & slice_mask(j',i'-j'+1), j) in
+ sail_shiftleft(extzv(s, xs), i' - j' + 1) | extzv(s, ys)
}
val place_subrange : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
function place_subrange(m,xs,i,j,shift) = {
- let xs = (xs & slice_mask(j,i-j+1)) >> j in
- extzv(m, xs) << shift
+ let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
+ sail_shiftleft(extzv(m, xs), shift)
}
val place_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
function place_slice(m,xs,i,l,shift) = {
- let xs = (xs & slice_mask(i,l)) >> i in
- extzv(m, xs) << shift
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
+ sail_shiftleft(extzv(m, xs), shift)
}
val set_slice_zeros : forall 'n, 'n >= 0.
@@ -123,14 +107,14 @@ val set_slice_zeros : forall 'n, 'n >= 0.
function set_slice_zeros(n, xs, i, l) = {
let ys : bits('n) = slice_mask(n, i, l) in
- xs & ~(ys)
+ xs & not_vec(ys)
}
val zext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int) -> bits('m) effect pure
function zext_slice(m,xs,i,l) = {
- let xs = (xs & slice_mask(i,l)) >> i in
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
extzv(m, xs)
}
@@ -138,7 +122,7 @@ val sext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int) -> bits('m) effect pure
function sext_slice(m,xs,i,l) = {
- let xs = arith_shiftright(((xs & slice_mask(i,l)) << ('n - i - l)), 'n - l) in
+ let xs = sail_arith_shiftright(sail_shiftleft((xs & slice_mask(i,l)), ('n - i - l)), 'n - l) in
extsv(m, xs)
}
@@ -146,7 +130,7 @@ val place_slice_signed : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
function place_slice_signed(m,xs,i,l,shift) = {
- sext_slice(m, xs, i, l) << shift
+ sail_shiftleft(sext_slice(m, xs, i, l), shift)
}
/* This has different names in the aarch64 prelude (UInt) and the other
@@ -157,28 +141,46 @@ val _builtin_unsigned = {
lem: "uint",
interpreter: "uint",
c: "sail_uint"
-} : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1)
+} : forall 'n. bits('n) -> {'m, 0 <= 'm < 2 ^ 'n. int('m)}
+
+/* There are different implementation choices for division and remainder, but
+ they agree on positive values. We use this here to give more precise return
+ types for unsigned_slice and unsigned_subrange. */
-val unsigned_slice : forall 'n, 'n >= 0.
- (bits('n), int, int) -> int effect pure
+val _builtin_mod_nat = {
+ smt: "mod",
+ ocaml: "modulus",
+ lem: "integerMod",
+ c: "tmod_int",
+ coq: "Z.rem"
+} : forall 'n 'm, 'n >= 0 & 'm >= 0. (int('n), int('m)) -> {'r, 0 <= 'r < 'm. int('r)}
+
+/* Below we need the fact that 2 ^ 'n >= 0, so we axiomatise it in the return
+ type of pow2, as SMT solvers tend to have problems with exponentiation. */
+val _builtin_pow2 = "pow2" : forall 'n, 'n >= 0. int('n) -> {'m, 'm == 2 ^ 'n & 'm >= 0. int('m)}
+
+val unsigned_slice : forall 'n 'l, 'n >= 0 & 'l >= 0.
+ (bits('n), int, int('l)) -> {'m, 0 <= 'm < 2 ^ 'l. int('m)} effect pure
function unsigned_slice(xs,i,l) = {
- let xs = (xs & slice_mask(i,l)) >> i in
- _builtin_unsigned(xs)
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
+ _builtin_mod_nat(_builtin_unsigned(xs), _builtin_pow2(l))
}
-val unsigned_subrange : forall 'n, 'n >= 0.
- (bits('n), int, int) -> int effect pure
+val unsigned_subrange : forall 'n 'i 'j, 'n >= 0 & ('i - 'j) >= 0.
+ (bits('n), int('i), int('j)) -> {'m, 0 <= 'm < 2 ^ ('i - 'j + 1). int('m)} effect pure
function unsigned_subrange(xs,i,j) = {
- let xs = (xs & slice_mask(j,i-j+1)) >> i in
- _builtin_unsigned(xs)
+ let xs = sail_shiftright(xs & slice_mask(j,i-j+1), i) in
+ _builtin_mod_nat(_builtin_unsigned(xs), _builtin_pow2(i - j + 1))
}
val zext_ones : forall 'n, 'n >= 0. (implicit('n), int) -> bits('n) effect pure
function zext_ones(n, m) = {
- let v : bits('n) = extsv(0b1) in
- v >> (n - m)
+ let v : bits('n) = extsv([bitone] : bits(1)) in
+ sail_shiftright(v, n - m)
}
+
+$endif
diff --git a/lib/sail.c b/lib/sail.c
index 6c71d7ae..2d47939e 100644
--- a/lib/sail.c
+++ b/lib/sail.c
@@ -350,6 +350,27 @@ void mult_int(sail_int *rop, const sail_int op1, const sail_int op2)
mpz_mul(*rop, op1, op2);
}
+
+inline
+void ediv_int(sail_int *rop, const sail_int op1, const sail_int op2)
+{
+ /* GMP doesn't have Euclidean division but we can emulate it using
+ flooring and ceiling division. */
+ if (mpz_sgn(op2) >= 0) {
+ mpz_fdiv_q(*rop, op1, op2);
+ } else {
+ mpz_cdiv_q(*rop, op1, op2);
+ }
+}
+
+inline
+void emod_int(sail_int *rop, const sail_int op1, const sail_int op2)
+{
+ /* The documentation isn't that explicit but I think this is
+ Euclidean mod. */
+ mpz_mod(*rop, op1, op2);
+}
+
inline
void tdiv_int(sail_int *rop, const sail_int op1, const sail_int op2)
{
diff --git a/lib/sail.h b/lib/sail.h
index e06629f0..1c368d2d 100644
--- a/lib/sail.h
+++ b/lib/sail.h
@@ -138,6 +138,8 @@ SAIL_INT_FUNCTION(add_int, sail_int, const sail_int, const sail_int);
SAIL_INT_FUNCTION(sub_int, sail_int, const sail_int, const sail_int);
SAIL_INT_FUNCTION(sub_nat, sail_int, const sail_int, const sail_int);
SAIL_INT_FUNCTION(mult_int, sail_int, const sail_int, const sail_int);
+SAIL_INT_FUNCTION(ediv_int, sail_int, const sail_int, const sail_int);
+SAIL_INT_FUNCTION(emod_int, sail_int, const sail_int, const sail_int);
SAIL_INT_FUNCTION(tdiv_int, sail_int, const sail_int, const sail_int);
SAIL_INT_FUNCTION(tmod_int, sail_int, const sail_int, const sail_int);
SAIL_INT_FUNCTION(fdiv_int, sail_int, const sail_int, const sail_int);
diff --git a/lib/smt.sail b/lib/smt.sail
index d886c127..4d250bef 100644
--- a/lib/smt.sail
+++ b/lib/smt.sail
@@ -3,24 +3,21 @@ $define _SMT
// see http://smtlib.cs.uiowa.edu/theories-Ints.shtml
-val div = {
+/*! Euclidean division */
+val ediv_int = {
ocaml: "quotient",
lem: "integerDiv",
- c: "tdiv_int",
+ c: "ediv_int",
coq: "ediv_with_eq"
} : forall 'n 'm. (int('n), int('m)) -> int(div('n, 'm))
-overload operator / = {div}
-
-val mod = {
+val emod_int = {
ocaml: "modulus",
lem: "integerMod",
- c: "tmod_int",
+ c: "emod_int",
coq: "emod_with_eq"
} : forall 'n 'm. (int('n), int('m)) -> int(mod('n, 'm))
-overload operator % = {mod}
-
val abs_int = {
ocaml: "abs_int",
lem: "abs_int",
diff --git a/lib/vector_dec.sail b/lib/vector_dec.sail
index f31e4ed2..b4014aa6 100644
--- a/lib/vector_dec.sail
+++ b/lib/vector_dec.sail
@@ -15,6 +15,16 @@ val eq_bits = {
overload operator == = {eq_bit, eq_bits}
+val neq_bits = {
+ lem: "neq_vec",
+ c: "neq_bits",
+ coq: "neq_vec"
+} : forall 'n. (vector('n, dec, bit), vector('n, dec, bit)) -> bool
+
+function neq_bits(x, y) = not_bool(eq_bits(x, y))
+
+overload operator != = {neq_bits}
+
val bitvector_length = {coq: "length_mword", _:"length"} : forall 'n. bits('n) -> atom('n)
val vector_length = {
@@ -27,8 +37,6 @@ val vector_length = {
overload length = {bitvector_length, vector_length}
-val sail_zeros = "zeros" : forall 'n. atom('n) -> bits('n)
-
val "print_bits" : forall 'n. (string, bits('n)) -> unit
val "prerr_bits" : forall 'n. (string, bits('n)) -> unit
@@ -126,6 +134,23 @@ val add_bits_int = {
overload operator + = {add_bits, add_bits_int}
+val sub_bits = {
+ ocaml: "sub_vec",
+ lem: "sub_vec",
+ c: "sub_bits",
+ coq: "sub_vec"
+} : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+val not_vec = {c: "not_bits", _: "not_vec"} : forall 'n. bits('n) -> bits('n)
+
+val and_vec = {lem: "and_vec", c: "and_bits", coq: "and_vec", ocaml: "and_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+overload operator & = {and_vec}
+
+val or_vec = {lem: "or_vec", c: "or_bits", coq: "or_vec", ocaml: "or_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+overload operator | = {or_vec}
+
val vector_subrange = {
ocaml: "subrange",
interpreter: "subrange",
@@ -143,8 +168,37 @@ val vector_update_subrange = {
coq: "update_subrange_vec_dec"
} : forall 'n 'm 'o, 0 <= 'o <= 'm < 'n. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n)
+val sail_shiftleft = "shiftl" : forall 'n ('ord : Order).
+ (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+
+val sail_shiftright = "shiftr" : forall 'n ('ord : Order).
+ (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+
+val sail_arith_shiftright = "arith_shiftr" : forall 'n ('ord : Order).
+ (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+
+val sail_zeros = "zeros" : forall 'n, 'n >= 0. atom('n) -> bits('n)
+
+val sail_ones : forall 'n, 'n >= 0. atom('n) -> bits('n)
+
+function sail_ones(n) = not_vec(sail_zeros(n))
+
// Some ARM specific builtins
+val slice = "slice" : forall 'n 'm 'o, 0 <= 'm & 0 <= 'n.
+ (bits('m), atom('o), atom('n)) -> bits('n)
+
+val replicate_bits = "replicate_bits" : forall 'n 'm. (bits('n), atom('m)) -> bits('n * 'm)
+
+val slice_mask : forall 'n, 'n >= 0. (implicit('n), int, int) -> bits('n) effect pure
+function slice_mask(n,i,l) =
+ if l >= n then {
+ sail_ones(n)
+ } else {
+ let one : bits('n) = sail_mask(n, [bitone] : bits(1)) in
+ sail_shiftleft(sub_bits(sail_shiftleft(one, l), one), i)
+ }
+
val get_slice_int = "get_slice_int" : forall 'w. (atom('w), int, int) -> bits('w)
val set_slice_int = "set_slice_int" : forall 'w. (atom('w), int, int, bits('w)) -> int
@@ -152,11 +206,6 @@ val set_slice_int = "set_slice_int" : forall 'w. (atom('w), int, int, bits('w))
val set_slice_bits = "set_slice" : forall 'n 'm.
(atom('n), atom('m), bits('n), int, bits('m)) -> bits('n)
-val slice = "slice" : forall 'n 'm 'o, 0 <= 'o < 'm & 'o + 'n <= 'm & 0 <= 'n.
- (bits('m), atom('o), atom('n)) -> bits('n)
-
-val replicate_bits = "replicate_bits" : forall 'n 'm. (bits('n), atom('m)) -> bits('n * 'm)
-
/*!
converts a bit vector of length $n$ to an integer in the range $0$ to $2^n - 1$.
*/
diff --git a/opam b/opam
index e457e1d2..c205c1f6 100644
--- a/opam
+++ b/opam
@@ -1,6 +1,6 @@
opam-version: "1.2"
name: "sail"
-version: "0.8"
+version: "0.9"
maintainer: "Sail Devs <cl-sail-dev@lists.cam.ac.uk>"
authors: [
"Alasdair Armstrong"
@@ -28,7 +28,7 @@ depends: [
"ocamlbuild"
"zarith"
"menhir"
- "linenoise"
+ "linenoise" {>= "1.1.0"}
"ott" {>= "0.28"}
"lem" {>= "2018-12-14"}
"linksem" {>= "0.3"}
@@ -38,4 +38,4 @@ depends: [
"base64"
"yojson"
]
-available: [ocaml-version >= "4.06.0"]
+available: [ocaml-version >= "4.06.1"]
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 34345210..386c080a 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -179,9 +179,9 @@ module Id = struct
let compare id1 id2 =
match (id1, id2) with
| Id_aux (Id x, _), Id_aux (Id y, _) -> String.compare x y
- | Id_aux (DeIid x, _), Id_aux (DeIid y, _) -> String.compare x y
- | Id_aux (Id _, _), Id_aux (DeIid _, _) -> -1
- | Id_aux (DeIid _, _), Id_aux (Id _, _) -> 1
+ | Id_aux (Operator x, _), Id_aux (Operator y, _) -> String.compare x y
+ | Id_aux (Id _, _), Id_aux (Operator _, _) -> -1
+ | Id_aux (Operator _, _), Id_aux (Id _, _) -> 1
end
module Nexp = struct
@@ -360,7 +360,7 @@ let rec constraint_disj (NC_aux (nc_aux, l) as nc) =
let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown)
let mk_typ_arg arg = A_aux (arg, Parse_ast.Unknown)
let mk_kid str = Kid_aux (Var ("'" ^ str), Parse_ast.Unknown)
-let mk_infix_id str = Id_aux (DeIid str, Parse_ast.Unknown)
+let mk_infix_id str = Id_aux (Operator str, Parse_ast.Unknown)
let mk_id_typ id = Typ_aux (Typ_id id, Parse_ast.Unknown)
@@ -725,23 +725,23 @@ let def_loc = function
let string_of_id = function
| Id_aux (Id v, _) -> v
- | Id_aux (DeIid v, _) -> "(operator " ^ v ^ ")"
+ | Id_aux (Operator v, _) -> "(operator " ^ v ^ ")"
let id_of_kid = function
| Kid_aux (Var v, l) -> Id_aux (Id (String.sub v 1 (String.length v - 1)), l)
let kid_of_id = function
| Id_aux (Id v, l) -> Kid_aux (Var ("'" ^ v), l)
- | Id_aux (DeIid v, _) -> assert false
+ | Id_aux (Operator v, _) -> assert false
let prepend_id str = function
| Id_aux (Id v, l) -> Id_aux (Id (str ^ v), l)
- | Id_aux (DeIid v, l) -> Id_aux (DeIid (str ^ v), l)
+ | Id_aux (Operator v, l) -> Id_aux (Operator (str ^ v), l)
let append_id id str =
match id with
| Id_aux (Id v, l) -> Id_aux (Id (v ^ str), l)
- | Id_aux (DeIid v, l) -> Id_aux (DeIid (v ^ str), l)
+ | Id_aux (Operator v, l) -> Id_aux (Operator (v ^ str), l)
let prepend_kid str = function
| Kid_aux (Var v, l) -> Kid_aux (Var ("'" ^ str ^ String.sub v 1 (String.length v - 1)), l)
@@ -839,7 +839,7 @@ and string_of_n_constraint = function
"(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_set (kid, ns), _) ->
string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}"
- | NC_aux (NC_app (Id_aux (DeIid op, _), [arg1; arg2]), _) ->
+ | NC_aux (NC_app (Id_aux (Operator op, _), [arg1; arg2]), _) ->
"(" ^ string_of_typ_arg arg1 ^ " " ^ op ^ " " ^ string_of_typ_arg arg2 ^ ")"
| NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")"
| NC_aux (NC_var v, _) -> string_of_kid v
@@ -1174,6 +1174,8 @@ module NC = struct
let compare = nc_compare
end
+module NCMap = Map.Make(NC)
+
module Typ = struct
type t = typ
let compare = typ_compare
@@ -1854,6 +1856,8 @@ and constraint_subst_aux l sv subst = function
| NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_set (kid, ints) as set_nc ->
begin match subst with
+ | A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)), _) when Kid.compare kid sv = 0 ->
+ NC_set (kid', ints)
| A_aux (A_nexp n, _) when Kid.compare kid sv = 0 ->
nexp_set_to_or l n ints
| _ -> set_nc
@@ -1987,20 +1991,12 @@ let subst_kids_nc, subst_kids_typ, subst_kids_typ_arg =
| A_bool nc -> A_aux (A_bool (subst_kids_nc substs nc), l)
in subst_kids_nc, s_styp, s_starg
-
-let rec simp_loc = function
- | Parse_ast.Unknown -> None
- | Parse_ast.Unique (_, l) -> simp_loc l
- | Parse_ast.Generated l -> simp_loc l
- | Parse_ast.Range (p1, p2) -> Some (p1, p2)
- | Parse_ast.Documented (_, l) -> simp_loc l
-
let before p1 p2 =
let open Lexing in
p1.pos_fname = p2.pos_fname && p1.pos_cnum <= p2.pos_cnum
let subloc sl l =
- match sl, simp_loc l with
+ match sl, Reporting.simp_loc l with
| _, None -> false
| None, _ -> false
| Some (p1a, p1b), Some (p2a, p2b) ->
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 64b39b51..cfbc26fe 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -265,6 +265,10 @@ module Bindings : sig
include Map.S with type key = id
end
+module NCMap : sig
+ include Map.S with type key = n_constraint
+end
+
module TypMap : sig
include Map.S with type key = typ
end
@@ -490,9 +494,6 @@ val unique : l -> l
val extern_assoc : string -> (string * string) list -> string option
-(** Reduce a location to a pair of positions if possible *)
-val simp_loc : Ast.l -> (Lexing.position * Lexing.position) option
-
(** Try to find the annotation closest to the provided (simplified)
location. Note that this function makes no guarantees about finding
the closest annotation or even finding an annotation at all. This
diff --git a/src/constant_fold.ml b/src/constant_fold.ml
index 6706cc01..14d6550c 100644
--- a/src/constant_fold.ml
+++ b/src/constant_fold.ml
@@ -137,13 +137,17 @@ let fold_to_unit id =
in
IdSet.mem id remove
-let rec is_constant (E_aux (e_aux, _)) =
+let rec is_constant (E_aux (e_aux, _) as exp) =
match e_aux with
| E_lit _ -> true
| E_vector exps -> List.for_all is_constant exps
| E_record fexps -> List.for_all is_constant_fexp fexps
| E_cast (_, exp) -> is_constant exp
| E_tuple exps -> List.for_all is_constant exps
+ | E_id id ->
+ (match Env.lookup_id id (env_of exp) with
+ | Enum _ -> true
+ | _ -> false)
| _ -> false
and is_constant_fexp (FE_aux (FE_Fexp (_, exp), _)) = is_constant exp
@@ -184,21 +188,18 @@ let rec run frame =
- Throws an exception that isn't caught.
*)
-let rec rewrite_constant_function_calls' env ast =
- let rewrite_count = ref 0 in
- let ok () = incr rewrite_count in
- let not_ok () = decr rewrite_count in
-
+let initial_state ast env =
let lstate, gstate =
Interpreter.initial_state ast env safe_primops
in
- let gstate = { gstate with Interpreter.allow_registers = false } in
+ (lstate, { gstate with Interpreter.allow_registers = false })
+let rw_exp ok not_ok istate =
let evaluate e_aux annot =
let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in
try
begin
- let v = run (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in
+ let v = run (Interpreter.Step (lazy "", istate, initial_monad, [])) in
let exp = exp_of_value v in
try (ok (); Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)) with
| Type_error (env, l, err) ->
@@ -242,19 +243,27 @@ let rec rewrite_constant_function_calls' env ast =
| _ -> E_aux (e_aux, annot)
in
- let rw_exp = {
- id_exp_alg with
- e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)
+ fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)}
+
+let rewrite_exp_once = rw_exp (fun _ -> ()) (fun _ -> ())
+
+let rec rewrite_constant_function_calls' ast =
+ let rewrite_count = ref 0 in
+ let ok () = incr rewrite_count in
+ let not_ok () = decr rewrite_count in
+
+ let rw_defs = {
+ rewriters_base with
+ rewrite_exp = (fun _ -> rw_exp ok not_ok (initial_state ast Type_check.initial_env))
} in
- let rw_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) } in
let ast = rewrite_defs_base rw_defs ast in
(* We keep iterating until we have no more re-writes to do *)
if !rewrite_count > 0
- then rewrite_constant_function_calls' env ast
+ then rewrite_constant_function_calls' ast
else ast
-let rewrite_constant_function_calls env ast =
+let rewrite_constant_function_calls ast =
if !optimize_constant_fold then
- rewrite_constant_function_calls' env ast
+ rewrite_constant_function_calls' ast
else
ast
diff --git a/src/constant_propagation.ml b/src/constant_propagation.ml
index 33b67008..ce04798a 100644
--- a/src/constant_propagation.ml
+++ b/src/constant_propagation.ml
@@ -111,9 +111,13 @@ let rec is_value (E_aux (e,(l,annot))) =
| E_id id -> is_constructor id
| E_lit _ -> true
| E_tuple es -> List.for_all is_value es
+ | E_record fes ->
+ List.for_all (fun (FE_aux (FE_Fexp (_, e), _)) -> is_value e) fes
| E_app (id,es) -> is_constructor id && List.for_all is_value es
(* We add casts to undefined to keep the type information in the AST *)
| E_cast (typ,E_aux (E_lit (L_aux (L_undef,_)),_)) -> true
+ (* Also keep casts around records, as type inference fails without *)
+ | E_cast (_, (E_aux (E_record _, _) as e')) -> is_value e'
(* TODO: more? *)
| _ -> false
@@ -263,93 +267,6 @@ let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) =
| L_num i1, L_num i2 -> Some (Big_int.equal i1 i2)
| _ -> Some (l1 = l2)
-let try_app (l,ann) (id,args) =
- let new_l = Parse_ast.Generated l in
- let env = env_of_annot (l,ann) in
- let get_overloads f = List.map string_of_id
- (Env.get_overloads (Id_aux (Id f, Parse_ast.Unknown)) env @
- Env.get_overloads (Id_aux (DeIid f, Parse_ast.Unknown)) env) in
- let is_id f = List.mem (string_of_id id) (f :: get_overloads f) in
- if is_id "==" || is_id "!=" then
- match args with
- | [E_aux (E_lit l1,_); E_aux (E_lit l2,_)] ->
- let lit b = if b then L_true else L_false in
- let lit b = lit (if is_id "==" then b else not b) in
- (match lit_eq l1 l2 with
- | None -> None
- | Some b -> Some (E_aux (E_lit (L_aux (lit b,new_l)),(l,ann))))
- | _ -> None
- else if is_id "cast_bit_bool" then
- match args with
- | [E_aux (E_lit L_aux (L_zero,_),_)] -> Some (E_aux (E_lit (L_aux (L_false,new_l)),(l,ann)))
- | [E_aux (E_lit L_aux (L_one ,_),_)] -> Some (E_aux (E_lit (L_aux (L_true ,new_l)),(l,ann)))
- | _ -> None
- else if is_id "UInt" || is_id "unsigned" then
- match args with
- | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit,_), _)] ->
- Some (E_aux (E_lit (L_aux (L_num (int_of_str_lit lit),new_l)),(l,ann)))
- | _ -> None
- else if is_id "slice" then
- match args with
- | [E_aux (E_lit (L_aux ((L_hex _| L_bin _),_) as lit), annot);
- E_aux (E_lit L_aux (L_num i,_), _);
- E_aux (E_lit L_aux (L_num len,_), _)] ->
- (match Env.base_typ_of (env_of_annot annot) (typ_of_annot annot) with
- | Typ_aux (Typ_app (_,[_;A_aux (A_order ord,_);_]),_) ->
- (match slice_lit lit i len ord with
- | Some lit' -> Some (E_aux (E_lit lit',(l,ann)))
- | None -> None)
- | _ -> None)
- | _ -> None
- else if is_id "bitvector_concat" then
- match args with
- | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit1,_), _);
- E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit2,_), _)] ->
- Some (E_aux (E_lit (L_aux (concat_vec lit1 lit2,new_l)),(l,ann)))
- | _ -> None
- else if is_id "shl_int" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.shift_left i (Big_int.to_int j)),new_l)),(l,ann)))
- | _ -> None
- else if is_id "mult_atom" || is_id "mult_int" || is_id "mult_range" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.mul i j),new_l)),(l,ann)))
- | _ -> None
- else if is_id "quotient_nat" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.div i j),new_l)),(l,ann)))
- | _ -> None
- else if is_id "add_atom" || is_id "add_int" || is_id "add_range" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.add i j),new_l)),(l,ann)))
- | _ -> None
- else if is_id "negate_range" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.negate i),new_l)),(l,ann)))
- | _ -> None
- else if is_id "ex_int" then
- match args with
- | [E_aux (E_lit lit,(l,_))] -> Some (E_aux (E_lit lit,(l,ann)))
- | [E_aux (E_cast (_,(E_aux (E_lit (L_aux (L_undef,_)),_) as e)),(l,_))] ->
- Some (reduce_cast (typ_of_annot (l,ann)) e l ann)
- | _ -> None
- else if is_id "vector_access" || is_id "bitvector_access" then
- match args with
- | [E_aux (E_lit L_aux ((L_hex _ | L_bin _) as lit,_),_);
- E_aux (E_lit L_aux (L_num i,_),_)] ->
- let v = int_of_str_lit lit in
- let b = Big_int.bitwise_and (Big_int.shift_right v (Big_int.to_int i)) (Big_int.of_int 1) in
- let lit' = if Big_int.equal b (Big_int.of_int 1) then L_one else L_zero in
- Some (E_aux (E_lit (L_aux (lit',new_l)),(l,ann)))
- | _ -> None
- else None
-
-
let construct_lit_vector args =
let rec aux l = function
| [] -> Some (L_aux (L_bin (String.concat "" (List.rev l)),Unknown))
@@ -361,10 +278,18 @@ let construct_lit_vector args =
(* Add a cast to undefined so that it retains its type, otherwise it can't be
substituted safely *)
let keep_undef_typ value =
- match value with
- | E_aux (E_lit (L_aux (L_undef,lann)),eann) ->
- E_aux (E_cast (typ_of_annot eann,value),(Generated Unknown,snd eann))
- | _ -> value
+ let e_aux (e, ann) =
+ match e with
+ | E_lit (L_aux (L_undef, _)) ->
+ (* Add cast to undefined... *)
+ E_aux (E_cast (typ_of_annot ann, E_aux (e, ann)), ann)
+ | E_cast (typ, E_aux (E_cast (_, e), _)) ->
+ (* ... unless there was a cast already *)
+ E_aux (E_cast (typ, e), ann)
+ | _ -> E_aux (e, ann)
+ in
+ let open Rewriter in
+ fold_exp { id_exp_alg with e_aux = e_aux } value
(* Check whether the current environment with the given kid assignments is
inconsistent (and hence whether the code is dead) *)
@@ -373,8 +298,29 @@ let is_env_inconsistent env ksubsts =
Env.add_constraint (nc_eq (nvar k) nexp) env) ksubsts env in
prove __POS__ env nc_false
+module StringSet = Set.Make(String)
+module StringMap = Map.Make(String)
let const_props defs ref_vars =
+ let const_fold exp =
+ (* Constant-fold function applications with constant arguments *)
+ let interpreter_istate =
+ (* Do not interpret undefined_X functions *)
+ let open Interpreter in
+ let undefined_builtin_ids = ids_of_defs (Defs Initial_check.undefined_builtin_val_specs) in
+ let remove_primop id = StringMap.remove (string_of_id id) in
+ let remove_undefined_primops = IdSet.fold remove_primop undefined_builtin_ids in
+ let (lstate, gstate) = Constant_fold.initial_state defs Type_check.initial_env in
+ (lstate, { gstate with primops = remove_undefined_primops gstate.primops })
+ in
+ try
+ strip_exp exp
+ |> infer_exp (env_of exp)
+ |> Constant_fold.rewrite_exp_once interpreter_istate
+ |> keep_undef_typ
+ with
+ | _ -> exp
+ in
let rec const_prop_exp substs assigns ((E_aux (e,(l,annot))) as exp) =
(* Functions to treat lists and tuples of subexpressions as possibly
non-deterministic: that is, we stop making any assumptions about
@@ -414,7 +360,8 @@ let const_props defs ref_vars =
let e4',_ = const_prop_exp substs assigns e4 in
e1',e2',e3',e4',assigns
in
- let re e assigns = E_aux (e,(l,annot)),assigns in
+ let rewrap e = E_aux (e,(l,annot)) in
+ let re e assigns = rewrap e,assigns in
match e with
(* TODO: are there more circumstances in which we should get rid of these? *)
| E_block [e] -> const_prop_exp substs assigns e
@@ -444,12 +391,7 @@ let const_props defs ref_vars =
| E_app (id,es) ->
let es',assigns = non_det_exp_list es in
let env = Type_check.env_of_annot (l, annot) in
- (match try_app (l,annot) (id,es') with
- | None ->
- (match const_prop_try_fn l env (id,es') with
- | None -> re (E_app (id,es')) assigns
- | Some r -> r,assigns)
- | Some r -> r,assigns)
+ const_prop_try_fn env (id, es') (l, annot), assigns
| E_tuple es ->
let es',assigns = non_det_exp_list es in
re (E_tuple es') assigns
@@ -466,7 +408,7 @@ let const_props defs ref_vars =
let env1 = env_of e1_no_casts in
let is_equal id =
List.exists (fun id' -> Id.compare id id' == 0)
- (Env.get_overloads (Id_aux (DeIid "==", Parse_ast.Unknown))
+ (Env.get_overloads (Id_aux (Operator "==", Parse_ast.Unknown))
env1)
in
let substs_true =
@@ -539,10 +481,33 @@ let const_props defs ref_vars =
let assigned_in = IdSet.union (assigned_vars_in_fexps fes) (assigned_vars e) in
let assigns = isubst_minus_set assigns assigned_in in
let e',_ = const_prop_exp substs assigns e in
- re (E_record_update (e', const_prop_fexps substs assigns fes)) assigns
+ let fes' = const_prop_fexps substs assigns fes in
+ begin
+ match unaux_exp (fst (uncast_exp e')) with
+ | E_record (fes0) ->
+ let apply_fexp (FE_aux (FE_Fexp (id, e), _)) (FE_aux (FE_Fexp (id', e'), ann)) =
+ if Id.compare id id' = 0 then
+ FE_aux (FE_Fexp (id', e), ann)
+ else
+ FE_aux (FE_Fexp (id', e'), ann)
+ in
+ let update_fields fexp = List.map (apply_fexp fexp) in
+ let fes0' = List.fold_right update_fields fes' fes0 in
+ re (E_record fes0') assigns
+ | _ ->
+ re (E_record_update (e', fes')) assigns
+ end
| E_field (e,id) ->
let e',assigns = const_prop_exp substs assigns e in
- re (E_field (e',id)) assigns
+ begin
+ let is_field (FE_aux (FE_Fexp (id', _), _)) = Id.compare id id' = 0 in
+ match unaux_exp e' with
+ | E_record fes0 when List.exists is_field fes0 ->
+ let (FE_aux (FE_Fexp (_, e), _)) = List.find is_field fes0 in
+ re (unaux_exp e) assigns
+ | _ ->
+ re (E_field (e',id)) assigns
+ end
| E_case (e,cases) ->
let e',assigns = const_prop_exp substs assigns e in
(match can_match e' cases substs assigns with
@@ -568,7 +533,7 @@ let const_props defs ref_vars =
let e2',assigns = const_prop_exp substs' assigns e2 in
re (E_let (LB_aux (LB_val (p,e'), annot),
e2')) assigns in
- if is_value e' && not (is_value e) then
+ if is_value e' then
match can_match e' [Pat_aux (Pat_exp (p,e2),(Unknown,empty_tannot))] substs assigns with
| None -> plain ()
| Some (e'',bindings,kbindings) ->
@@ -653,48 +618,36 @@ let const_props defs ref_vars =
| LEXP_field (le,id) -> re (LEXP_field (fst (const_prop_lexp substs assigns le), id))
| LEXP_deref e ->
re (LEXP_deref (fst (const_prop_exp substs assigns e)))
- (* Reduce a function when
- 1. all arguments are values,
- 2. the function is pure,
- 3. the result is a value
- (and 4. the function is not scattered, but that's not terribly important)
- to try and keep execution time and the results managable.
+ (* Try to evaluate function calls with constant arguments via
+ (interpreter-based) constant folding.
+ Boolean connectives are special-cased to support short-circuiting when one
+ argument has a suitable value (even if the other one is not constant).
+ Moreover, calls to a __size function (in particular generated by sizeof
+ rewriting) with a known-constant return type are replaced by that constant;
+ e.g., (length(op : bits(32)) : int(32)) becomes 32 even if op is not constant.
*)
- and const_prop_try_fn l env (id,args) =
- if not (List.for_all is_value args) then
- None
- else
- let (tq,typ) = Env.get_val_spec_orig id env in
- let eff = match typ with
- | Typ_aux (Typ_fn (_,_,eff),_) -> Some eff
- | _ -> None
- in
- let Defs ds = defs in
- match eff, list_extract (function
- | (DEF_fundef (FD_aux (FD_function (_,_,eff,((FCL_aux (FCL_Funcl (id',_),_))::_ as fcls)),_)))
- -> if Id.compare id id' = 0 then Some fcls else None
- | _ -> None) ds with
- | None,_ | _,None -> None
- | Some eff,_ when not (is_pure eff) -> None
- | Some _,Some fcls ->
- let arg = match args with
- | [] -> E_aux (E_lit (L_aux (L_unit,Generated l)),(Generated l,empty_tannot))
- | [e] -> e
- | _ -> E_aux (E_tuple args,(Generated l,empty_tannot)) in
- let cases = List.map (function
- | FCL_aux (FCL_Funcl (_,pexp), ann) -> pexp)
- fcls in
- match can_match_with_env env arg cases (Bindings.empty,KBindings.empty) Bindings.empty with
- | Some (exp,bindings,kbindings) ->
- let substs = bindings_from_list bindings, kbindings_from_list kbindings in
- let result,_ = const_prop_exp substs Bindings.empty exp in
- let result = match result with
- | E_aux (E_return e,_) -> e
- | _ -> result
- in
- if is_value result then Some result else None
- | None -> None
-
+ and const_prop_try_fn env (id, args) (l, annot) =
+ let exp = E_aux (E_app (id, args), (l, annot)) in
+ let rec is_overload_of f =
+ Env.get_overloads f env
+ |> List.exists (fun id' -> Id.compare id id' = 0 || is_overload_of id')
+ in
+ match (string_of_id id, args) with
+ | "and_bool", ([E_aux (E_lit (L_aux (L_false, _)), _) as e_false; _] |
+ [_; E_aux (E_lit (L_aux (L_false, _)), _) as e_false]) ->
+ e_false
+ | "or_bool", ([E_aux (E_lit (L_aux (L_true, _)), _) as e_true; _] |
+ [_; E_aux (E_lit (L_aux (L_true, _)), _) as e_true]) ->
+ e_true
+ | _, _ when List.for_all Constant_fold.is_constant args ->
+ const_fold exp
+ | _, [arg] when is_overload_of (mk_id "__size") ->
+ (match destruct_atom_nexp env (typ_of exp) with
+ | Some (Nexp_aux (Nexp_constant i, _)) ->
+ E_aux (E_lit (mk_lit (L_num i)), (l, annot))
+ | _ -> exp)
+ | _ -> exp
+
and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases (substs,ksubsts) assigns =
let rec findpat_generic check_pat description assigns = function
| [] -> (Reporting.print_err l "Monomorphisation"
@@ -816,6 +769,8 @@ let const_props defs ref_vars =
(Reporting.print_err l' "Monomorphisation"
"Unexpected kind of pattern for literal"; GiveUp)
in findpat_generic checkpat "literal" assigns cases
+ | E_record _ | E_cast (_, E_aux (E_record _, _)) ->
+ findpat_generic (fun _ -> DoesNotMatch) "record" assigns cases
| _ -> None
and can_match exp =
diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml
new file mode 100644
index 00000000..285ba45d
--- /dev/null
+++ b/src/constant_propagation_mutrec.ml
@@ -0,0 +1,232 @@
+(**************************************************************************)
+(* Sail *)
+(* *)
+(* Copyright (c) 2013-2017 *)
+(* Kathyrn Gray *)
+(* Shaked Flur *)
+(* Stephen Kell *)
+(* Gabriel Kerneis *)
+(* Robert Norton-Wright *)
+(* Christopher Pulte *)
+(* Peter Sewell *)
+(* Alasdair Armstrong *)
+(* Brian Campbell *)
+(* Thomas Bauereiss *)
+(* Anthony Fox *)
+(* Jon French *)
+(* Dominic Mulligan *)
+(* Stephen Kell *)
+(* Mark Wassell *)
+(* *)
+(* All rights reserved. *)
+(* *)
+(* This software was developed by the University of Cambridge Computer *)
+(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *)
+(* (REMS) project, funded by EPSRC grant EP/K008528/1. *)
+(* *)
+(* Redistribution and use in source and binary forms, with or without *)
+(* modification, are permitted provided that the following conditions *)
+(* are met: *)
+(* 1. Redistributions of source code must retain the above copyright *)
+(* notice, this list of conditions and the following disclaimer. *)
+(* 2. Redistributions in binary form must reproduce the above copyright *)
+(* notice, this list of conditions and the following disclaimer in *)
+(* the documentation and/or other materials provided with the *)
+(* distribution. *)
+(* *)
+(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *)
+(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *)
+(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *)
+(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *)
+(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *)
+(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *)
+(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *)
+(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *)
+(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *)
+(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *)
+(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *)
+(* SUCH DAMAGE. *)
+(**************************************************************************)
+
+open Ast
+open Ast_util
+open Type_check
+open Rewriter
+
+(* Unroll mutually recursive calls, starting with the functions given as
+ targets on the command line, by looking for recursive calls with (some)
+ constant arguments, and creating copies of those functions with the
+ constants propagated in. This may cause branches with mutually recursively
+ calls to disappear, breaking the mutually recursive cycle. *)
+
+let targets = ref ([] : id list)
+
+let rec is_const_exp exp = match unaux_exp exp with
+ | E_lit (L_aux ((L_true | L_false | L_one | L_zero | L_num _), _)) -> true
+ | E_vector es -> List.for_all is_const_exp es && is_bitvector_typ (typ_of exp)
+ | E_record fes -> List.for_all is_const_fexp fes
+ | _ -> false
+and is_const_fexp (FE_aux (FE_Fexp (_, e), _)) = is_const_exp e
+
+let recheck_exp exp = check_exp (env_of exp) (strip_exp exp) (typ_of exp)
+
+(* Name function copy by encoding values of constant arguments *)
+let generate_fun_id id args =
+ let rec suffix exp = match unaux_exp exp with
+ | E_lit (L_aux (L_one, _)) -> "1"
+ | E_lit (L_aux (L_zero, _)) -> "0"
+ | E_lit (L_aux (L_true, _)) -> "T"
+ | E_lit (L_aux (L_false, _)) -> "F"
+ | E_record fes when is_const_exp exp ->
+ let fsuffix (FE_aux (FE_Fexp (id, e), _)) = suffix e
+ in
+ "struct" ^
+ Util.zencode_string (string_of_typ (typ_of exp)) ^
+ "#" ^
+ String.concat "" (List.map fsuffix fes)
+ | E_vector es when is_const_exp exp ->
+ String.concat "" (List.map suffix es)
+ | _ ->
+ if is_const_exp exp
+ then "#" ^ Util.zencode_string (string_of_exp exp)
+ else "v"
+ in
+ append_id id ("#mutrec_" ^ String.concat "" (List.map suffix args))
+
+(* Generate a val spec for a function copy, removing the constant arguments
+ that will be propagated in *)
+let generate_val_spec env id args l annot =
+ match Env.get_val_spec_orig id env with
+ | tq, Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) ->
+ let orig_ksubst (kid, typ_arg) =
+ match typ_arg with
+ | A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg)
+ | _ -> raise (Reporting.err_todo l "Propagation of polymorphic arguments not implemented")
+ in
+ let ksubsts =
+ recheck_exp (E_aux (E_app (id, args), (l, annot)))
+ |> instantiation_of
+ |> KBindings.bindings
+ |> List.map orig_ksubst
+ |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
+ in
+ let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in
+ let arg_typs' =
+ List.map (KBindings.fold typ_subst ksubsts) arg_typs
+ |> List.map2 (fun arg typ -> if is_const_exp arg then [] else [typ]) args
+ |> List.concat
+ |> function [] -> [unit_typ] | typs -> typs
+ in
+ let typ' = mk_typ (Typ_fn (arg_typs', ret_typ', eff)) in
+ let tyvars = tyvars_of_typ typ' in
+ let tq' =
+ quant_items tq |>
+ List.filter (fun qi -> KidSet.subset (tyvars_of_quant_item qi) tyvars) |>
+ mk_typquant
+ in
+ let typschm = mk_typschm tq' typ' in
+ mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, [], false)),
+ ksubsts
+ | _, Typ_aux (_, l) ->
+ raise (Reporting.err_unreachable l __POS__ "Function val spec is not a function type")
+
+let const_prop defs substs ksubsts exp =
+ (* Constant_propagation currently only supports nexps for kid substitutions *)
+ let nexp_substs =
+ KBindings.bindings ksubsts
+ |> List.map (function (kid, A_aux (A_nexp n, _)) -> [(kid, n)] | _ -> [])
+ |> List.concat
+ |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
+ in
+ Constant_propagation.const_prop
+ (Defs defs)
+ (Constant_propagation.referenced_vars exp)
+ (substs, nexp_substs)
+ Bindings.empty
+ exp
+ |> fst
+
+(* Propagate constant arguments into function clause pexp *)
+let prop_args_pexp defs ksubsts args pexp =
+ let pat, guard, exp, annot = destruct_pexp pexp in
+ let pats = match pat with
+ | P_aux (P_tup pats, _) -> pats
+ | _ -> [pat]
+ in
+ let match_arg (E_aux (_, (l, _)) as arg) pat (pats, substs) =
+ if is_const_exp arg then
+ match pat with
+ | P_aux (P_id id, _) -> (pats, Bindings.add id arg substs)
+ | _ ->
+ raise (Reporting.err_todo l
+ ("Unsupported pattern match in propagation of constant arguments: " ^
+ string_of_exp arg ^ " and " ^ string_of_pat pat))
+ else (pat :: pats, substs)
+ in
+ let pats, substs = List.fold_right2 match_arg args pats ([], Bindings.empty) in
+ let exp' = const_prop defs substs ksubsts exp in
+ let pat' = match pats with
+ | [pat] -> pat
+ | _ -> P_aux (P_tup pats, (Parse_ast.Unknown, empty_tannot))
+ in
+ construct_pexp (pat', guard, exp', annot)
+
+let rewrite_defs env (Defs defs) =
+ let rec rewrite = function
+ | [] -> []
+ | DEF_internal_mutrec mutrecs :: ds ->
+ let mutrec_ids = IdSet.of_list (List.map id_of_fundef mutrecs) in
+ let valspecs = ref ([] : unit def list) in
+ let fundefs = ref ([] : unit def list) in
+ (* Try to replace mutually recursive calls that have some constant arguments *)
+ let rec e_app (id, args) (l, annot) =
+ if IdSet.mem id mutrec_ids && List.exists is_const_exp args then
+ let id' = generate_fun_id id args in
+ let args' = match List.filter (fun e -> not (is_const_exp e)) args with
+ | [] -> [infer_exp env (mk_lit_exp L_unit)]
+ | args' -> args'
+ in
+ if not (IdSet.mem id' (ids_of_defs (Defs !valspecs))) then begin
+ (* Generate copy of function with constant arguments propagated in *)
+ let (FD_aux (FD_function (_, _, _, fcls), _)) =
+ List.find (fun fd -> Id.compare id (id_of_fundef fd) = 0) mutrecs
+ in
+ let valspec, ksubsts = generate_val_spec env id args l annot in
+ let const_prop_funcl (FCL_aux (FCL_Funcl (_, pexp), (l, _))) =
+ let pexp' =
+ prop_args_pexp defs ksubsts args pexp
+ |> rewrite_pexp
+ |> strip_pexp
+ in
+ FCL_aux (FCL_Funcl (id', pexp'), (Parse_ast.Generated l, ()))
+ in
+ valspecs := valspec :: !valspecs;
+ let fundef = mk_fundef (List.map const_prop_funcl fcls) in
+ fundefs := fundef :: !fundefs
+ end else ();
+ E_aux (E_app (id', args'), (l, annot))
+ else E_aux (E_app (id, args), (l, annot))
+ and e_aux (e, (l, annot)) =
+ match e with
+ | E_app (id, args) -> e_app (id, args) (l, annot)
+ | _ -> E_aux (e, (l, annot))
+ and rewrite_pexp pexp = fold_pexp { id_exp_alg with e_aux = e_aux } pexp
+ and rewrite_funcl (FCL_aux (FCL_Funcl (id, pexp), a) as funcl) =
+ let pexp' =
+ if List.exists (fun id' -> Id.compare id id' = 0) !targets then
+ let pat, guard, body, annot = destruct_pexp pexp in
+ let body' = const_prop defs Bindings.empty KBindings.empty body in
+ rewrite_pexp (construct_pexp (pat, guard, recheck_exp body', annot))
+ else pexp
+ in FCL_aux (FCL_Funcl (id, pexp'), a)
+ and rewrite_fundef (FD_aux (FD_function (ropt, topt, eopt, fcls), a)) =
+ let fcls' = List.map rewrite_funcl fcls in
+ FD_aux (FD_function (ropt, topt, eopt, fcls'), a)
+ in
+ let mutrecs' = List.map (fun fd -> DEF_fundef (rewrite_fundef fd)) mutrecs in
+ let (Defs fdefs) = fst (check env (Defs (!valspecs @ !fundefs))) in
+ mutrecs' @ fdefs @ rewrite ds
+ | d :: ds ->
+ d :: rewrite ds
+ in
+ Spec_analysis.top_sort_defs (Defs (rewrite defs))
diff --git a/src/constraint.ml b/src/constraint.ml
index 5402f6f7..1bd6a437 100644
--- a/src/constraint.ml
+++ b/src/constraint.ml
@@ -259,7 +259,15 @@ let save_digests () =
DigestMap.iter output !known_problems;
close_out out_chan
-let call_smt' l vars constraints : smt_result =
+let kopt_pair kopt = (kopt_kid kopt, unaux_kind (kopt_kind kopt))
+
+let call_smt' l constraints : smt_result =
+ let vars =
+ kopts_of_constraint constraints
+ |> KOptSet.elements
+ |> List.map kopt_pair
+ |> List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty
+ in
let problems = [constraints] in
let smt_file, _ = smtlib_of_constraints l vars constraints in
@@ -270,58 +278,67 @@ let call_smt' l vars constraints : smt_result =
let rec input_lines chan = function
| 0 -> []
| n ->
- begin
- let l = input_line chan in
- let ls = input_lines chan (n - 1) in
- l :: ls
- end
+ let l = input_line chan in
+ let ls = input_lines chan (n - 1) in
+ l :: ls
in
let digest = Digest.string smt_file in
- try
- let result = DigestMap.find digest !known_problems in
- result
- with
- | Not_found ->
- begin
- let (input_file, tmp_chan) =
- try Filename.open_temp_file "constraint_" ".smt2" with
- | Sys_error msg -> raise (Reporting.err_general l ("Could not open temp file when calling SMT: " ^ msg))
- in
- output_string tmp_chan smt_file;
- close_out tmp_chan;
- let smt_output =
- try
- let smt_out, smt_in, smt_err = Unix.open_process_full (!opt_solver.command ^ " " ^ input_file) (Unix.environment ()) in
- let smt_output =
- try List.combine problems (input_lines smt_out (List.length problems)) with
- | End_of_file -> List.combine problems ["unknown"]
- in
- let _ = Unix.close_process_full (smt_out, smt_in, smt_err) in
- smt_output
- with
- | exn -> raise (Reporting.err_general l ("Error when calling smt: " ^ Printexc.to_string exn))
- in
- Sys.remove input_file;
- try
- let (problem, _) = List.find (fun (_, result) -> result = "unsat") smt_output in
- known_problems := DigestMap.add digest Unsat !known_problems;
- Unsat
- with
- | Not_found ->
- let unsolved = List.filter (fun (_, result) -> result = "unknown") smt_output in
- if unsolved == []
- then (known_problems := DigestMap.add digest Sat !known_problems; Sat)
- else (known_problems := DigestMap.add digest Unknown !known_problems; Unknown)
- end
-
-let call_smt l vars constraints =
+
+ match DigestMap.find_opt digest !known_problems with
+ | Some result -> result
+ | None ->
+ let (input_file, tmp_chan) =
+ try Filename.open_temp_file "constraint_" ".smt2" with
+ | Sys_error msg -> raise (Reporting.err_general l ("Could not open temp file when calling SMT: " ^ msg))
+ in
+ output_string tmp_chan smt_file;
+ close_out tmp_chan;
+ let status, smt_output =
+ try
+ let smt_out, smt_in, smt_err = Unix.open_process_full (!opt_solver.command ^ " " ^ input_file) (Unix.environment ()) in
+ let smt_output =
+ try List.combine problems (input_lines smt_out (List.length problems)) with
+ | End_of_file -> List.combine problems ["unknown"]
+ in
+ let status = Unix.close_process_full (smt_out, smt_in, smt_err) in
+ status, smt_output
+ with
+ | exn ->
+ raise (Reporting.err_general l ("Error when calling smt: " ^ Printexc.to_string exn))
+ in
+ let _ = match status with
+ | Unix.WEXITED 0 -> ()
+ | Unix.WEXITED n ->
+ raise (Reporting.err_general l ("SMT solver returned unexpected status " ^ string_of_int n))
+ | Unix.WSIGNALED n | Unix.WSTOPPED n ->
+ raise (Reporting.err_general l ("SMT solver killed by signal " ^ string_of_int n))
+ in
+ Sys.remove input_file;
+ try
+ let (problem, _) = List.find (fun (_, result) -> result = "unsat") smt_output in
+ known_problems := DigestMap.add digest Unsat !known_problems;
+ Unsat
+ with
+ | Not_found ->
+ let unsolved = List.filter (fun (_, result) -> result = "unknown") smt_output in
+ if unsolved == []
+ then (known_problems := DigestMap.add digest Sat !known_problems; Sat)
+ else (known_problems := DigestMap.add digest Unknown !known_problems; Unknown)
+
+let call_smt l constraints =
let t = Profile.start_smt () in
- let result = call_smt' l vars constraints in
+ let result = call_smt' l constraints in
Profile.finish_smt t;
result
-let solve_smt l vars constraints var =
+let solve_smt l constraints var =
+ let vars =
+ kopts_of_constraint constraints
+ |> KOptSet.elements
+ |> List.map kopt_pair
+ |> List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty
+ in
let smt_file, smt_var = smtlib_of_constraints ~get_model:true l vars constraints in
let smt_var = pp_sexpr (smt_var var) in
@@ -360,22 +377,22 @@ let solve_smt l vars constraints var =
with
| Not_found -> None
-let solve_all_smt l vars constraints var =
+let solve_all_smt l constraints var =
let rec aux results =
let constraints = List.fold_left (fun ncs r -> (nc_and ncs (nc_neq (nconstant r) (nvar var)))) constraints results in
- match solve_smt l vars constraints var with
+ match solve_smt l constraints var with
| Some result -> aux (result :: results)
| None ->
- match call_smt l vars constraints with
+ match call_smt l constraints with
| Unsat -> Some results
| _ -> None
in
aux []
-let solve_unique_smt l vars constraints var =
- match solve_smt l vars constraints var with
+let solve_unique_smt l constraints var =
+ match solve_smt l constraints var with
| Some result ->
- begin match call_smt l vars (nc_and constraints (nc_neq (nconstant result) (nvar var))) with
+ begin match call_smt l (nc_and constraints (nc_neq (nconstant result) (nvar var))) with
| Unsat -> Some result
| _ -> None
end
diff --git a/src/constraint.mli b/src/constraint.mli
index b5d6ff6b..34e83964 100644
--- a/src/constraint.mli
+++ b/src/constraint.mli
@@ -61,10 +61,10 @@ type smt_result = Unknown | Sat | Unsat
val load_digests : unit -> unit
val save_digests : unit -> unit
-val call_smt : l -> kind_aux KBindings.t -> n_constraint -> smt_result
+val call_smt : l -> n_constraint -> smt_result
-val solve_smt : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num option
+val solve_smt : l -> n_constraint -> kid -> Big_int.num option
-val solve_all_smt : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num list option
+val solve_all_smt : l -> n_constraint -> kid -> Big_int.num list option
-val solve_unique_smt : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num option
+val solve_unique_smt : l -> n_constraint -> kid -> Big_int.num option
diff --git a/src/gen_lib/sail2_operators_bitlists.lem b/src/gen_lib/sail2_operators_bitlists.lem
index bacf59e7..8b75fa38 100644
--- a/src/gen_lib/sail2_operators_bitlists.lem
+++ b/src/gen_lib/sail2_operators_bitlists.lem
@@ -41,6 +41,9 @@ let zeros len = repeat [B0] len
val vector_truncate : list bitU -> integer -> list bitU
let vector_truncate bs len = extz_bv len bs
+val vector_truncateLSB : list bitU -> integer -> list bitU
+let vector_truncateLSB bs len = take_list len bs
+
val vec_of_bits_maybe : list bitU -> maybe (list bitU)
val vec_of_bits_fail : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e
val vec_of_bits_nondet : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e
diff --git a/src/gen_lib/sail2_operators_mwords.lem b/src/gen_lib/sail2_operators_mwords.lem
index d47d9b40..181fa149 100644
--- a/src/gen_lib/sail2_operators_mwords.lem
+++ b/src/gen_lib/sail2_operators_mwords.lem
@@ -82,6 +82,13 @@ let zeros _ = Machine_word.wordFromNatural 0
val vector_truncate : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b
let vector_truncate w _ = Machine_word.zeroExtend w
+val vector_truncateLSB : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b
+let vector_truncateLSB w len =
+ let len = nat_of_int len in
+ let lo = Machine_word.word_length w - len in
+ let hi = lo + len - 1 in
+ Machine_word.word_extract lo hi w
+
val concat_vec : forall 'a 'b 'c. Size 'a, Size 'b, Size 'c => mword 'a -> mword 'b -> mword 'c
let concat_vec = Machine_word.word_concat
diff --git a/src/initial_check.ml b/src/initial_check.ml
index 2aa0c511..28446db2 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -70,7 +70,7 @@ type ctx = {
let string_of_parse_id_aux = function
| P.Id v -> v
- | P.DeIid v -> v
+ | P.Operator v -> v
let string_of_parse_id (P.Id_aux (id, l)) = string_of_parse_id_aux id
@@ -93,7 +93,7 @@ let to_ast_id (P.Id_aux(id, l)) =
else
Id_aux ((match id with
| P.Id x -> Id x
- | P.DeIid x -> DeIid x),
+ | P.Operator x -> Operator x),
l)
let to_ast_var (P.Kid_aux (P.Var v, l)) = Kid_aux (Var v, l)
@@ -224,7 +224,7 @@ and to_ast_order ctx (P.ATyp_aux (aux, l)) =
and to_ast_constraint ctx (P.ATyp_aux (aux, l) as atyp) =
let aux = match aux with
- | P.ATyp_app (Id_aux (DeIid op, _) as id, [t1; t2]) ->
+ | P.ATyp_app (Id_aux (Operator op, _) as id, [t1; t2]) ->
begin match op with
| "==" -> NC_equal (to_ast_nexp ctx t1, to_ast_nexp ctx t2)
| "!=" -> NC_not_equal (to_ast_nexp ctx t1, to_ast_nexp ctx t2)
@@ -522,52 +522,100 @@ let rec to_ast_range (P.BF_aux(r,l)) = (* TODO add check that ranges are sensibl
| P.BF_concat(ir1,ir2) -> BF_concat(to_ast_range ir1, to_ast_range ir2)),
l)
-let to_ast_type_union ctx (P.Tu_aux (P.Tu_ty_id (atyp, id), l)) =
- let typ = to_ast_typ ctx atyp in
- Tu_aux (Tu_ty_id (typ, to_ast_id id), l)
+let to_ast_type_union ctx = function
+ | P.Tu_aux (P.Tu_ty_id (atyp, id), l) ->
+ let typ = to_ast_typ ctx atyp in
+ Tu_aux (Tu_ty_id (typ, to_ast_id id), l)
+ | P.Tu_aux (_, l) ->
+ raise (Reporting.err_unreachable l __POS__ "Anonymous record type should have been rewritten by now")
let add_constructor id typq ctx =
let kinds = List.map (fun kopt -> unaux_kind (kopt_kind kopt)) (quant_kopts typq) in
{ ctx with type_constructors = Bindings.add id kinds ctx.type_constructors }
-let to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit type_def ctx_out =
- let aux, ctx = match aux with
- | P.TD_abbrev (id, typq, kind, typ_arg) ->
- let id = to_ast_id id in
- let typq, typq_ctx = to_ast_typquant ctx typq in
- let kind = to_ast_kind kind in
- let typ_arg = to_ast_typ_arg typq_ctx typ_arg (unaux_kind kind) in
- TD_abbrev (id, typq, typ_arg),
- add_constructor id typq ctx
-
- | P.TD_record (id, typq, fields, _) ->
- let id = to_ast_id id in
- let typq, typq_ctx = to_ast_typquant ctx typq in
- let fields = List.map (fun (atyp, id) -> to_ast_typ typq_ctx atyp, to_ast_id id) fields in
- TD_record (id, typq, fields, false),
- add_constructor id typq ctx
-
- | P.TD_variant (id, typq, arms, _) ->
- let id = to_ast_id id in
- let typq, typq_ctx = to_ast_typquant ctx typq in
- let arms = List.map (to_ast_type_union typq_ctx) arms in
- TD_variant (id, typq, arms, false),
- add_constructor id typq ctx
-
- | P.TD_enum (id, enums, _) ->
- let id = to_ast_id id in
- let enums = List.map to_ast_id enums in
- TD_enum (id, enums, false),
- { ctx with type_constructors = Bindings.add id [] ctx.type_constructors }
+let anon_rec_constructor_typ record_id = function
+ | P.TypQ_aux (P.TypQ_no_forall, l) -> P.ATyp_aux (P.ATyp_id record_id, Generated l)
+ | P.TypQ_aux (P.TypQ_tq quants, l) ->
+ let rec quant_arg = function
+ | P.QI_aux (P.QI_id (P.KOpt_aux (P.KOpt_none v, l)), _) -> [P.ATyp_aux (P.ATyp_var v, Generated l)]
+ | P.QI_aux (P.QI_id (P.KOpt_aux (P.KOpt_kind (_, v), l)), _) -> [P.ATyp_aux (P.ATyp_var v, Generated l)]
+ | P.QI_aux (P.QI_const _, _) -> []
+ in
+ match List.concat (List.map quant_arg quants) with
+ | [] -> P.ATyp_aux (P.ATyp_id record_id, Generated l)
+ | args -> P.ATyp_aux (P.ATyp_app (record_id, args), Generated l)
+
+let rec realise_union_anon_rec_types orig_union arms =
+ match orig_union with
+ | P.TD_variant (union_id, typq, _, flag) ->
+ begin match arms with
+ | [] -> []
+ | arm :: arms ->
+ match arm with
+ | (P.Tu_aux ((P.Tu_ty_id _), _)) -> (None, arm) :: realise_union_anon_rec_types orig_union arms
+ | (P.Tu_aux ((P.Tu_ty_anon_rec (fields, id)), l)) ->
+ let open Parse_ast in
+ let record_str = "_" ^ string_of_parse_id union_id ^ "_" ^ string_of_parse_id id ^ "_record" in
+ let record_id = Id_aux (Id record_str, Generated l) in
+ let new_arm = Tu_aux (Tu_ty_id (anon_rec_constructor_typ record_id typq, id), Generated l) in
+ let new_rec_def = TD_aux (TD_record (record_id, typq, fields, flag), Generated l) in
+ (Some new_rec_def, new_arm) :: (realise_union_anon_rec_types orig_union arms)
+ end
+ | _ ->
+ raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Non union type-definition passed to realise_union_anon_rec_typs")
- | P.TD_bitfield (id, typ, ranges) ->
- let id = to_ast_id id in
- let typ = to_ast_typ ctx typ in
- let ranges = List.map (fun (id, range) -> (to_ast_id id, to_ast_range range)) ranges in
- TD_bitfield (id, typ, ranges),
- { ctx with type_constructors = Bindings.add id [] ctx.type_constructors }
- in
- TD_aux (aux, (l, ())), ctx
+let rec to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit type_def list ctx_out =
+ match aux with
+ | P.TD_abbrev (id, typq, kind, typ_arg) ->
+ let id = to_ast_id id in
+ let typq, typq_ctx = to_ast_typquant ctx typq in
+ let kind = to_ast_kind kind in
+ let typ_arg = to_ast_typ_arg typq_ctx typ_arg (unaux_kind kind) in
+ [TD_aux (TD_abbrev (id, typq, typ_arg), (l, ()))],
+ add_constructor id typq ctx
+
+ | P.TD_record (id, typq, fields, _) ->
+ let id = to_ast_id id in
+ let typq, typq_ctx = to_ast_typquant ctx typq in
+ let fields = List.map (fun (atyp, id) -> to_ast_typ typq_ctx atyp, to_ast_id id) fields in
+ [TD_aux (TD_record (id, typq, fields, false), (l, ()))],
+ add_constructor id typq ctx
+
+ | P.TD_variant (id, typq, arms, _) as union ->
+ (* First generate auxilliary record types for anonymous records in constructors *)
+ let records_and_arms = realise_union_anon_rec_types union arms in
+ let rec filter_records = function
+ | [] -> []
+ | Some x :: xs -> x :: filter_records xs
+ | None :: xs -> filter_records xs
+ in
+ let generated_records = filter_records (List.map fst records_and_arms) in
+ let generated_records, ctx =
+ List.fold_left (fun (prev, ctx) td -> let td, ctx = to_ast_typedef ctx td in prev @ td, ctx)
+ ([], ctx)
+ generated_records
+ in
+ let arms = List.map snd records_and_arms in
+ let union = Parse_ast.TD_variant (id, typq, arms, false) in
+ (* Now generate the AST union type *)
+ let id = to_ast_id id in
+ let typq, typq_ctx = to_ast_typquant ctx typq in
+ let arms = List.map (to_ast_type_union typq_ctx) arms in
+ [TD_aux (TD_variant (id, typq, arms, false), (l, ()))] @ generated_records,
+ add_constructor id typq ctx
+
+ | P.TD_enum (id, enums, _) ->
+ let id = to_ast_id id in
+ let enums = List.map to_ast_id enums in
+ [TD_aux (TD_enum (id, enums, false), (l, ()))],
+ { ctx with type_constructors = Bindings.add id [] ctx.type_constructors }
+
+ | P.TD_bitfield (id, typ, ranges) ->
+ let id = to_ast_id id in
+ let typ = to_ast_typ ctx typ in
+ let ranges = List.map (fun (id, range) -> (to_ast_id id, to_ast_range range)) ranges in
+ [TD_aux (TD_bitfield (id, typ, ranges), (l, ()))],
+ { ctx with type_constructors = Bindings.add id [] ctx.type_constructors }
let to_ast_rec ctx (P.Rec_aux(r,l): P.rec_opt) : unit rec_opt =
Rec_aux((match r with
@@ -714,44 +762,44 @@ let to_ast_prec = function
| P.InfixL -> InfixL
| P.InfixR -> InfixR
-let to_ast_def ctx def : unit def ctx_out =
+let to_ast_def ctx def : unit def list ctx_out =
match def with
| P.DEF_overload (id, ids) ->
- DEF_overload (to_ast_id id, List.map to_ast_id ids), ctx
+ [DEF_overload (to_ast_id id, List.map to_ast_id ids)], ctx
| P.DEF_fixity (prec, n, op) ->
- DEF_fixity (to_ast_prec prec, n, to_ast_id op), ctx
+ [DEF_fixity (to_ast_prec prec, n, to_ast_id op)], ctx
| P.DEF_type(t_def) ->
- let td, ctx = to_ast_typedef ctx t_def in
- DEF_type td, ctx
+ let tds, ctx = to_ast_typedef ctx t_def in
+ List.map (fun td -> DEF_type td) tds, ctx
| P.DEF_fundef(f_def) ->
let fd = to_ast_fundef ctx f_def in
- DEF_fundef fd, ctx
+ [DEF_fundef fd], ctx
| P.DEF_mapdef(m_def) ->
let md = to_ast_mapdef ctx m_def in
- DEF_mapdef md, ctx
+ [DEF_mapdef md], ctx
| P.DEF_val(lbind) ->
let lb = to_ast_letbind ctx lbind in
- DEF_val lb, ctx
+ [DEF_val lb], ctx
| P.DEF_spec(val_spec) ->
let vs,ctx = to_ast_spec ctx val_spec in
- DEF_spec vs, ctx
+ [DEF_spec vs], ctx
| P.DEF_default(typ_spec) ->
let default,ctx = to_ast_default ctx typ_spec in
- DEF_default default, ctx
+ [DEF_default default], ctx
| P.DEF_reg_dec dec ->
let d = to_ast_dec ctx dec in
- DEF_reg_dec d, ctx
+ [DEF_reg_dec d], ctx
| P.DEF_pragma (pragma, arg, l) ->
- DEF_pragma (pragma, arg, l), ctx
+ [DEF_pragma (pragma, arg, l)], ctx
| P.DEF_internal_mutrec _ ->
(* Should never occur because of remove_mutrec *)
raise (Reporting.err_unreachable P.Unknown __POS__
"Internal mutual block found when processing scattered defs")
| P.DEF_scattered sdef ->
let sdef, ctx = to_ast_scattered ctx sdef in
- DEF_scattered sdef, ctx
+ [DEF_scattered sdef], ctx
| P.DEF_measure (id, pat, exp) ->
- DEF_measure (to_ast_id id, to_ast_pat ctx pat, to_ast_exp ctx exp), ctx
+ [DEF_measure (to_ast_id id, to_ast_pat ctx pat, to_ast_exp ctx exp)], ctx
let rec remove_mutrec = function
| [] -> []
@@ -763,7 +811,7 @@ let rec remove_mutrec = function
let to_ast ctx (P.Defs(defs)) =
let defs = remove_mutrec defs in
let defs, ctx =
- List.fold_left (fun (defs, ctx) def -> let def, ctx = to_ast_def ctx def in (def :: defs, ctx)) ([], ctx) defs
+ List.fold_left (fun (defs, ctx) def -> let def, ctx = to_ast_def ctx def in (def @ defs, ctx)) ([], ctx) defs
in
Defs (List.rev defs), ctx
@@ -834,30 +882,31 @@ let undefined_typschm id typq =
let have_undefined_builtins = ref false
+let undefined_builtin_val_specs =
+ [extern_of_string (mk_id "internal_pick") "forall ('a:Type). list('a) -> 'a effect {undef}";
+ extern_of_string (mk_id "undefined_bool") "unit -> bool effect {undef}";
+ extern_of_string (mk_id "undefined_bit") "unit -> bit effect {undef}";
+ extern_of_string (mk_id "undefined_int") "unit -> int effect {undef}";
+ extern_of_string (mk_id "undefined_nat") "unit -> nat effect {undef}";
+ extern_of_string (mk_id "undefined_real") "unit -> real effect {undef}";
+ extern_of_string (mk_id "undefined_string") "unit -> string effect {undef}";
+ extern_of_string (mk_id "undefined_list") "forall ('a:Type). 'a -> list('a) effect {undef}";
+ extern_of_string (mk_id "undefined_range") "forall 'n 'm. (atom('n), atom('m)) -> range('n,'m) effect {undef}";
+ extern_of_string (mk_id "undefined_vector") "forall 'n ('a:Type) ('ord : Order). (atom('n), 'a) -> vector('n, 'ord,'a) effect {undef}";
+ (* Only used with lem_mwords *)
+ extern_of_string (mk_id "undefined_bitvector") "forall 'n. atom('n) -> vector('n, dec, bit) effect {undef}";
+ extern_of_string (mk_id "undefined_unit") "unit -> unit effect {undef}"]
+
let generate_undefineds vs_ids (Defs defs) =
- let gen_vs id str =
- if (IdSet.mem id vs_ids) then [] else [extern_of_string id str]
- in
let undefined_builtins =
if !have_undefined_builtins then
[]
else
begin
have_undefined_builtins := true;
- List.concat
- [gen_vs (mk_id "internal_pick") "forall ('a:Type). list('a) -> 'a effect {undef}";
- gen_vs (mk_id "undefined_bool") "unit -> bool effect {undef}";
- gen_vs (mk_id "undefined_bit") "unit -> bit effect {undef}";
- gen_vs (mk_id "undefined_int") "unit -> int effect {undef}";
- gen_vs (mk_id "undefined_nat") "unit -> nat effect {undef}";
- gen_vs (mk_id "undefined_real") "unit -> real effect {undef}";
- gen_vs (mk_id "undefined_string") "unit -> string effect {undef}";
- gen_vs (mk_id "undefined_list") "forall ('a:Type). 'a -> list('a) effect {undef}";
- gen_vs (mk_id "undefined_range") "forall 'n 'm. (atom('n), atom('m)) -> range('n,'m) effect {undef}";
- gen_vs (mk_id "undefined_vector") "forall 'n ('a:Type) ('ord : Order). (atom('n), 'a) -> vector('n, 'ord,'a) effect {undef}";
- (* Only used with lem_mwords *)
- gen_vs (mk_id "undefined_bitvector") "forall 'n. atom('n) -> vector('n, dec, bit) effect {undef}";
- gen_vs (mk_id "undefined_unit") "unit -> unit effect {undef}"]
+ List.filter
+ (fun def -> IdSet.is_empty (IdSet.inter vs_ids (ids_of_def def)))
+ undefined_builtin_val_specs
end
in
let undefined_tu = function
@@ -1036,6 +1085,10 @@ let process_ast ?generate:(generate=true) defs =
else
ast
+let ast_of_def_string_with f str =
+ let def = Parser.def_eof Lexer.token (Lexing.from_string str) in
+ process_ast (f (P.Defs [def]))
+
let ast_of_def_string str =
let def = Parser.def_eof Lexer.token (Lexing.from_string str) in
process_ast (P.Defs [def])
diff --git a/src/initial_check.mli b/src/initial_check.mli
index b96a9efb..59c8f0b6 100644
--- a/src/initial_check.mli
+++ b/src/initial_check.mli
@@ -87,6 +87,11 @@ val opt_enum_casts : bool ref
all the loaded files. *)
val have_undefined_builtins : bool ref
+(** Val specs of undefined functions for builtin types that get added to
+ the AST if opt_undefined_gen is set (minus those functions that already
+ exist in the AST). *)
+val undefined_builtin_val_specs : unit def list
+
(** {2 Desugar and process AST } *)
(** If the generate flag is false, then we won't generate any
@@ -98,6 +103,7 @@ val process_ast : ?generate:bool -> Parse_ast.defs -> unit defs
val extern_of_string : id -> string -> unit def
val val_spec_of_string : id -> string -> unit def
val ast_of_def_string : string -> unit defs
+val ast_of_def_string_with : (Parse_ast.defs -> Parse_ast.defs) -> string -> unit defs
val exp_of_string : string -> unit exp
val typ_of_string : string -> typ
val constraint_of_string : string -> n_constraint
diff --git a/src/interpreter.ml b/src/interpreter.ml
index 1ebfdeff..263430f1 100644
--- a/src/interpreter.ml
+++ b/src/interpreter.ml
@@ -545,6 +545,11 @@ let rec step (E_aux (e_aux, annot) as orig_exp) =
let record = coerce_record (value_of_exp exp) in
return (exp_of_value (StringMap.find (string_of_id id) record))
+ | E_var (lexp, exp, E_aux (E_block body, _)) ->
+ wrap (E_block (E_aux (E_assign (lexp, exp), annot) :: body))
+ | E_var (lexp, exp, body) ->
+ wrap (E_block [E_aux (E_assign (lexp, exp), annot); body])
+
| E_assign (lexp, exp) when not (is_value exp) -> step exp >>= fun exp' -> wrap (E_assign (lexp, exp'))
| E_assign (LEXP_aux (LEXP_memory (id, args), _), exp) -> wrap (E_app (id, args @ [exp]))
| E_assign (LEXP_aux (LEXP_field (lexp, id), ul), exp) ->
diff --git a/src/isail.ml b/src/isail.ml
index 4db39123..094ad3df 100644
--- a/src/isail.ml
+++ b/src/isail.ml
@@ -161,7 +161,6 @@ let rec run () =
end
let rec run_steps n =
- print_endline ("step " ^ string_of_int n);
match !current_mode with
| _ when n <= 0 -> ()
| Normal | Emacs -> ()
@@ -196,42 +195,83 @@ let rec run_steps n =
run_steps (n - 1)
end
-let help = function
+let help =
+ let open Printf in
+ let open Util in
+ let color c str = str |> c |> clear in
+ function
| ":t" | ":type" ->
- "(:t | :type) <function name> - Print the type of a function."
+ sprintf "(:t | :type) %s - Print the type of a function."
+ (color yellow "<function name>")
| ":q" | ":quit" ->
"(:q | :quit) - Exit the interpreter."
| ":i" | ":infer" ->
- "(:i | :infer) <expression> - Infer the type of an expression."
+ sprintf "(:i | :infer) %s - Infer the type of an expression."
+ (color yellow "<expression>")
| ":v" | ":verbose" ->
"(:v | :verbose) - Increase the verbosity level, or reset to zero at max verbosity."
+ | ":b" | ":bind" ->
+ sprintf "(:b | :bind) %s : %s - Declare a variable of a specific type"
+ (color yellow "<id>") (color yellow "<type>")
+ | ":let" ->
+ sprintf ":let %s = %s - Bind a variable to expression"
+ (color yellow "<id>") (color yellow "<expression>")
+ | ":def" ->
+ sprintf ":def %s - Evaluate a top-level definition"
+ (color yellow "<definition>")
+ | ":prove" ->
+ sprintf ":prove %s - Try to prove a constraint in the top-level environment"
+ (color yellow "<constraint>")
+ | ":assume" ->
+ sprintf ":assume %s - Add a constraint to the top-level environment"
+ (color yellow "<constraint>")
| ":commands" ->
":commands - List all available commands."
| ":help" ->
- ":help <command> - Get a description of <command>. Commands are prefixed with a colon, e.g. :help :type."
+ sprintf ":help %s - Get a description of <command>. Commands are prefixed with a colon, e.g. %s."
+ (color yellow "<command>") (color green ":help :type")
| ":elf" ->
- ":elf <file> - Load an ELF file."
+ sprintf ":elf %s - Load an ELF file."
+ (color yellow "<file>")
| ":bin" ->
":bin <address> <file> - Load a binary file at the given address."
| ":r" | ":run" ->
"(:r | :run) - Completely evaluate the currently evaluating expression."
| ":s" | ":step" ->
- "(:s | :step) <number> - Perform a number of evaluation steps."
+ sprintf "(:s | :step) %s - Perform a number of evaluation steps."
+ (color yellow "<number>")
| ":n" | ":normal" ->
"(:n | :normal) - Exit evaluation mode back to normal mode."
| ":clear" ->
- ":clear (on|off) - Set whether to clear the screen or not in evaluation mode."
+ sprintf ":clear %s - Set whether to clear the screen or not in evaluation mode."
+ (color yellow "(on|off)")
| ":l" | ":load" -> String.concat "\n"
- [ "(:l | :load) <files> - Load sail files and add their definitions to the interactive environment.";
+ [ sprintf "(:l | :load) %s - Load sail files and add their definitions to the interactive environment."
+ (color yellow "<files>");
"Files containing scattered definitions must be loaded together." ]
| ":u" | ":unload" ->
"(:u | :unload) - Unload all loaded files."
| ":output" ->
- ":output <file> - Redirect evaluating expression output to a file."
+ sprintf ":output %s - Redirect evaluating expression output to a file."
+ (color yellow "<file>")
| ":option" ->
- ":option string - Parse string as if it was an option passed on the command line. Try :option -help."
+ sprintf ":option %s - Parse string as if it was an option passed on the command line. e.g. :option -help."
+ (color yellow "<string>")
+ | ":rewrite" ->
+ sprintf ":rewrite %s - Apply a rewrite to the AST. %s shows all possible rewrites. See also %s"
+ (color yellow "<rewrite> <args>") (color green ":list_rewrites") (color green ":rewrites")
+ | ":rewrites" ->
+ sprintf ":rewrites %s - Apply all rewrites for a specific target, valid targets are lem, coq, ocaml, c, and interpreter"
+ (color yellow "<target>")
+ | ":compile" ->
+ sprintf ":compile %s - Compile AST to a specified target, valid targets are lem, coq, ocaml, c, and ir (intermediate representation)"
+ (color yellow "<target>")
+ | "" ->
+ sprintf "Type %s for a list of commands, and %s %s for information about a specific command"
+ (color green ":commands") (color green ":help") (color yellow "<command>")
| cmd ->
- "Either invalid command passed to help, or no documentation for " ^ cmd ^ ". Try :help :help."
+ sprintf "Either invalid command passed to help, or no documentation for %s. Try %s."
+ (color green cmd) (color green ":help :help")
let format_pos_emacs p1 p2 contents =
let open Lexing in
@@ -249,6 +289,17 @@ let rec emacs_error l contents =
| Parse_ast.Documented (_, l) -> emacs_error l contents
| Parse_ast.Generated l -> emacs_error l contents
+let slice_roots = ref IdSet.empty
+let slice_cuts = ref IdSet.empty
+
+let rec describe_rewrite =
+ let open Rewrites in
+ function
+ | String_rewriter rw -> "<string>" :: describe_rewrite (rw "")
+ | Bool_rewriter rw -> "<bool>" :: describe_rewrite (rw false)
+ | Literal_rewriter rw -> "(ocaml|lem|all)" :: describe_rewrite (rw (fun _ -> true))
+ | Basic_rewriter rw -> []
+
type session = {
id : string;
files : string list
@@ -337,12 +388,12 @@ let handle_input' input =
let exp = Type_check.infer_exp !Interactive.env exp in
pretty_sail stdout (doc_typ (Type_check.typ_of exp));
print_newline ()
- | ":canon" ->
- let typ = Initial_check.typ_of_string arg in
- print_endline (string_of_typ (Type_check.canonicalize !Interactive.env typ))
| ":prove" ->
let nc = Initial_check.constraint_of_string arg in
print_endline (string_of_bool (Type_check.prove __POS__ !Interactive.env nc))
+ | ":assume" ->
+ let nc = Initial_check.constraint_of_string arg in
+ Interactive.env := Type_check.Env.add_constraint nc !Interactive.env
| ":v" | ":verbose" ->
Type_check.opt_tc_debug := (!Type_check.opt_tc_debug + 1) mod 3;
print_endline ("Verbosity: " ^ string_of_int !Type_check.opt_tc_debug)
@@ -354,8 +405,8 @@ let handle_input' input =
else print_endline "Invalid argument for :clear, expected either :clear on or :clear off"
| ":commands" ->
let commands =
- [ "Universal commands - :(t)ype :(i)nfer :(q)uit :(v)erbose :clear :commands :help :output :option";
- "Normal mode commands - :elf :(l)oad :(u)nload";
+ [ "Universal commands - :(t)ype :(i)nfer :(q)uit :(v)erbose :prove :assume :clear :commands :help :output :option";
+ "Normal mode commands - :elf :(l)oad :(u)nload :let :def :(b)ind :rewrite :rewrites :list_rewrites :compile";
"Evaluation mode commands - :(r)un :(s)tep :(n)ormal";
"";
":(c)ommand can be called as either :c or :command." ]
@@ -365,7 +416,11 @@ let handle_input' input =
begin
try
let args = Str.split (Str.regexp " +") arg in
- Arg.parse_argv ~current:(ref 0) (Array.of_list ("sail" :: args)) Sail.options (fun _ -> ()) "";
+ begin match args with
+ | opt :: args ->
+ Arg.parse_argv ~current:(ref 0) (Array.of_list ["sail"; opt; String.concat " " args]) Sail.options (fun _ -> ()) "";
+ | [] -> print_endline "Must provide a valid option"
+ end
with
| Arg.Bad message | Arg.Help message -> print_endline message
end;
@@ -376,16 +431,6 @@ let handle_input' input =
interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops
| ":pretty" ->
print_endline (Pretty_print_sail.to_string (Latex.defs !Interactive.ast))
- | ":compile" ->
- (*
- let open PPrint in
- let open C_backend in
- let ast = Process_file.rewrite_ast_c !Interactive.env !Interactive.ast in
- let ast, env = Specialize.(specialize typ_ord_specialization ast !Interactive.env) in
- let ctx = initial_ctx env in
- interactive_bytecode := bytecode_ast ctx (List.map flatten_cdef) ast
- *)
- ()
| ":ir" ->
print_endline arg;
let open Jib in
@@ -421,7 +466,6 @@ let handle_input' input =
| ":l" | ":load" ->
let files = Util.split_on_char ' ' arg in
let (_, ast, env) = load_files !Interactive.env files in
- let ast = Process_file.rewrite_ast_interpreter !Interactive.env ast in
Interactive.ast := append_ast !Interactive.ast ast;
interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops;
Interactive.env := env;
@@ -444,6 +488,110 @@ let handle_input' input =
(* See initial_check.mli for an explanation of why we need this. *)
Initial_check.have_undefined_builtins := false;
Process_file.clear_symbols ()
+ | ":b" | ":bind" ->
+ let args = Str.split (Str.regexp " +") arg in
+ begin match args with
+ | v :: ":" :: args ->
+ let typ = Initial_check.typ_of_string (String.concat " " args) in
+ let _, env, _ = Type_check.bind_pat !Interactive.env (mk_pat (P_id (mk_id v))) typ in
+ Interactive.env := env
+ | _ -> print_endline "Invalid arguments for :bind"
+ end
+ | ":let" ->
+ let args = Str.split (Str.regexp " +") arg in
+ begin match args with
+ | v :: "=" :: args ->
+ let exp = Initial_check.exp_of_string (String.concat " " args) in
+ let ast, env = Type_check.check !Interactive.env (Defs [DEF_val (mk_letbind (mk_pat (P_id (mk_id v))) exp)]) in
+ Interactive.ast := append_ast !Interactive.ast ast;
+ Interactive.env := env;
+ interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops;
+ | _ -> print_endline "Invalid arguments for :let"
+ end
+ | ":def" ->
+ let ast = Initial_check.ast_of_def_string_with (Process_file.preprocess_ast options) arg in
+ let ast, env = Type_check.check !Interactive.env ast in
+ Interactive.ast := append_ast !Interactive.ast ast;
+ Interactive.env := env;
+ interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops;
+ | ":graph" ->
+ let format = if arg = "" then "svg" else arg in
+ let dotfile, out_chan = Filename.open_temp_file "sail_graph_" ".gz" in
+ let image = Filename.temp_file "sail_graph_" ("." ^ format) in
+ Slice.dot_of_ast out_chan !Interactive.ast;
+ close_out out_chan;
+ let _ = Unix.system (Printf.sprintf "dot -T%s %s -o %s" format dotfile image) in
+ let _ = Unix.system (Printf.sprintf "xdg-open %s" image) in
+ ()
+ | ":slice_roots" ->
+ let args = Str.split (Str.regexp " +") arg in
+ let ids = List.map mk_id args |> IdSet.of_list in
+ Specialize.add_initial_calls ids;
+ slice_roots := IdSet.union ids !slice_roots
+ | ":slice_cuts" ->
+ let args = Str.split (Str.regexp " +") arg in
+ let ids = List.map mk_id args |> IdSet.of_list in
+ slice_cuts := IdSet.union ids !slice_cuts
+ | ":slice" ->
+ let open Slice in
+ let module SliceNodeSet = Set.Make(Slice.Node) in
+ let module G = Graph.Make(Slice.Node) in
+ let g = Slice.graph_of_ast !Interactive.ast in
+ let roots = !slice_roots |> IdSet.elements |> List.map (fun id -> Function id) |> SliceNodeSet.of_list in
+ let cuts = !slice_cuts |> IdSet.elements |> List.map (fun id -> Function id) |> SliceNodeSet.of_list in
+ let g = G.prune roots cuts g in
+ Interactive.ast := Slice.filter_ast !slice_cuts g !Interactive.ast
+ | ":list_rewrites" ->
+ let print_rewrite (name, rw) =
+ print_endline (name ^ " " ^ Util.(String.concat " " (describe_rewrite rw) |> yellow |> clear))
+ in
+ List.sort (fun a b -> String.compare (fst a) (fst b)) Rewrites.all_rewrites
+ |> List.iter print_rewrite
+ | ":rewrite" ->
+ let open Rewrites in
+ let args = Str.split (Str.regexp " +") arg in
+ let rec parse_args rw args =
+ match rw, args with
+ | Basic_rewriter rw, [] -> rw
+ | Bool_rewriter rw, arg :: args -> parse_args (rw (bool_of_string arg)) args
+ | String_rewriter rw, arg :: args -> parse_args (rw arg) args
+ | Literal_rewriter rw, arg :: args ->
+ begin match arg with
+ | "ocaml" -> parse_args (rw rewrite_lit_ocaml) args
+ | "lem" -> parse_args (rw rewrite_lit_lem) args
+ | "all" -> parse_args (rw (fun _ -> true)) args
+ | _ -> failwith "target for literal rewrite must be one of ocaml/lem/all"
+ end
+ | _, _ -> failwith "Invalid arguments to rewrite"
+ in
+ begin match args with
+ | rw :: args ->
+ let rw = List.assoc rw Rewrites.all_rewrites in
+ let rw = parse_args rw args in
+ Interactive.ast := rw !Interactive.env !Interactive.ast;
+ | [] ->
+ failwith "Must provide the name of a rewrite, use :list_rewrites for a list of possible rewrites"
+ end
+ | ":rewrites" ->
+ Interactive.ast := Process_file.rewrite_ast_target arg !Interactive.env !Interactive.ast;
+ interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops
+ | ":prover_regstate" ->
+ let env, ast = prover_regstate (Some arg) !Interactive.ast !Interactive.env in
+ Interactive.env := env;
+ Interactive.ast := ast;
+ interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops
+ | ":recheck" ->
+ let ast, env = Type_check.check Type_check.initial_env !Interactive.ast in
+ Interactive.env := env;
+ Interactive.ast := ast;
+ interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops;
+ vs_ids := val_spec_ids !Interactive.ast
+ | ":compile" ->
+ let out_name = match !opt_file_out with
+ | None -> "out.sail"
+ | Some f -> f ^ ".sail"
+ in
+ target (Some arg) out_name !Interactive.ast !Interactive.env
| _ -> unrecognised_command cmd
end
| Expression str ->
@@ -491,7 +639,7 @@ let handle_input' input =
begin match find_annot_ast sl !Interactive.ast with
| Some annot ->
let msg = String.escaped (string_of_typ (Type_check.typ_of_annot annot)) in
- begin match simp_loc (fst annot) with
+ begin match Reporting.simp_loc (fst annot) with
| Some (p1, p2) ->
print_endline ("(sail-highlight-region "
^ string_of_int (p1.pos_cnum + 1) ^ " " ^ string_of_int (p2.pos_cnum + 1)
@@ -561,6 +709,8 @@ let handle_input' input =
let handle_input input =
try handle_input' input with
+ | Failure str ->
+ print_endline ("Error: " ^ str)
| Type_check.Type_error (env, l, err) ->
print_endline (Type_error.string_of_type_error err)
| Reporting.Fatal_error err ->
@@ -569,26 +719,88 @@ let handle_input input =
print_endline (Printexc.to_string exn)
let () =
- (* Auto complete function names based on val specs *)
- LNoise.set_completion_callback
- begin
+ (* Auto complete function names based on val specs, directories if :load command, or rewrites if :rewrite command *)
+ LNoise.set_completion_callback (
fun line_so_far ln_completions ->
let line_so_far, last_id =
try
- let p = Str.search_backward (Str.regexp "[^a-zA-Z0-9_]") line_so_far (String.length line_so_far - 1) in
+ let p = Str.search_backward (Str.regexp "[^a-zA-Z0-9_/-]") line_so_far (String.length line_so_far - 1) in
Str.string_before line_so_far (p + 1), Str.string_after line_so_far (p + 1)
with
| Not_found -> "", line_so_far
| Invalid_argument _ -> line_so_far, ""
in
+ let n = try String.index line_so_far ' ' with Not_found -> String.length line_so_far in
+ let cmd = Str.string_before line_so_far n in
if last_id <> "" then
- IdSet.elements !vs_ids
- |> List.map string_of_id
- |> List.filter (fun id -> Str.string_match (Str.regexp_string last_id) id 0)
- |> List.map (fun completion -> line_so_far ^ completion)
- |> List.iter (LNoise.add_completion ln_completions)
+ begin match cmd with
+ | ":load" | ":l" ->
+ let dirname, basename = Filename.dirname last_id, Filename.basename last_id in
+ if Sys.file_exists last_id then
+ LNoise.add_completion ln_completions (line_so_far ^ last_id);
+ if (try Sys.is_directory dirname with Sys_error _ -> false) then
+ let contents = Sys.readdir (Filename.concat (Sys.getcwd ()) dirname) in
+ for i = 0 to Array.length contents - 1 do
+ if Str.string_match (Str.regexp_string basename) contents.(i) 0 then
+ let is_dir = (try Sys.is_directory (Filename.concat dirname contents.(i)) with Sys_error _ -> false) in
+ LNoise.add_completion ln_completions
+ (line_so_far ^ Filename.concat dirname contents.(i) ^ (if is_dir then Filename.dir_sep else ""))
+ done
+ | ":rewrite" ->
+ List.map fst Rewrites.all_rewrites
+ |> List.filter (fun opt -> Str.string_match (Str.regexp_string last_id) opt 0)
+ |> List.map (fun completion -> line_so_far ^ completion)
+ |> List.iter (LNoise.add_completion ln_completions)
+ | ":option" ->
+ List.map (fun (opt, _, _) -> opt) options
+ |> List.filter (fun opt -> Str.string_match (Str.regexp_string last_id) opt 0)
+ |> List.map (fun completion -> line_so_far ^ completion)
+ |> List.iter (LNoise.add_completion ln_completions)
+ | _ ->
+ IdSet.elements !vs_ids
+ |> List.map string_of_id
+ |> List.filter (fun id -> Str.string_match (Str.regexp_string last_id) id 0)
+ |> List.map (fun completion -> line_so_far ^ completion)
+ |> List.iter (LNoise.add_completion ln_completions)
+ end
else ()
- end;
+ );
+
+ LNoise.set_hints_callback (
+ fun line_so_far ->
+ let hint str = Some (" " ^ str, LNoise.Yellow, false) in
+ match String.trim line_so_far with
+ | _ when !Interactive.opt_emacs_mode -> None
+ | ":load" | ":l" -> hint "<sail file>"
+ | ":bind" | ":b" -> hint "<id> : <type>"
+ | ":infer" | ":i" -> hint "<expression>"
+ | ":type" | ":t" -> hint "<function id>"
+ | ":let" -> hint "<id> = <expression>"
+ | ":def" -> hint "<definition>"
+ | ":prove" -> hint "<constraint>"
+ | ":assume" -> hint "<constraint>"
+ | ":compile" -> hint "<target>"
+ | ":rewrites" -> hint "<target>"
+ | str ->
+ let args = Str.split (Str.regexp " +") str in
+ match args with
+ | [":rewrite"] -> hint "<rewrite>"
+ | ":rewrite" :: rw :: args ->
+ begin match List.assoc_opt rw Rewrites.all_rewrites with
+ | Some rw ->
+ let hints = describe_rewrite rw in
+ let hints = Util.drop (List.length args) hints in
+ (match hints with [] -> None | _ -> hint (String.concat " " hints))
+ | None -> None
+ end
+ | [":option"] -> hint "<flag>"
+ | [":option"; flag] ->
+ begin match List.find_opt (fun (opt, _, _) -> flag = opt) options with
+ | Some (_, _, help) -> hint (Str.global_replace (Str.regexp " +") " " help)
+ | None -> None
+ end
+ | _ -> None
+ );
(* Read the script file if it is set with the -is option, and excute them *)
begin
diff --git a/src/jib/anf.ml b/src/jib/anf.ml
index 025138d0..0a410249 100644
--- a/src/jib/anf.ml
+++ b/src/jib/anf.ml
@@ -91,6 +91,7 @@ and 'a apat_aux =
| AP_global of id * 'a
| AP_app of id * 'a apat * 'a
| AP_cons of 'a apat * 'a apat
+ | AP_as of 'a apat * id * 'a
| AP_nil of 'a
| AP_wild of 'a
@@ -113,6 +114,7 @@ let rec apat_bindings (AP_aux (apat_aux, _, _)) =
| AP_global (id, _) -> IdSet.empty
| AP_app (id, apat, _) -> apat_bindings apat
| AP_cons (apat1, apat2) -> IdSet.union (apat_bindings apat1) (apat_bindings apat2)
+ | AP_as (apat, id, _) -> IdSet.add id (apat_bindings apat)
| AP_nil _ -> IdSet.empty
| AP_wild _ -> IdSet.empty
@@ -132,6 +134,7 @@ let rec apat_types (AP_aux (apat_aux, _, _)) =
| AP_global (id, _) -> Bindings.empty
| AP_app (id, apat, _) -> apat_types apat
| AP_cons (apat1, apat2) -> (Bindings.merge merge) (apat_types apat1) (apat_types apat2)
+ | AP_as (apat, id, typ) -> Bindings.add id typ (apat_types apat)
| AP_nil _ -> Bindings.empty
| AP_wild _ -> Bindings.empty
@@ -143,6 +146,8 @@ let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) =
| AP_global (id, typ) -> AP_global (id, typ)
| AP_app (ctor, apat, typ) -> AP_app (ctor, apat_rename from_id to_id apat, typ)
| AP_cons (apat1, apat2) -> AP_cons (apat_rename from_id to_id apat1, apat_rename from_id to_id apat2)
+ | AP_as (apat, id, typ) when Id.compare id from_id = 0 -> AP_as (apat, to_id, typ)
+ | AP_as (apat, id, typ) -> AP_as (apat, id, typ)
| AP_nil typ -> AP_nil typ
| AP_wild typ -> AP_wild typ
in
@@ -158,7 +163,7 @@ let rec aval_rename from_id to_id = function
| AV_list (avals, typ) -> AV_list (List.map (aval_rename from_id to_id) avals, typ)
| AV_vector (avals, typ) -> AV_vector (List.map (aval_rename from_id to_id) avals, typ)
| AV_record (avals, typ) -> AV_record (Bindings.map (aval_rename from_id to_id) avals, typ)
- | AV_C_fragment (fragment, typ, ctyp) -> AV_C_fragment (frag_rename from_id to_id fragment, typ, ctyp)
+ | AV_C_fragment (fragment, typ, ctyp) -> AV_C_fragment (frag_rename (name from_id) (name to_id) fragment, typ, ctyp)
let rec aexp_rename from_id to_id (AE_aux (aexp, env, l)) =
let recur = aexp_rename from_id to_id in
@@ -382,6 +387,7 @@ and pp_apat (AP_aux (apat_aux, _, _)) =
| AP_app (id, apat, typ) -> pp_annot typ (pp_id id ^^ parens (pp_apat apat))
| AP_nil _ -> string "[||]"
| AP_cons (hd_apat, tl_apat) -> pp_apat hd_apat ^^ string " :: " ^^ pp_apat tl_apat
+ | AP_as (apat, id, ctyp) -> pp_apat apat ^^ string " as " ^^ pp_id id
and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace
@@ -445,6 +451,7 @@ let rec anf_pat ?global:(global=false) (P_aux (p_aux, annot) as pat) =
| P_cons (hd_pat, tl_pat) -> mk_apat (AP_cons (anf_pat ~global:global hd_pat, anf_pat ~global:global tl_pat))
| P_list pats -> List.fold_right (fun pat apat -> mk_apat (AP_cons (anf_pat ~global:global pat, apat))) pats (mk_apat (AP_nil (typ_of_pat pat)))
| P_lit (L_aux (L_unit, _)) -> mk_apat (AP_wild (typ_of_pat pat))
+ | P_as (pat, id) -> mk_apat (AP_as (anf_pat ~global:global pat, id, typ_of_pat pat))
| _ ->
raise (Reporting.err_unreachable (fst annot) __POS__
("Could not convert pattern to ANF: " ^ string_of_pat pat))
@@ -456,6 +463,7 @@ let rec apat_globals (AP_aux (aux, _, _)) =
| AP_tup apats -> List.concat (List.map apat_globals apats)
| AP_app (_, apat, _) -> apat_globals apat
| AP_cons (hd_apat, tl_apat) -> apat_globals hd_apat @ apat_globals tl_apat
+ | AP_as (apat, _, _) -> apat_globals apat
let rec anf (E_aux (e_aux, ((l, _) as exp_annot)) as exp) =
let mk_aexp aexp = AE_aux (aexp, env_of_annot exp_annot, l) in
@@ -526,8 +534,8 @@ let rec anf (E_aux (e_aux, ((l, _) as exp_annot)) as exp) =
wrap (mk_aexp (AE_if (cond_val, then_aexp, else_aexp, typ_of exp)))
| E_app_infix (x, Id_aux (Id op, l), y) ->
- anf (E_aux (E_app (Id_aux (DeIid op, l), [x; y]), exp_annot))
- | E_app_infix (x, Id_aux (DeIid op, l), y) ->
+ anf (E_aux (E_app (Id_aux (Operator op, l), [x; y]), exp_annot))
+ | E_app_infix (x, Id_aux (Operator op, l), y) ->
anf (E_aux (E_app (Id_aux (Id op, l), [x; y]), exp_annot))
| E_vector exps ->
diff --git a/src/jib/anf.mli b/src/jib/anf.mli
index 79fb35ca..26b847e2 100644
--- a/src/jib/anf.mli
+++ b/src/jib/anf.mli
@@ -111,6 +111,7 @@ and 'a apat_aux =
| AP_global of id * 'a
| AP_app of id * 'a apat * 'a
| AP_cons of 'a apat * 'a apat
+ | AP_as of 'a apat * id * 'a
| AP_nil of 'a
| AP_wild of 'a
diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml
index 846b619f..ee16e2e6 100644
--- a/src/jib/c_backend.ml
+++ b/src/jib/c_backend.ml
@@ -86,7 +86,6 @@ let optimize_primops = ref false
let optimize_hoist_allocations = ref false
let optimize_struct_updates = ref false
let optimize_alias = ref false
-let optimize_experimental = ref false
let c_debug str =
if !c_verbosity > 0 then prerr_endline (Lazy.force str) else ()
@@ -96,7 +95,7 @@ let c_error ?loc:(l=Parse_ast.Unknown) message =
let zencode_id = function
| Id_aux (Id str, l) -> Id_aux (Id (Util.zencode_string str), l)
- | Id_aux (DeIid str, l) -> Id_aux (Id (Util.zencode_string ("op " ^ str)), l)
+ | Id_aux (Operator str, l) -> Id_aux (Id (Util.zencode_string ("op " ^ str)), l)
(**************************************************************************)
(* 2. Converting sail types to C types *)
@@ -310,21 +309,21 @@ let rec c_aval ctx = function
(* We need to check that id's type hasn't changed due to flow typing *)
let _, ctyp' = Bindings.find id ctx.locals in
if ctyp_equal ctyp ctyp' then
- AV_C_fragment (F_id id, typ, ctyp)
+ AV_C_fragment (F_id (name id), typ, ctyp)
else
(* id's type changed due to flow
typing, so it's really still heap allocated! *)
v
with
(* Hack: Assuming global letbindings don't change from flow typing... *)
- Not_found -> AV_C_fragment (F_id id, typ, ctyp)
+ Not_found -> AV_C_fragment (F_id (name id), typ, ctyp)
end
else
v
| Register (_, _, typ) when is_stack_typ ctx typ ->
let ctyp = ctyp_of_typ ctx typ in
if is_stack_ctyp ctyp then
- AV_C_fragment (F_id id, typ, ctyp)
+ AV_C_fragment (F_id (name id), typ, ctyp)
else
v
| _ -> v
@@ -612,24 +611,6 @@ let analyze_primop ctx id args typ =
else
no_change
-let generate_cleanup instrs =
- let generate_cleanup' (I_aux (instr, _)) =
- match instr with
- | I_init (ctyp, id, cval) -> [(id, iclear ctyp id)]
- | I_decl (ctyp, id) -> [(id, iclear ctyp id)]
- | instr -> []
- in
- let is_clear ids = function
- | I_aux (I_clear (_, id), _) -> IdSet.add id ids
- | _ -> ids
- in
- let cleaned = List.fold_left is_clear IdSet.empty instrs in
- instrs
- |> List.map generate_cleanup'
- |> List.concat
- |> List.filter (fun (id, _) -> not (IdSet.mem id cleaned))
- |> List.map snd
-
(** Functions that have heap-allocated return types are implemented by
passing a pointer a location where the return value should be
stored. The ANF -> Sail IR pass for expressions simply outputs an
@@ -643,7 +624,7 @@ let fix_early_heap_return ret ret_ctyp instrs =
let end_function_label = label "end_function_" in
let is_return_recur (I_aux (instr, _)) =
match instr with
- | I_if _ | I_block _ | I_end | I_funcall _ | I_copy _ | I_undefined _ -> true
+ | I_if _ | I_block _ | I_end _ | I_funcall _ | I_copy _ | I_undefined _ -> true
| _ -> false
in
let rec rewrite_return instrs =
@@ -657,15 +638,15 @@ let fix_early_heap_return ret ret_ctyp instrs =
before
@ [iif cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp]
@ rewrite_return after
- | before, I_aux (I_funcall (CL_return ctyp, extern, fid, args), aux) :: after ->
+ | before, I_aux (I_funcall (CL_id (Return _, ctyp), extern, fid, args), aux) :: after ->
before
@ [I_aux (I_funcall (CL_addr (CL_id (ret, CT_ref ctyp)), extern, fid, args), aux)]
@ rewrite_return after
- | before, I_aux (I_copy (CL_return ctyp, cval), aux) :: after ->
+ | before, I_aux (I_copy (CL_id (Return _, ctyp), cval), aux) :: after ->
before
@ [I_aux (I_copy (CL_addr (CL_id (ret, CT_ref ctyp)), cval), aux)]
@ rewrite_return after
- | before, I_aux ((I_end | I_undefined _), _) :: after ->
+ | before, I_aux ((I_end _ | I_undefined _), _) :: after ->
before
@ [igoto end_function_label]
@ rewrite_return after
@@ -680,7 +661,7 @@ let fix_early_heap_return ret ret_ctyp instrs =
let fix_early_stack_return ret ret_ctyp instrs =
let is_return_recur (I_aux (instr, _)) =
match instr with
- | I_if _ | I_block _ | I_end | I_funcall _ | I_copy _ -> true
+ | I_if _ | I_block _ | I_end _ | I_funcall _ | I_copy _ -> true
| _ -> false
in
let rec rewrite_return instrs =
@@ -694,15 +675,15 @@ let fix_early_stack_return ret ret_ctyp instrs =
before
@ [iif cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp]
@ rewrite_return after
- | before, I_aux (I_funcall (CL_return ctyp, extern, fid, args), aux) :: after ->
+ | before, I_aux (I_funcall (CL_id (Return _, ctyp), extern, fid, args), aux) :: after ->
before
@ [I_aux (I_funcall (CL_id (ret, ctyp), extern, fid, args), aux)]
@ rewrite_return after
- | before, I_aux (I_copy (CL_return ctyp, cval), aux) :: after ->
+ | before, I_aux (I_copy (CL_id (Return _, ctyp), cval), aux) :: after ->
before
@ [I_aux (I_copy (CL_id (ret, ctyp), cval), aux)]
@ rewrite_return after
- | before, I_aux (I_end, _) :: after ->
+ | before, I_aux (I_end _, _) :: after ->
before
@ [ireturn (F_id ret, ret_ctyp)]
@ rewrite_return after
@@ -722,10 +703,10 @@ let rec insert_heap_returns ret_ctyps = function
| None ->
raise (Reporting.err_general (id_loc id) ("Cannot find return type for function " ^ string_of_id id))
| Some ret_ctyp when not (is_stack_ctyp ret_ctyp) ->
- CDEF_fundef (id, Some gs, args, fix_early_heap_return gs ret_ctyp body)
+ CDEF_fundef (id, Some gs, args, fix_early_heap_return (name gs) ret_ctyp body)
:: insert_heap_returns ret_ctyps cdefs
| Some ret_ctyp ->
- CDEF_fundef (id, None, args, fix_early_stack_return gs ret_ctyp (idecl ret_ctyp gs :: body))
+ CDEF_fundef (id, None, args, fix_early_stack_return (name gs) ret_ctyp (idecl ret_ctyp (name gs) :: body))
:: insert_heap_returns ret_ctyps cdefs
end
@@ -766,32 +747,6 @@ let add_local_labels instrs =
(* 5. Optimizations *)
(**************************************************************************)
-let rec instrs_rename from_id to_id =
- let rename id = if Id.compare id from_id = 0 then to_id else id in
- let crename = cval_rename from_id to_id in
- let irename instrs = instrs_rename from_id to_id instrs in
- let lrename = clexp_rename from_id to_id in
- function
- | (I_aux (I_decl (ctyp, new_id), _) :: _) as instrs when Id.compare from_id new_id = 0 -> instrs
- | I_aux (I_decl (ctyp, new_id), aux) :: instrs -> I_aux (I_decl (ctyp, new_id), aux) :: irename instrs
- | I_aux (I_reset (ctyp, id), aux) :: instrs -> I_aux (I_reset (ctyp, rename id), aux) :: irename instrs
- | I_aux (I_init (ctyp, id, cval), aux) :: instrs -> I_aux (I_init (ctyp, rename id, crename cval), aux) :: irename instrs
- | I_aux (I_reinit (ctyp, id, cval), aux) :: instrs -> I_aux (I_reinit (ctyp, rename id, crename cval), aux) :: irename instrs
- | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs ->
- I_aux (I_if (crename cval, irename then_instrs, irename else_instrs, ctyp), aux) :: irename instrs
- | I_aux (I_jump (cval, label), aux) :: instrs -> I_aux (I_jump (crename cval, label), aux) :: irename instrs
- | I_aux (I_funcall (clexp, extern, id, cvals), aux) :: instrs ->
- I_aux (I_funcall (lrename clexp, extern, rename id, List.map crename cvals), aux) :: irename instrs
- | I_aux (I_copy (clexp, cval), aux) :: instrs -> I_aux (I_copy (lrename clexp, crename cval), aux) :: irename instrs
- | I_aux (I_alias (clexp, cval), aux) :: instrs -> I_aux (I_alias (lrename clexp, crename cval), aux) :: irename instrs
- | I_aux (I_clear (ctyp, id), aux) :: instrs -> I_aux (I_clear (ctyp, rename id), aux) :: irename instrs
- | I_aux (I_return cval, aux) :: instrs -> I_aux (I_return (crename cval), aux) :: irename instrs
- | I_aux (I_block block, aux) :: instrs -> I_aux (I_block (irename block), aux) :: irename instrs
- | I_aux (I_try_block block, aux) :: instrs -> I_aux (I_try_block (irename block), aux) :: irename instrs
- | I_aux (I_throw cval, aux) :: instrs -> I_aux (I_throw (crename cval), aux) :: irename instrs
- | (I_aux ((I_comment _ | I_raw _ | I_end | I_label _ | I_goto _ | I_match_failure | I_undefined _), _) as instr) :: instrs -> instr :: irename instrs
- | [] -> []
-
let hoist_ctyp = function
| CT_lint | CT_lbits _ | CT_struct _ -> true
| _ -> false
@@ -800,7 +755,7 @@ let hoist_counter = ref 0
let hoist_id () =
let id = mk_id ("gh#" ^ string_of_int !hoist_counter) in
incr hoist_counter;
- id
+ name id
let hoist_allocations recursive_functions = function
| CDEF_fundef (function_id, _, _, _) as cdef when IdSet.mem function_id recursive_functions ->
@@ -871,7 +826,7 @@ let rec specialize_variants ctx prior =
if ctyp_equal ctyp suprema then
[], (unpoly frag, ctyp), []
else
- let gs = gensym () in
+ let gs = ngensym () in
[idecl suprema gs;
icopy l (CL_id (gs, suprema)) (unpoly frag, ctyp)],
(F_id gs, suprema),
@@ -997,26 +952,26 @@ let remove_alias =
let rec scan ctyp id n instrs =
match n, !alias, instrs with
| 0, None, I_aux (I_copy (CL_id (id', ctyp'), (F_id a, ctyp'')), _) :: instrs
- when Id.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' ->
+ when Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' ->
alias := Some a;
scan ctyp id 1 instrs
| 1, Some a, I_aux (I_copy (CL_id (a', ctyp'), (F_id id', ctyp'')), _) :: instrs
- when Id.compare a a' = 0 && Id.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' ->
+ when Name.compare a a' = 0 && Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' ->
scan ctyp id 2 instrs
| 1, Some a, instr :: instrs ->
- if IdSet.mem a (instr_ids instr) then
+ if NameSet.mem a (instr_ids instr) then
None
else
scan ctyp id 1 instrs
| 2, Some a, I_aux (I_clear (ctyp', id'), _) :: instrs
- when Id.compare id id' = 0 && ctyp_equal ctyp ctyp' ->
+ when Name.compare id id' = 0 && ctyp_equal ctyp ctyp' ->
scan ctyp id 2 instrs
| 2, Some a, instr :: instrs ->
- if IdSet.mem id (instr_ids instr) then
+ if NameSet.mem id (instr_ids instr) then
None
else
scan ctyp id 2 instrs
@@ -1031,9 +986,9 @@ let remove_alias =
in
let remove_alias id alias = function
| I_aux (I_copy (CL_id (id', _), (F_id alias', _)), _)
- when Id.compare id id' = 0 && Id.compare alias alias' = 0 -> removed
+ when Name.compare id id' = 0 && Name.compare alias alias' = 0 -> removed
| I_aux (I_copy (CL_id (alias', _), (F_id id', _)), _)
- when Id.compare id id' = 0 && Id.compare alias alias' = 0 -> removed
+ when Name.compare id id' = 0 && Name.compare alias alias' = 0 -> removed
| I_aux (I_clear (_, id'), _) -> removed
| instr -> instr
in
@@ -1066,17 +1021,17 @@ let unique_names =
let unique_id () =
let id = mk_id ("u#" ^ string_of_int !unique_counter) in
incr unique_counter;
- id
+ name id
in
let rec opt seen = function
- | I_aux (I_decl (ctyp, id), aux) :: instrs when IdSet.mem id seen ->
+ | I_aux (I_decl (ctyp, id), aux) :: instrs when NameSet.mem id seen ->
let id' = unique_id () in
let instrs', seen = opt seen instrs in
I_aux (I_decl (ctyp, id'), aux) :: instrs_rename id id' instrs', seen
| I_aux (I_decl (ctyp, id), aux) :: instrs ->
- let instrs', seen = opt (IdSet.add id seen) instrs in
+ let instrs', seen = opt (NameSet.add id seen) instrs in
I_aux (I_decl (ctyp, id), aux) :: instrs', seen
| I_aux (I_block block, aux) :: instrs ->
@@ -1103,11 +1058,11 @@ let unique_names =
in
function
| CDEF_fundef (function_id, heap_return, args, body) ->
- [CDEF_fundef (function_id, heap_return, args, fst (opt IdSet.empty body))]
+ [CDEF_fundef (function_id, heap_return, args, fst (opt NameSet.empty body))]
| CDEF_reg_dec (id, ctyp, instrs) ->
- [CDEF_reg_dec (id, ctyp, fst (opt IdSet.empty instrs))]
+ [CDEF_reg_dec (id, ctyp, fst (opt NameSet.empty instrs))]
| CDEF_let (n, bindings, instrs) ->
- [CDEF_let (n, bindings, fst (opt IdSet.empty instrs))]
+ [CDEF_let (n, bindings, fst (opt NameSet.empty instrs))]
| cdef -> [cdef]
(** This optimization looks for patterns of the form
@@ -1135,26 +1090,26 @@ let combine_variables =
scan id 1 instrs
| 1, Some c, I_aux (I_copy (CL_id (id', ctyp'), (F_id c', ctyp'')), _) :: instrs
- when Id.compare c c' = 0 && Id.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' ->
+ when Name.compare c c' = 0 && Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' ->
scan id 2 instrs
(* Ignore seemingly early clears of x, as this can happen along exception paths *)
| 1, Some c, I_aux (I_clear (_, id'), _) :: instrs
- when Id.compare id id' = 0 ->
+ when Name.compare id id' = 0 ->
scan id 1 instrs
| 1, Some c, instr :: instrs ->
- if IdSet.mem id (instr_ids instr) then
+ if NameSet.mem id (instr_ids instr) then
None
else
scan id 1 instrs
| 2, Some c, I_aux (I_clear (ctyp', c'), _) :: instrs
- when Id.compare c c' = 0 && ctyp_equal ctyp ctyp' ->
+ when Name.compare c c' = 0 && ctyp_equal ctyp ctyp' ->
!combine
| 2, Some c, instr :: instrs ->
- if IdSet.mem c (instr_ids instr) then
+ if NameSet.mem c (instr_ids instr) then
None
else
scan id 2 instrs
@@ -1167,12 +1122,12 @@ let combine_variables =
scan id 0
in
let remove_variable id = function
- | I_aux (I_decl (_, id'), _) when Id.compare id id' = 0 -> removed
- | I_aux (I_clear (_, id'), _) when Id.compare id id' = 0 -> removed
+ | I_aux (I_decl (_, id'), _) when Name.compare id id' = 0 -> removed
+ | I_aux (I_clear (_, id'), _) when Name.compare id id' = 0 -> removed
| instr -> instr
in
let is_not_self_assignment = function
- | I_aux (I_copy (CL_id (id, _), (F_id id', _)), _) when Id.compare id id' = 0 -> false
+ | I_aux (I_copy (CL_id (id, _), (F_id id', _)), _) when Name.compare id id' = 0 -> false
| _ -> true
in
let rec opt = function
@@ -1200,63 +1155,6 @@ let combine_variables =
[CDEF_fundef (function_id, heap_return, args, opt body)]
| cdef -> [cdef]
-(** hoist_alias looks for patterns like
-
- recreate x; y = x; // no furthner mentions of x
-
- Provided x has a certain type, then we can make y an alias to x
- (denoted in the IR as 'alias y = x'). This only works if y also has
- a lifespan that also spans the entire function body. It's possible
- we may need to do a more thorough lifetime evaluation to get this
- to be 100% correct - so it's behind the -Oexperimental flag
- for now. Some benchmarking shows that this kind of optimization
- is very valuable however! *)
-let hoist_alias =
- (* Must return true for a subset of the types hoist_ctyp would return true for. *)
- let is_struct = function
- | CT_struct _ -> true
- | _ -> false
- in
- let pattern heap_return id ctyp instrs =
- let rec scan instrs =
- match instrs with
- (* The only thing that has a longer lifetime than id is the
- function return, so we want to make sure we avoid that
- case. *)
- | (I_aux (I_copy (clexp, (F_id id', ctyp')), aux) as instr) :: instrs
- when not (IdSet.mem heap_return (instr_writes instr)) && Id.compare id id' = 0
- && ctyp_equal (clexp_ctyp clexp) ctyp && ctyp_equal ctyp ctyp' ->
- if List.exists (IdSet.mem id) (List.map instr_ids instrs) then
- instr :: scan instrs
- else
- I_aux (I_alias (clexp, (F_id id', ctyp')), aux) :: instrs
-
- | instr :: instrs -> instr :: scan instrs
- | [] -> []
- in
- scan instrs
- in
- let optimize heap_return =
- let rec opt = function
- | (I_aux (I_reset (ctyp, id), _) as instr) :: instrs when not (is_stack_ctyp ctyp) && is_struct ctyp ->
- instr :: opt (pattern heap_return id ctyp instrs)
-
- | I_aux (I_block block, aux) :: instrs -> I_aux (I_block (opt block), aux) :: opt instrs
- | I_aux (I_try_block block, aux) :: instrs -> I_aux (I_try_block (opt block), aux) :: opt instrs
- | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs ->
- I_aux (I_if (cval, opt then_instrs, opt else_instrs, ctyp), aux) :: opt instrs
-
- | instr :: instrs ->
- instr :: opt instrs
- | [] -> []
- in
- opt
- in
- function
- | CDEF_fundef (function_id, Some heap_return, args, body) ->
- [CDEF_fundef (function_id, Some heap_return, args, optimize heap_return body)]
- | cdef -> [cdef]
-
let concatMap f xs = List.concat (List.map f xs)
let optimize recursive_functions cdefs =
@@ -1267,13 +1165,13 @@ let optimize recursive_functions cdefs =
|> (if !optimize_alias then concatMap combine_variables else nothing)
(* We need the runtime to initialize hoisted allocations *)
|> (if !optimize_hoist_allocations && not !opt_no_rts then concatMap (hoist_allocations recursive_functions) else nothing)
- |> (if !optimize_hoist_allocations && !optimize_experimental then concatMap hoist_alias else nothing)
(**************************************************************************)
(* 6. Code generation *)
(**************************************************************************)
let sgen_id id = Util.zencode_string (string_of_id id)
+let sgen_name id = string_of_name id
let codegen_id id = string (sgen_id id)
let sgen_function_id id =
@@ -1286,9 +1184,9 @@ let rec sgen_ctyp = function
| CT_unit -> "unit"
| CT_bit -> "fbits"
| CT_bool -> "bool"
- | CT_fbits _ -> "fbits"
+ | CT_fbits _ -> "uint64_t"
| CT_sbits _ -> "sbits"
- | CT_fint _ -> "mach_int"
+ | CT_fint _ -> "int64_t"
| CT_lint -> "sail_int"
| CT_lbits _ -> "lbits"
| CT_tup _ as tup -> "struct " ^ Util.zencode_string ("tuple_" ^ string_of_ctyp tup)
@@ -1336,23 +1234,23 @@ let sgen_cval_param (frag, ctyp) =
let sgen_cval = function (frag, _) -> string_of_fragment frag
let rec sgen_clexp = function
- | CL_id (id, _) -> "&" ^ sgen_id id
+ | CL_id (Have_exception _, _) -> "have_exception"
+ | CL_id (Current_exception _, _) -> "current_exception"
+ | CL_id (Return _, _) -> assert false
+ | CL_id (Name (id, _), _) -> "&" ^ sgen_id id
| CL_field (clexp, field) -> "&((" ^ sgen_clexp clexp ^ ")->" ^ Util.zencode_string field ^ ")"
| CL_tuple (clexp, n) -> "&((" ^ sgen_clexp clexp ^ ")->ztup" ^ string_of_int n ^ ")"
| CL_addr clexp -> "(*(" ^ sgen_clexp clexp ^ "))"
- | CL_have_exception -> "have_exception"
- | CL_current_exception _ -> "current_exception"
- | CL_return _ -> assert false
| CL_void -> assert false
let rec sgen_clexp_pure = function
- | CL_id (id, _) -> sgen_id id
+ | CL_id (Have_exception _, _) -> "have_exception"
+ | CL_id (Current_exception _, _) -> "current_exception"
+ | CL_id (Return _, _) -> assert false
+ | CL_id (Name (id, _), _) -> sgen_id id
| CL_field (clexp, field) -> sgen_clexp_pure clexp ^ "." ^ Util.zencode_string field
| CL_tuple (clexp, n) -> sgen_clexp_pure clexp ^ ".ztup" ^ string_of_int n
| CL_addr clexp -> "(*(" ^ sgen_clexp_pure clexp ^ "))"
- | CL_have_exception -> "have_exception"
- | CL_current_exception _ -> "current_exception"
- | CL_return _ -> assert false
| CL_void -> assert false
(** Generate instructions to copy from a cval to a clexp. This will
@@ -1397,16 +1295,13 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) =
let open Printf in
match instr with
| I_decl (ctyp, id) when is_stack_ctyp ctyp ->
- ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_id id)
+ ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id)
| I_decl (ctyp, id) ->
- ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_id id) ^^ hardline
- ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_id id)
+ ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id) ^^ hardline
+ ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id)
| I_copy (clexp, cval) -> codegen_conversion l clexp cval
- | I_alias (clexp, cval) ->
- ksprintf string " %s = %s;" (sgen_clexp_pure clexp) (sgen_cval cval)
-
| I_jump (cval, label) ->
ksprintf string " if (%s) goto %s;" (sgen_cval cval) label
@@ -1488,9 +1383,7 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) =
| "undefined_vector", _ -> Printf.sprintf "UNDEFINED(vector_%s)" (sgen_ctyp_name ctyp)
| fname, _ -> fname
in
- if fname = "sail_assert" && !optimize_experimental then
- empty
- else if fname = "reg_deref" then
+ if fname = "reg_deref" then
if is_stack_ctyp ctyp then
string (Printf.sprintf " %s = *(%s);" (sgen_clexp_pure x) c_args)
else
@@ -1504,7 +1397,7 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) =
| I_clear (ctyp, id) when is_stack_ctyp ctyp ->
empty
| I_clear (ctyp, id) ->
- string (Printf.sprintf " KILL(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_id id))
+ string (Printf.sprintf " KILL(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id))
| I_init (ctyp, id, cval) ->
codegen_instr fid ctx (idecl ctyp id) ^^ hardline
@@ -1515,9 +1408,9 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) =
^^ codegen_conversion Parse_ast.Unknown (CL_id (id, ctyp)) cval
| I_reset (ctyp, id) when is_stack_ctyp ctyp ->
- string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_id id))
+ string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_name id))
| I_reset (ctyp, id) ->
- string (Printf.sprintf " RECREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_id id))
+ string (Printf.sprintf " RECREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id))
| I_return cval ->
string (Printf.sprintf " return %s;" (sgen_cval cval))
@@ -1536,24 +1429,24 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) =
| CT_bool -> "false", []
| CT_enum (_, ctor :: _) -> sgen_id ctor, []
| CT_tup ctyps when is_stack_ctyp ctyp ->
- let gs = gensym () in
+ let gs = ngensym () in
let fold (inits, prev) (n, ctyp) =
let init, prev' = codegen_exn_return ctyp in
Printf.sprintf ".ztup%d = %s" n init :: inits, prev @ prev'
in
let inits, prev = List.fold_left fold ([], []) (List.mapi (fun i x -> (i, x)) ctyps) in
- sgen_id gs,
- [Printf.sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_id gs)
+ sgen_name gs,
+ [Printf.sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name gs)
^ Util.string_of_list ", " (fun x -> x) inits ^ " };"] @ prev
| CT_struct (id, ctors) when is_stack_ctyp ctyp ->
- let gs = gensym () in
+ let gs = ngensym () in
let fold (inits, prev) (id, ctyp) =
let init, prev' = codegen_exn_return ctyp in
Printf.sprintf ".%s = %s" (sgen_id id) init :: inits, prev @ prev'
in
let inits, prev = List.fold_left fold ([], []) ctors in
- sgen_id gs,
- [Printf.sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_id gs)
+ sgen_name gs,
+ [Printf.sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name gs)
^ Util.string_of_list ", " (fun x -> x) inits ^ " };"] @ prev
| ctyp -> c_error ("Cannot create undefined value for type: " ^ string_of_ctyp ctyp)
in
@@ -1575,7 +1468,7 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) =
| I_raw str ->
string (" " ^ str)
- | I_end -> assert false
+ | I_end _ -> assert false
| I_match_failure ->
string (" sail_match_failure(\"" ^ String.escaped (string_of_id fid) ^ "\");")
@@ -2014,13 +1907,13 @@ let is_decl = function
let codegen_decl = function
| I_aux (I_decl (ctyp, id), _) ->
- string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))
+ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_name id))
| _ -> assert false
let codegen_alloc = function
| I_aux (I_decl (ctyp, id), _) when is_stack_ctyp ctyp -> empty
| I_aux (I_decl (ctyp, id), _) ->
- string (Printf.sprintf " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_id id))
+ string (Printf.sprintf " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id))
| _ -> assert false
let codegen_def' ctx = function
@@ -2101,10 +1994,10 @@ let codegen_def' ctx = function
| CDEF_let (number, bindings, instrs) ->
let instrs = add_local_labels instrs in
let setup =
- List.concat (List.map (fun (id, ctyp) -> [idecl ctyp id]) bindings)
+ List.concat (List.map (fun (id, ctyp) -> [idecl ctyp (name id)]) bindings)
in
let cleanup =
- List.concat (List.map (fun (id, ctyp) -> [iclear ctyp id]) bindings)
+ List.concat (List.map (fun (id, ctyp) -> [iclear ctyp (name id)]) bindings)
in
separate_map hardline (fun (id, ctyp) -> string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))) bindings
^^ hardline ^^ string (Printf.sprintf "static void create_letbind_%d(void) " number)
diff --git a/src/jib/c_backend.mli b/src/jib/c_backend.mli
index 7314eb5a..3e8c426b 100644
--- a/src/jib/c_backend.mli
+++ b/src/jib/c_backend.mli
@@ -100,7 +100,6 @@ val optimize_primops : bool ref
val optimize_hoist_allocations : bool ref
val optimize_struct_updates : bool ref
val optimize_alias : bool ref
-val optimize_experimental : bool ref
(** Convert a typ to a IR ctyp *)
val ctyp_of_typ : Jib_compile.ctx -> Ast.typ -> ctyp
diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml
index 27f833d8..4a72ffff 100644
--- a/src/jib/jib_compile.ml
+++ b/src/jib/jib_compile.ml
@@ -61,6 +61,8 @@ let opt_debug_function = ref ""
let opt_debug_flow_graphs = ref false
let opt_memo_cache = ref false
+let ngensym () = name (gensym ())
+
(**************************************************************************)
(* 4. Conversion to low-level AST *)
(**************************************************************************)
@@ -183,26 +185,26 @@ let rec compile_aval l ctx = function
begin
try
let _, ctyp = Bindings.find id ctx.locals in
- [], (F_id id, ctyp), []
+ [], (F_id (name id), ctyp), []
with
| Not_found ->
- [], (F_id id, ctyp_of_typ ctx (lvar_typ typ)), []
+ [], (F_id (name id), ctyp_of_typ ctx (lvar_typ typ)), []
end
| AV_ref (id, typ) ->
- [], (F_ref id, CT_ref (ctyp_of_typ ctx (lvar_typ typ))), []
+ [], (F_ref (name id), CT_ref (ctyp_of_typ ctx (lvar_typ typ))), []
| AV_lit (L_aux (L_string str, _), typ) ->
[], (F_lit (V_string (String.escaped str)), ctyp_of_typ ctx typ), []
| AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) ->
- let gs = gensym () in
+ let gs = ngensym () in
[iinit CT_lint gs (F_lit (V_int n), CT_fint 64)],
(F_id gs, CT_lint),
[iclear CT_lint gs]
| AV_lit (L_aux (L_num n, _), typ) ->
- let gs = gensym () in
+ let gs = ngensym () in
[iinit CT_lint gs (F_lit (V_string (Big_int.to_string n)), CT_string)],
(F_id gs, CT_lint),
[iclear CT_lint gs]
@@ -214,7 +216,7 @@ let rec compile_aval l ctx = function
| AV_lit (L_aux (L_false, _), _) -> [], (F_lit (V_bool false), CT_bool), []
| AV_lit (L_aux (L_real str, _), _) ->
- let gs = gensym () in
+ let gs = ngensym () in
[iinit CT_real gs (F_lit (V_string str), CT_string)],
(F_id gs, CT_real),
[iclear CT_real gs]
@@ -230,7 +232,7 @@ let rec compile_aval l ctx = function
let setup = List.concat (List.map (fun (setup, _, _) -> setup) elements) in
let cleanup = List.concat (List.rev (List.map (fun (_, _, cleanup) -> cleanup) elements)) in
let tup_ctyp = CT_tup (List.map cval_ctyp cvals) in
- let gs = gensym () in
+ let gs = ngensym () in
setup
@ [idecl tup_ctyp gs]
@ List.mapi (fun n cval -> icopy l (CL_tuple (CL_id (gs, tup_ctyp), n)) cval) cvals,
@@ -240,7 +242,7 @@ let rec compile_aval l ctx = function
| AV_record (fields, typ) ->
let ctyp = ctyp_of_typ ctx typ in
- let gs = gensym () in
+ let gs = ngensym () in
let compile_fields (id, aval) =
let field_setup, cval, field_cleanup = compile_aval l ctx aval in
field_setup
@@ -278,7 +280,7 @@ let rec compile_aval l ctx = function
let bitstring avals = F_lit (V_bits (List.map value_of_aval_bit avals)) in
let first_chunk = bitstring (Util.take (len mod 64) avals) in
let chunks = Util.drop (len mod 64) avals |> chunkify 64 |> List.map bitstring in
- let gs = gensym () in
+ let gs = ngensym () in
[iinit (CT_lbits true) gs (first_chunk, CT_fbits (len mod 64, true))]
@ List.map (fun chunk -> ifuncall (CL_id (gs, CT_lbits true))
(mk_id "append_64")
@@ -295,7 +297,7 @@ let rec compile_aval l ctx = function
| Ord_aux (Ord_dec, _) -> true
| Ord_aux (Ord_var _, _) -> raise (Reporting.err_general l "Polymorphic vector direction found")
in
- let gs = gensym () in
+ let gs = ngensym () in
let ctyp = CT_fbits (len, direction) in
let mask i = V_bits (Util.list_init (63 - i) (fun _ -> Sail2_values.B0) @ [Sail2_values.B1] @ Util.list_init i (fun _ -> Sail2_values.B0)) in
let aval_mask i aval =
@@ -323,7 +325,7 @@ let rec compile_aval l ctx = function
| Ord_aux (Ord_var _, _) -> raise (Reporting.err_general l "Polymorphic vector direction found")
in
let vector_ctyp = CT_vector (direction, ctyp_of_typ ctx typ) in
- let gs = gensym () in
+ let gs = ngensym () in
let aval_set i aval =
let setup, cval, cleanup = compile_aval l ctx aval in
setup
@@ -346,7 +348,7 @@ let rec compile_aval l ctx = function
| Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> ctyp_of_typ ctx typ
| _ -> raise (Reporting.err_general l "Invalid list type")
in
- let gs = gensym () in
+ let gs = ngensym () in
let mk_cons aval =
let setup, cval, cleanup = compile_aval l ctx aval in
setup @ [ifuncall (CL_id (gs, CT_list ctyp)) (mk_id ("cons#" ^ string_of_ctyp ctyp)) [cval; (F_id gs, CT_list ctyp)]] @ cleanup
@@ -384,7 +386,7 @@ let compile_funcall l ctx id args typ =
else if ctyp_equal ctyp have_ctyp then
cval
else
- let gs = gensym () in
+ let gs = ngensym () in
setup := iinit ctyp gs cval :: !setup;
cleanup := iclear ctyp gs :: !cleanup;
(F_id gs, ctyp)
@@ -399,7 +401,7 @@ let compile_funcall l ctx id args typ =
if ctyp_equal (clexp_ctyp clexp) ret_ctyp then
ifuncall clexp id setup_args
else
- let gs = gensym () in
+ let gs = ngensym () in
iblock [idecl ret_ctyp gs;
ifuncall (CL_id (gs, ret_ctyp)) id setup_args;
icopy l clexp (F_id gs, ret_ctyp);
@@ -414,30 +416,37 @@ let rec apat_ctyp ctx (AP_aux (apat, _, _)) =
| AP_cons (apat, _) -> CT_list (apat_ctyp ctx apat)
| AP_wild typ | AP_nil typ | AP_id (_, typ) -> ctyp_of_typ ctx typ
| AP_app (_, _, typ) -> ctyp_of_typ ctx typ
+ | AP_as (_, _, typ) -> ctyp_of_typ ctx typ
let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label =
let ctx = { ctx with local_env = env } in
match apat_aux, cval with
| AP_id (pid, _), (frag, ctyp) when Env.is_union_constructor pid ctx.tc_env ->
- [ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind (string_of_id pid))), CT_bool) case_label],
+ [ijump (F_ctor_kind (frag, pid, [], ctyp), CT_bool) case_label],
[],
ctx
| AP_global (pid, typ), (frag, ctyp) ->
let global_ctyp = ctyp_of_typ ctx typ in
- [icopy l (CL_id (pid, global_ctyp)) cval], [], ctx
+ [icopy l (CL_id (name pid, global_ctyp)) cval], [], ctx
| AP_id (pid, _), (frag, ctyp) when is_ct_enum ctyp ->
begin match Env.lookup_id pid ctx.tc_env with
- | Unbound -> [idecl ctyp pid; icopy l (CL_id (pid, ctyp)) (frag, ctyp)], [], ctx
- | _ -> [ijump (F_op (F_id pid, "!=", frag), CT_bool) case_label], [], ctx
+ | Unbound -> [idecl ctyp (name pid); icopy l (CL_id (name pid, ctyp)) (frag, ctyp)], [], ctx
+ | _ -> [ijump (F_op (F_id (name pid), "!=", frag), CT_bool) case_label], [], ctx
end
| AP_id (pid, typ), _ ->
let ctyp = cval_ctyp cval in
let id_ctyp = ctyp_of_typ ctx typ in
let ctx = { ctx with locals = Bindings.add pid (Immutable, id_ctyp) ctx.locals } in
- [idecl id_ctyp pid; icopy l (CL_id (pid, id_ctyp)) cval], [iclear id_ctyp pid], ctx
+ [idecl id_ctyp (name pid); icopy l (CL_id (name pid, id_ctyp)) cval], [iclear id_ctyp (name pid)], ctx
+
+ | AP_as (apat, id, typ), _ ->
+ let id_ctyp = ctyp_of_typ ctx typ in
+ let instrs, cleanup, ctx = compile_match ctx apat cval case_label in
+ let ctx = { ctx with locals = Bindings.add id (Immutable, id_ctyp) ctx.locals } in
+ instrs @ [idecl id_ctyp (name id); icopy l (CL_id (name id, id_ctyp)) cval], iclear id_ctyp (name id) :: cleanup, ctx
| AP_tup apats, (frag, ctyp) ->
begin
@@ -456,25 +465,21 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label =
| AP_app (ctor, apat, variant_typ), (frag, ctyp) ->
begin match ctyp with
| CT_variant (_, ctors) ->
- let ctor_c_id = string_of_id ctor in
let ctor_ctyp = Bindings.find ctor (ctor_bindings ctors) in
+ let pat_ctyp = apat_ctyp ctx apat in
(* These should really be the same, something has gone wrong if they are not. *)
if ctyp_equal ctor_ctyp (ctyp_of_typ ctx variant_typ) then
raise (Reporting.err_general l (Printf.sprintf "%s is not the same type as %s" (string_of_ctyp ctor_ctyp) (string_of_ctyp (ctyp_of_typ ctx variant_typ))))
else ();
- let ctor_c_id, ctor_ctyp =
+ let unifiers, ctor_ctyp =
if is_polymorphic ctor_ctyp then
- let unification = List.map ctyp_suprema (ctyp_unify ctor_ctyp (apat_ctyp ctx apat)) in
- (if List.length unification > 0 then
- ctor_c_id ^ "_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification
- else
- ctor_c_id),
- ctyp_suprema (apat_ctyp ctx apat)
+ let unifiers = List.map ctyp_suprema (ctyp_unify ctor_ctyp pat_ctyp) in
+ unifiers, ctyp_suprema (apat_ctyp ctx apat)
else
- ctor_c_id, ctor_ctyp
+ [], ctor_ctyp
in
- let instrs, cleanup, ctx = compile_match ctx apat ((F_field (frag, Util.zencode_string ctor_c_id), ctor_ctyp)) case_label in
- [ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind ctor_c_id)), CT_bool) case_label]
+ let instrs, cleanup, ctx = compile_match ctx apat (F_ctor_unwrap (ctor, unifiers, frag), ctor_ctyp) case_label in
+ [ijump (F_ctor_kind (frag, ctor, unifiers, pat_ctyp), CT_bool) case_label]
@ instrs,
cleanup,
ctx
@@ -507,7 +512,9 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
let binding_ctyp = ctyp_of_typ { ctx with local_env = body_env } binding_typ in
let setup, call, cleanup = compile_aexp ctx binding in
let letb_setup, letb_cleanup =
- [idecl binding_ctyp id; iblock (setup @ [call (CL_id (id, binding_ctyp))] @ cleanup)], [iclear binding_ctyp id]
+ [idecl binding_ctyp (name id);
+ iblock (setup @ [call (CL_id (name id, binding_ctyp))] @ cleanup)],
+ [iclear binding_ctyp (name id)]
in
let ctx = { ctx with locals = Bindings.add id (mut, binding_ctyp) ctx.locals } in
let setup, call, cleanup = compile_aexp ctx body in
@@ -524,7 +531,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
| AE_case (aval, cases, typ) ->
let ctyp = ctyp_of_typ ctx typ in
let aval_setup, cval, aval_cleanup = compile_aval l ctx aval in
- let case_return_id = gensym () in
+ let case_return_id = ngensym () in
let finish_match_label = label "finish_match_" in
let compile_case (apat, guard, body) =
let trivial_guard = match guard with
@@ -536,13 +543,12 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
let destructure, destructure_cleanup, ctx = compile_match ctx apat cval case_label in
let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard in
let body_setup, body_call, body_cleanup = compile_aexp ctx body in
- let gs = gensym () in
+ let gs = ngensym () in
let case_instrs =
- destructure @ [icomment "end destructuring"]
+ destructure
@ (if not trivial_guard then
guard_setup @ [idecl CT_bool gs; guard_call (CL_id (gs, CT_bool))] @ guard_cleanup
@ [iif (F_unary ("!", F_id gs), CT_bool) (destructure_cleanup @ [igoto case_label]) [] CT_unit]
- @ [icomment "end guard"]
else [])
@ body_setup @ [body_call (CL_id (case_return_id, ctyp))] @ body_cleanup @ destructure_cleanup
@ [igoto finish_match_label]
@@ -552,21 +558,19 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
else
[iblock case_instrs; ilabel case_label]
in
- [icomment "begin match"]
- @ aval_setup @ [idecl ctyp case_return_id]
+ aval_setup @ [idecl ctyp case_return_id]
@ List.concat (List.map compile_case cases)
@ [imatch_failure ()]
@ [ilabel finish_match_label],
(fun clexp -> icopy l clexp (F_id case_return_id, ctyp)),
[iclear ctyp case_return_id]
@ aval_cleanup
- @ [icomment "end match"]
(* Compile try statement *)
| AE_try (aexp, cases, typ) ->
let ctyp = ctyp_of_typ ctx typ in
let aexp_setup, aexp_call, aexp_cleanup = compile_aexp ctx aexp in
- let try_return_id = gensym () in
+ let try_return_id = ngensym () in
let handled_exception_label = label "handled_exception_" in
let fallthrough_label = label "fallthrough_exception_" in
let compile_case (apat, guard, body) =
@@ -576,11 +580,11 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
| _ -> false
in
let try_label = label "try_" in
- let exn_cval = (F_current_exception, ctyp_of_typ ctx (mk_typ (Typ_id (mk_id "exception")))) in
+ let exn_cval = (F_id current_exception, ctyp_of_typ ctx (mk_typ (Typ_id (mk_id "exception")))) in
let destructure, destructure_cleanup, ctx = compile_match ctx apat exn_cval try_label in
let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard in
let body_setup, body_call, body_cleanup = compile_aexp ctx body in
- let gs = gensym () in
+ let gs = ngensym () in
let case_instrs =
destructure @ [icomment "end destructuring"]
@ (if not trivial_guard then
@@ -596,11 +600,11 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
assert (ctyp_equal ctyp (ctyp_of_typ ctx typ));
[idecl ctyp try_return_id;
itry_block (aexp_setup @ [aexp_call (CL_id (try_return_id, ctyp))] @ aexp_cleanup);
- ijump (F_unary ("!", F_have_exception), CT_bool) handled_exception_label]
+ ijump (F_unary ("!", F_id have_exception), CT_bool) handled_exception_label]
@ List.concat (List.map compile_case cases)
@ [igoto fallthrough_label;
ilabel handled_exception_label;
- icopy l CL_have_exception (F_lit (V_bool false), CT_bool);
+ icopy l (CL_id (have_exception, CT_bool)) (F_lit (V_bool false), CT_bool);
ilabel fallthrough_label],
(fun clexp -> icopy l clexp (F_id try_return_id, ctyp)),
[]
@@ -631,7 +635,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
| CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> Bindings.add k v m) Bindings.empty ctors
| _ -> raise (Reporting.err_general l "Cannot perform record update for non-record type")
in
- let gs = gensym () in
+ let gs = ngensym () in
let compile_fields (id, aval) =
let field_setup, cval, field_cleanup = compile_aval l ctx aval in
field_setup
@@ -650,7 +654,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
| AE_short_circuit (SC_and, aval, aexp) ->
let left_setup, cval, left_cleanup = compile_aval l ctx aval in
let right_setup, call, right_cleanup = compile_aexp ctx aexp in
- let gs = gensym () in
+ let gs = ngensym () in
left_setup
@ [ idecl CT_bool gs;
iif cval
@@ -663,7 +667,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
| AE_short_circuit (SC_or, aval, aexp) ->
let left_setup, cval, left_cleanup = compile_aval l ctx aval in
let right_setup, call, right_cleanup = compile_aexp ctx aexp in
- let gs = gensym () in
+ let gs = ngensym () in
left_setup
@ [ idecl CT_bool gs;
iif cval
@@ -681,7 +685,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
let compile_fields (field_id, aval) =
let field_setup, cval, field_cleanup = compile_aval l ctx aval in
field_setup
- @ [icopy l (CL_field (CL_id (id, ctyp_of_typ ctx typ), string_of_id field_id)) cval]
+ @ [icopy l (CL_field (CL_id (name id, ctyp_of_typ ctx typ), string_of_id field_id)) cval]
@ field_cleanup
in
List.concat (List.map compile_fields (Bindings.bindings fields)),
@@ -695,7 +699,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
| None -> ctyp_of_typ ctx assign_typ
in
let setup, call, cleanup = compile_aexp ctx aexp in
- setup @ [call (CL_id (id, assign_ctyp))], (fun clexp -> icopy l clexp unit_fragment), cleanup
+ setup @ [call (CL_id (name id, assign_ctyp))], (fun clexp -> icopy l clexp unit_fragment), cleanup
| AE_block (aexps, aexp, _) ->
let block = compile_block ctx aexps in
@@ -707,8 +711,8 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
let loop_end_label = label "wend_" in
let cond_setup, cond_call, cond_cleanup = compile_aexp ctx cond in
let body_setup, body_call, body_cleanup = compile_aexp ctx body in
- let gs = gensym () in
- let unit_gs = gensym () in
+ let gs = ngensym () in
+ let unit_gs = ngensym () in
let loop_test = (F_unary ("!", F_id gs), CT_bool) in
[idecl CT_bool gs; idecl CT_unit unit_gs]
@ [ilabel loop_start_label]
@@ -729,8 +733,8 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
let loop_end_label = label "until_" in
let cond_setup, cond_call, cond_cleanup = compile_aexp ctx cond in
let body_setup, body_call, body_cleanup = compile_aexp ctx body in
- let gs = gensym () in
- let unit_gs = gensym () in
+ let gs = ngensym () in
+ let unit_gs = ngensym () in
let loop_test = (F_id gs, CT_bool) in
[idecl CT_bool gs; idecl CT_unit unit_gs]
@ [ilabel loop_start_label]
@@ -759,7 +763,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
if ctyp_equal fn_return_ctyp (cval_ctyp cval) then
[ireturn cval]
else
- let gs = gensym () in
+ let gs = ngensym () in
[idecl fn_return_ctyp gs;
icopy l (CL_id (gs, fn_return_ctyp)) cval;
ireturn (F_id gs, fn_return_ctyp)]
@@ -775,7 +779,8 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
(fun clexp -> icomment "unreachable after throw"),
[]
- | AE_field (aval, id, _) ->
+ | AE_field (aval, id, typ) ->
+ let aval_ctyp = ctyp_of_typ ctx typ in
let setup, cval, cleanup = compile_aval l ctx aval in
let ctyp = match cval_ctyp cval with
| CT_struct (struct_id, fields) ->
@@ -788,8 +793,19 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
| _ ->
raise (Reporting.err_unreachable l __POS__ "Field access on non-struct type in ANF representation!")
in
+ let unifiers, ctyp =
+ if is_polymorphic ctyp then
+ let unifiers = List.map ctyp_suprema (ctyp_unify ctyp aval_ctyp) in
+ unifiers, ctyp_suprema aval_ctyp
+ else
+ [], ctyp
+ in
+ let field_str = match unifiers with
+ | [] -> Util.zencode_string (string_of_id id)
+ | _ -> Util.zencode_string (string_of_id id ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)
+ in
setup,
- (fun clexp -> icopy l clexp (F_field (fst cval, Util.zencode_string (string_of_id id)), ctyp)),
+ (fun clexp -> icopy l clexp (F_field (fst cval, field_str), ctyp)),
cleanup
| AE_for (loop_var, loop_from, loop_to, loop_step, Ord_aux (ord, _), body) ->
@@ -804,11 +820,11 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
(* Loop variables *)
let from_setup, from_call, from_cleanup = compile_aexp ctx loop_from in
- let from_gs = gensym () in
+ let from_gs = ngensym () in
let to_setup, to_call, to_cleanup = compile_aexp ctx loop_to in
- let to_gs = gensym () in
+ let to_gs = ngensym () in
let step_setup, step_call, step_cleanup = compile_aexp ctx loop_step in
- let step_gs = gensym () in
+ let step_gs = ngensym () in
let variable_init gs setup call cleanup =
[idecl (CT_fint 64) gs;
iblock (setup @ [call (CL_id (gs, CT_fint 64))] @ cleanup)]
@@ -817,7 +833,9 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
let loop_start_label = label "for_start_" in
let loop_end_label = label "for_end_" in
let body_setup, body_call, body_cleanup = compile_aexp ctx body in
- let body_gs = gensym () in
+ let body_gs = ngensym () in
+
+ let loop_var = name loop_var in
variable_init from_gs from_setup from_call from_cleanup
@ variable_init to_gs to_setup to_call to_cleanup
@@ -842,7 +860,7 @@ and compile_block ctx = function
| exp :: exps ->
let setup, call, cleanup = compile_aexp ctx exp in
let rest = compile_block ctx exps in
- let gs = gensym () in
+ let gs = ngensym () in
iblock (setup @ [idecl CT_unit gs; call (CL_id (gs, CT_unit))] @ cleanup) :: rest
(** Compile a sail type definition into a IR one. Most of the
@@ -892,14 +910,14 @@ let generate_cleanup instrs =
| instr -> []
in
let is_clear ids = function
- | I_aux (I_clear (_, id), _) -> IdSet.add id ids
+ | I_aux (I_clear (_, id), _) -> NameSet.add id ids
| _ -> ids
in
- let cleaned = List.fold_left is_clear IdSet.empty instrs in
+ let cleaned = List.fold_left is_clear NameSet.empty instrs in
instrs
|> List.map generate_cleanup'
|> List.concat
- |> List.filter (fun (id, _) -> not (IdSet.mem id cleaned))
+ |> List.filter (fun (id, _) -> not (NameSet.mem id cleaned))
|> List.map snd
let fix_exception_block ?return:(return=None) ctx instrs =
@@ -927,8 +945,8 @@ let fix_exception_block ?return:(return=None) ctx instrs =
@ rewrite_exception historic after
| before, I_aux (I_throw cval, (_, l)) :: after ->
before
- @ [icopy l (CL_current_exception (cval_ctyp cval)) cval;
- icopy l CL_have_exception (F_lit (V_bool true), CT_bool)]
+ @ [icopy l (CL_id (current_exception, cval_ctyp cval)) cval;
+ icopy l (CL_id (have_exception, CT_bool)) (F_lit (V_bool true), CT_bool)]
@ generate_cleanup (historic @ before)
@ [igoto end_block_label]
@ rewrite_exception (historic @ before) after
@@ -941,7 +959,7 @@ let fix_exception_block ?return:(return=None) ctx instrs =
if has_effect effects BE_escape then
before
@ [funcall;
- iif (F_have_exception, CT_bool) (generate_cleanup (historic @ before) @ [igoto end_block_label]) [] CT_unit]
+ iif (F_id have_exception, CT_bool) (generate_cleanup (historic @ before) @ [igoto end_block_label]) [] CT_unit]
@ rewrite_exception (historic @ before) after
else
before @ funcall :: rewrite_exception (historic @ before) after
@@ -958,10 +976,10 @@ let rec map_try_block f (I_aux (instr, aux)) =
| I_decl _ | I_reset _ | I_init _ | I_reinit _ -> instr
| I_if (cval, instrs1, instrs2, ctyp) ->
I_if (cval, List.map (map_try_block f) instrs1, List.map (map_try_block f) instrs2, ctyp)
- | I_funcall _ | I_copy _ | I_alias _ | I_clear _ | I_throw _ | I_return _ -> instr
+ | I_funcall _ | I_copy _ | I_clear _ | I_throw _ | I_return _ -> instr
| I_block instrs -> I_block (List.map (map_try_block f) instrs)
| I_try_block instrs -> I_try_block (f (List.map (map_try_block f) instrs))
- | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_jump _ | I_match_failure | I_undefined _ | I_end -> instr
+ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_jump _ | I_match_failure | I_undefined _ | I_end _ -> instr
in
I_aux (instr, aux)
@@ -979,7 +997,7 @@ let rec compile_arg_pat ctx label (P_aux (p_aux, (l, _)) as pat) ctyp =
| _ ->
let apat = anf_pat pat in
let gs = gensym () in
- let destructure, cleanup, _ = compile_match ctx apat (F_id gs, ctyp) label in
+ let destructure, cleanup, _ = compile_match ctx apat (F_id (name gs), ctyp) label in
(gs, (destructure, cleanup))
let rec compile_arg_pats ctx label (P_aux (p_aux, (l, _)) as pat) ctyps =
@@ -994,10 +1012,10 @@ let rec compile_arg_pats ctx label (P_aux (p_aux, (l, _)) as pat) ctyps =
let arg_id, (destructure, cleanup) = compile_arg_pat ctx label pat (CT_tup ctyps) in
let new_ids = List.map (fun ctyp -> gensym (), ctyp) ctyps in
destructure
- @ [idecl (CT_tup ctyps) arg_id]
- @ List.mapi (fun i (id, ctyp) -> icopy l (CL_tuple (CL_id (arg_id, CT_tup ctyps), i)) (F_id id, ctyp)) new_ids,
+ @ [idecl (CT_tup ctyps) (name arg_id)]
+ @ List.mapi (fun i (id, ctyp) -> icopy l (CL_tuple (CL_id (name arg_id, CT_tup ctyps), i)) (F_id (name id), ctyp)) new_ids,
List.map (fun (id, _) -> id, ([], [])) new_ids,
- [iclear (CT_tup ctyps) arg_id]
+ [iclear (CT_tup ctyps) (name arg_id)]
@ cleanup
let combine_destructure_cleanup xs = List.concat (List.map fst xs), List.concat (List.rev (List.map snd xs))
@@ -1108,7 +1126,7 @@ and compile_def' n total ctx = function
| DEF_reg_dec (DEC_aux (DEC_config (id, typ, exp), _)) ->
let aexp = ctx.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in
let setup, call, cleanup = compile_aexp ctx aexp in
- let instrs = setup @ [call (CL_id (id, ctyp_of_typ ctx typ))] @ cleanup in
+ let instrs = setup @ [call (CL_id (name id, ctyp_of_typ ctx typ))] @ cleanup in
[CDEF_reg_dec (id, ctyp_of_typ ctx typ, instrs)], ctx
| DEF_reg_dec (DEC_aux (_, (l, _))) ->
@@ -1161,8 +1179,8 @@ and compile_def' n total ctx = function
compiled_args |> List.map snd |> combine_destructure_cleanup |> fix_destructure fundef_label
in
- let instrs = arg_setup @ destructure @ setup @ [call (CL_return ret_ctyp)] @ cleanup @ destructure_cleanup @ arg_cleanup in
- let instrs = fix_early_return (CL_return ret_ctyp) instrs in
+ let instrs = arg_setup @ destructure @ setup @ [call (CL_id (return, ret_ctyp))] @ cleanup @ destructure_cleanup @ arg_cleanup in
+ let instrs = fix_early_return (CL_id (return, ret_ctyp)) instrs in
let instrs = fix_exception ~return:(Some ret_ctyp) ctx instrs in
if Id.compare (mk_id !opt_debug_function) id = 0 then
@@ -1179,10 +1197,15 @@ and compile_def' n total ctx = function
if !opt_debug_flow_graphs then
begin
let instrs = Jib_optimize.(instrs |> optimize_unit |> flatten_instrs) in
- let cfg = Jib_ssa.ssa instrs in
+ let root, _, cfg = Jib_ssa.control_flow_graph instrs in
+ let idom = Jib_ssa.immediate_dominators cfg root in
+ let _, cfg = Jib_ssa.ssa instrs in
let out_chan = open_out (Util.zencode_string (string_of_id id) ^ ".gv") in
Jib_ssa.make_dot out_chan cfg;
close_out out_chan;
+ let out_chan = open_out (Util.zencode_string (string_of_id id) ^ ".dom.gv") in
+ Jib_ssa.make_dominators_dot out_chan idom cfg;
+ close_out out_chan;
end;
[CDEF_fundef (id, None, List.map fst compiled_args, instrs)], orig_ctx
@@ -1206,7 +1229,7 @@ and compile_def' n total ctx = function
let aexp = ctx.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in
let setup, call, cleanup = compile_aexp ctx aexp in
let apat = anf_pat ~global:true pat in
- let gs = gensym () in
+ let gs = ngensym () in
let end_label = label "let_end_" in
let destructure, destructure_cleanup, _ = compile_match ctx apat (F_id gs, ctyp) end_label in
let gs_setup, gs_cleanup =
@@ -1257,13 +1280,19 @@ let rec specialize_variants ctx prior =
| CT_variant (id, ctors) when Id.compare id var_id = 0 -> CT_variant (id, new_ctors)
| ctyp -> ctyp
in
+ let fix_struct_ctyp struct_id new_fields = function
+ | CT_struct (id, ctors) when Id.compare id struct_id = 0 -> CT_struct (id, new_fields)
+ | ctyp -> ctyp
+ in
- let specialize_constructor ctx ctor_id ctyp =
- function
+ (* specialize_constructor is called on all instructions when we find
+ a constructor in a union type that is polymorphic. It's job is to
+ record all uses of that constructor so we can monomorphise it. *)
+ let specialize_constructor ctx ctor_id ctyp = function
| I_aux (I_funcall (clexp, extern, id, [cval]), ((_, l) as aux)) as instr when Id.compare id ctor_id = 0 ->
(* Work out how each call to a constructor in instantiated and add that to unifications *)
- let unification = List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval)) in
- let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification) in
+ let unifiers = List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval)) in
+ let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in
unifications := Bindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications;
(* We need to cast each cval to it's ctyp_suprema in order to put it in the most general constructor *)
@@ -1273,7 +1302,7 @@ let rec specialize_variants ctx prior =
if ctyp_equal ctyp suprema then
[], (unpoly frag, ctyp), []
else
- let gs = gensym () in
+ let gs = ngensym () in
[idecl suprema gs;
icopy l (CL_id (gs, suprema)) (unpoly frag, ctyp)],
(F_id gs, suprema),
@@ -1297,17 +1326,36 @@ let rec specialize_variants ctx prior =
| I_aux (I_funcall (clexp, extern, id, cvals), ((_, l) as aux)) as instr when Id.compare id ctor_id = 0 ->
Reporting.unreachable l __POS__ "Multiple argument constructor found"
+ (* We have to be careful this is the only place where an F_ctor_kind can appear before calling specialize variants *)
+ | I_aux (I_jump ((F_ctor_kind (_, id, unifiers, pat_ctyp), CT_bool), _), _) as instr when Id.compare id ctor_id = 0 ->
+ let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in
+ unifications := Bindings.add mono_id (ctyp_suprema pat_ctyp) !unifications;
+ instr
+
| instr -> instr
in
+ (* specialize_field performs the same job as specialize_constructor,
+ but for struct fields rather than union constructors. *)
+ let specialize_field ctx field_id ctyp = function
+ | I_aux (I_copy (CL_field (clexp, field_str), cval), aux) when string_of_id field_id = field_str ->
+ let unifiers = List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval)) in
+ let mono_id = append_id field_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in
+ unifications := Bindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications;
+ I_aux (I_copy (CL_field (clexp, string_of_id mono_id), cval), aux)
+
+ | instr -> instr
+ in
+
function
| (CDEF_type (CTD_variant (var_id, ctors)) as cdef) :: cdefs ->
let polymorphic_ctors = List.filter (fun (_, ctyp) -> is_polymorphic ctyp) ctors in
let cdefs =
- List.fold_left (fun cdefs (ctor_id, ctyp) -> List.map (cdef_map_instr (specialize_constructor ctx ctor_id ctyp)) cdefs)
- cdefs
- polymorphic_ctors
+ List.fold_left
+ (fun cdefs (ctor_id, ctyp) -> List.map (cdef_map_instr (specialize_constructor ctx ctor_id ctyp)) cdefs)
+ cdefs
+ polymorphic_ctors
in
let monomorphic_ctors = List.filter (fun (_, ctyp) -> not (is_polymorphic ctyp)) ctors in
@@ -1324,6 +1372,30 @@ let rec specialize_variants ctx prior =
let prior = List.map (cdef_map_ctyp (map_ctyp (fix_variant_ctyp var_id new_ctors))) prior in
specialize_variants ctx (CDEF_type (CTD_variant (var_id, new_ctors)) :: prior) cdefs
+ | (CDEF_type (CTD_struct (struct_id, fields)) as cdef) :: cdefs ->
+ let polymorphic_fields = List.filter (fun (_, ctyp) -> is_polymorphic ctyp) fields in
+
+ let cdefs =
+ List.fold_left
+ (fun cdefs (field_id, ctyp) -> List.map (cdef_map_instr (specialize_field ctx field_id ctyp)) cdefs)
+ cdefs
+ polymorphic_fields
+ in
+
+ let monomorphic_fields = List.filter (fun (_, ctyp) -> not (is_polymorphic ctyp)) fields in
+ let specialized_fields = Bindings.bindings !unifications in
+ let new_fields = monomorphic_fields @ specialized_fields in
+
+ let ctx = {
+ ctx with records = Bindings.add struct_id
+ (List.fold_left (fun m (id, ctyp) -> Bindings.add id ctyp m) !unifications monomorphic_fields)
+ ctx.records
+ } in
+
+ let cdefs = List.map (cdef_map_ctyp (map_ctyp (fix_struct_ctyp struct_id new_fields))) cdefs in
+ let prior = List.map (cdef_map_ctyp (map_ctyp (fix_struct_ctyp struct_id new_fields))) prior in
+ specialize_variants ctx (CDEF_type (CTD_struct (struct_id, new_fields)) :: prior) cdefs
+
| cdef :: cdefs ->
let remove_poly (I_aux (instr, aux)) =
match instr with
diff --git a/src/jib/jib_compile.mli b/src/jib/jib_compile.mli
index f3bd8c76..a0cacc3c 100644
--- a/src/jib/jib_compile.mli
+++ b/src/jib/jib_compile.mli
@@ -63,6 +63,8 @@ val opt_debug_flow_graphs : bool ref
(** Print the IR representation of a specific function. *)
val opt_debug_function : string ref
+val ngensym : unit -> name
+
(** {2 Jib context} *)
(** Context for compiling Sail to Jib. We need to pass a (global)
diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml
index 889e650e..73b175a1 100644
--- a/src/jib/jib_optimize.ml
+++ b/src/jib/jib_optimize.ml
@@ -71,12 +71,6 @@ let optimize_unit instrs =
I_aux (I_copy (CL_void, unit_cval cval), annot)
| _ -> instr
end
- | I_aux (I_alias (clexp, cval), annot) as instr ->
- begin match clexp_ctyp clexp with
- | CT_unit ->
- I_aux (I_alias (CL_void, unit_cval cval), annot)
- | _ -> instr
- end
| instr -> instr
in
let non_pointless_copy (I_aux (aux, annot)) =
@@ -90,7 +84,7 @@ let flat_counter = ref 0
let flat_id () =
let id = mk_id ("local#" ^ string_of_int !flat_counter) in
incr flat_counter;
- id
+ name id
let rec flatten_instrs = function
| I_aux (I_decl (ctyp, decl_id), aux) :: instrs ->
@@ -127,3 +121,157 @@ let flatten_cdef =
CDEF_let (n, bindings, flatten_instrs instrs)
| cdef -> cdef
+
+let unique_per_function_ids cdefs =
+ let unique_id i = function
+ | Name (id, ssa_num) -> Name (append_id id ("#u" ^ string_of_int i), ssa_num)
+ | name -> name
+ in
+ let rec unique_instrs i = function
+ | I_aux (I_decl (ctyp, id), aux) :: rest ->
+ I_aux (I_decl (ctyp, unique_id i id), aux) :: unique_instrs i (instrs_rename id (unique_id i id) rest)
+ | I_aux (I_init (ctyp, id, cval), aux) :: rest ->
+ I_aux (I_init (ctyp, unique_id i id, cval), aux) :: unique_instrs i (instrs_rename id (unique_id i id) rest)
+ | I_aux (I_block instrs, aux) :: rest ->
+ I_aux (I_block (unique_instrs i instrs), aux) :: unique_instrs i rest
+ | I_aux (I_try_block instrs, aux) :: rest ->
+ I_aux (I_try_block (unique_instrs i instrs), aux) :: unique_instrs i rest
+ | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: rest ->
+ I_aux (I_if (cval, unique_instrs i then_instrs, unique_instrs i else_instrs, ctyp), aux) :: unique_instrs i rest
+ | instr :: instrs -> instr :: unique_instrs i instrs
+ | [] -> []
+ in
+ let unique_cdef i = function
+ | CDEF_reg_dec (id, ctyp, instrs) -> CDEF_reg_dec (id, ctyp, unique_instrs i instrs)
+ | CDEF_type ctd -> CDEF_type ctd
+ | CDEF_let (n, bindings, instrs) -> CDEF_let (n, bindings, unique_instrs i instrs)
+ | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, ctyps, ctyp)
+ | CDEF_fundef (id, heap_return, args, instrs) -> CDEF_fundef (id, heap_return, args, unique_instrs i instrs)
+ | CDEF_startup (id, instrs) -> CDEF_startup (id, unique_instrs i instrs)
+ | CDEF_finish (id, instrs) -> CDEF_finish (id, unique_instrs i instrs)
+ in
+ List.mapi unique_cdef cdefs
+
+let rec frag_subst id subst = function
+ | F_id id' -> if Name.compare id id' = 0 then subst else F_id id'
+ | F_ref reg_id -> F_ref reg_id
+ | F_lit vl -> F_lit vl
+ | F_op (frag1, op, frag2) -> F_op (frag_subst id subst frag1, op, frag_subst id subst frag2)
+ | F_unary (op, frag) -> F_unary (op, frag_subst id subst frag)
+ | F_call (op, frags) -> F_call (op, List.map (frag_subst id subst) frags)
+ | F_field (frag, field) -> F_field (frag_subst id subst frag, field)
+ | F_raw str -> F_raw str
+ | F_ctor_kind (frag, ctor, unifiers, ctyp) -> F_ctor_kind (frag_subst id subst frag, ctor, unifiers, ctyp)
+ | F_ctor_unwrap (ctor, unifiers, frag) -> F_ctor_unwrap (ctor, unifiers, frag_subst id subst frag)
+ | F_poly frag -> F_poly (frag_subst id subst frag)
+
+let cval_subst id subst (frag, ctyp) = frag_subst id subst frag, ctyp
+
+let rec instrs_subst id subst =
+ function
+ | (I_aux (I_decl (_, id'), _) :: _) as instrs when Name.compare id id' = 0 ->
+ instrs
+
+ | I_aux (I_init (ctyp, id', cval), aux) :: rest when Name.compare id id' = 0 ->
+ I_aux (I_init (ctyp, id', cval_subst id subst cval), aux) :: rest
+
+ | (I_aux (I_reset (_, id'), _) :: _) as instrs when Name.compare id id' = 0 ->
+ instrs
+
+ | I_aux (I_reinit (ctyp, id', cval), aux) :: rest when Name.compare id id' = 0 ->
+ I_aux (I_reinit (ctyp, id', cval_subst id subst cval), aux) :: rest
+
+ | I_aux (instr, aux) :: instrs ->
+ let instrs = instrs_subst id subst instrs in
+ let instr = match instr with
+ | I_decl (ctyp, id') -> I_decl (ctyp, id')
+ | I_init (ctyp, id', cval) -> I_init (ctyp, id', cval_subst id subst cval)
+ | I_jump (cval, label) -> I_jump (cval_subst id subst cval, label)
+ | I_goto label -> I_goto label
+ | I_label label -> I_label label
+ | I_funcall (clexp, extern, fid, args) -> I_funcall (clexp, extern, fid, List.map (cval_subst id subst) args)
+ | I_copy (clexp, cval) -> I_copy (clexp, cval_subst id subst cval)
+ | I_clear (clexp, id') -> I_clear (clexp, id')
+ | I_undefined ctyp -> I_undefined ctyp
+ | I_match_failure -> I_match_failure
+ | I_end id' -> I_end id'
+ | I_if (cval, then_instrs, else_instrs, ctyp) ->
+ I_if (cval_subst id subst cval, instrs_subst id subst then_instrs, instrs_subst id subst else_instrs, ctyp)
+ | I_block instrs -> I_block (instrs_subst id subst instrs)
+ | I_try_block instrs -> I_try_block (instrs_subst id subst instrs)
+ | I_throw cval -> I_throw (cval_subst id subst cval)
+ | I_comment str -> I_comment str
+ | I_raw str -> I_raw str
+ | I_return cval -> I_return cval
+ | I_reset (ctyp, id') -> I_reset (ctyp, id')
+ | I_reinit (ctyp, id', cval) -> I_reinit (ctyp, id', cval_subst id subst cval)
+ in
+ I_aux (instr, aux) :: instrs
+
+ | [] -> []
+
+let rec clexp_subst id subst = function
+ | CL_id (id', ctyp) when Name.compare id id' = 0 ->
+ assert (ctyp_equal ctyp (clexp_ctyp subst));
+ subst
+ | CL_id (id', ctyp) -> CL_id (id', ctyp)
+ | CL_field (clexp, field) -> CL_field (clexp_subst id subst clexp, field)
+ | CL_addr clexp -> CL_addr (clexp_subst id subst clexp)
+ | CL_tuple (clexp, n) -> CL_tuple (clexp_subst id subst clexp, n)
+ | CL_void -> CL_void
+
+let rec find_function fid = function
+ | CDEF_fundef (fid', heap_return, args, body) :: _ when Id.compare fid fid' = 0 ->
+ Some (heap_return, args, body)
+
+ | cdef :: cdefs -> find_function fid cdefs
+
+ | [] -> None
+
+let inline cdefs should_inline instrs =
+ let inlines = ref (-1) in
+
+ let replace_return subst = function
+ | I_aux (I_funcall (clexp, extern, fid, args), aux) ->
+ I_aux (I_funcall (clexp_subst return subst clexp, extern, fid, args), aux)
+ | I_aux (I_copy (clexp, cval), aux) ->
+ I_aux (I_copy (clexp_subst return subst clexp, cval), aux)
+ | instr -> instr
+ in
+
+ let replace_end label = function
+ | I_aux (I_end _, aux) -> I_aux (I_goto label, aux)
+ | I_aux (I_undefined _, aux) -> I_aux (I_goto label, aux)
+ | instr -> instr
+ in
+
+ let rec inline_instr = function
+ | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline function_id ->
+ begin match find_function function_id cdefs with
+ | Some (None, ids, body) ->
+ incr inlines;
+ let inline_label = label "end_inline_" in
+ let body = List.fold_right2 instrs_subst (List.map name ids) (List.map fst args) body in
+ let body = List.map (map_instr (replace_end inline_label)) body in
+ let body = List.map (map_instr (replace_return clexp)) body in
+ I_aux (I_block (body @ [ilabel inline_label]), aux)
+ | Some (Some _, ids, body) ->
+ (* Some _ is only introduced by C backend, so we don't
+ expect it at this point. *)
+ raise (Reporting.err_general (snd aux) "Unexpected return method in IR")
+ | None -> instr
+ end
+ | instr -> instr
+ in
+
+ let rec go instrs =
+ if !inlines <> 0 then
+ begin
+ inlines := 0;
+ let instrs = List.map (map_instr inline_instr) instrs in
+ go instrs
+ end
+ else
+ instrs
+ in
+ go instrs
diff --git a/src/jib/jib_optimize.mli b/src/jib/jib_optimize.mli
index beffa81e..78759d08 100644
--- a/src/jib/jib_optimize.mli
+++ b/src/jib/jib_optimize.mli
@@ -61,3 +61,6 @@ val optimize_unit : instr list -> instr list
val flatten_instrs : instr list -> instr list
val flatten_cdef : cdef -> cdef
+val unique_per_function_ids : cdef list -> cdef list
+
+val inline : cdef list -> (Ast.id -> bool) -> instr list -> instr list
diff --git a/src/jib/jib_ssa.ml b/src/jib/jib_ssa.ml
index 1f477696..a086f0b9 100644
--- a/src/jib/jib_ssa.ml
+++ b/src/jib/jib_ssa.ml
@@ -68,6 +68,15 @@ let make ~initial_size () = {
nodes = Array.make initial_size None
}
+let get_vertex graph n = graph.nodes.(n)
+
+let iter_graph f graph =
+ for n = 0 to graph.next - 1 do
+ match graph.nodes.(n) with
+ | Some (x, y, z) -> f x y z
+ | None -> ()
+ done
+
(** Add a vertex to a graph, returning the node index *)
let add_vertex data graph =
let n = graph.next in
@@ -133,8 +142,11 @@ let prune visited graph =
type cf_node =
| CF_label of string
| CF_block of instr list
+ | CF_guard of cval
| CF_start
+let cval_not (f, ctyp) = (F_unary ("!", f), ctyp)
+
let control_flow_graph instrs =
let module StringMap = Map.Make(String) in
let labels = ref StringMap.empty in
@@ -150,14 +162,14 @@ let control_flow_graph instrs =
let cf_split (I_aux (aux, _)) =
match aux with
- | I_block _ | I_label _ | I_goto _ | I_jump _ | I_if _ | I_end | I_match_failure | I_undefined _ -> true
+ | I_block _ | I_label _ | I_goto _ | I_jump _ | I_if _ | I_end _ | I_match_failure | I_undefined _ -> true
| _ -> false
in
let rec cfg preds instrs =
let before, after = instr_split_at cf_split instrs in
let last = match after with
- | I_aux (I_label _, _) :: _ -> []
+ | I_aux ((I_label _ | I_goto _ | I_jump _), _) :: _ -> []
| instr :: _ -> [instr]
| _ -> []
in
@@ -174,7 +186,7 @@ let control_flow_graph instrs =
let e = cfg preds else_instrs in
cfg (t @ e) after
- | I_aux ((I_end | I_match_failure | I_undefined _), _) :: after ->
+ | I_aux ((I_end _ | I_match_failure | I_undefined _), _) :: after ->
cfg [] after
| I_aux (I_goto label, _) :: after ->
@@ -182,8 +194,11 @@ let control_flow_graph instrs =
cfg [] after
| I_aux (I_jump (cval, label), _) :: after ->
- List.iter (fun p -> add_edge p (StringMap.find label !labels) graph) preds;
- cfg preds after
+ let t = add_vertex ([], CF_guard cval) graph in
+ let f = add_vertex ([], CF_guard (cval_not cval)) graph in
+ List.iter (fun p -> add_edge p t graph; add_edge p f graph) preds;
+ add_edge t (StringMap.find label !labels) graph;
+ cfg [f] after
| I_aux (I_label label, _) :: after ->
cfg (StringMap.find label !labels :: preds) after
@@ -351,55 +366,56 @@ let dominance_frontiers graph root idom children =
(**************************************************************************)
type ssa_elem =
- | Phi of Ast.id * Ast.id list
+ | Phi of Jib.name * Jib.ctyp * Jib.name list
+ | Pi of Jib.cval list
let place_phi_functions graph df =
- let defsites = ref Bindings.empty in
+ let defsites = ref NameCTMap.empty in
- let all_vars = ref IdSet.empty in
+ let all_vars = ref NameCTSet.empty in
let rec all_decls = function
- | I_aux (I_decl (_, id), _) :: instrs ->
- IdSet.add id (all_decls instrs)
+ | I_aux ((I_init (ctyp, id, _) | I_decl (ctyp, id)), _) :: instrs ->
+ NameCTSet.add (id, ctyp) (all_decls instrs)
| _ :: instrs -> all_decls instrs
- | [] -> IdSet.empty
+ | [] -> NameCTSet.empty
in
let orig_A n =
match graph.nodes.(n) with
| Some ((_, CF_block instrs), _, _) ->
- let vars = List.fold_left IdSet.union IdSet.empty (List.map instr_writes instrs) in
- let vars = IdSet.diff vars (all_decls instrs) in
- all_vars := IdSet.union vars !all_vars;
+ let vars = List.fold_left NameCTSet.union NameCTSet.empty (List.map instr_typed_writes instrs) in
+ let vars = NameCTSet.diff vars (all_decls instrs) in
+ all_vars := NameCTSet.union vars !all_vars;
vars
- | Some _ -> IdSet.empty
- | None -> IdSet.empty
+ | Some _ -> NameCTSet.empty
+ | None -> NameCTSet.empty
in
- let phi_A = ref Bindings.empty in
+ let phi_A = ref NameCTMap.empty in
for n = 0 to graph.next - 1 do
- IdSet.iter (fun a ->
- let ds = match Bindings.find_opt a !defsites with Some ds -> ds | None -> IntSet.empty in
- defsites := Bindings.add a (IntSet.add n ds) !defsites
+ NameCTSet.iter (fun a ->
+ let ds = match NameCTMap.find_opt a !defsites with Some ds -> ds | None -> IntSet.empty in
+ defsites := NameCTMap.add a (IntSet.add n ds) !defsites
) (orig_A n)
done;
- IdSet.iter (fun a ->
- let workset = ref (Bindings.find a !defsites) in
+ NameCTSet.iter (fun a ->
+ let workset = ref (NameCTMap.find a !defsites) in
while not (IntSet.is_empty !workset) do
let n = IntSet.choose !workset in
workset := IntSet.remove n !workset;
IntSet.iter (fun y ->
- let phi_A_a = match Bindings.find_opt a !phi_A with Some set -> set | None -> IntSet.empty in
+ let phi_A_a = match NameCTMap.find_opt a !phi_A with Some set -> set | None -> IntSet.empty in
if not (IntSet.mem y phi_A_a) then
begin
begin match graph.nodes.(y) with
| Some ((phis, cfnode), preds, succs) ->
- graph.nodes.(y) <- Some ((Phi (a, Util.list_init (IntSet.cardinal preds) (fun _ -> a)) :: phis, cfnode), preds, succs)
+ graph.nodes.(y) <- Some ((Phi (fst a, snd a, Util.list_init (IntSet.cardinal preds) (fun _ -> fst a)) :: phis, cfnode), preds, succs)
| None -> assert false
end;
- phi_A := Bindings.add a (IntSet.add y phi_A_a) !phi_A;
- if not (IdSet.mem a (orig_A y)) then
+ phi_A := NameCTMap.add a (IntSet.add y phi_A_a) !phi_A;
+ if not (NameCTSet.mem a (orig_A y)) then
workset := IntSet.add y !workset
end
) df.(n)
@@ -407,49 +423,53 @@ let place_phi_functions graph df =
) !all_vars
let rename_variables graph root children =
- let counts = ref Bindings.empty in
- let stacks = ref Bindings.empty in
+ let counts = ref NameMap.empty in
+ let stacks = ref NameMap.empty in
let get_count id =
- match Bindings.find_opt id !counts with Some n -> n | None -> 0
+ match NameMap.find_opt id !counts with Some n -> n | None -> 0
in
let top_stack id =
- match Bindings.find_opt id !stacks with Some (x :: _) -> x | (Some [] | None) -> 0
+ match NameMap.find_opt id !stacks with Some (x :: _) -> x | (Some [] | None) -> 0
in
let push_stack id n =
- stacks := Bindings.add id (n :: match Bindings.find_opt id !stacks with Some s -> s | None -> []) !stacks
+ stacks := NameMap.add id (n :: match NameMap.find_opt id !stacks with Some s -> s | None -> []) !stacks
+ in
+
+ let ssa_name i = function
+ | Name (id, _) -> Name (id, i)
+ | Have_exception _ -> Have_exception i
+ | Current_exception _ -> Current_exception i
+ | Return _ -> Return i
in
let rec fold_frag = function
| F_id id ->
let i = top_stack id in
- F_id (append_id id ("_" ^ string_of_int i))
+ F_id (ssa_name i id)
| F_ref id ->
let i = top_stack id in
- F_ref (append_id id ("_" ^ string_of_int i))
+ F_ref (ssa_name i id)
| F_lit vl -> F_lit vl
- | F_have_exception -> F_have_exception
- | F_current_exception -> F_current_exception
| F_op (f1, op, f2) -> F_op (fold_frag f1, op, fold_frag f2)
| F_unary (op, f) -> F_unary (op, fold_frag f)
| F_call (id, fs) -> F_call (id, List.map fold_frag fs)
| F_field (f, field) -> F_field (fold_frag f, field)
| F_raw str -> F_raw str
+ | F_ctor_kind (f, ctor, unifiers, ctyp) -> F_ctor_kind (fold_frag f, ctor, unifiers, ctyp)
+ | F_ctor_unwrap (ctor, unifiers, f) -> F_ctor_unwrap (ctor, unifiers, fold_frag f)
| F_poly f -> F_poly (fold_frag f)
in
let rec fold_clexp = function
| CL_id (id, ctyp) ->
let i = get_count id + 1 in
- counts := Bindings.add id i !counts;
+ counts := NameMap.add id i !counts;
push_stack id i;
- CL_id (append_id id ("_" ^ string_of_int i), ctyp)
+ CL_id (ssa_name i id, ctyp)
| CL_field (clexp, field) -> CL_field (fold_clexp clexp, field)
| CL_addr clexp -> CL_addr (fold_clexp clexp)
| CL_tuple (clexp, n) -> CL_tuple (fold_clexp clexp, n)
- | CL_current_exception ctyp -> CL_current_exception ctyp
- | CL_have_exception -> CL_have_exception
- | CL_return ctyp -> CL_return ctyp
| CL_void -> CL_void
in
@@ -465,15 +485,20 @@ let rename_variables graph root children =
I_copy (fold_clexp clexp, cval)
| I_decl (ctyp, id) ->
let i = get_count id + 1 in
- counts := Bindings.add id i !counts;
+ counts := NameMap.add id i !counts;
push_stack id i;
- I_decl (ctyp, append_id id ("_" ^ string_of_int i))
+ I_decl (ctyp, ssa_name i id)
| I_init (ctyp, id, cval) ->
let cval = fold_cval cval in
let i = get_count id + 1 in
- counts := Bindings.add id i !counts;
+ counts := NameMap.add id i !counts;
push_stack id i;
- I_init (ctyp, append_id id ("_" ^ string_of_int i), cval)
+ I_init (ctyp, ssa_name i id, cval)
+ | I_jump (cval, label) ->
+ I_jump (fold_cval cval, label)
+ | I_end id ->
+ let i = top_stack id in
+ I_end (ssa_name i id)
| instr -> instr
in
I_aux (aux, annot)
@@ -483,24 +508,28 @@ let rename_variables graph root children =
| CF_start -> CF_start
| CF_block instrs -> CF_block (List.map ssa_instr instrs)
| CF_label label -> CF_label label
+ | CF_guard cval -> CF_guard (fold_cval cval)
in
let ssa_ssanode = function
- | Phi (id, args) ->
+ | Phi (id, ctyp, args) ->
let i = get_count id + 1 in
- counts := Bindings.add id i !counts;
+ counts := NameMap.add id i !counts;
push_stack id i;
- Phi (append_id id ("_" ^ string_of_int i), args)
+ Phi (ssa_name i id, ctyp, args)
+ | Pi _ -> assert false (* Should not be introduced at this point *)
in
let fix_phi j = function
- | Phi (id, ids) ->
- Phi (id, List.mapi (fun k a ->
- if k = j then
- let i = top_stack a in
- append_id a ("_" ^ string_of_int i)
- else a)
- ids)
+ | Phi (id, ctyp, ids) ->
+ let fix_arg k a =
+ if k = j then
+ let i = top_stack a in
+ ssa_name i a
+ else a
+ in
+ Phi (id, ctyp, List.mapi fix_arg ids)
+ | Pi _ -> assert false (* Should not be introduced at this point *)
in
let rec rename n =
@@ -529,6 +558,53 @@ let rename_variables graph root children =
in
rename root
+let place_pi_functions graph start idom children =
+ let get_guard = function
+ | CF_guard guard -> [guard]
+ | _ -> []
+ in
+ let get_pi_contents ssanodes =
+ List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes)
+ in
+
+ let rec go n =
+ begin match graph.nodes.(n) with
+ | Some ((ssa, cfnode), preds, succs) ->
+ let p = idom.(n) in
+ if p <> -1 then
+ begin match graph.nodes.(p) with
+ | Some ((dom_ssa, _), _, _) ->
+ let args = get_guard cfnode @ get_pi_contents dom_ssa in
+ graph.nodes.(n) <- Some ((Pi args :: ssa, cfnode), preds, succs)
+ | None -> assert false
+ end
+ | None -> assert false
+ end;
+ IntSet.iter go children.(n)
+ in
+ go start
+
+(** Remove p nodes. Assumes the graph is acyclic. *)
+let remove_nodes remove_cf graph =
+ for n = 0 to graph.next - 1 do
+ match graph.nodes.(n) with
+ | Some ((_, cfnode), preds, succs) when remove_cf cfnode ->
+ IntSet.iter (fun pred ->
+ match graph.nodes.(pred) with
+ | Some (content, preds', succs') ->
+ graph.nodes.(pred) <- Some (content, preds', IntSet.remove n (IntSet.union succs succs'))
+ | None -> assert false
+ ) preds;
+ IntSet.iter (fun succ ->
+ match graph.nodes.(succ) with
+ | Some (content, preds', succs') ->
+ graph.nodes.(succ) <- Some (content, IntSet.remove n (IntSet.union preds preds'), succs')
+ | None -> assert false
+ ) succs;
+ graph.nodes.(n) <- None
+ | _ -> ()
+ done
+
let ssa instrs =
let start, finish, cfg = control_flow_graph instrs in
let idom = immediate_dominators cfg start in
@@ -536,36 +612,39 @@ let ssa instrs =
let df = dominance_frontiers cfg start idom children in
place_phi_functions cfg df;
rename_variables cfg start children;
- cfg
+ place_pi_functions cfg start idom children;
+ (* remove_guard_nodes (function CF_guard _ -> true | CF_label _ -> true | _ -> false) cfg; *)
+ start, cfg
(* Debugging utilities for outputing Graphviz files. *)
+let string_of_ssainstr = function
+ | Phi (id, ctyp, args) ->
+ string_of_name id ^ " : " ^ string_of_ctyp ctyp ^ " = &phi;(" ^ Util.string_of_list ", " string_of_name args ^ ")"
+ | Pi cvals ->
+ "&pi;(" ^ Util.string_of_list ", " (fun (f, _) -> String.escaped (string_of_fragment ~zencode:false f)) cvals ^ ")"
+
let string_of_phis = function
| [] -> ""
- | phis -> Util.string_of_list "\\l" (fun (Phi (id, args)) -> string_of_id id ^ " = phi(" ^ Util.string_of_list ", " string_of_id args ^ ")") phis ^ "\\l"
+ | phis -> Util.string_of_list "\\l" string_of_ssainstr phis ^ "\\l"
let string_of_node = function
| (phis, CF_label label) -> string_of_phis phis ^ label
| (phis, CF_block instrs) -> string_of_phis phis ^ Util.string_of_list "\\l" (fun instr -> String.escaped (Pretty_print_sail.to_string (pp_instr ~short:true instr))) instrs
| (phis, CF_start) -> string_of_phis phis ^ "START"
+ | (phis, CF_guard cval) -> string_of_phis phis ^ (String.escaped (Pretty_print_sail.to_string (pp_cval cval)))
let vertex_color = function
| (_, CF_start) -> "peachpuff"
| (_, CF_block _) -> "white"
| (_, CF_label _) -> "springgreen"
-
-let edge_color node_from node_to =
- match node_from, node_to with
- | CF_block _, CF_block _ -> "black"
- | CF_label _, CF_block _ -> "red"
- | CF_block _, CF_label _ -> "blue"
- | _, _ -> "deeppink"
+ | (_, CF_guard _) -> "yellow"
let make_dot out_chan graph =
Util.opt_colors := false;
output_string out_chan "digraph DEPS {\n";
let make_node i n =
- output_string out_chan (Printf.sprintf " n%i [label=\"%s\";shape=box;style=filled;fillcolor=%s];\n" i (string_of_node n) (vertex_color n))
+ output_string out_chan (Printf.sprintf " n%i [label=\"%i\\n%s\\l\";shape=box;style=filled;fillcolor=%s];\n" i i (string_of_node n) (vertex_color n))
in
let make_line i s =
output_string out_chan (Printf.sprintf " n%i -> n%i [color=black];\n" i s)
@@ -584,7 +663,7 @@ let make_dominators_dot out_chan idom graph =
Util.opt_colors := false;
output_string out_chan "digraph DOMS {\n";
let make_node i n =
- output_string out_chan (Printf.sprintf " n%i [label=\"%s\";shape=box;style=filled;fillcolor=%s];\n" i (string_of_node n) (vertex_color n))
+ output_string out_chan (Printf.sprintf " n%i [label=\"%i\\n%s\\l\";shape=box;style=filled;fillcolor=%s];\n" i i (string_of_node n) (vertex_color n))
in
let make_line i s =
output_string out_chan (Printf.sprintf " n%i -> n%i [color=black];\n" i s)
diff --git a/src/jib/jib_ssa.mli b/src/jib/jib_ssa.mli
index 3796a114..b146861c 100644
--- a/src/jib/jib_ssa.mli
+++ b/src/jib/jib_ssa.mli
@@ -57,6 +57,12 @@ type 'a array_graph
underlying array. *)
val make : initial_size:int -> unit -> 'a array_graph
+module IntSet : Set.S with type elt = int
+
+val get_vertex : 'a array_graph -> int -> ('a * IntSet.t * IntSet.t) option
+
+val iter_graph : ('a -> IntSet.t -> IntSet.t -> unit) -> 'a array_graph -> unit
+
(** Add a vertex to a graph, returning the index of the inserted
vertex. If the number of vertices exceeds the size of the
underlying array, then it is dynamically resized. *)
@@ -69,17 +75,25 @@ val add_edge : int -> int -> 'a array_graph -> unit
type cf_node =
| CF_label of string
| CF_block of Jib.instr list
+ | CF_guard of Jib.cval
| CF_start
val control_flow_graph : Jib.instr list -> int * int list * ('a list * cf_node) array_graph
+(** [immediate_dominators graph root] will calculate the immediate
+ dominators for a control flow graph with a specified root node. *)
+val immediate_dominators : 'a array_graph -> int -> int array
+
type ssa_elem =
- | Phi of Ast.id * Ast.id list
+ | Phi of Jib.name * Jib.ctyp * Jib.name list
+ | Pi of Jib.cval list
(** Convert a list of instructions into SSA form *)
-val ssa : Jib.instr list -> (ssa_elem list * cf_node) array_graph
+val ssa : Jib.instr list -> int * (ssa_elem list * cf_node) array_graph
(** Output the control-flow graph in graphviz format for
debugging. Can use 'dot -Tpng X.gv -o X.png' to generate a png
image of the graph. *)
val make_dot : out_channel -> (ssa_elem list * cf_node) array_graph -> unit
+
+val make_dominators_dot : out_channel -> int array -> (ssa_elem list * cf_node) array_graph -> unit
diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml
index 81cd07ef..904e0209 100644
--- a/src/jib/jib_util.ml
+++ b/src/jib/jib_util.ml
@@ -82,12 +82,12 @@ let ifuncall ?loc:(l=Parse_ast.Unknown) clexp id cvals =
let iextern ?loc:(l=Parse_ast.Unknown) clexp id cvals =
I_aux (I_funcall (clexp, true, id, cvals), (instr_number (), l))
+let icall ?loc:(l=Parse_ast.Unknown) clexp extern id cvals =
+ I_aux (I_funcall (clexp, extern, id, cvals), (instr_number (), l))
+
let icopy l clexp cval =
I_aux (I_copy (clexp, cval), (instr_number (), l))
-let ialias l clexp cval =
- I_aux (I_alias (clexp, cval), (instr_number (), l))
-
let iclear ?loc:(l=Parse_ast.Unknown) ctyp id =
I_aux (I_clear (ctyp, id), (instr_number (), l))
@@ -95,7 +95,7 @@ let ireturn ?loc:(l=Parse_ast.Unknown) cval =
I_aux (I_return cval, (instr_number (), l))
let iend ?loc:(l=Parse_ast.Unknown) () =
- I_aux (I_end, (instr_number (), l))
+ I_aux (I_end (Return (-1)), (instr_number (), l))
let iblock ?loc:(l=Parse_ast.Unknown) instrs =
I_aux (I_block instrs, (instr_number (), l))
@@ -105,11 +105,13 @@ let itry_block ?loc:(l=Parse_ast.Unknown) instrs =
let ithrow ?loc:(l=Parse_ast.Unknown) cval =
I_aux (I_throw cval, (instr_number (), l))
+
let icomment ?loc:(l=Parse_ast.Unknown) str =
I_aux (I_comment str, (instr_number (), l))
let ilabel ?loc:(l=Parse_ast.Unknown) label =
I_aux (I_label label, (instr_number (), l))
+
let igoto ?loc:(l=Parse_ast.Unknown) label =
I_aux (I_goto label, (instr_number (), l))
@@ -125,25 +127,52 @@ let iraw ?loc:(l=Parse_ast.Unknown) str =
let ijump ?loc:(l=Parse_ast.Unknown) cval label =
I_aux (I_jump (cval, label), (instr_number (), l))
+module Name = struct
+ type t = name
+ let compare id1 id2 =
+ match id1, id2 with
+ | Name (x, n), Name (y, m) ->
+ let c1 = Id.compare x y in
+ if c1 = 0 then compare n m else c1
+ | Have_exception n, Have_exception m -> compare n m
+ | Current_exception n, Current_exception m -> compare n m
+ | Return n, Return m -> compare n m
+ | Name _, _ -> 1
+ | _, Name _ -> -1
+ | Have_exception _, _ -> 1
+ | _, Have_exception _ -> -1
+ | Current_exception _, _ -> 1
+ | _, Current_exception _ -> -1
+end
+
+module NameSet = Set.Make(Name)
+module NameMap = Map.Make(Name)
+
+let current_exception = Current_exception (-1)
+let have_exception = Have_exception (-1)
+let return = Return (-1)
+
+let name id = Name (id, -1)
+
let rec frag_rename from_id to_id = function
- | F_id id when Id.compare id from_id = 0 -> F_id to_id
+ | F_id id when Name.compare id from_id = 0 -> F_id to_id
| F_id id -> F_id id
- | F_ref id when Id.compare id from_id = 0 -> F_ref to_id
+ | F_ref id when Name.compare id from_id = 0 -> F_ref to_id
| F_ref id -> F_ref id
| F_lit v -> F_lit v
- | F_have_exception -> F_have_exception
- | F_current_exception -> F_current_exception
| F_call (call, frags) -> F_call (call, List.map (frag_rename from_id to_id) frags)
| F_op (f1, op, f2) -> F_op (frag_rename from_id to_id f1, op, frag_rename from_id to_id f2)
| F_unary (op, f) -> F_unary (op, frag_rename from_id to_id f)
| F_field (f, field) -> F_field (frag_rename from_id to_id f, field)
| F_raw raw -> F_raw raw
+ | F_ctor_kind (f, ctor, unifiers, ctyp) -> F_ctor_kind (frag_rename from_id to_id f, ctor, unifiers, ctyp)
+ | F_ctor_unwrap (ctor, unifiers, f) -> F_ctor_unwrap (ctor, unifiers, frag_rename from_id to_id f)
| F_poly f -> F_poly (frag_rename from_id to_id f)
let cval_rename from_id to_id (frag, ctyp) = (frag_rename from_id to_id frag, ctyp)
let rec clexp_rename from_id to_id = function
- | CL_id (id, ctyp) when Id.compare id from_id = 0 -> CL_id (to_id, ctyp)
+ | CL_id (id, ctyp) when Name.compare id from_id = 0 -> CL_id (to_id, ctyp)
| CL_id (id, ctyp) -> CL_id (id, ctyp)
| CL_field (clexp, field) ->
CL_field (clexp_rename from_id to_id clexp, field)
@@ -151,17 +180,14 @@ let rec clexp_rename from_id to_id = function
CL_addr (clexp_rename from_id to_id clexp)
| CL_tuple (clexp, n) ->
CL_tuple (clexp_rename from_id to_id clexp, n)
- | CL_current_exception ctyp -> CL_current_exception ctyp
- | CL_have_exception -> CL_have_exception
- | CL_return ctyp -> CL_return ctyp
| CL_void -> CL_void
let rec instr_rename from_id to_id (I_aux (instr, aux)) =
let instr = match instr with
- | I_decl (ctyp, id) when Id.compare id from_id = 0 -> I_decl (ctyp, to_id)
+ | I_decl (ctyp, id) when Name.compare id from_id = 0 -> I_decl (ctyp, to_id)
| I_decl (ctyp, id) -> I_decl (ctyp, id)
- | I_init (ctyp, id, cval) when Id.compare id from_id = 0 ->
+ | I_init (ctyp, id, cval) when Name.compare id from_id = 0 ->
I_init (ctyp, to_id, cval_rename from_id to_id cval)
| I_init (ctyp, id, cval) ->
I_init (ctyp, id, cval_rename from_id to_id cval)
@@ -178,9 +204,8 @@ let rec instr_rename from_id to_id (I_aux (instr, aux)) =
I_funcall (clexp_rename from_id to_id clexp, extern, id, List.map (cval_rename from_id to_id) args)
| I_copy (clexp, cval) -> I_copy (clexp_rename from_id to_id clexp, cval_rename from_id to_id cval)
- | I_alias (clexp, cval) -> I_alias (clexp_rename from_id to_id clexp, cval_rename from_id to_id cval)
- | I_clear (ctyp, id) when Id.compare id from_id = 0 -> I_clear (ctyp, to_id)
+ | I_clear (ctyp, id) when Name.compare id from_id = 0 -> I_clear (ctyp, to_id)
| I_clear (ctyp, id) -> I_clear (ctyp, id)
| I_return cval -> I_return (cval_rename from_id to_id cval)
@@ -203,12 +228,13 @@ let rec instr_rename from_id to_id (I_aux (instr, aux)) =
| I_match_failure -> I_match_failure
- | I_end -> I_end
+ | I_end id when Name.compare id from_id = 0 -> I_end to_id
+ | I_end id -> I_end id
- | I_reset (ctyp, id) when Id.compare id from_id = 0 -> I_reset (ctyp, to_id)
+ | I_reset (ctyp, id) when Name.compare id from_id = 0 -> I_reset (ctyp, to_id)
| I_reset (ctyp, id) -> I_reset (ctyp, id)
- | I_reinit (ctyp, id, cval) when Id.compare id from_id = 0 ->
+ | I_reinit (ctyp, id, cval) when Name.compare id from_id = 0 ->
I_reinit (ctyp, to_id, cval_rename from_id to_id cval)
| I_reinit (ctyp, id, cval) ->
I_reinit (ctyp, id, cval_rename from_id to_id cval)
@@ -229,15 +255,24 @@ let string_of_value = function
| V_unit -> "UNIT"
| V_bit Sail2_values.B0 -> "UINT64_C(0)"
| V_bit Sail2_values.B1 -> "UINT64_C(1)"
+ | V_bit Sail2_values.BU -> failwith "Undefined bit found in value"
| V_string str -> "\"" ^ str ^ "\""
- | V_ctor_kind str -> "Kind_" ^ Util.zencode_string str
- | _ -> failwith "Cannot convert value to string"
+
+let string_of_name ?zencode:(zencode=true) =
+ let ssa_num n = if n < 0 then "" else ("/" ^ string_of_int n) in
+ function
+ | Name (id, n) ->
+ (if zencode then Util.zencode_string (string_of_id id) else string_of_id id) ^ ssa_num n
+ | Have_exception n ->
+ "have_exception" ^ ssa_num n
+ | Return n ->
+ "return" ^ ssa_num n
+ | Current_exception n ->
+ "(*current_exception)" ^ ssa_num n
let rec string_of_fragment ?zencode:(zencode=true) = function
- | F_id id when zencode -> Util.zencode_string (string_of_id id)
- | F_id id -> string_of_id id
- | F_ref id when zencode -> "&" ^ Util.zencode_string (string_of_id id)
- | F_ref id -> "&" ^ string_of_id id
+ | F_id id -> string_of_name ~zencode:zencode id
+ | F_ref id -> "&" ^ string_of_name ~zencode:zencode id
| F_lit v -> string_of_value v
| F_call (str, frags) ->
Printf.sprintf "%s(%s)" str (Util.string_of_list ", " (string_of_fragment ~zencode:zencode) frags)
@@ -247,9 +282,21 @@ let rec string_of_fragment ?zencode:(zencode=true) = function
Printf.sprintf "%s %s %s" (string_of_fragment' ~zencode:zencode f1) op (string_of_fragment' ~zencode:zencode f2)
| F_unary (op, f) ->
op ^ string_of_fragment' ~zencode:zencode f
- | F_have_exception -> "have_exception"
- | F_current_exception -> "(*current_exception)"
| F_raw raw -> raw
+ | F_ctor_kind (f, ctor, [], _) ->
+ string_of_fragment' ~zencode:zencode f ^ ".kind"
+ ^ " != Kind_" ^ Util.zencode_string (string_of_id ctor)
+ | F_ctor_kind (f, ctor, unifiers, _) ->
+ string_of_fragment' ~zencode:zencode f ^ ".kind"
+ ^ " != Kind_" ^ Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)
+ | F_ctor_unwrap (ctor, [], f) ->
+ Printf.sprintf "%s.%s"
+ (string_of_fragment' ~zencode:zencode f)
+ (Util.zencode_string (string_of_id ctor))
+ | F_ctor_unwrap (ctor, unifiers, f) ->
+ Printf.sprintf "%s.%s"
+ (string_of_fragment' ~zencode:zencode f)
+ (Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers))
| F_poly f -> string_of_fragment ~zencode:zencode f
and string_of_fragment' ?zencode:(zencode=true) f =
match f with
@@ -284,9 +331,14 @@ and string_of_ctyp = function
constructors in variants and structs. Used for debug output. *)
and full_string_of_ctyp = function
| CT_tup ctyps -> "(" ^ Util.string_of_list ", " full_string_of_ctyp ctyps ^ ")"
- | CT_struct (id, ctors) | CT_variant (id, ctors) ->
+ | CT_struct (id, ctors) ->
"struct " ^ string_of_id id
- ^ "{ "
+ ^ "{"
+ ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors
+ ^ "}"
+ | CT_variant (id, ctors) ->
+ "union " ^ string_of_id id
+ ^ "{"
^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors
^ "}"
| CT_vector (true, ctyp) -> "vector(dec, " ^ full_string_of_ctyp ctyp ^ ")"
@@ -407,6 +459,13 @@ let rec ctyp_unify ctyp1 ctyp2 =
| CT_list ctyp1, CT_list ctyp2 -> ctyp_unify ctyp1 ctyp2
+ | CT_struct (id1, fields1), CT_struct (id2, fields2)
+ when Id.compare id1 id2 = 0 && List.length fields1 == List.length fields2 ->
+ if List.for_all2 (fun x y -> x = y) (List.map fst fields1) (List.map fst fields2) then
+ List.concat (List.map2 ctyp_unify (List.map snd fields1) (List.map snd fields2))
+ else
+ raise (Invalid_argument "ctyp_unify")
+
| CT_ref ctyp1, CT_ref ctyp2 -> ctyp_unify ctyp1 ctyp2
| CT_poly, _ -> [ctyp2]
@@ -466,6 +525,9 @@ let rec is_polymorphic = function
let pp_id id =
string (string_of_id id)
+let pp_name id =
+ string (string_of_name ~zencode:false id)
+
let pp_ctyp ctyp =
string (full_string_of_ctyp ctyp |> Util.yellow |> Util.clear)
@@ -476,19 +538,16 @@ let pp_cval (frag, ctyp) =
string (string_of_fragment ~zencode:false frag) ^^ string " : " ^^ pp_ctyp ctyp
let rec pp_clexp = function
- | CL_id (id, ctyp) -> pp_id id ^^ string " : " ^^ pp_ctyp ctyp
+ | CL_id (id, ctyp) -> pp_name id ^^ string " : " ^^ pp_ctyp ctyp
| CL_field (clexp, field) -> parens (pp_clexp clexp) ^^ string "." ^^ string field
| CL_tuple (clexp, n) -> parens (pp_clexp clexp) ^^ string "." ^^ string (string_of_int n)
| CL_addr clexp -> string "*" ^^ pp_clexp clexp
- | CL_current_exception ctyp -> string "current_exception : " ^^ pp_ctyp ctyp
- | CL_have_exception -> string "have_exception"
- | CL_return ctyp -> string "return : " ^^ pp_ctyp ctyp
| CL_void -> string "void"
let rec pp_instr ?short:(short=false) (I_aux (instr, aux)) =
match instr with
| I_decl (ctyp, id) ->
- pp_keyword "var" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp
+ pp_keyword "var" ^^ pp_name id ^^ string " : " ^^ pp_ctyp ctyp
| I_if (cval, then_instrs, else_instrs, ctyp) ->
let pp_if_block = function
| [] -> string "{}"
@@ -508,20 +567,18 @@ let rec pp_instr ?short:(short=false) (I_aux (instr, aux)) =
| I_try_block instrs ->
pp_keyword "try" ^^ surround 2 0 lbrace (separate_map (semi ^^ hardline) pp_instr instrs) rbrace
| I_reset (ctyp, id) ->
- pp_keyword "recreate" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp
+ pp_keyword "recreate" ^^ pp_name id ^^ string " : " ^^ pp_ctyp ctyp
| I_init (ctyp, id, cval) ->
- pp_keyword "create" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp ^^ string " = " ^^ pp_cval cval
+ pp_keyword "create" ^^ pp_name id ^^ string " : " ^^ pp_ctyp ctyp ^^ string " = " ^^ pp_cval cval
| I_reinit (ctyp, id, cval) ->
- pp_keyword "recreate" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp ^^ string " = " ^^ pp_cval cval
+ pp_keyword "recreate" ^^ pp_name id ^^ string " : " ^^ pp_ctyp ctyp ^^ string " = " ^^ pp_cval cval
| I_funcall (x, _, f, args) ->
separate space [ pp_clexp x; string "=";
string (string_of_id f |> Util.green |> Util.clear) ^^ parens (separate_map (string ", ") pp_cval args) ]
| I_copy (clexp, cval) ->
separate space [pp_clexp clexp; string "="; pp_cval cval]
- | I_alias (clexp, cval) ->
- pp_keyword "alias" ^^ separate space [pp_clexp clexp; string "="; pp_cval cval]
| I_clear (ctyp, id) ->
- pp_keyword "kill" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp
+ pp_keyword "kill" ^^ pp_name id ^^ string " : " ^^ pp_ctyp ctyp
| I_return cval ->
pp_keyword "return" ^^ pp_cval cval
| I_throw cval ->
@@ -534,7 +591,7 @@ let rec pp_instr ?short:(short=false) (I_aux (instr, aux)) =
pp_keyword "goto" ^^ string (str |> Util.blue |> Util.clear)
| I_match_failure ->
pp_keyword "match_failure"
- | I_end ->
+ | I_end _ ->
pp_keyword "end"
| I_undefined ctyp ->
pp_keyword "undefined" ^^ pp_ctyp ctyp
@@ -584,55 +641,72 @@ let pp_cdef = function
^^ hardline
let rec fragment_deps = function
- | F_id id | F_ref id -> IdSet.singleton id
- | F_lit _ -> IdSet.empty
+ | F_id id | F_ref id -> NameSet.singleton id
+ | F_lit _ -> NameSet.empty
| F_field (frag, _) | F_unary (_, frag) | F_poly frag -> fragment_deps frag
- | F_call (_, frags) -> List.fold_left IdSet.union IdSet.empty (List.map fragment_deps frags)
- | F_op (frag1, _, frag2) -> IdSet.union (fragment_deps frag1) (fragment_deps frag2)
- | F_current_exception -> IdSet.empty
- | F_have_exception -> IdSet.empty
- | F_raw _ -> IdSet.empty
+ | F_call (_, frags) -> List.fold_left NameSet.union NameSet.empty (List.map fragment_deps frags)
+ | F_op (frag1, _, frag2) -> NameSet.union (fragment_deps frag1) (fragment_deps frag2)
+ | F_ctor_kind (frag, _, _, _) -> fragment_deps frag
+ | F_ctor_unwrap (_, _, frag) -> fragment_deps frag
+ | F_raw _ -> NameSet.empty
let cval_deps = function (frag, _) -> fragment_deps frag
let rec clexp_deps = function
- | CL_id (id, _) -> IdSet.singleton id
+ | CL_id (id, _) -> NameSet.singleton id
| CL_field (clexp, _) -> clexp_deps clexp
| CL_tuple (clexp, _) -> clexp_deps clexp
| CL_addr clexp -> clexp_deps clexp
- | CL_have_exception -> IdSet.empty
- | CL_current_exception _ -> IdSet.empty
- | CL_return _ -> IdSet.empty
- | CL_void -> IdSet.empty
+ | CL_void -> NameSet.empty
(* Return the direct, read/write dependencies of a single instruction *)
let instr_deps = function
- | I_decl (ctyp, id) -> IdSet.empty, IdSet.singleton id
- | I_reset (ctyp, id) -> IdSet.empty, IdSet.singleton id
- | I_init (ctyp, id, cval) | I_reinit (ctyp, id, cval) -> cval_deps cval, IdSet.singleton id
- | I_if (cval, _, _, _) -> cval_deps cval, IdSet.empty
- | I_jump (cval, label) -> cval_deps cval, IdSet.empty
- | I_funcall (clexp, _, _, cvals) -> List.fold_left IdSet.union IdSet.empty (List.map cval_deps cvals), clexp_deps clexp
+ | I_decl (ctyp, id) -> NameSet.empty, NameSet.singleton id
+ | I_reset (ctyp, id) -> NameSet.empty, NameSet.singleton id
+ | I_init (ctyp, id, cval) | I_reinit (ctyp, id, cval) -> cval_deps cval, NameSet.singleton id
+ | I_if (cval, _, _, _) -> cval_deps cval, NameSet.empty
+ | I_jump (cval, label) -> cval_deps cval, NameSet.empty
+ | I_funcall (clexp, _, _, cvals) -> List.fold_left NameSet.union NameSet.empty (List.map cval_deps cvals), clexp_deps clexp
| I_copy (clexp, cval) -> cval_deps cval, clexp_deps clexp
- | I_alias (clexp, cval) -> cval_deps cval, clexp_deps clexp
- | I_clear (_, id) -> IdSet.singleton id, IdSet.empty
- | I_throw cval | I_return cval -> cval_deps cval, IdSet.empty
- | I_block _ | I_try_block _ -> IdSet.empty, IdSet.empty
- | I_comment _ | I_raw _ -> IdSet.empty, IdSet.empty
- | I_label label -> IdSet.empty, IdSet.empty
- | I_goto label -> IdSet.empty, IdSet.empty
- | I_undefined _ -> IdSet.empty, IdSet.empty
- | I_match_failure -> IdSet.empty, IdSet.empty
- | I_end -> IdSet.empty, IdSet.empty
+ | I_clear (_, id) -> NameSet.singleton id, NameSet.empty
+ | I_throw cval | I_return cval -> cval_deps cval, NameSet.empty
+ | I_block _ | I_try_block _ -> NameSet.empty, NameSet.empty
+ | I_comment _ | I_raw _ -> NameSet.empty, NameSet.empty
+ | I_label label -> NameSet.empty, NameSet.empty
+ | I_goto label -> NameSet.empty, NameSet.empty
+ | I_undefined _ -> NameSet.empty, NameSet.empty
+ | I_match_failure -> NameSet.empty, NameSet.empty
+ | I_end id -> NameSet.singleton id, NameSet.empty
+
+module NameCT = struct
+ type t = name * ctyp
+ let compare (n1, ctyp1) (n2, ctyp2) =
+ let c = Name.compare n1 n2 in
+ if c = 0 then CT.compare ctyp1 ctyp2 else c
+end
+
+module NameCTSet = Set.Make(NameCT)
+module NameCTMap = Map.Make(NameCT)
+
+let rec clexp_typed_writes = function
+ | CL_id (id, ctyp) -> NameCTSet.singleton (id, ctyp)
+ | CL_field (clexp, _) -> clexp_typed_writes clexp
+ | CL_tuple (clexp, _) -> clexp_typed_writes clexp
+ | CL_addr clexp -> clexp_typed_writes clexp
+ | CL_void -> NameCTSet.empty
+
+let instr_typed_writes (I_aux (aux, _)) =
+ match aux with
+ | I_decl (ctyp, id) | I_reset (ctyp, id) -> NameCTSet.singleton (id, ctyp)
+ | I_init (ctyp, id, _) | I_reinit (ctyp, id, _) -> NameCTSet.singleton (id, ctyp)
+ | I_funcall (clexp, _, _, _) | I_copy (clexp, _) -> clexp_typed_writes clexp
+ | _ -> NameCTSet.empty
let rec map_clexp_ctyp f = function
| CL_id (id, ctyp) -> CL_id (id, f ctyp)
| CL_field (clexp, field) -> CL_field (map_clexp_ctyp f clexp, field)
| CL_tuple (clexp, n) -> CL_tuple (map_clexp_ctyp f clexp, n)
| CL_addr clexp -> CL_addr (map_clexp_ctyp f clexp)
- | CL_current_exception ctyp -> CL_current_exception (f ctyp)
- | CL_have_exception -> CL_have_exception
- | CL_return ctyp -> CL_return (f ctyp)
| CL_void -> CL_void
let rec map_instr_ctyp f (I_aux (instr, aux)) =
@@ -645,7 +719,6 @@ let rec map_instr_ctyp f (I_aux (instr, aux)) =
| I_funcall (clexp, extern, id, cvals) ->
I_funcall (map_clexp_ctyp f clexp, extern, id, List.map (fun (frag, ctyp) -> frag, f ctyp) cvals)
| I_copy (clexp, (frag, ctyp)) -> I_copy (map_clexp_ctyp f clexp, (frag, f ctyp))
- | I_alias (clexp, (frag, ctyp)) -> I_alias (map_clexp_ctyp f clexp, (frag, f ctyp))
| I_clear (ctyp, id) -> I_clear (f ctyp, id)
| I_return (frag, ctyp) -> I_return (frag, f ctyp)
| I_block instrs -> I_block (List.map (map_instr_ctyp f) instrs)
@@ -654,7 +727,7 @@ let rec map_instr_ctyp f (I_aux (instr, aux)) =
| I_undefined ctyp -> I_undefined (f ctyp)
| I_reset (ctyp, id) -> I_reset (f ctyp, id)
| I_reinit (ctyp1, id, (frag, ctyp2)) -> I_reinit (f ctyp1, id, (frag, f ctyp2))
- | I_end -> I_end
+ | I_end id -> I_end id
| (I_comment _ | I_raw _ | I_label _ | I_goto _ | I_match_failure) as instr -> instr
in
I_aux (instr, aux)
@@ -663,8 +736,8 @@ let rec map_instr_ctyp f (I_aux (instr, aux)) =
let rec map_instr f (I_aux (instr, aux)) =
let instr = match instr with
| I_decl _ | I_init _ | I_reset _ | I_reinit _
- | I_funcall _ | I_copy _ | I_alias _ | I_clear _ | I_jump _ | I_throw _ | I_return _
- | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_undefined _ | I_end -> instr
+ | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _
+ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_undefined _ | I_end _ -> instr
| I_if (cval, instrs1, instrs2, ctyp) ->
I_if (cval, List.map (map_instr f) instrs1, List.map (map_instr f) instrs2, ctyp)
| I_block instrs ->
@@ -678,8 +751,8 @@ let rec map_instr f (I_aux (instr, aux)) =
let rec iter_instr f (I_aux (instr, aux)) =
match instr with
| I_decl _ | I_init _ | I_reset _ | I_reinit _
- | I_funcall _ | I_copy _ | I_alias _ | I_clear _ | I_jump _ | I_throw _ | I_return _
- | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_undefined _ | I_end -> f (I_aux (instr, aux))
+ | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _
+ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_undefined _ | I_end _ -> f (I_aux (instr, aux))
| I_if (cval, instrs1, instrs2, ctyp) ->
List.iter (iter_instr f) instrs1;
List.iter (iter_instr f) instrs2
@@ -717,10 +790,10 @@ let rec map_instrs f (I_aux (instr, aux)) =
| I_decl _ | I_init _ | I_reset _ | I_reinit _ -> instr
| I_if (cval, instrs1, instrs2, ctyp) ->
I_if (cval, f (List.map (map_instrs f) instrs1), f (List.map (map_instrs f) instrs2), ctyp)
- | I_funcall _ | I_copy _ | I_alias _ | I_clear _ | I_jump _ | I_throw _ | I_return _ -> instr
+ | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _ -> instr
| I_block instrs -> I_block (f (List.map (map_instrs f) instrs))
| I_try_block instrs -> I_try_block (f (List.map (map_instrs f) instrs))
- | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_undefined _ | I_end -> instr
+ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_undefined _ | I_end _ -> instr
in
I_aux (instr, aux)
@@ -732,7 +805,7 @@ let map_instrs_list f instrs =
let rec instr_ids (I_aux (instr, _)) =
let reads, writes = instr_deps instr in
- IdSet.union reads writes
+ NameSet.union reads writes
let rec instr_reads (I_aux (instr, _)) =
fst (instr_deps instr)
@@ -764,7 +837,6 @@ let cval_ctyp = function (_, ctyp) -> ctyp
let rec clexp_ctyp = function
| CL_id (_, ctyp) -> ctyp
- | CL_return ctyp -> ctyp
| CL_field (clexp, field) ->
begin match clexp_ctyp clexp with
| CT_struct (id, ctors) ->
@@ -788,8 +860,6 @@ let rec clexp_ctyp = function
end
| ctyp -> failwith ("Bad ctyp for CL_addr " ^ string_of_ctyp ctyp)
end
- | CL_have_exception -> CT_bool
- | CL_current_exception ctyp -> ctyp
| CL_void -> CT_unit
let rec instr_ctyps (I_aux (instr, aux)) =
@@ -805,13 +875,13 @@ let rec instr_ctyps (I_aux (instr, aux)) =
| I_funcall (clexp, _, _, cvals) ->
List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty (List.map cval_ctyp cvals)
|> CTSet.add (clexp_ctyp clexp)
- | I_copy (clexp, cval) | I_alias (clexp, cval) ->
+ | I_copy (clexp, cval) ->
CTSet.add (clexp_ctyp clexp) (CTSet.singleton (cval_ctyp cval))
| I_block instrs | I_try_block instrs ->
instrs_ctyps instrs
| I_throw cval | I_jump (cval, _) | I_return cval ->
CTSet.singleton (cval_ctyp cval)
- | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_end ->
+ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_end _ ->
CTSet.empty
and instrs_ctyps instrs = List.fold_left CTSet.union CTSet.empty (List.map instr_ctyps instrs)
@@ -848,12 +918,12 @@ let instr_split_at f =
instr_split_at' f []
let rec instrs_rename from_id to_id =
- let rename id = if Id.compare id from_id = 0 then to_id else id in
+ let rename id = if Name.compare id from_id = 0 then to_id else id in
let crename = cval_rename from_id to_id in
let irename instrs = instrs_rename from_id to_id instrs in
let lrename = clexp_rename from_id to_id in
function
- | (I_aux (I_decl (ctyp, new_id), _) :: _) as instrs when Id.compare from_id new_id = 0 -> instrs
+ | (I_aux (I_decl (ctyp, new_id), _) :: _) as instrs when Name.compare from_id new_id = 0 -> instrs
| I_aux (I_decl (ctyp, new_id), aux) :: instrs -> I_aux (I_decl (ctyp, new_id), aux) :: irename instrs
| I_aux (I_reset (ctyp, id), aux) :: instrs -> I_aux (I_reset (ctyp, rename id), aux) :: irename instrs
| I_aux (I_init (ctyp, id, cval), aux) :: instrs -> I_aux (I_init (ctyp, rename id, crename cval), aux) :: irename instrs
@@ -861,14 +931,13 @@ let rec instrs_rename from_id to_id =
| I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs ->
I_aux (I_if (crename cval, irename then_instrs, irename else_instrs, ctyp), aux) :: irename instrs
| I_aux (I_jump (cval, label), aux) :: instrs -> I_aux (I_jump (crename cval, label), aux) :: irename instrs
- | I_aux (I_funcall (clexp, extern, id, cvals), aux) :: instrs ->
- I_aux (I_funcall (lrename clexp, extern, rename id, List.map crename cvals), aux) :: irename instrs
+ | I_aux (I_funcall (clexp, extern, function_id, cvals), aux) :: instrs ->
+ I_aux (I_funcall (lrename clexp, extern, function_id, List.map crename cvals), aux) :: irename instrs
| I_aux (I_copy (clexp, cval), aux) :: instrs -> I_aux (I_copy (lrename clexp, crename cval), aux) :: irename instrs
- | I_aux (I_alias (clexp, cval), aux) :: instrs -> I_aux (I_alias (lrename clexp, crename cval), aux) :: irename instrs
| I_aux (I_clear (ctyp, id), aux) :: instrs -> I_aux (I_clear (ctyp, rename id), aux) :: irename instrs
| I_aux (I_return cval, aux) :: instrs -> I_aux (I_return (crename cval), aux) :: irename instrs
| I_aux (I_block block, aux) :: instrs -> I_aux (I_block (irename block), aux) :: irename instrs
| I_aux (I_try_block block, aux) :: instrs -> I_aux (I_try_block (irename block), aux) :: irename instrs
| I_aux (I_throw cval, aux) :: instrs -> I_aux (I_throw (crename cval), aux) :: irename instrs
- | (I_aux ((I_comment _ | I_raw _ | I_end | I_label _ | I_goto _ | I_match_failure | I_undefined _), _) as instr) :: instrs -> instr :: irename instrs
+ | (I_aux ((I_comment _ | I_raw _ | I_end _ | I_label _ | I_goto _ | I_match_failure | I_undefined _), _) as instr) :: instrs -> instr :: irename instrs
| [] -> []
diff --git a/src/latex.ml b/src/latex.ml
index 1806da47..aa786b83 100644
--- a/src/latex.ml
+++ b/src/latex.ml
@@ -300,7 +300,7 @@ let rec read_lines in_chan = function
l :: ls
let latex_loc no_loc l =
- match simp_loc l with
+ match Reporting.simp_loc l with
| Some (p1, p2) ->
begin
let open Lexing in
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 9f82bb17..6c82fc72 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -339,7 +339,7 @@ let split_src_type id ty (TypQ_aux (q,ql)) =
in
let wrap = match id with
| Id_aux (Id i,l) -> (fun f -> Id_aux (Id (f i),Generated l))
- | Id_aux (DeIid i,l) -> (fun f -> Id_aux (DeIid (f i),l))
+ | Id_aux (Operator i,l) -> (fun f -> Id_aux (Operator (f i),l))
in
let name_seg = function
| (_,None) -> ""
@@ -442,7 +442,7 @@ let freshen_id =
let () = counter := n + 1 in
match id with
| Id_aux (Id x, l) -> Id_aux (Id (x ^ "#m" ^ string_of_int n),Generated l)
- | Id_aux (DeIid x, l) -> Id_aux (DeIid (x ^ "#m" ^ string_of_int n),Generated l)
+ | Id_aux (Operator x, l) -> Id_aux (Operator (x ^ "#m" ^ string_of_int n),Generated l)
(* TODO: only freshen bindings that might be shadowed *)
let rec freshen_pat_bindings p =
@@ -690,13 +690,19 @@ let split_defs all_errors splits defs =
| Typ_app (Id_aux (Id "vector",_), [A_aux (A_nexp len,_);_;A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
(match len with
- | Nexp_aux (Nexp_constant sz,_) ->
- let lits = make_vectors (Big_int.to_int sz) in
- List.map (fun lit ->
- P_aux (P_lit lit,(l,annot)),
- [var,E_aux (E_lit lit,(new_l,annot))],[],[]) lits
+ | Nexp_aux (Nexp_constant sz,_) when Big_int.greater_equal sz Big_int.zero ->
+ let sz = Big_int.to_int sz in
+ let num_lits = Big_int.pow_int (Big_int.of_int 2) sz in
+ (* Check that split size is within limits before generating the list of literals *)
+ if (Big_int.less_equal num_lits (Big_int.of_int size_set_limit)) then
+ let lits = make_vectors sz in
+ List.map (fun lit ->
+ P_aux (P_lit lit,(l,annot)),
+ [var,E_aux (E_lit lit,(new_l,annot))],[],[]) lits
+ else
+ cannot ("bitvector length outside limit, " ^ string_of_nexp len)
| _ ->
- cannot ("length not constant, " ^ string_of_nexp len)
+ cannot ("length not constant and positive, " ^ string_of_nexp len)
)
(* set constrained numbers *)
| Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (value,_) as nexp),_)]) ->
@@ -740,7 +746,7 @@ let split_defs all_errors splits defs =
let split_pat vars p =
let id_match = function
| Id_aux (Id x,_) -> (try Some (List.assoc x vars) with Not_found -> None)
- | Id_aux (DeIid x,_) -> (try Some (List.assoc x vars) with Not_found -> None)
+ | Id_aux (Operator x,_) -> (try Some (List.assoc x vars) with Not_found -> None)
in
let rec list f = function
@@ -1289,6 +1295,11 @@ let rewrite_size_parameters env (Defs defs) =
let pat,guard,exp,pannot = destruct_pexp pexp in
let env = env_of_annot (l,ann) in
let _, typ = Env.get_val_spec_orig id env in
+ let already_visible_nexps =
+ NexpSet.union
+ (Pretty_print_lem.lem_nexps_of_typ typ)
+ (Pretty_print_lem.typeclass_nexps typ)
+ in
let types = match typ with
| Typ_aux (Typ_fn (arg_typs,_,_),_) -> List.map (Env.expand_synonyms env) arg_typs
| _ -> raise (Reporting.err_unreachable l __POS__ "Function clause does not have a function type")
@@ -1299,11 +1310,14 @@ let rewrite_size_parameters env (Defs defs) =
Typ_aux (Typ_app(Id_aux (Id "range",_),
[A_aux (A_nexp nexp,_);
A_aux (A_nexp nexp',_)]),_)
- when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) ->
- NexpMap.add nexp i nmap
+ when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) &&
+ not (NexpSet.mem nexp already_visible_nexps) ->
+ (* Split integer variables if the nexp is not already available via a bitvector length *)
+ NexpMap.add nexp i nmap
| Typ_aux (Typ_app(Id_aux (Id "atom", _),
[A_aux (A_nexp nexp,_)]), _)
- when not (NexpMap.mem nexp nmap) ->
+ when not (NexpMap.mem nexp nmap) &&
+ not (NexpSet.mem nexp already_visible_nexps) ->
NexpMap.add nexp i nmap
| _ -> nmap
in (i+1,nmap)
@@ -2172,6 +2186,11 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) =
| E_constraint nc ->
(deps_of_nc env.kid_deps nc, assigns, empty)
in
+ let deps =
+ match destruct_atom_bool (env_of exp) (typ_of exp) with
+ | Some nc -> dmerge deps (deps_of_nc env.kid_deps nc)
+ | None -> deps
+ in
let r =
(* Check for bitvector types with parametrised sizes *)
match destruct_tannot annot with
@@ -2450,11 +2469,14 @@ let rec sets_from_assert e =
| None -> KBindings.empty)
| _ -> KBindings.empty
in
- match e with
- | E_aux (E_app (Id_aux (Id "and_bool",_),[e1;e2]),_) ->
- merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2)
- | E_aux (E_constraint nc,_) -> sets_from_nc nc
- | _ -> set_from_or_exps e
+ match destruct_atom_bool (env_of e) (typ_of e) with
+ | Some nc -> sets_from_nc nc
+ | None ->
+ match e with
+ | E_aux (E_app (Id_aux (Id "and_bool",_),[e1;e2]),_) ->
+ merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2)
+ | E_aux (E_constraint nc,_) -> sets_from_nc nc
+ | _ -> set_from_or_exps e
(* Find all the easily reached set assertions in a function body, to use as
case splits. Note that this should be mirrored in stop_at_false_assertions,
@@ -2670,12 +2692,17 @@ let rec rewrite_app env typ (id,args) =
let is_append = is_id env (Id "append") in
let is_subrange = is_id env (Id "vector_subrange") in
let is_slice = is_id env (Id "slice") in
- let is_zeros = is_id env (Id "Zeros") in
+ let is_zeros id =
+ is_id env (Id "Zeros") id || is_id env (Id "zeros") id ||
+ is_id env (Id "sail_zeros") id
+ in
+ let is_ones = is_id env (Id "Ones") in
let is_zero_extend =
is_id env (Id "ZeroExtend") id ||
is_id env (Id "zero_extend") id || is_id env (Id "sail_zero_extend") id ||
is_id env (Id "mips_zero_extend") id
in
+ let is_truncate = is_id env (Id "truncate") id in
let mk_exp e = E_aux (e, (Unknown, empty_tannot)) in
let try_cast_to_typ (E_aux (e,(l, _)) as exp) =
let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in
@@ -2777,6 +2804,17 @@ let rec rewrite_app env typ (id,args) =
(E_aux (E_app (mk_id "slice_slice_concat",
[vector1; start1; length1; vector2; start2; length2]),(Unknown,empty_tannot)))
+ (* variable-slice @ local-var *)
+ | [E_aux (E_app (slice1,
+ [vector1; start1; length1]),_);
+ (E_aux (E_id _,_) as vector2)]
+ when is_slice slice1 && not (is_constant length1) ->
+ let start2 = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in
+ let length2 = mk_exp (E_app (mk_id "length", [vector2])) in
+ try_cast_to_typ
+ (E_aux (E_app (mk_id "slice_slice_concat",
+ [vector1; start1; length1; vector2; start2; length2]),(Unknown,empty_tannot)))
+
| [E_aux (E_app (append1,
[e1;
E_aux (E_app (slice1, [vector1; start1; length1]),_)]),_);
@@ -2805,13 +2843,24 @@ let rec rewrite_app env typ (id,args) =
[vector1; start1; length1; length2]),(Unknown,empty_tannot))]),
(Unknown,empty_tannot)))
end
+
+ (* known-length @ (known-length @ var-length) *)
+ | [e1; E_aux (E_app (append1, [e2; e3]), _)]
+ when is_append append1 && is_constant_vec_typ env (typ_of e1) &&
+ is_constant_vec_typ env (typ_of e2) &&
+ not (is_constant_vec_typ env (typ_of e3)) ->
+ let (size1,order,bittyp) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in
+ let (size2,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e2)) in
+ let size12 = nexp_simp (nsum size1 size2) in
+ let tannot12 = mk_tannot env (vector_typ size12 order bittyp) no_effect in
+ E_app (id, [E_aux (E_app (append1, [e1; e2]), (Unknown, tannot12)); e3])
+
| _ -> E_app (id,args)
- else if is_id env (Id "eq_vec") id || is_id env (Id "neq_vec") id then
+ else if is_id env (Id "eq_bits") id || is_id env (Id "neq_bits") id then
(* variable-range == variable_range *)
- let is_subrange = is_id env (Id "vector_subrange") in
let wrap e =
- if is_id env (Id "neq_vec") id
+ if is_id env (Id "neq_bits") id
then E_app (mk_id "not_bool", [mk_exp e])
else e
in
@@ -2867,11 +2916,7 @@ let rec rewrite_app env typ (id,args) =
E_app (mk_id "is_ones_slice", [vector1; start1; len1])
| _ -> E_app (id,args)
- else if is_zero_extend then
- let is_subrange = is_id env (Id "vector_subrange") in
- let is_slice = is_id env (Id "slice") in
- let is_zeros = is_id env (Id "Zeros") in
- let is_ones = is_id env (Id "Ones") in
+ else if is_zero_extend || is_truncate then
let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in
match List.filter (fun arg -> not (is_number (typ_of arg))) args with
| [E_aux (E_app (append1,
@@ -2881,10 +2926,18 @@ let rec rewrite_app env typ (id,args) =
-> try_cast_to_typ (rewrap (E_app (mk_id "place_subrange", length_arg @ [vector1; start1; end1; len1])))
| [E_aux (E_app (append1,
- [E_aux (E_app (slice1, [vector1; start1; length1]), _);
+ [vector1;
E_aux (E_app (zeros1, [length2]),_)]),_)]
- when is_slice slice1 && is_zeros zeros1 && is_append append1
- -> try_cast_to_typ (rewrap (E_app (mk_id "place_slice", length_arg @ [vector1; start1; length1; length2])))
+ when is_constant_vec_typ env (typ_of vector1) && is_zeros zeros1 && is_append append1
+ -> let (vector1, start1, length1) =
+ match vector1 with
+ | E_aux (E_app (slice1, [vector1; start1; length1]), _) ->
+ (vector1, start1, length1)
+ | _ ->
+ let (length1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of vector1)) in
+ (vector1, mk_exp (E_lit (mk_lit (L_num (Big_int.zero)))), mk_exp (E_sizeof length1))
+ in
+ try_cast_to_typ (rewrap (E_app (mk_id "place_slice", length_arg @ [vector1; start1; length1; length2])))
(* If we've already rewritten to slice_slice_concat or subrange_subrange_concat,
we can just drop the zero extension because those functions can do it
@@ -2902,10 +2955,19 @@ let rec rewrite_app env typ (id,args) =
| [E_aux (E_app (ones, [len1]),_)] when is_ones ones ->
try_cast_to_typ (rewrap (E_app (mk_id "zext_ones", length_arg @ [len1])))
+ | [E_aux (E_app (replicate_bits, [E_aux (E_lit (L_aux (L_bin "1", _)), _); len1]), _)]
+ when is_id env (Id "replicate_bits") replicate_bits ->
+ let start1 = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in
+ try_cast_to_typ (rewrap (E_app (mk_id "slice_mask", length_arg @ [start1; len1])))
+
+ | [E_aux (E_app (zeros, [len1]),_)]
+ | [E_aux (E_cast (_, E_aux (E_app (zeros, [len1]),_)), _)]
+ when is_zeros zeros ->
+ try_cast_to_typ (rewrap (E_app (id, length_arg)))
+
| _ -> E_app (id,args)
else if is_id env (Id "SignExtend") id || is_id env (Id "sign_extend") id then
- let is_slice = is_id env (Id "slice") in
let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in
match List.filter (fun arg -> not (is_number (typ_of arg))) args with
| [E_aux (E_app (slice1, [vector1; start1; length1]),_)]
@@ -2947,8 +3009,6 @@ let rec rewrite_app env typ (id,args) =
| _ -> E_app (id, args)
else if is_id env (Id "UInt") id || is_id env (Id "unsigned") id then
- let is_slice = is_id env (Id "slice") in
- let is_subrange = is_id env (Id "vector_subrange") in
match args with
| [E_aux (E_app (slice1, [vector1; start1; length1]),_)]
when is_slice slice1 && not (is_constant length1) ->
@@ -3032,7 +3092,7 @@ let check_for_spec env name =
(* These functions add cast functions across case splits, so that when a
bitvector size becomes known in sail, the generated Lem code contains a
function call to change mword 'n to (say) mword ty16, and vice versa. *)
-let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
+let make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ =
let genunk = Generated Unknown in
let fresh =
let counter = ref 0 in
@@ -3056,7 +3116,7 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
Typ_app (Id_aux (Id "vector",_) as t_id,
[A_aux (A_nexp size',l_size'); t_ord;
A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_) as t_bit]) -> begin
- match simplify_size_nexp env quant_kids size, simplify_size_nexp env quant_kids size' with
+ match simplify_size_nexp env quant_kids size, simplify_size_nexp top_env quant_kids size' with
| Some size, Some size' when Nexp.compare size size' <> 0 ->
let var = fresh () in
let tar_typ' = Typ_aux (Typ_app (t_id, [A_aux (A_nexp size',l_size');t_ord;t_bit]),
@@ -3112,7 +3172,7 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
(* TODO: bound vars *)
let make_bitvector_env_casts env quant_kids (kid,i) exp =
- let mk_cast var typ exp = (fst (make_bitvector_cast_fns "bitvector_cast_in" env quant_kids typ (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
+ let mk_cast var typ exp = (fst (make_bitvector_cast_fns "bitvector_cast_in" env env quant_kids typ (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
let locals = Env.get_locals env in
Bindings.fold (fun var (mut,typ) exp ->
if mut = Immutable then mk_cast var typ exp else exp) locals exp
@@ -3157,7 +3217,7 @@ let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp =
let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in
E_aux (E_app (f,[aux arg (src_arg_typ, tgt_arg_typ)]),(l,ann))
| _ ->
- (snd (make_bitvector_cast_fns cast_name cast_env quant_kids typ target_typ)) exp
+ (snd (make_bitvector_cast_fns cast_name cast_env (env_of exp) quant_kids typ target_typ)) exp
in
aux exp (typ, target_typ)
@@ -3287,9 +3347,10 @@ let add_bitvector_casts (Defs defs) =
{ id_exp_alg with
e_aux = rewrite_aux } exp
in
- let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann)) =
+ let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),((l,_) as fcl_ann))) =
let fcl_env = env_of_annot fcl_ann in
let (tq,typ) = Env.get_val_spec_orig id fcl_env in
+ let fun_env = add_typquant l tq fcl_env in
let quant_kids = List.map kopt_kid (List.filter is_int_kopt (quant_kopts tq)) in
let ret_typ =
match typ with
@@ -3300,11 +3361,10 @@ let add_bitvector_casts (Defs defs) =
" is not a function type"))
in
let pat,guard,body,annot = destruct_pexp pexp in
- let body_env = env_of body in
- let body = rewrite_body id quant_kids body_env ret_typ body in
+ let body = rewrite_body id quant_kids fun_env ret_typ body in
(* Also add a cast around the entire function clause body, if necessary *)
let body =
- make_bitvector_cast_exp "bitvector_cast_out" fcl_env quant_kids (fill_in_type body_env (typ_of body)) ret_typ body
+ make_bitvector_cast_exp "bitvector_cast_out" fun_env quant_kids (fill_in_type (env_of body) (typ_of body)) ret_typ body
in
let pexp = construct_pexp (pat,guard,body,annot) in
FCL_aux (FCL_Funcl (id,pexp),fcl_ann)
@@ -3469,7 +3529,7 @@ let rewrite_toplevel_nexps (Defs defs) =
in
(* Changing types in the body confuses simple sizeof rewriting, so turn it
off for now *)
- (* let rewrite_typ_in_body env nexp_map typ =
+ let rewrite_typ_in_body env nexp_map typ =
let rec aux (Typ_aux (t,l) as typ_full) =
match t with
| Typ_tup typs -> Typ_aux (Typ_tup (List.map aux typs),l)
@@ -3514,10 +3574,17 @@ let rewrite_toplevel_nexps (Defs defs) =
| P_typ (typ,p') -> P_aux (P_typ (rewrite_typ_in_body (env_of_annot ann) nexp_map typ,p'),ann)
| _ -> P_aux (p,ann)
in
+ let rewrite_one_lexp nexp_map (lexp, ann) =
+ match lexp with
+ | LEXP_cast (typ, id) ->
+ LEXP_aux (LEXP_cast (rewrite_typ_in_body (env_of_annot ann) nexp_map typ, id), ann)
+ | _ -> LEXP_aux (lexp, ann)
+ in
let rewrite_body nexp_map pexp =
let open Rewriter in
fold_pexp { id_exp_alg with
e_aux = rewrite_one_exp nexp_map;
+ lEXP_aux = rewrite_one_lexp nexp_map;
pat_alg = { id_pat_alg with p_aux = rewrite_one_pat nexp_map }
} pexp
in
@@ -3525,25 +3592,29 @@ let rewrite_toplevel_nexps (Defs defs) =
match Bindings.find id spec_map with
| nexp_map -> FCL_aux (FCL_Funcl (id,rewrite_body nexp_map pexp),ann)
| exception Not_found -> funcl
- in *)
+ in
let rewrite_def spec_map def =
match def with
| DEF_spec vs -> (match rewrite_valspec vs with
| None -> spec_map, def
| Some (id, nexp_map, vs) -> Bindings.add id nexp_map spec_map, DEF_spec vs)
- (* | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) ->
+ | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) ->
(* Type annotations on function definitions will have been turned into
valspecs by type checking, so it should be safe to drop them rather
than updating them. *)
let tann = Typ_annot_opt_aux (Typ_annot_opt_none,Generated Unknown) in
spec_map,
- DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann)) *)
+ DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann))
| _ -> spec_map, def
in
let _, defs = List.fold_left (fun (spec_map,t) def ->
let spec_map, def = rewrite_def spec_map def in
(spec_map, def::t)) (Bindings.empty, []) defs
- in Defs (List.rev defs)
+ in
+ (* Allow use of div and mod in nexp rewriting during later typechecking passes
+ to help prove equivalences such as (8 * 'n) = 'p8_times_n# *)
+ Type_check.opt_smt_div := true;
+ Defs (List.rev defs)
type options = {
auto : bool;
diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml
index 27b5b16e..cc1afaac 100644
--- a/src/ocaml_backend.ml
+++ b/src/ocaml_backend.ml
@@ -463,6 +463,7 @@ let ocaml_funcls ctx =
match Bindings.find id ctx.val_specs with
| Typ_aux (Typ_fn (typs, typ, _), _) -> (typs, typ)
| _ -> failwith "Found val spec which was not a function!"
+ | exception Not_found -> failwith ("No val spec found for " ^ string_of_id id)
in
(* Any remaining type variables after simple_typ rewrite should
indicate Type-polymorphism. If we have it, we need to generate
@@ -578,7 +579,7 @@ let ocaml_string_of_struct ctx id typq fields =
let ocaml_field (typ, id) =
separate space [string (string_of_id id ^ " = \""); string "^"; ocaml_string_typ typ (arg ^^ string "." ^^ zencode ctx id)]
in
- separate space [string "let"; ocaml_string_of id; parens (arg ^^ space ^^ colon ^^ space ^^ zencode ctx id); equals]
+ separate space [string "let"; ocaml_string_of id; parens (arg ^^ space ^^ colon ^^ space ^^ ocaml_typquant typq ^^ space ^^ zencode ctx id); equals]
^//^ (string "\"{" ^^ separate_map (hardline ^^ string "^ \", ") ocaml_field fields ^^ string " ^ \"}\"")
let ocaml_string_of_abbrev ctx id typq typ =
diff --git a/src/parse_ast.ml b/src/parse_ast.ml
index b86d4dd5..fcf921b7 100644
--- a/src/parse_ast.ml
+++ b/src/parse_ast.ml
@@ -109,7 +109,7 @@ kid_aux = (* identifiers with kind, ticked to differntiate from program variabl
type
id_aux = (* Identifier *)
Id of x
- | DeIid of x (* remove infix status *)
+ | Operator of x (* remove infix status *)
type
base_effect =
diff --git a/src/parser.mly b/src/parser.mly
index 9f7e2e0c..1c7d1580 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -69,7 +69,7 @@ let cons_fst h (t,x) = (h::t,x)
let string_of_id = function
| Id_aux (Id str, _) -> str
- | Id_aux (DeIid str, _) -> str
+ | Id_aux (Operator str, _) -> str
let prepend_id str1 = function
| Id_aux (Id str2, loc) -> Id_aux (Id (str1 ^ str2), loc)
@@ -84,8 +84,8 @@ let id_of_kid = function
| Kid_aux (Var v, l) -> Id_aux (Id (String.sub v 1 (String.length v - 1)), l)
let deinfix = function
- | (Id_aux (Id v, l)) -> Id_aux (DeIid v, l)
- | (Id_aux (DeIid v, l)) -> Id_aux (Id v, l)
+ | (Id_aux (Id v, l)) -> Id_aux (Operator v, l)
+ | (Id_aux (Operator v, l)) -> Id_aux (Id v, l)
let mk_effect e n m = BE_aux (e, loc n m)
let mk_typ t n m = ATyp_aux (t, loc n m)
@@ -142,7 +142,7 @@ type lchain =
| LC_lteq
| LC_nexp of atyp
-let tyop op t1 t2 s e = mk_typ (ATyp_app (Id_aux (DeIid op, loc s e), [t1; t2])) s e
+let tyop op t1 t2 s e = mk_typ (ATyp_app (Id_aux (Operator op, loc s e), [t1; t2])) s e
let rec desugar_lchain chain s e =
match chain with
@@ -230,51 +230,51 @@ let rec desugar_rchain chain s e =
id:
| Id { mk_id (Id $1) $startpos $endpos }
- | Op Op0 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op1 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op2 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op3 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op4 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op5 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op6 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op7 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op8 { mk_id (DeIid $2) $startpos $endpos }
- | Op Op9 { mk_id (DeIid $2) $startpos $endpos }
-
- | Op Op0l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op1l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op2l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op3l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op4l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op5l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op6l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op7l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op8l { mk_id (DeIid $2) $startpos $endpos }
- | Op Op9l { mk_id (DeIid $2) $startpos $endpos }
-
- | Op Op0r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op1r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op2r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op3r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op4r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op5r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op6r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op7r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op8r { mk_id (DeIid $2) $startpos $endpos }
- | Op Op9r { mk_id (DeIid $2) $startpos $endpos }
-
- | Op Plus { mk_id (DeIid "+") $startpos $endpos }
- | Op Minus { mk_id (DeIid "-") $startpos $endpos }
- | Op Star { mk_id (DeIid "*") $startpos $endpos }
- | Op EqEq { mk_id (DeIid "==") $startpos $endpos }
- | Op ExclEq { mk_id (DeIid "!=") $startpos $endpos }
- | Op Lt { mk_id (DeIid "<") $startpos $endpos }
- | Op Gt { mk_id (DeIid ">") $startpos $endpos }
- | Op LtEq { mk_id (DeIid "<=") $startpos $endpos }
- | Op GtEq { mk_id (DeIid ">=") $startpos $endpos }
- | Op Amp { mk_id (DeIid "&") $startpos $endpos }
- | Op Bar { mk_id (DeIid "|") $startpos $endpos }
- | Op Caret { mk_id (DeIid "^") $startpos $endpos }
+ | Op Op0 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op1 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op2 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op3 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op4 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op5 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op6 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op7 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op8 { mk_id (Operator $2) $startpos $endpos }
+ | Op Op9 { mk_id (Operator $2) $startpos $endpos }
+
+ | Op Op0l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op1l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op2l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op3l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op4l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op5l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op6l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op7l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op8l { mk_id (Operator $2) $startpos $endpos }
+ | Op Op9l { mk_id (Operator $2) $startpos $endpos }
+
+ | Op Op0r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op1r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op2r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op3r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op4r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op5r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op6r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op7r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op8r { mk_id (Operator $2) $startpos $endpos }
+ | Op Op9r { mk_id (Operator $2) $startpos $endpos }
+
+ | Op Plus { mk_id (Operator "+") $startpos $endpos }
+ | Op Minus { mk_id (Operator "-") $startpos $endpos }
+ | Op Star { mk_id (Operator "*") $startpos $endpos }
+ | Op EqEq { mk_id (Operator "==") $startpos $endpos }
+ | Op ExclEq { mk_id (Operator "!=") $startpos $endpos }
+ | Op Lt { mk_id (Operator "<") $startpos $endpos }
+ | Op Gt { mk_id (Operator ">") $startpos $endpos }
+ | Op LtEq { mk_id (Operator "<=") $startpos $endpos }
+ | Op GtEq { mk_id (Operator ">=") $startpos $endpos }
+ | Op Amp { mk_id (Operator "&") $startpos $endpos }
+ | Op Bar { mk_id (Operator "|") $startpos $endpos }
+ | Op Caret { mk_id (Operator "^") $startpos $endpos }
op0: Op0 { mk_id (Id $1) $startpos $endpos }
op1: Op1 { mk_id (Id $1) $startpos $endpos }
diff --git a/src/pretty_print_common.ml b/src/pretty_print_common.ml
index 3a1deed0..c1680878 100644
--- a/src/pretty_print_common.ml
+++ b/src/pretty_print_common.ml
@@ -76,25 +76,9 @@ let semi_sp = semi ^^ space
let comma_sp = comma ^^ space
let colon_sp = spaces colon
-let doc_var (Kid_aux(Var v,_)) = string v
let doc_int i = string (Big_int.to_string i)
let doc_op symb a b = infix 2 1 symb a b
let doc_unop symb a = prefix 2 1 symb a
-let doc_id (Id_aux(i,_)) =
- match i with
- | Id i -> string i
- | DeIid x ->
- (* add an extra space through empty to avoid a closing-comment
- * token in case of x ending with star. *)
- parens (separate space [string "deinfix"; string x; empty])
-
-(*
-let rec doc_range (BF_aux(r,_)) = match r with
- | BF_single i -> doc_int i
- | BF_range(i1,i2) -> doc_op dotdot (doc_int i1) (doc_int i2)
- | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2)
-*)
-
let print ?(len=100) channel doc = ToChannel.pretty 1. len channel doc
let to_buf ?(len=100) buf doc = ToBuffer.pretty 1. len buf doc
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml
index ee83c89f..d80dabe9 100644
--- a/src/pretty_print_coq.ml
+++ b/src/pretty_print_coq.ml
@@ -81,8 +81,12 @@ let opt_debug_on : string list ref = ref []
type context = {
early_ret : bool;
- kid_renames : kid KBindings.t; (* Plain tyvar -> tyvar renames *)
- kid_id_renames : id KBindings.t; (* tyvar -> argument renames *)
+ kid_renames : kid KBindings.t; (* Plain tyvar -> tyvar renames,
+ used to avoid variable/type variable name clashes *)
+ (* Note that as well as these kid renames, we also attempt to replace entire
+ n_constraints with equivalent variables in doc_nc_prop and doc_nc_exp. *)
+ kid_id_renames : (id option) KBindings.t; (* tyvar -> argument renames *)
+ kid_id_renames_rev : kid Bindings.t; (* reverse of kid_id_renames *)
bound_nvars : KidSet.t;
build_at_return : string option;
recursive_ids : IdSet.t;
@@ -92,12 +96,24 @@ let empty_ctxt = {
early_ret = false;
kid_renames = KBindings.empty;
kid_id_renames = KBindings.empty;
+ kid_id_renames_rev = Bindings.empty;
bound_nvars = KidSet.empty;
build_at_return = None;
recursive_ids = IdSet.empty;
debug = false;
}
+let add_single_kid_id_rename ctxt id kid =
+ let kir =
+ match Bindings.find_opt id ctxt.kid_id_renames_rev with
+ | Some kid -> KBindings.add kid None ctxt.kid_id_renames
+ | None -> ctxt.kid_id_renames
+ in
+ { ctxt with
+ kid_id_renames = KBindings.add kid (Some id) kir;
+ kid_id_renames_rev = Bindings.add id kid ctxt.kid_id_renames_rev
+ }
+
let debug_depth = ref 0
let rec indent n = match n with
@@ -168,7 +184,7 @@ let rec fix_id remove_tick name = match name with
let string_id (Id_aux(i,_)) =
match i with
| Id i -> fix_id false i
- | DeIid x -> Util.zencode_string ("op " ^ x)
+ | Operator x -> Util.zencode_string ("op " ^ x)
let doc_id id = string (string_id id)
@@ -177,16 +193,17 @@ let doc_id_type (Id_aux(i,_)) =
| Id("int") -> string "Z"
| Id("real") -> string "R"
| Id i -> string (fix_id false i)
- | DeIid x -> string (Util.zencode_string ("op " ^ x))
+ | Operator x -> string (Util.zencode_string ("op " ^ x))
let doc_id_ctor (Id_aux(i,_)) =
match i with
| Id i -> string (fix_id false i)
- | DeIid x -> string (Util.zencode_string ("op " ^ x))
+ | Operator x -> string (Util.zencode_string ("op " ^ x))
let doc_var ctx kid =
match KBindings.find kid ctx.kid_id_renames with
- | id -> doc_id id
+ | Some id -> doc_id id
+ | None -> underscore (* The original id has been shadowed, hope Coq can work it out... TODO: warn? *)
| exception Not_found ->
string (fix_id true (string_of_kid (try KBindings.find kid ctx.kid_renames with Not_found -> kid)))
@@ -371,24 +388,51 @@ let doc_nc_fn id =
| "not" -> string "negb"
| s -> string s
-let merge_bool_count = KBindings.union (fun _ m n -> Some (m+n))
+let merge_kid_count = KBindings.union (fun _ m n -> Some (m+n))
-let rec count_bool_vars (NC_aux (nc,_)) =
+let rec count_nexp_vars (Nexp_aux (nexp,_)) =
+ match nexp with
+ | Nexp_id _
+ | Nexp_constant _
+ -> KBindings.empty
+ | Nexp_var kid -> KBindings.singleton kid 1
+ | Nexp_app (_,nes) ->
+ List.fold_left merge_kid_count KBindings.empty (List.map count_nexp_vars nes)
+ | Nexp_times (n1,n2)
+ | Nexp_sum (n1,n2)
+ | Nexp_minus (n1,n2)
+ -> merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2)
+ | Nexp_exp n
+ | Nexp_neg n
+ -> count_nexp_vars n
+
+let rec count_nc_vars (NC_aux (nc,_)) =
let count_arg (A_aux (arg,_)) =
match arg with
- | A_bool nc -> count_bool_vars nc
- | A_nexp _ | A_typ _ | A_order _ -> KBindings.empty
+ | A_bool nc -> count_nc_vars nc
+ | A_nexp nexp -> count_nexp_vars nexp
+ | A_typ _ | A_order _ -> KBindings.empty
in
match nc with
| NC_or (nc1,nc2)
| NC_and (nc1,nc2)
- -> merge_bool_count (count_bool_vars nc1) (count_bool_vars nc2)
- | NC_var kid -> KBindings.singleton kid 1
- | NC_equal _ | NC_bounded_ge _ | NC_bounded_le _ | NC_not_equal _
- | NC_set _ | NC_true | NC_false
+ -> merge_kid_count (count_nc_vars nc1) (count_nc_vars nc2)
+ | NC_var kid
+ | NC_set (kid,_)
+ -> KBindings.singleton kid 1
+ | NC_equal (n1,n2)
+ | NC_bounded_ge (n1,n2)
+ | NC_bounded_le (n1,n2)
+ | NC_not_equal (n1,n2)
+ -> merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2)
+ | NC_true | NC_false
-> KBindings.empty
| NC_app (_,args) ->
- List.fold_left merge_bool_count KBindings.empty (List.map count_arg args)
+ List.fold_left merge_kid_count KBindings.empty (List.map count_arg args)
+
+(* Simplify some of the complex boolean types created by the Sail type checker,
+ whereever an existentially bound variable is used once in a trivial way,
+ for example exists b, b and exists n, n = 32. *)
type atom_bool_prop =
Bool_boring
@@ -398,13 +442,22 @@ let simplify_atom_bool l kopts nc atom_nc =
(*prerr_endline ("simplify " ^ string_of_n_constraint nc ^ " for bool " ^ string_of_n_constraint atom_nc);*)
let counter = ref 0 in
let is_bound kid = List.exists (fun kopt -> Kid.compare kid (kopt_kid kopt) == 0) kopts in
- let bool_vars = merge_bool_count (count_bool_vars nc) (count_bool_vars atom_nc) in
- let lin_bool_vars = KBindings.filter (fun kid n -> is_bound kid && n = 1) bool_vars in
+ let ty_vars = merge_kid_count (count_nc_vars nc) (count_nc_vars atom_nc) in
+ let lin_ty_vars = KBindings.filter (fun kid n -> is_bound kid && n = 1) ty_vars in
let rec simplify (NC_aux (nc,l) as nc_full) =
let is_ex_var news (NC_aux (nc,_)) =
match nc with
- | NC_var kid when KBindings.mem kid lin_bool_vars -> Some kid
+ | NC_var kid when KBindings.mem kid lin_ty_vars -> Some kid
| NC_var kid when KidSet.mem kid news -> Some kid
+ | NC_equal (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid
+ | NC_equal (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid
+ | NC_bounded_ge (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid
+ | NC_bounded_ge (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid
+ | NC_bounded_le (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid
+ | NC_bounded_le (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid
+ | NC_not_equal (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid
+ | NC_not_equal (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid
+ | NC_set (kid, _::_) when KBindings.mem kid lin_ty_vars -> Some kid
| _ -> None
in
let replace kills vars =
@@ -439,7 +492,10 @@ let simplify_atom_bool l kopts nc atom_nc =
(* We don't currently recurse into general uses of NC_app, but the
"boring" cases we really want to get rid of won't contain
those. *)
- | _ -> KidSet.empty, KidSet.empty, nc_full
+ | _ ->
+ match is_ex_var KidSet.empty nc_full with
+ | Some kid -> replace KidSet.empty [kid]
+ | None -> KidSet.empty, KidSet.empty, nc_full
in
let new_nc, kill_nc, nc = simplify nc in
let new_atom, kill_atom, atom_nc = simplify atom_nc in
@@ -451,13 +507,17 @@ let simplify_atom_bool l kopts nc atom_nc =
in
(*prerr_endline ("now have " ^ string_of_n_constraint nc ^ " for bool " ^ string_of_n_constraint atom_nc);*)
match atom_nc with
- | NC_aux (NC_var kid,_) when KBindings.mem kid lin_bool_vars -> Bool_boring
+ | NC_aux (NC_var kid,_) when KBindings.mem kid lin_ty_vars -> Bool_boring
| NC_aux (NC_var kid,_) when KidSet.mem kid new_kids -> Bool_boring
| _ -> Bool_complex (kopts, nc, atom_nc)
type ex_kind = ExNone | ExGeneral
+let string_of_ex_kind = function
+ | ExNone -> "none"
+ | ExGeneral -> "general"
+
(* Should a Sail type be turned into a dependent pair in Coq?
Optionally takes a variable that we're binding (to avoid trivial cases where
the type is exactly the boolean we're binding), and whether to turn bools
@@ -465,7 +525,7 @@ type ex_kind = ExNone | ExGeneral
let classify_ex_type ctxt env ?binding ?(rawbools=false) (Typ_aux (t,l) as t0) =
let is_binding kid =
match binding, KBindings.find_opt kid ctxt.kid_id_renames with
- | Some id, Some id' when Id.compare id id' == 0 -> true
+ | Some id, Some (Some id') when Id.compare id id' == 0 -> true
| _ -> false
in
let simplify_atom_bool l kopts nc atom_nc =
@@ -490,7 +550,7 @@ let classify_ex_type ctxt env ?binding ?(rawbools=false) (Typ_aux (t,l) as t0) =
| _ -> ExNone,[],t0
(* When making changes here, check whether they affect coq_nvars_of_typ *)
-let rec doc_typ_fns ctx =
+let rec doc_typ_fns ctx env =
(* following the structure of parser for precedence *)
let rec typ ty = fn_typ true ty
and typ' ty = fn_typ false ty
@@ -541,7 +601,7 @@ let rec doc_typ_fns ctx =
braces (separate space
[doc_var ctx var; colon; string "bool";
ampersand;
- doc_arithfact ctx nc])
+ doc_arithfact ctx env nc])
end
| Typ_app(id,args) ->
let tpp = (doc_id_type id) ^^ space ^^ (separate_map space doc_typ_arg args) in
@@ -573,12 +633,12 @@ let rec doc_typ_fns ctx =
begin match nexp, kopts with
| (Nexp_aux (Nexp_var kid,_)), [kopt] when Kid.compare kid (kopt_kid kopt) == 0 ->
braces (separate space [doc_var ctx kid; colon; string "Z";
- ampersand; doc_arithfact ctx nc])
+ ampersand; doc_arithfact ctx env nc])
| _ ->
let var = mk_kid "_atom" in (* TODO collision avoid *)
let nc = nice_and (nc_eq (nvar var) nexp) nc in
braces (separate space [doc_var ctx var; colon; string "Z";
- ampersand; doc_arithfact ctx ~exists:(List.map kopt_kid kopts) nc])
+ ampersand; doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) nc])
end
| Typ_aux (Typ_app (Id_aux (Id "vector",_),
[A_aux (A_nexp m, _);
@@ -601,7 +661,7 @@ let rec doc_typ_fns ctx =
braces (separate space
[doc_var ctx var; colon; tpp;
ampersand;
- doc_arithfact ctx ~exists:(List.map kopt_kid kopts) ?extra:length_constraint_pp nc])
+ doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) ?extra:length_constraint_pp nc])
| Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),_) -> begin
match simplify_atom_bool l kopts nc atom_nc with
| Bool_boring -> string "bool"
@@ -611,7 +671,7 @@ let rec doc_typ_fns ctx =
braces (separate space
[doc_var ctx var; colon; string "bool";
ampersand;
- doc_arithfact ctx ~exists:(List.map kopt_kid kopts) nc])
+ doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) nc])
end
| _ ->
raise (Reporting.err_todo l
@@ -642,14 +702,14 @@ let rec doc_typ_fns ctx =
| A_typ t -> app_typ true t
| A_nexp n -> doc_nexp ctx n
| A_order o -> empty
- | A_bool nc -> doc_nc_prop ~top:false ctx nc
+ | A_bool nc -> doc_nc_prop ~top:false ctx env nc
in typ', atomic_typ, doc_typ_arg
-and doc_typ ctx = let f,_,_ = doc_typ_fns ctx in f
-and doc_atomic_typ ctx = let _,f,_ = doc_typ_fns ctx in f
-and doc_typ_arg ctx = let _,_,f = doc_typ_fns ctx in f
+and doc_typ ctx env = let f,_,_ = doc_typ_fns ctx env in f
+and doc_atomic_typ ctx env = let _,f,_ = doc_typ_fns ctx env in f
+and doc_typ_arg ctx env = let _,_,f = doc_typ_fns ctx env in f
-and doc_arithfact ctxt ?(exists = []) ?extra nc =
- let prop = doc_nc_prop ctxt nc in
+and doc_arithfact ctxt env ?(exists = []) ?extra nc =
+ let prop = doc_nc_prop ctxt env nc in
let prop = match extra with
| None -> prop
| Some pp -> separate space [pp; string "/\\"; parens prop]
@@ -662,14 +722,28 @@ and doc_arithfact ctxt ?(exists = []) ?extra nc =
string "ArithFact" ^^ space ^^ parens prop
(* Follows Coq precedence levels *)
-and doc_nc_prop ?(top = true) ctx nc =
+and doc_nc_prop ?(top = true) ctx env nc =
+ let locals = Env.get_locals env |> Bindings.bindings in
+ let nc_id_map =
+ List.fold_left
+ (fun m (v,(_,Typ_aux (typ,_))) ->
+ match typ with
+ | Typ_app (id, [A_aux (A_bool nc,_)]) when string_of_id id = "atom_bool" ->
+ NCMap.add nc v m
+ | _ -> m) NCMap.empty locals
+ in
+ let newnc f nc =
+ match NCMap.find_opt nc nc_id_map with
+ | Some id -> parens (doc_op equals (doc_id id) (string "true"))
+ | None -> f nc
+ in
let rec l85 (NC_aux (nc,_) as nc_full) =
match nc with
- | NC_or (nc1, nc2) -> doc_op (string "\\/") (l80 nc1) (l85 nc2)
+ | NC_or (nc1, nc2) -> doc_op (string "\\/") (newnc l80 nc1) (newnc l85 nc2)
| _ -> l80 nc_full
and l80 (NC_aux (nc,_) as nc_full) =
match nc with
- | NC_and (nc1, nc2) -> doc_op (string "/\\") (l70 nc1) (l80 nc2)
+ | NC_and (nc1, nc2) -> doc_op (string "/\\") (newnc l70 nc1) (newnc l80 nc2)
| _ -> l70 nc_full
and l70 (NC_aux (nc,_) as nc_full) =
match nc with
@@ -685,7 +759,7 @@ and doc_nc_prop ?(top = true) ctx nc =
separate space [string "In"; doc_var ctx kid;
brackets (separate (string "; ")
(List.map (fun i -> string (Nat_big_num.to_string i)) is))]
- | NC_app (f,args) -> separate space (doc_nc_fn_prop f::List.map (doc_typ_arg ctx) args)
+ | NC_app (f,args) -> separate space (doc_nc_fn_prop f::List.map (doc_typ_arg ctx env) args)
| _ -> l0 nc_full
and l0 (NC_aux (nc,_) as nc_full) =
match nc with
@@ -700,10 +774,24 @@ and doc_nc_prop ?(top = true) ctx nc =
| NC_bounded_ge _
| NC_bounded_le _
| NC_not_equal _ -> parens (l85 nc_full)
- in if top then l85 nc else l0 nc
+ in if top then newnc l85 nc else newnc l0 nc
(* Follows Coq precedence levels *)
let rec doc_nc_exp ctx env nc =
+ let locals = Env.get_locals env |> Bindings.bindings in
+ let nc_id_map =
+ List.fold_left
+ (fun m (v,(_,Typ_aux (typ,_))) ->
+ match typ with
+ | Typ_app (id, [A_aux (A_bool nc,_)]) when string_of_id id = "atom_bool" ->
+ NCMap.add nc v m
+ | _ -> m) NCMap.empty locals
+ in
+ let newnc f nc =
+ match NCMap.find_opt nc nc_id_map with
+ | Some id -> doc_id id
+ | None -> f nc
+ in
let nc = Env.expand_constraint_synonyms env nc in
let rec l70 (NC_aux (nc,_) as nc_full) =
match nc with
@@ -713,11 +801,11 @@ let rec doc_nc_exp ctx env nc =
| _ -> l50 nc_full
and l50 (NC_aux (nc,_) as nc_full) =
match nc with
- | NC_or (nc1, nc2) -> doc_op (string "||") (l50 nc1) (l40 nc2)
+ | NC_or (nc1, nc2) -> doc_op (string "||") (newnc l50 nc1) (newnc l40 nc2)
| _ -> l40 nc_full
and l40 (NC_aux (nc,_) as nc_full) =
match nc with
- | NC_and (nc1, nc2) -> doc_op (string "&&") (l40 nc1) (l10 nc2)
+ | NC_and (nc1, nc2) -> doc_op (string "&&") (newnc l40 nc1) (newnc l10 nc2)
| _ -> l10 nc_full
and l10 (NC_aux (nc,_) as nc_full) =
match nc with
@@ -735,7 +823,7 @@ let rec doc_nc_exp ctx env nc =
| NC_bounded_le _
| NC_or _
| NC_and _ -> parens (l70 nc_full)
- in l70 nc
+ in newnc l70 nc
and doc_typ_arg_exp ctx env (A_aux (arg,l)) =
match arg with
| A_nexp nexp -> doc_nexp ctx nexp
@@ -769,7 +857,7 @@ let replace_typ_size ctxt env (Typ_aux (t,a)) =
let doc_tannot ctxt env eff typ =
let of_typ typ =
- let ta = doc_typ ctxt typ in
+ let ta = doc_typ ctxt env typ in
if eff then
if ctxt.early_ret
then string " : MR " ^^ parens ta ^^ string " _"
@@ -842,7 +930,7 @@ let quant_item_id_name ctx (QI_aux (qi,_)) =
let doc_quant_item_constr ctx delimit (QI_aux (qi,_)) =
match qi with
| QI_id _ -> None
- | QI_const nc -> Some (bquote ^^ braces (doc_arithfact ctx nc))
+ | QI_const nc -> Some (bquote ^^ braces (doc_arithfact ctx Env.empty nc))
(* At the moment these are all anonymous - when used we rely on Coq to fill
them in. *)
@@ -904,7 +992,7 @@ let rec typeclass_nexps (Typ_aux(t,l)) =
| Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
let doc_typschm ctx quants (TypSchm_aux(TypSchm_ts(tq,t),_)) =
- let pt = doc_typ ctx t in
+ let pt = doc_typ ctx Env.empty t in
if quants then doc_typquant ctx tq pt else pt
let is_ctor env id = match Env.lookup_id id env with
@@ -1076,21 +1164,13 @@ let similar_nexps ctxt env n1 n2 =
| _ -> false
in if same_nexp_shape (nexp_const_eval n1) (nexp_const_eval n2) then true else false
-let constraint_fns = ["Z.leb"; "Z.geb"; "Z.ltb"; "Z.gtb"; "Z.eqb"; "neq_atom"]
+let constraint_fns = ["Z.leb"; "Z.geb"; "Z.ltb"; "Z.gtb"; "Z.eqb"; "neq_int"]
-let condition_produces_constraint exp =
- (* Cheat a little - this isn't quite the right environment for subexpressions
- but will have all of the relevant functions in it. *)
+let condition_produces_constraint ctxt exp =
let env = env_of exp in
- Rewriter.fold_exp
- { (Rewriter.pure_exp_alg false (||)) with
- Rewriter.e_app = fun (f,bs) ->
- List.exists (fun x -> x) bs ||
- (let name = if Env.is_extern f env "coq"
- then Env.get_extern f env "coq"
- else string_id f in
- List.exists (fun id -> String.compare name id == 0) constraint_fns)
- } exp
+ match classify_ex_type ctxt env ~rawbools:true (typ_of exp) with
+ | ExNone, _, _ -> false
+ | ExGeneral, _, _ -> true
(* For most functions whose return types are non-trivial atoms we return a
dependent pair with a proof that the result is the expected integer. This
@@ -1140,31 +1220,63 @@ let is_prefix s s' =
String.sub s' 0 l = s
let merge_new_tyvars ctxt old_env pat new_env =
- let is_new_binding id =
- match Env.lookup_id ~raw:true id old_env with
- | Unbound -> true
- | _ -> false
+ let remove_binding id (m,r) =
+ match Bindings.find_opt id r with
+ | Some kid ->
+ debug ctxt (lazy ("Removing " ^ string_of_kid kid ^ " to " ^ string_of_id id));
+ KBindings.add kid None m, Bindings.remove id r
+ | None -> m,r
+ in
+ let check_kid id kid (m,r) =
+ try
+ let _ = Env.get_typ_var kid old_env in
+ debug ctxt (lazy (" tyvar " ^ string_of_kid kid ^ " already in env"));
+ m,r
+ with _ ->
+ debug ctxt (lazy (" adding tyvar mapping " ^ string_of_kid kid ^ " to " ^ string_of_id id));
+ KBindings.add kid (Some id) m, Bindings.add id kid r
in
- let new_ids = IdSet.filter is_new_binding (pat_ids pat) in
let merge_new_kids id m =
let typ = lvar_typ (Env.lookup_id ~raw:true id new_env) in
debug ctxt (lazy (" considering tyvar mapping for " ^ string_of_id id ^ " at type " ^ string_of_typ typ ));
match destruct_numeric typ, destruct_atom_bool new_env typ with
| Some ([],_,Nexp_aux (Nexp_var kid,_)), _
- | _, Some (NC_aux (NC_var kid,_)) ->
- begin try
- let _ = Env.get_typ_var kid old_env in
- debug ctxt (lazy (" tyvar " ^ string_of_kid kid ^ " already in env"));
- m
- with _ ->
- debug ctxt (lazy (" adding tyvar mapping " ^ string_of_kid kid ^ " to " ^ string_of_id id));
- KBindings.add kid id m
- end
+ | _, Some (NC_aux (NC_var kid,_))
+ -> check_kid id kid m
| _ ->
debug ctxt (lazy (" not suitable type"));
m
in
- { ctxt with kid_id_renames = IdSet.fold merge_new_kids new_ids ctxt.kid_id_renames }
+ let rec merge_pat m (P_aux (p,(l,_))) =
+ match p with
+ | P_lit _ | P_wild
+ -> m
+ | P_not _ -> unreachable l __POS__ "Coq backend doesn't support not patterns"
+ | P_or _ -> unreachable l __POS__ "Coq backend doesn't support or patterns yet"
+ | P_typ (_,p) -> merge_pat m p
+ | P_as (p,id) -> merge_new_kids id (merge_pat m p)
+ | P_id id -> merge_new_kids id m
+ | P_var (p,ty_p) ->
+ begin match p, ty_p with
+ | _, TP_aux (TP_wild,_) -> merge_pat m p
+ | P_aux (P_id id,_), TP_aux (TP_var kid,_) -> check_kid id kid (merge_pat m p)
+ | _ -> merge_pat m p
+ end
+ (* Some of these don't make it through to the backend, but it's obvious what
+ they'd do *)
+ | P_app (_,ps)
+ | P_vector ps
+ | P_vector_concat ps
+ | P_tup ps
+ | P_list ps
+ | P_string_append ps
+ -> List.fold_left merge_pat m ps
+ | P_record (fps,_) -> unreachable l __POS__ "Coq backend doesn't support record patterns properly yet"
+ | P_cons (p1,p2) -> merge_pat (merge_pat m p1) p2
+ in
+ let m,r = IdSet.fold remove_binding (pat_ids pat) (ctxt.kid_id_renames, ctxt.kid_id_renames_rev) in
+ let m,r = merge_pat (m, r) pat in
+ { ctxt with kid_id_renames = m; kid_id_renames_rev = r }
let prefix_recordtype = true
let report = Reporting.err_unreachable
@@ -1317,6 +1429,7 @@ let doc_exp, doc_let =
raise (report l __POS__ "E_loop should have been rewritten before pretty-printing")
| E_let(leb,e) ->
let pat = match leb with LB_aux (LB_val (p,_),_) -> p in
+ let () = debug ctxt (lazy ("Let with pattern " ^ string_of_pat pat)) in
let new_ctxt = merge_new_tyvars ctxt (env_of_annot (l,annot)) pat (env_of e) in
let epp = let_exp ctxt leb ^^ space ^^ string "in" ^^ hardline ^^ top_exp new_ctxt false e in
if aexp_needed then parens epp else epp
@@ -1331,9 +1444,9 @@ let doc_exp, doc_let =
let epp = expY exp in
match is_auto_decomposed_exist ctxt (env_of exp) ~rawbools:true (general_typ_of exp) with
| Some _ ->
- if informative then parens (epp ^^ doc_tannot ctxt (env_of exp) true (general_typ_of exp)) else
- let proj = if effectful (effect_of exp) then "projT1_m" else "projT1" in
- parens (string proj ^/^ epp)
+ if informative
+ then parens (epp ^^ doc_tannot ctxt (env_of exp) true (general_typ_of exp))
+ else parens (string "projT1_m" ^/^ epp)
| None ->
if informative then parens (string "build_trivial_ex" ^/^ epp)
else epp
@@ -1362,6 +1475,7 @@ let doc_exp, doc_let =
in
let combinator = if effectful (effect_of body) then "foreach_ZM" else "foreach_Z" in
let combinator = combinator ^ dir in
+ let body_ctxt = add_single_kid_id_rename ctxt loopvar (mk_kid ("loop_" ^ string_of_id loopvar)) in
let used_vars_body = find_e_ids body in
let body_lambda =
(* Work around indentation issues in Lem when translating
@@ -1384,7 +1498,7 @@ let doc_exp, doc_let =
expY from_exp; expY to_exp; expY step_exp;
expY vartuple])
(parens
- (prefix 2 1 (group body_lambda) (expN body))
+ (prefix 2 1 (group body_lambda) (top_exp body_ctxt false body))
)
)
| _ -> raise (Reporting.err_unreachable l __POS__
@@ -1444,8 +1558,8 @@ let doc_exp, doc_let =
aexp_needed, epp
else
let tannot = separate space [string "MR";
- doc_atomic_typ ctxt false (typ_of full_exp);
- doc_atomic_typ ctxt false (typ_of exp)] in
+ doc_atomic_typ ctxt (env_of full_exp) false (typ_of full_exp);
+ doc_atomic_typ ctxt (env_of exp) false (typ_of exp)] in
true, doc_op colon epp tannot in
if aexp_needed then parens tepp else tepp
| _ -> raise (Reporting.err_unreachable l __POS__
@@ -1606,11 +1720,20 @@ let doc_exp, doc_let =
| _ -> false
in pack,unpack,autocast
in
+ let () =
+ debug ctxt (lazy (" packeff: " ^ string_of_bool packeff ^
+ " unpack: " ^ string_of_bool unpack ^
+ " autocast: " ^ string_of_bool autocast))
+ in
let autocast_id, proj_id =
if effectful eff
then "autocast_m", "projT1_m"
else "autocast", "projT1" in
- let epp = if unpack && not (effectful eff) then string proj_id ^/^ parens epp else epp in
+ (* We need to unpack an existential if it's generated by a pure
+ computation, or if the monadic binding isn't expecting one. *)
+ let epp = if unpack && not (effectful eff && packeff)
+ then string proj_id ^/^ parens epp
+ else epp in
let epp = if autocast then string autocast_id ^^ space ^^ parens epp else epp in
let epp =
if effectful eff && packeff && not unpack
@@ -1667,7 +1790,6 @@ let doc_exp, doc_let =
end
| E_lit lit -> doc_lit lit
| E_cast(typ,e) ->
- let epp = expV true e in
let env = env_of_annot (l,annot) in
let outer_typ = Env.expand_synonyms env (general_typ_of_annot (l,annot)) in
let outer_typ = expand_range_type outer_typ in
@@ -1679,6 +1801,7 @@ let doc_exp, doc_let =
debug ctxt (lazy (" on expr of type " ^ string_of_typ inner_typ));
debug ctxt (lazy (" where type expected is " ^ string_of_typ outer_typ))
in
+ let epp = expV true e in
let outer_ex,_,outer_typ' = classify_ex_type ctxt env outer_typ in
let cast_ex,_,cast_typ' = classify_ex_type ctxt env ~rawbools:true cast_typ in
let inner_ex,_,inner_typ' = classify_ex_type ctxt env inner_typ in
@@ -1692,6 +1815,18 @@ let doc_exp, doc_let =
| _ -> false
in
let effects = effectful (effect_of e) in
+ let autocast =
+ (* We don't currently have a version of autocast under existentials,
+ but they're rare and may be unnecessary *)
+ if effects && outer_ex = ExGeneral then false else autocast
+ in
+ let () =
+ debug ctxt (lazy (" effectful: " ^ string_of_bool effects ^
+ " outer_ex: " ^ string_of_ex_kind outer_ex ^
+ " cast_ex: " ^ string_of_ex_kind cast_ex ^
+ " inner_ex: " ^ string_of_ex_kind inner_ex ^
+ " autocast: " ^ string_of_bool autocast))
+ in
let epp =
if effects then
match inner_ex, cast_ex with
@@ -1839,7 +1974,8 @@ let doc_exp, doc_let =
debug ctxt (lazy ("Internal plet, pattern " ^ string_of_pat pat));
debug ctxt (lazy (" type of e1 " ^ string_of_typ (typ_of e1)))
in
- let new_ctxt = merge_new_tyvars ctxt (env_of_annot (l,annot)) pat (env_of e2) in
+ let outer_env = env_of_annot (l,annot) in
+ let new_ctxt = merge_new_tyvars ctxt outer_env pat (env_of e2) in
match pat, e1, e2 with
| (P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _)),
(E_aux (E_assert (assert_e1,assert_e2),_)), _ ->
@@ -1860,7 +1996,7 @@ let doc_exp, doc_let =
| P_aux (P_typ (typ, P_aux (P_id id,_)),_)
when Util.is_none (is_auto_decomposed_exist ctxt (env_of e1) typ) &&
not (is_enum (env_of e1) id) ->
- separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow]
+ separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt outer_env typ; bigarrow]
| P_aux (P_typ (typ, P_aux (P_id id,_)),_)
| P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id,_),_),_)),_)
| P_aux (P_var (P_aux (P_typ (typ, P_aux (P_id id,_)),_),_),_)
@@ -1868,20 +2004,20 @@ let doc_exp, doc_let =
let full_typ = (expand_range_type typ) in
let binder = match classify_ex_type ctxt env1 (Env.expand_synonyms env1 full_typ) with
| ExGeneral, _, _ ->
- squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt typ])
+ squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt outer_env typ])
| ExNone, _, _ ->
- parens (separate space [doc_id id; colon; doc_typ ctxt typ])
+ parens (separate space [doc_id id; colon; doc_typ ctxt outer_env typ])
in separate space [string ">>= fun"; binder; bigarrow]
| P_aux (P_id id,_) ->
let typ = typ_of e1 in
let plain_binder = squote ^^ doc_pat ctxt true true (pat, typ_of e1) in
let binder = match classify_ex_type ctxt env1 ~binding:id (Env.expand_synonyms env1 typ) with
| ExGeneral, _, (Typ_aux (Typ_app (Id_aux (Id "atom_bool",_),_),_) as typ') ->
- squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt typ])
+ squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt outer_env typ])
| ExNone, _, typ' -> begin
match typ' with
| Typ_aux (Typ_app (Id_aux (Id "atom_bool",_),_),_) ->
- squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt typ])
+ squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt outer_env typ])
| _ -> plain_binder
end
| _ -> plain_binder
@@ -1889,7 +2025,9 @@ let doc_exp, doc_let =
| _ ->
separate space [string ">>= fun"; squote ^^ doc_pat ctxt true true (pat, typ_of e1); bigarrow]
in
- infix 0 1 middle (expY e1) (top_exp new_ctxt false e2)
+ let e1_pp = expY e1 in
+ let e2_pp = top_exp new_ctxt false e2 in
+ infix 0 1 middle e1_pp e2_pp
in
if aexp_needed then parens (align epp) else epp
end
@@ -1924,8 +2062,8 @@ let doc_exp, doc_let =
then empty
else separate space
[string ret_monad;
- parens (doc_typ ctxt (typ_of full_exp));
- parens (doc_typ ctxt (typ_of r))] in
+ parens (doc_typ ctxt (env_of full_exp) (typ_of full_exp));
+ parens (doc_typ ctxt (env_of full_exp) (typ_of r))] in
align (parens (string "early_return" ^//^ exp_pp ^//^ ta))
| E_constraint nc -> wrap_parens (doc_nc_exp ctxt (env_of full_exp) nc)
| E_internal_value _ ->
@@ -1933,6 +2071,8 @@ let doc_exp, doc_let =
"unsupported internal expression encountered while pretty-printing")
and if_exp ctxt (elseif : bool) c t e =
let if_pp = string (if elseif then "else if" else "if") in
+ let c_pp = top_exp ctxt true c in
+ let t_pp = top_exp ctxt false t in
let else_pp = match e with
| E_aux (E_if (c', t', e'), _)
| E_aux (E_cast (_, E_aux (E_if (c', t', e'), _)), _) ->
@@ -1945,9 +2085,9 @@ let doc_exp, doc_let =
in
(prefix 2 1
(soft_surround 2 1 if_pp
- ((if condition_produces_constraint c then string "sumbool_of_bool" ^^ space else empty)
- ^^ parens (top_exp ctxt true c)) (string "then"))
- (top_exp ctxt false t)) ^^
+ ((if condition_produces_constraint ctxt c then string "sumbool_of_bool" ^^ space else empty)
+ ^^ parens c_pp) (string "then"))
+ t_pp) ^^
break 1 ^^
else_pp
and let_exp ctxt (LB_aux(lb,_)) = match lb with
@@ -1963,7 +2103,7 @@ let doc_exp, doc_let =
when Util.is_none (is_auto_decomposed_exist ctxt (env_of e) typ) &&
not (is_enum (env_of e) id) ->
prefix 2 1
- (separate space [string "let"; doc_id id; colon; doc_typ ctxt typ; coloneq])
+ (separate space [string "let"; doc_id id; colon; doc_typ ctxt (env_of e) typ; coloneq])
(top_exp ctxt false e)
| LB_val(P_aux (P_typ (typ,pat),_),(E_aux (_,e_ann) as e)) ->
prefix 2 1
@@ -2061,7 +2201,20 @@ let types_used_with_generic_eq defs =
let doc_type_union ctxt typ_name (Tu_aux(Tu_ty_id(typ,id),_)) =
separate space [doc_id_ctor id; colon;
- doc_typ ctxt typ; arrow; typ_name]
+ doc_typ ctxt Env.empty typ; arrow; typ_name]
+
+(* For records and variants we declare the type parameters as implicit
+ so that they're implicit in the constructors. Currently Coq also
+ makes them implicit in the type, so undo that here. *)
+let doc_reset_implicits id_pp typq =
+ let (kopts,ncs) = quant_split typq in
+ let resets = List.map (fun _ -> underscore) kopts in
+ let implicits = List.map (fun _ -> string "{_}") ncs in
+ let args = match implicits with
+ | [] -> [colon; string "clear implicits"]
+ | _ -> resets @ implicits
+ in
+ separate space ([string "Arguments"; id_pp] @ args) ^^ dot
(*
let rec doc_range ctxt (BF_aux(r,_)) = match r with
@@ -2094,7 +2247,7 @@ let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with
then concat [doc_id id;string "_";doc_id_type fid;]
else doc_id_type fid in
let f_pp (typ,fid) =
- concat [fname fid;space;colon;space;doc_typ empty_ctxt typ; semi] in
+ concat [fname fid;space;colon;space;doc_typ empty_ctxt Env.empty typ; semi] in
let rectyp = match typq with
| TypQ_aux (TypQ_tq qs, _) ->
let quant_item = function
@@ -2138,11 +2291,11 @@ let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with
string "Defined." ^^ hardline
else empty
in
- let resetimplicit = separate space [string "Arguments"; id_pp; colon; string "clear implicits."] in
+ let reset_implicits_pp = doc_reset_implicits id_pp typq in
doc_op coloneq
(separate space [string "Record"; id_pp; doc_typquant_items empty_ctxt braces typq])
((*doc_typquant typq*) (braces (space ^^ align fs_doc ^^ space))) ^^
- dot ^^ hardline ^^ resetimplicit ^^ hardline ^^ eq_pp ^^ updates_pp
+ dot ^^ hardline ^^ reset_implicits_pp ^^ hardline ^^ eq_pp ^^ updates_pp
| TD_variant(id,typq,ar,_) ->
(match id with
| Id_aux ((Id "read_kind"),_) -> empty
@@ -2162,11 +2315,8 @@ let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with
(doc_op coloneq)
(concat [string "Inductive"; space; typ_nm])
((*doc_typquant typq*) ar_doc) in
- (* We declared the type parameters as implicit so that they're implicit
- in the constructors. Currently Coq also makes them implicit in the
- type, so undo that here. *)
- let resetimplicit = separate space [string "Arguments"; id_pp; colon; string "clear implicits."] in
- typ_pp ^^ dot ^^ hardline ^^ resetimplicit ^^ hardline ^^ hardline)
+ let reset_implicits_pp = doc_reset_implicits id_pp typq in
+ typ_pp ^^ dot ^^ hardline ^^ reset_implicits_pp ^^ hardline ^^ hardline)
| TD_enum(id,enums,_) ->
(match id with
| Id_aux ((Id "read_kind"),_) -> empty
@@ -2277,7 +2427,7 @@ let rec atom_constraint ctxt (pat, typ) =
(match nexp with
(* When the kid is mapped to the id, we don't need a constraint *)
| Nexp_aux (Nexp_var kid,_)
- when (try Id.compare (KBindings.find kid ctxt.kid_id_renames) id == 0 with _ -> false) ->
+ when (try Id.compare (Util.option_get_exn Not_found (KBindings.find kid ctxt.kid_id_renames)) id == 0 with _ -> false) ->
None
| _ ->
Some (bquote ^^ braces (string "ArithFact" ^^ space ^^
@@ -2318,7 +2468,7 @@ let tyvars_of_typquant (TypQ_aux (tq,_)) =
let mk_kid_renames ids_to_avoid kids =
let map_id = function
| Id_aux (Id i, _) -> Some (fix_id false i)
- | Id_aux (DeIid _, _) -> None
+ | Id_aux (Operator _, _) -> None
in
let ids = StringSet.of_list (Util.map_filter map_id (IdSet.elements ids_to_avoid)) in
let rec check_kid kid (newkids,rebindings) =
@@ -2344,7 +2494,7 @@ let merge_kids_atoms pats =
" but rearranging arguments isn't supported yet") in
gone,map,seen
else
- KidSet.add kid gone, KBindings.add kid id map, KidSet.add kid seen
+ KidSet.add kid gone, KBindings.add kid (Some id) map, KidSet.add kid seen
in
match Type_check.destruct_atom_nexp (env_of_annot ann) typ with
| Some (Nexp_aux (Nexp_var kid,l)) -> merge kid l
@@ -2363,13 +2513,13 @@ let merge_var_patterns map pats =
let map,pats = List.fold_left (fun (map,pats) (pat, typ) ->
match pat with
| P_aux (P_var (P_aux (P_id id,_), TP_aux (TP_var kid,_)),ann) ->
- KBindings.add kid id map, (P_aux (P_id id,ann), typ) :: pats
+ KBindings.add kid (Some id) map, (P_aux (P_id id,ann), typ) :: pats
| _ -> map, (pat,typ)::pats) (map,[]) pats
in map, List.rev pats
type mutrec_pos = NotMutrec | FirstFn | LaterFn
-let doc_funcl mutrec rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) =
+let doc_funcl mutrec rec_opt ?rec_set (FCL_aux(FCL_Funcl(id, pexp), annot)) =
let env = env_of_annot annot in
let (tq,typ) = Env.get_val_spec_orig id env in
let (arg_typs, ret_typ, eff) = match typ with
@@ -2392,15 +2542,20 @@ let doc_funcl mutrec rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) =
let kid_to_arg_rename, pats = merge_var_patterns kid_to_arg_rename pats in
let kids_used = KidSet.diff bound_kids eliminated_kids in
let is_measured, recursive_ids = match rec_opt with
- (* No mutual recursion in this backend yet; only change recursive
- definitions where we have a measure *)
- | Rec_aux (Rec_measure _,_) -> true, IdSet.singleton id
+ | Rec_aux (Rec_measure _,_) ->
+ true, (match rec_set with None -> IdSet.singleton id | Some s -> s)
| _ -> false, IdSet.empty
in
+ let kir_rev =
+ KBindings.fold
+ (fun kid idopt m -> match idopt with Some id -> Bindings.add id kid m | None -> m)
+ kid_to_arg_rename Bindings.empty
+ in
let ctxt0 =
{ early_ret = contains_early_return exp;
kid_renames = mk_kid_renames ids_to_avoid kids_used;
kid_id_renames = kid_to_arg_rename;
+ kid_id_renames_rev = kir_rev;
bound_nvars = bound_kids;
build_at_return = None; (* filled in below *)
recursive_ids = recursive_ids;
@@ -2420,7 +2575,8 @@ let doc_funcl mutrec rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) =
debug ctxt (lazy (" build_ex " ^ match build_ex with Some s -> s ^ " needed" | _ -> "not needed"));
debug ctxt (lazy (if effectful eff then " effectful" else " pure"));
debug ctxt (lazy (" kid_id_renames " ^ String.concat ", " (List.map
- (fun (kid,id) -> string_of_kid kid ^ " |-> " ^ string_of_id id)
+ (fun (kid,id) -> string_of_kid kid ^ " |-> " ^
+ match id with Some id -> string_of_id id | None -> "<>")
(KBindings.bindings kid_to_arg_rename))))
in
(* Put the constraints after pattern matching so that any type variable that's
@@ -2435,11 +2591,12 @@ let doc_funcl mutrec rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) =
debug ctxt (lazy (" pattern " ^ string_of_pat pat));
debug ctxt (lazy (" with expanded type " ^ string_of_typ exp_typ))
in
+ (* TODO: probably should provide partial environments to doc_typ *)
match pat_is_plain_binder env pat with
| Some id -> begin
match classify_ex_type ctxt env ~binding:id exp_typ with
| ExNone, _, typ' ->
- parens (separate space [doc_id id; colon; doc_typ ctxt typ'])
+ parens (separate space [doc_id id; colon; doc_typ ctxt Env.empty typ'])
| ExGeneral, _, _ ->
let full_typ = (expand_range_type exp_typ) in
match destruct_exist_plain (Env.expand_synonyms env full_typ) with
@@ -2454,21 +2611,22 @@ let doc_funcl mutrec rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) =
[A_aux (A_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_))
when Kid.compare (kopt_kid kopt) kid == 0 && not is_measured ->
(used_a_pattern := true;
- squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt typ]))
+ squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt Env.empty typ]))
| _ ->
- parens (separate space [doc_id id; colon; doc_typ ctxt typ])
+ parens (separate space [doc_id id; colon; doc_typ ctxt Env.empty typ])
end
| None ->
(used_a_pattern := true;
- squote ^^ parens (separate space [doc_pat ctxt true true (pat, exp_typ); colon; doc_typ ctxt typ]))
+ squote ^^ parens (separate space [doc_pat ctxt true true (pat, exp_typ); colon; doc_typ ctxt Env.empty typ]))
in
let patspp = flow_map (break 1) doc_binder pats in
let atom_constrs = Util.map_filter (atom_constraint ctxt) pats in
let atom_constr_pp = separate space atom_constrs in
let retpp =
+ (* TODO: again, probably should provide proper environment *)
if effectful eff
- then string "M" ^^ space ^^ parens (doc_typ ctxt ret_typ)
- else doc_typ ctxt ret_typ
+ then string "M" ^^ space ^^ parens (doc_typ ctxt Env.empty ret_typ)
+ else doc_typ ctxt Env.empty ret_typ
in
let idpp = doc_id id in
let intropp, accpp, measurepp, fixupspp = match rec_opt with
@@ -2535,17 +2693,17 @@ let get_id = function
(* Coq doesn't support multiple clauses for a single function joined
by "and". However, all the funcls should have been merged by the
merge_funcls rewrite now. *)
-let doc_fundef_rhs ?(mutrec=NotMutrec) (FD_aux(FD_function(r, typa, efa, funcls),(l,_))) =
+let doc_fundef_rhs ?(mutrec=NotMutrec) rec_set (FD_aux(FD_function(r, typa, efa, funcls),(l,_))) =
match funcls with
| [] -> unreachable l __POS__ "function with no clauses"
- | [funcl] -> doc_funcl mutrec r funcl
+ | [funcl] -> doc_funcl mutrec r ~rec_set funcl
| (FCL_aux (FCL_Funcl (id,_),_))::_ -> unreachable l __POS__ ("function " ^ string_of_id id ^ " has multiple clauses in backend")
-let doc_mutrec = function
+let doc_mutrec rec_set = function
| [] -> failwith "DEF_internal_mutrec with empty function list"
| fundef::fundefs ->
- doc_fundef_rhs ~mutrec:FirstFn fundef ^^ hardline ^^
- separate_map hardline (doc_fundef_rhs ~mutrec:LaterFn) fundefs ^^ dot
+ doc_fundef_rhs ~mutrec:FirstFn rec_set fundef ^^ hardline ^^
+ separate_map hardline (doc_fundef_rhs ~mutrec:LaterFn rec_set) fundefs ^^ dot
let rec doc_fundef (FD_aux(FD_function(r, typa, efa, fcls),fannot)) =
match fcls with
@@ -2635,7 +2793,7 @@ let doc_regtype_fields (tname, (n1, n2, fields)) =
separate_map hardline doc_field fields
(* Remove some type variables in a similar fashion to merge_kids_atoms *)
-let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) =
+let doc_axiom_typschm typ_env l (tqs,typ) =
let typ_env = add_typquant l tqs typ_env in
match typ with
| Typ_aux (Typ_fn (typs, ret_ty, eff),l') ->
@@ -2645,27 +2803,52 @@ let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) =
if KidSet.mem kid used then args,used else
KidSet.add kid args, used
| Some _ -> args, used
- | _ -> args, KidSet.union used (tyvars_of_typ typ)
+ | _ ->
+ match Type_check.destruct_atom_bool typ_env typ with
+ | Some (NC_aux (NC_var kid,_)) ->
+ if KidSet.mem kid used then args,used else
+ KidSet.add kid args, used
+ | _ ->
+ args, KidSet.union used (tyvars_of_typ typ)
in
let args, used = List.fold_left check_typ (KidSet.empty, KidSet.empty) typs in
let used = if is_number ret_ty then used else KidSet.union used (tyvars_of_typ ret_ty) in
+ let kopts,constraints = quant_split tqs in
+ let used = List.fold_left (fun used nc -> KidSet.union used (tyvars_of_constraint nc)) used constraints in
let tqs = match tqs with
| TypQ_aux (TypQ_tq qs,l) -> TypQ_aux (TypQ_tq (List.filter (function
- | QI_aux (QI_id kopt,_) when is_int_kopt kopt ->
+ | QI_aux (QI_id kopt,_) ->
let kid = kopt_kid kopt in
KidSet.mem kid used && not (KidSet.mem kid args)
| _ -> true) qs),l)
| _ -> tqs
in
+ let typ_count = ref 0 in
+ let fresh_var () =
+ let n = !typ_count in
+ let () = typ_count := n+1 in
+ string ("x" ^ string_of_int n)
+ in
let doc_typ' typ =
match Type_check.destruct_atom_nexp typ_env typ with
| Some (Nexp_aux (Nexp_var kid,_)) when KidSet.mem kid args ->
parens (doc_var empty_ctxt kid ^^ string " : Z")
- | _ -> parens (underscore ^^ string " : " ^^ doc_typ empty_ctxt typ)
+ (* This case is silly, but useful for tests *)
+ | Some (Nexp_aux (Nexp_constant n,_)) ->
+ let v = fresh_var () in
+ parens (v ^^ string " : Z") ^/^
+ bquote ^^ braces (string "ArithFact " ^^
+ parens (v ^^ string " = " ^^ string (Big_int.to_string n)))
+ | _ ->
+ match Type_check.destruct_atom_bool typ_env typ with
+ | Some (NC_aux (NC_var kid,_)) when KidSet.mem kid args ->
+ parens (doc_var empty_ctxt kid ^^ string " : bool")
+ | _ ->
+ parens (underscore ^^ string " : " ^^ doc_typ empty_ctxt Env.empty typ)
in
let arg_typs_pp = separate space (List.map doc_typ' typs) in
let _, ret_ty = replace_atom_return_type ret_ty in
- let ret_typ_pp = doc_typ empty_ctxt ret_ty in
+ let ret_typ_pp = doc_typ empty_ctxt Env.empty ret_ty in
let ret_typ_pp =
if effectful eff
then string "M" ^^ space ^^ parens ret_typ_pp
@@ -2674,13 +2857,17 @@ let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) =
let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt braces tqs in
string "forall" ^/^ separate space tyvars_pp ^/^
arg_typs_pp ^/^ separate space constrs_pp ^^ comma ^/^ ret_typ_pp
- | _ -> doc_typschm empty_ctxt true ts
+ | _ -> doc_typschm empty_ctxt true (TypSchm_aux (TypSchm_ts (tqs,typ),l))
-let doc_val_spec unimplemented (VS_aux (VS_val_spec(tys,id,_,_),ann)) =
+let doc_val_spec unimplemented (VS_aux (VS_val_spec(_,id,_,_),(l,ann)) as vs) =
if !opt_undef_axioms && IdSet.mem id unimplemented then
- let typ_env = env_of_annot ann in
+ let typ_env = env_of_annot (l,ann) in
+ (* The type checker will expand the type scheme, and we need to look at the
+ environment afterwards to find it. *)
+ let _, next_env = check_val_spec typ_env vs in
+ let tys = Env.get_val_spec id next_env in
group (separate space
- [string "Axiom"; doc_id id; colon; doc_axiom_typschm typ_env tys] ^^ dot) ^/^ hardline
+ [string "Axiom"; doc_id id; colon; doc_axiom_typschm typ_env l tys] ^^ dot) ^/^ hardline
else empty (* Type signatures appear in definitions *)
(* If a top-level value is declared with an existential type, we turn it into
@@ -2698,7 +2885,7 @@ let doc_val pat exp =
in
let typpp = match pat_typ with
| None -> empty
- | Some typ -> space ^^ colon ^^ space ^^ doc_typ empty_ctxt typ
+ | Some typ -> space ^^ colon ^^ space ^^ doc_typ empty_ctxt Env.empty typ
in
let env = env_of exp in
let ctxt = { empty_ctxt with debug = List.mem (string_of_id id) (!opt_debug_on) } in
@@ -2730,7 +2917,7 @@ let rec doc_def unimplemented generic_eq_types def =
| DEF_default df -> empty
| DEF_fundef fdef -> group (doc_fundef fdef) ^/^ hardline
- | DEF_internal_mutrec fundefs -> doc_mutrec fundefs ^/^ hardline
+ | DEF_internal_mutrec fundefs -> doc_mutrec (ids_of_def def) fundefs ^/^ hardline
| DEF_val (LB_aux (LB_val (pat, exp), _)) -> doc_val pat exp
| DEF_scattered sdef -> failwith "doc_def: shoulnd't have DEF_scattered at this point"
| DEF_mapdef (MD_aux (_, (l,_))) -> unreachable l __POS__ "Coq doesn't support mappings"
@@ -2798,6 +2985,8 @@ try
[string "(*" ^^ (string top_line) ^^ string "*)";hardline;
(separate_map hardline)
(fun lib -> separate space [string "Require Import";string lib] ^^ dot) types_modules;hardline;
+ string "Import ListNotations.";
+ hardline;
separate empty (List.map doc_def typdefs); hardline;
hardline;
separate empty (List.map doc_def statedefs); hardline;
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index eec61874..633d910e 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -69,12 +69,14 @@ type context = {
top_env : Env.t
}
let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty; top_env = Env.empty }
-
+
let print_to_from_interp_value = ref false
let langlebar = string "<|"
let ranglebar = string "|>"
let anglebars = enclose langlebar ranglebar
+let doc_var (Kid_aux(Var v,_)) = string v
+
let is_number_char c =
c = '0' || c = '1' || c = '2' || c = '3' || c = '4' || c = '5' ||
c = '6' || c = '7' || c = '8' || c = '9'
@@ -113,7 +115,7 @@ let rec fix_id remove_tick name = match name with
let doc_id_lem (Id_aux(i,_)) =
match i with
| Id i -> string (fix_id false i)
- | DeIid x -> string (Util.zencode_string ("op " ^ x))
+ | Operator x -> string (Util.zencode_string ("op " ^ x))
let doc_id_lem_type (Id_aux(i,_)) =
match i with
@@ -121,7 +123,7 @@ let doc_id_lem_type (Id_aux(i,_)) =
| Id("nat") -> string "ii"
| Id("option") -> string "maybe"
| Id i -> string (fix_id false i)
- | DeIid x -> string (Util.zencode_string ("op " ^ x))
+ | Operator x -> string (Util.zencode_string ("op " ^ x))
let doc_id_lem_ctor (Id_aux(i,_)) =
match i with
@@ -131,11 +133,11 @@ let doc_id_lem_ctor (Id_aux(i,_)) =
| Id("Some") -> string "Just"
| Id("None") -> string "Nothing"
| Id i -> string (fix_id false (String.capitalize_ascii i))
- | DeIid x -> string (Util.zencode_string ("op " ^ x))
+ | Operator x -> string (Util.zencode_string ("op " ^ x))
let deinfix = function
- | Id_aux (Id v, l) -> Id_aux (DeIid v, l)
- | Id_aux (DeIid v, l) -> Id_aux (DeIid v, l)
+ | Id_aux (Id v, l) -> Id_aux (Operator v, l)
+ | Id_aux (Operator v, l) -> Id_aux (Operator v, l)
let doc_var_lem kid = string (fix_id true (string_of_kid kid))
@@ -927,7 +929,9 @@ let doc_exp_lem, doc_let_lem =
let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
let middle =
match fst (untyp_pat pat) with
- | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string ">>"
+ | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _)
+ when is_unit_typ (typ_of_pat pat) ->
+ string ">>"
| P_aux (P_tup _, _)
when not (IdSet.mem (mk_id "varstup") (find_e_ids e2)) ->
(* Work around indentation issues in Lem when translating
@@ -973,6 +977,10 @@ let doc_exp_lem, doc_let_lem =
| E_aux (E_if (c', t', e'), _)
| E_aux (E_cast (_, E_aux (E_if (c', t', e'), _)), _) ->
if_exp ctxt true c' t' e'
+ (* Special case to prevent current arm decoder becoming a staircase *)
+ (* TODO: replace with smarter pretty printing *)
+ | E_aux (E_internal_plet (pat,exp1,E_aux (E_cast (typ, (E_aux (E_if (_, _, _), _) as exp2)),_)),ann) when Typ.compare typ unit_typ == 0 ->
+ string "else" ^/^ top_exp ctxt false (E_aux (E_internal_plet (pat,exp1,exp2),ann))
| _ -> prefix 2 1 (string "else") (top_exp ctxt false e)
in
(prefix 2 1
@@ -1040,6 +1048,11 @@ let doc_typquant_sorts idpp (TypQ_aux (typq,_)) =
else empty
| TypQ_no_forall -> empty
+let doc_sia_id (Id_aux(i,_)) =
+ match i with
+ | Id i -> string i
+ | Operator x -> string ("operator " ^ x)
+
let doc_typdef_lem env (TD_aux(td, (l, annot))) = match td with
| TD_abbrev(id,typq,A_aux (A_typ typ, _)) ->
let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in
@@ -1119,7 +1132,7 @@ let doc_typdef_lem env (TD_aux(td, (l, annot))) = match td with
((*doc_typquant_lem typq*) ar_doc) in
let make_id pat id =
separate space [string "SIA.Id_aux";
- parens (string "SIA.Id " ^^ string_lit (doc_id id));
+ parens (string "SIA.Id " ^^ string_lit (doc_sia_id id));
if pat then underscore else string "SIA.Unknown"] in
let fromInterpValueF = concat [doc_id_lem_type id;string "FromInterpValue"] in
let toInterpValueF = concat [doc_id_lem_type id;string "ToInterpValue"] in
@@ -1155,7 +1168,7 @@ let doc_typdef_lem env (TD_aux(td, (l, annot))) = match td with
[pipe;doc_id_lem_ctor cid;string "v";arrow;
string "SI.V_ctor";
parens (make_id false cid);
- parens (string "SIA.T_id " ^^ string_lit (doc_id id));
+ parens (string "SIA.T_id " ^^ string_lit (doc_sia_id id));
string "SI.C_Union";
parens (string "toInterpValue v")])
ar) ^/^
@@ -1193,7 +1206,7 @@ let doc_typdef_lem env (TD_aux(td, (l, annot))) = match td with
let toInterpValueF = concat [doc_id_lem_type id;string "ToInterpValue"] in
let make_id pat id =
separate space [string "SIA.Id_aux";
- parens (string "SIA.Id " ^^ string_lit (doc_id id));
+ parens (string "SIA.Id " ^^ string_lit (doc_sia_id id));
if pat then underscore else string "SIA.Unknown"] in
let fromInterpValuePP =
(prefix 2 1)
@@ -1246,7 +1259,7 @@ let doc_typdef_lem env (TD_aux(td, (l, annot))) = match td with
[pipe;doc_id_lem_ctor cid;arrow;
string "SI.V_ctor";
parens (make_id false cid);
- parens (string "SIA.T_id " ^^ string_lit (doc_id id));
+ parens (string "SIA.T_id " ^^ string_lit (doc_sia_id id));
parens (string ("SI.C_Enum " ^ string_of_int number));
parens (string "toInterpValue ()")])
(List.combine enums (nats ((List.length enums) - 1)))) ^/^
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 9712b62c..aa7294bd 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -61,7 +61,7 @@ let doc_op symb a b = infix 2 1 symb a b
let doc_id (Id_aux (id_aux, _)) =
string (match id_aux with
| Id v -> v
- | DeIid op -> "operator " ^ op)
+ | Operator op -> "operator " ^ op)
let doc_kid kid = string (Ast_util.string_of_kid kid)
@@ -92,7 +92,7 @@ let rec doc_nexp =
let rec atomic_nexp (Nexp_aux (n_aux, _) as nexp) =
match n_aux with
| Nexp_constant c -> string (Big_int.to_string c)
- | Nexp_app (Id_aux (DeIid op, _), [n1; n2]) ->
+ | Nexp_app (Id_aux (Operator op, _), [n1; n2]) ->
separate space [atomic_nexp n1; string op; atomic_nexp n2]
| Nexp_app (id, nexps) -> string (string_of_nexp nexp)
(* This segfaults??!!!!
@@ -172,7 +172,7 @@ and doc_typ ?(simple=false) (Typ_aux (typ_aux, l)) =
match typ_aux with
| Typ_id id -> doc_id id
| Typ_app (id, []) -> doc_id id
- | Typ_app (Id_aux (DeIid str, _), [x; y]) ->
+ | Typ_app (Id_aux (Operator str, _), [x; y]) ->
separate space [doc_typ_arg x; doc_typ_arg y]
| Typ_app (id, typs) when Id.compare id (mk_id "atom") = 0 ->
string "int" ^^ parens (separate_map (string ", ") doc_typ_arg typs)
@@ -574,7 +574,7 @@ let doc_mapdef (MD_aux (MD_mapping (id, typa, mapcls), _)) =
| _ ->
let sep = string "," ^^ hardline in
let clauses = separate_map sep doc_mapcl mapcls in
- string "mapping" ^^ space ^^ doc_id id ^^ space ^^ string "=" ^^ (surround 2 0 lbrace clauses rbrace)
+ string "mapping" ^^ space ^^ doc_id id ^^ space ^^ string "=" ^^ space ^^ (surround 2 0 lbrace clauses rbrace)
let doc_dec (DEC_aux (reg,_)) =
match reg with
@@ -657,8 +657,8 @@ let rec doc_scattered (SD_aux (sd_aux, _)) =
separate space [string "mapping clause"; doc_id id; equals; doc_mapcl mapcl]
| SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_none, _)) ->
separate space [string "scattered mapping"; doc_id id]
- | SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_some (_, typ), _)) ->
- separate space [string "scattered mapping"; doc_id id; string ":"; doc_typ typ]
+ | SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), _)) ->
+ separate space [string "scattered mapping"; doc_id id; colon; doc_binding (typq, typ)]
| SD_unioncl (id, tu) ->
separate space [string "union clause"; doc_id id; equals; doc_union tu]
diff --git a/src/process_file.ml b/src/process_file.ml
index 3c2d4a22..dbe6d62d 100644
--- a/src/process_file.ml
+++ b/src/process_file.ml
@@ -126,33 +126,6 @@ let cond_pragma l defs =
in
scan defs
-let astid_to_string (Ast.Id_aux (id, _)) =
- match id with
- | Ast.Id x | Ast.DeIid x -> x
-
-let parseid_to_string (Parse_ast.Id_aux (id, _)) =
- match id with
- | Parse_ast.Id x | Parse_ast.DeIid x -> x
-
-let rec realise_union_anon_rec_types orig_union arms =
- match orig_union with
- | Parse_ast.TD_variant (union_id, typq, _, flag) ->
- begin match arms with
- | [] -> []
- | arm :: arms ->
- match arm with
- | (Parse_ast.Tu_aux ((Parse_ast.Tu_ty_id _), _)) -> (None, arm) :: realise_union_anon_rec_types orig_union arms
- | (Parse_ast.Tu_aux ((Parse_ast.Tu_ty_anon_rec (fields, id)), l)) ->
- let open Parse_ast in
- let record_str = "_" ^ parseid_to_string union_id ^ "_" ^ parseid_to_string id ^ "_record" in
- let record_id = Id_aux (Id record_str, Generated l) in
- let new_arm = Tu_aux ((Tu_ty_id ((ATyp_aux (ATyp_id record_id, Generated l)), id)), Generated l) in
- let new_rec_def = DEF_type (TD_aux (TD_record (record_id, typq, fields, flag), Generated l)) in
- (Some new_rec_def, new_arm) :: (realise_union_anon_rec_types orig_union arms)
- end
- | _ ->
- raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Non union type-definition passed to realise_union_anon_rec_typs")
-
let rec preprocess opts = function
| [] -> []
| Parse_ast.DEF_pragma ("define", symbol, _) :: defs ->
@@ -218,20 +191,6 @@ let rec preprocess opts = function
| Parse_ast.DEF_pragma (p, arg, l) :: defs ->
Parse_ast.DEF_pragma (p, arg, l) :: preprocess opts defs
- (* realise any anonymous record arms of variants *)
- | Parse_ast.DEF_type (Parse_ast.TD_aux
- (Parse_ast.TD_variant (id, typq, arms, flag) as union, l)
- ) :: defs ->
- let records_and_arms = realise_union_anon_rec_types union arms in
- let rec filter_records = function [] -> []
- | Some x :: xs -> x :: filter_records xs
- | None :: xs -> filter_records xs
- in
- let generated_records = filter_records (List.map fst records_and_arms) in
- let rewritten_arms = List.map snd records_and_arms in
- let rewritten_union = Parse_ast.TD_variant (id, typq, rewritten_arms, flag) in
- generated_records @ (Parse_ast.DEF_type (Parse_ast.TD_aux (rewritten_union, l))) :: preprocess opts defs
-
| (Parse_ast.DEF_default (Parse_ast.DT_aux (Parse_ast.DT_order (_, Parse_ast.ATyp_aux (atyp, _)), _)) as def) :: defs ->
begin match atyp with
| Parse_ast.ATyp_inc -> symbols := StringSet.add "_DEFAULT_INC" !symbols; def :: preprocess opts defs
@@ -402,14 +361,8 @@ let rewrite env rewriters defs =
| Type_check.Type_error (_, l, err) ->
raise (Reporting.err_typ l (Type_error.string_of_type_error err))
-let rewrite_ast env = rewrite env [("initial", fun _ -> Rewriter.rewrite_defs)]
-let rewrite_ast_lem env = rewrite env Rewrites.rewrite_defs_lem
-let rewrite_ast_coq env = rewrite env Rewrites.rewrite_defs_coq
-let rewrite_ast_ocaml env = rewrite env Rewrites.rewrite_defs_ocaml
-let rewrite_ast_c env ast =
- ast
- |> rewrite env Rewrites.rewrite_defs_c
- |> rewrite env [("constant_fold", fun _ -> Constant_fold.rewrite_constant_function_calls env)]
+let rewrite_ast_initial env = rewrite env [("initial", fun _ -> Rewriter.rewrite_defs)]
+
+let rewrite_ast_target tgt env = rewrite env (Rewrites.rewrite_defs_target tgt)
-let rewrite_ast_interpreter env = rewrite env Rewrites.rewrite_defs_interpreter
let rewrite_ast_check env = rewrite env Rewrites.rewrite_defs_check
diff --git a/src/process_file.mli b/src/process_file.mli
index 0411464b..e144727e 100644
--- a/src/process_file.mli
+++ b/src/process_file.mli
@@ -55,13 +55,9 @@ val parse_file : ?loc:Parse_ast.l -> string -> Parse_ast.defs
val clear_symbols : unit -> unit
val preprocess_ast : (Arg.key * Arg.spec * Arg.doc) list -> Parse_ast.defs -> Parse_ast.defs
-val check_ast: Type_check.Env.t -> unit Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t
-val rewrite_ast: Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
-val rewrite_ast_lem : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
-val rewrite_ast_coq : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
-val rewrite_ast_ocaml : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
-val rewrite_ast_c : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
-val rewrite_ast_interpreter : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
+val check_ast : Type_check.Env.t -> unit Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t
+val rewrite_ast_initial : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
+val rewrite_ast_target : string -> Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
val rewrite_ast_check : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
val load_file_no_check : (Arg.key * Arg.spec * Arg.doc) list -> Ast.order -> string -> unit Ast.defs
diff --git a/src/reporting.ml b/src/reporting.ml
index 0bc73ed6..20e44c57 100644
--- a/src/reporting.ml
+++ b/src/reporting.ml
@@ -111,6 +111,20 @@ let loc_to_string ?code:(code=true) l =
format_message (Location (l, Line "")) (buffer_formatter b);
Buffer.contents b
+let rec simp_loc = function
+ | Parse_ast.Unknown -> None
+ | Parse_ast.Unique (_, l) -> simp_loc l
+ | Parse_ast.Generated l -> simp_loc l
+ | Parse_ast.Range (p1, p2) -> Some (p1, p2)
+ | Parse_ast.Documented (_, l) -> simp_loc l
+
+let short_loc_to_string l =
+ match simp_loc l with
+ | None -> "unknown location"
+ | Some (p1, p2) ->
+ Printf.sprintf "%s %d:%d - %d:%d"
+ p1.pos_fname p1.pos_lnum (p1.pos_cnum - p1.pos_bol) p2.pos_lnum (p2.pos_cnum - p2.pos_bol)
+
let print_err l m1 m2 =
print_err_internal (Loc l) m1 m2
diff --git a/src/reporting.mli b/src/reporting.mli
index 2d886111..86399e84 100644
--- a/src/reporting.mli
+++ b/src/reporting.mli
@@ -65,6 +65,12 @@
(** [loc_to_string] includes code from file if code optional argument is true (default) *)
val loc_to_string : ?code:bool -> Parse_ast.l -> string
+(** Reduce a location to a pair of positions if possible *)
+val simp_loc : Ast.l -> (Lexing.position * Lexing.position) option
+
+(** [short_loc_to_string] prints the location as a single line in a simple format *)
+val short_loc_to_string : Parse_ast.l -> string
+
(** [print_err fatal print_loc_source l head mes] prints an error / warning message to
std-err. It starts with printing location information stored in [l]
It then prints "head: mes". If [fatal] is set, the program exists with error-code 1 afterwards.
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 15e6ad05..0107cf62 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -1018,7 +1018,7 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) =
let mk_exp e_aux = E_aux (e_aux, (l, ())) in
let mk_num_exp i = mk_lit_exp (L_num i) in
let check_eq_exp l r =
- let exp = mk_exp (E_app_infix (l, Id_aux (DeIid "==", Parse_ast.Unknown), r)) in
+ let exp = mk_exp (E_app_infix (l, Id_aux (Operator "==", Parse_ast.Unknown), r)) in
check_exp env exp bool_typ in
let access_bit_exp rootid l typ idx =
@@ -2460,14 +2460,20 @@ let rewrite_defs_letbind_effects env =
k (rewrap (E_throw exp')))
| E_internal_plet _ -> failwith "E_internal_plet should not be here yet" in
- let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),fdannot)) =
+ let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),fdannot) as fd) =
(* let propagate_funcl_effect (FCL_aux (FCL_Funcl(id, pexp), (l, a))) =
let pexp, eff = propagate_pexp_effect pexp in
FCL_aux (FCL_Funcl(id, pexp), (l, add_effect_annot a eff))
in
let funcls = List.map propagate_funcl_effect funcls in *)
+ let effectful_vs =
+ match Env.get_val_spec (id_of_fundef fd) env with
+ | _, Typ_aux (Typ_fn (_, _, effs), _) -> effectful_effs effs
+ | _, _ -> false
+ | exception Type_error _ -> false
+ in
let effectful_funcl (FCL_aux (FCL_Funcl(_, pexp), _)) = effectful_pexp pexp in
- let newreturn = List.exists effectful_funcl funcls in
+ let newreturn = effectful_vs || List.exists effectful_funcl funcls in
let rewrite_funcl (FCL_aux (FCL_Funcl(id,pexp),annot)) =
let _ = reset_fresh_name_counter () in
FCL_aux (FCL_Funcl (id,n_pexp newreturn pexp (fun x -> x)),annot)
@@ -2761,7 +2767,7 @@ let construct_toplevel_string_append_func env f_id pat =
let mapping_prefix_func =
match mapping_id with
| Id_aux (Id id, _)
- | Id_aux (DeIid id, _) -> id ^ "_matches_prefix"
+ | Id_aux (Operator id, _) -> id ^ "_matches_prefix"
in
let mapping_inner_typ =
match Env.get_val_spec (mk_id mapping_prefix_func) env with
@@ -2937,7 +2943,7 @@ let rec rewrite_defs_pat_string_append env =
let mapping_prefix_func =
match mapping_id with
| Id_aux (Id id, _)
- | Id_aux (DeIid id, _) -> id ^ "_matches_prefix"
+ | Id_aux (Operator id, _) -> id ^ "_matches_prefix"
in
let mapping_inner_typ =
match Env.get_val_spec (mk_id mapping_prefix_func) env with
@@ -3159,7 +3165,7 @@ let rewrite_defs_mapping_patterns env =
let mapping_name =
match mapping_id with
| Id_aux (Id id, _)
- | Id_aux (DeIid id, _) -> id
+ | Id_aux (Operator id, _) -> id
in
let mapping_matches_id = mk_id (mapping_name ^ "_" ^ mapping_direction ^ "_matches") in
@@ -4041,7 +4047,6 @@ let rewrite_defs_realise_mappings _ (Defs defs) =
in
Defs (List.map rewrite_def defs |> List.flatten)
-
(* Rewrite to make all pattern matches in Coq output exhaustive.
Assumes that guards, vector patterns, etc have been rewritten already,
and the scattered functions have been merged.
@@ -4471,8 +4476,9 @@ let rewrite_explicit_measure env (Defs defs) =
Bindings.add id (mpat,mexp) measures
| _ -> measures
in
- let scan_def measures = function
+ let rec scan_def measures = function
| DEF_fundef fd -> scan_function measures fd
+ | DEF_internal_mutrec fds -> List.fold_left scan_function measures fds
| _ -> measures
in
let measures = List.fold_left scan_def Bindings.empty defs in
@@ -4482,7 +4488,7 @@ let rewrite_explicit_measure env (Defs defs) =
(* NB: the Coq backend relies on recognising the #rec# prefix *)
let rec_id = function
| Id_aux (Id id,l)
- | Id_aux (DeIid id,l) -> Id_aux (Id ("#rec#" ^ id),Generated l)
+ | Id_aux (Operator id,l) -> Id_aux (Id ("#rec#" ^ id),Generated l)
in
let limit = mk_id "#reclimit" in
(* Add helper function with extra argument to spec *)
@@ -4505,7 +4511,7 @@ let rewrite_explicit_measure env (Defs defs) =
| exception Not_found -> [vs]
in
(* Add extra argument and assertion to each funcl, and rewrite recursive calls *)
- let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann) as fcl) =
+ let rewrite_funcl recset (FCL_aux (FCL_Funcl (id,pexp),fcl_ann) as fcl) =
let loc = Parse_ast.Generated (fst fcl_ann) in
let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in
let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in
@@ -4532,15 +4538,15 @@ let rewrite_explicit_measure env (Defs defs) =
let body =
fold_exp { id_exp_alg with
e_app = (fun (f,args) ->
- if Id.compare f id == 0
- then E_app (rec_id id, args@[tick])
+ if IdSet.mem f recset
+ then E_app (rec_id f, args@[tick])
else E_app (f, args))
} body
in
let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in
FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),fcl_ann)
in
- let rewrite_function (FD_aux (FD_function (r,t,e,fcls),ann) as fd) =
+ let rewrite_function recset (FD_aux (FD_function (r,t,e,fcls),ann) as fd) =
let loc = Parse_ast.Generated (fst ann) in
match fcls with
| FCL_aux (FCL_Funcl (id,_),fcl_ann)::_ -> begin
@@ -4562,15 +4568,16 @@ let rewrite_explicit_measure env (Defs defs) =
| _, P_aux (P_tup ps,_) -> ps
| _, _ -> [measure_pat]
in
- let mk_wrap i (P_aux (p,(l,_))) =
+ let mk_wrap i (P_aux (p,(l,_)) as p_full) =
let id =
match p with
| P_id id
| P_typ (_,(P_aux (P_id id,_))) -> id
+ | P_lit _
| P_wild
| P_typ (_,(P_aux (P_wild,_))) ->
mk_id ("_arg" ^ string_of_int i)
- | _ -> raise (Reporting.err_todo l "Measure patterns can only be identifiers or wildcards")
+ | _ -> raise (Reporting.err_todo l ("Measure patterns can only be identifiers or wildcards, not " ^ string_of_pat p_full))
in
P_aux (P_id id,(loc,empty_tannot)),
E_aux (E_id id,(loc,empty_tannot))
@@ -4588,15 +4595,22 @@ let rewrite_explicit_measure env (Defs defs) =
let new_rec =
Rec_aux (Rec_measure (P_aux (P_tup (List.map (fun _ -> P_aux (P_wild,(loc,empty_tannot))) measure_pats @ [P_aux (P_id limit,(loc,empty_tannot))]),(loc,empty_tannot)), E_aux (E_id limit, (loc,empty_tannot))), loc)
in
- [FD_aux (FD_function (new_rec,t,e,List.map rewrite_funcl fcls),ann);
- FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)]
- | exception Not_found -> [fd]
+ FD_aux (FD_function (new_rec,t,e,List.map (rewrite_funcl recset) fcls),ann),
+ [FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)]
+ | exception Not_found -> fd,[]
end
- | _ -> [fd]
+ | _ -> fd,[]
in
let rewrite_def = function
| DEF_spec vs -> List.map (fun vs -> DEF_spec vs) (rewrite_spec vs)
- | DEF_fundef fd -> List.map (fun f -> DEF_fundef f) (rewrite_function fd)
+ | DEF_fundef fd ->
+ let fd,extra = rewrite_function (IdSet.singleton (id_of_fundef fd)) fd in
+ List.map (fun f -> DEF_fundef f) (fd::extra)
+ | (DEF_internal_mutrec fds) as d ->
+ let recset = ids_of_def d in
+ let fds,extras = List.split (List.map (rewrite_function recset) fds) in
+ let extras = List.concat extras in
+ (DEF_internal_mutrec fds)::(List.map (fun f -> DEF_fundef f) extras)
| d -> [d]
in
Defs (List.flatten (List.map rewrite_def defs))
@@ -4656,176 +4670,270 @@ let if_mono f env defs =
let if_mwords f env defs =
if !Pretty_print_lem.opt_mwords then f env defs else if_mono f env defs
-let rewrite_defs_lem = [
- ("realise_mappings", rewrite_defs_realise_mappings);
- ("remove_mapping_valspecs", remove_mapping_valspecs);
- ("toplevel_string_append", rewrite_defs_toplevel_string_append);
- ("pat_string_append", rewrite_defs_pat_string_append);
- ("mapping_builtins", rewrite_defs_mapping_patterns);
- ("mono_rewrites", mono_rewrites);
- ("recheck_defs", if_mono recheck_defs);
- ("rewrite_toplevel_nexps", if_mono rewrite_toplevel_nexps);
- ("monomorphise", if_mono monomorphise);
- ("recheck_defs", if_mwords recheck_defs);
- ("add_bitvector_casts", if_mwords (fun _ -> Monomorphise.add_bitvector_casts));
- ("rewrite_atoms_to_singletons", if_mono (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
- ("recheck_defs", if_mwords recheck_defs);
- ("rewrite_undefined", rewrite_undefined_if_gen false);
- ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
- ("remove_not_pats", rewrite_defs_not_pats);
- ("remove_impossible_int_cases", Constant_propagation.remove_impossible_int_cases);
- ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
- ("vector_concat_assignments", rewrite_vector_concat_assignments);
- ("tuple_assignments", rewrite_tuple_assignments);
- ("simple_assignments", rewrite_simple_assignments);
- ("remove_vector_concat", rewrite_defs_remove_vector_concat);
- ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
- ("remove_numeral_pats", rewrite_defs_remove_numeral_pats);
- ("guarded_pats", rewrite_defs_guarded_pats);
- ("bitvector_exps", rewrite_bitvector_exps);
- (* ("register_ref_writes", rewrite_register_ref_writes); *)
- ("nexp_ids", rewrite_defs_nexp_ids);
- ("fix_val_specs", rewrite_fix_val_specs);
- ("split_execute", rewrite_split_fun_ctor_pats "execute");
- ("recheck_defs", recheck_defs);
- ("exp_lift_assign", rewrite_defs_exp_lift_assign);
- (* ("remove_assert", rewrite_defs_remove_assert); *)
- ("top_sort_defs", fun _ -> top_sort_defs);
- (* ("sizeof", rewrite_sizeof); *)
- ("early_return", rewrite_defs_early_return);
- ("fix_val_specs", rewrite_fix_val_specs);
- (* early_return currently breaks the types *)
- ("recheck_defs", recheck_defs);
- ("remove_blocks", rewrite_defs_remove_blocks);
- ("letbind_effects", rewrite_defs_letbind_effects);
- ("remove_e_assign", rewrite_defs_remove_e_assign);
- ("internal_lets", rewrite_defs_internal_lets);
- ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds);
- ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns);
- ("merge function clauses", merge_funcls);
- ("recheck_defs", recheck_defs)
+type rewriter =
+ | Basic_rewriter of (Env.t -> tannot defs -> tannot defs)
+ | Bool_rewriter of (bool -> rewriter)
+ | String_rewriter of (string -> rewriter)
+ | Literal_rewriter of ((lit -> bool) -> rewriter)
+
+type rewriter_arg =
+ | If_mono_arg
+ | If_mwords_arg
+ | Bool_arg of bool
+ | String_arg of string
+ | Literal_arg of string
+
+let instantiate_rewrite rewriter args =
+ let selector_function = function
+ | "ocaml" -> rewrite_lit_ocaml
+ | "lem" -> rewrite_lit_lem
+ | "all" -> (fun _ -> true)
+ | arg ->
+ raise (Reporting.err_general Parse_ast.Unknown ("No rewrite for literal target \"" ^ arg ^ "\", valid targets are ocaml/lem/all"))
+ in
+ let instantiate rewriter arg =
+ match rewriter, arg with
+ | Basic_rewriter rw, If_mono_arg -> Basic_rewriter (if_mono rw)
+ | Basic_rewriter rw, If_mwords_arg -> Basic_rewriter (if_mwords rw)
+ | Bool_rewriter rw, Bool_arg b -> rw b
+ | String_rewriter rw, String_arg str -> rw str
+ | Literal_rewriter rw, Literal_arg selector -> rw (selector_function selector)
+ | _, _ ->
+ raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Invalid rewrite argument")
+ in
+ match List.fold_left instantiate rewriter args with
+ | Basic_rewriter rw -> rw
+ | _ ->
+ raise (Reporting.err_general Parse_ast.Unknown "Rewrite not fully instantiated")
+
+let all_rewrites = [
+ ("no_effect_check", Basic_rewriter (fun _ defs -> opt_no_effects := true; defs));
+ ("recheck_defs", Basic_rewriter recheck_defs);
+ ("recheck_defs_without_effects", Basic_rewriter recheck_defs_without_effects);
+ ("optimize_recheck_defs", Basic_rewriter (fun _ -> Optimize.recheck));
+ ("realise_mappings", Basic_rewriter rewrite_defs_realise_mappings);
+ ("remove_mapping_valspecs", Basic_rewriter remove_mapping_valspecs);
+ ("toplevel_string_append", Basic_rewriter rewrite_defs_toplevel_string_append);
+ ("pat_string_append", Basic_rewriter rewrite_defs_pat_string_append);
+ ("mapping_builtins", Basic_rewriter rewrite_defs_mapping_patterns);
+ ("mono_rewrites", Basic_rewriter mono_rewrites);
+ ("toplevel_nexps", Basic_rewriter rewrite_toplevel_nexps);
+ ("monomorphise", Basic_rewriter monomorphise);
+ ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
+ ("add_bitvector_casts", Basic_rewriter (fun _ -> Monomorphise.add_bitvector_casts));
+ ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
+ ("remove_impossible_int_cases", Basic_rewriter Constant_propagation.remove_impossible_int_cases);
+ ("const_prop_mutrec", Basic_rewriter Constant_propagation_mutrec.rewrite_defs);
+ ("make_cases_exhaustive", Basic_rewriter MakeExhaustive.rewrite);
+ ("undefined", Bool_rewriter (fun b -> Basic_rewriter (rewrite_undefined_if_gen b)));
+ ("vector_string_pats_to_bit_list", Basic_rewriter rewrite_defs_vector_string_pats_to_bit_list);
+ ("remove_not_pats", Basic_rewriter rewrite_defs_not_pats);
+ ("pattern_literals", Literal_rewriter (fun f -> Basic_rewriter (rewrite_defs_pat_lits f)));
+ ("vector_concat_assignments", Basic_rewriter rewrite_vector_concat_assignments);
+ ("tuple_assignments", Basic_rewriter rewrite_tuple_assignments);
+ ("simple_assignments", Basic_rewriter rewrite_simple_assignments);
+ ("remove_vector_concat", Basic_rewriter rewrite_defs_remove_vector_concat);
+ ("remove_bitvector_pats", Basic_rewriter rewrite_defs_remove_bitvector_pats);
+ ("remove_numeral_pats", Basic_rewriter rewrite_defs_remove_numeral_pats);
+ ("guarded_pats", Basic_rewriter rewrite_defs_guarded_pats);
+ ("bitvector_exps", Basic_rewriter rewrite_bitvector_exps);
+ ("exp_lift_assign", Basic_rewriter rewrite_defs_exp_lift_assign);
+ ("early_return", Basic_rewriter rewrite_defs_early_return);
+ ("nexp_ids", Basic_rewriter rewrite_defs_nexp_ids);
+ ("fix_val_specs", Basic_rewriter rewrite_fix_val_specs);
+ ("remove_blocks", Basic_rewriter rewrite_defs_remove_blocks);
+ ("letbind_effects", Basic_rewriter rewrite_defs_letbind_effects);
+ ("remove_e_assign", Basic_rewriter rewrite_defs_remove_e_assign);
+ ("internal_lets", Basic_rewriter rewrite_defs_internal_lets);
+ ("remove_superfluous_letbinds", Basic_rewriter rewrite_defs_remove_superfluous_letbinds);
+ ("remove_superfluous_returns", Basic_rewriter rewrite_defs_remove_superfluous_returns);
+ ("merge_function_clauses", Basic_rewriter merge_funcls);
+ ("minimise_recursive_functions", Basic_rewriter minimise_recursive_functions);
+ ("move_termination_measures", Basic_rewriter move_termination_measures);
+ ("rewrite_explicit_measure", Basic_rewriter rewrite_explicit_measure);
+ ("simple_types", Basic_rewriter rewrite_simple_types);
+ ("overload_cast", Basic_rewriter rewrite_overload_cast);
+ ("top_sort_defs", Basic_rewriter (fun _ -> top_sort_defs));
+ ("constant_fold", Basic_rewriter (fun _ -> Constant_fold.rewrite_constant_function_calls));
+ ("split", String_rewriter (fun str -> Basic_rewriter (rewrite_split_fun_ctor_pats str)))
+ ]
+
+let rewrites_lem = [
+ ("realise_mappings", []);
+ ("remove_mapping_valspecs", []);
+ ("toplevel_string_append", []);
+ ("pat_string_append", []);
+ ("mapping_builtins", []);
+ ("mono_rewrites", []);
+ ("recheck_defs", [If_mono_arg]);
+ ("undefined", [Bool_arg false]);
+ ("toplevel_nexps", [If_mono_arg]);
+ ("monomorphise", [If_mono_arg]);
+ ("recheck_defs", [If_mwords_arg]);
+ ("add_bitvector_casts", [If_mwords_arg]);
+ ("atoms_to_singletons", [If_mono_arg]);
+ ("recheck_defs", [If_mwords_arg]);
+ ("vector_string_pats_to_bit_list", []);
+ ("remove_not_pats", []);
+ ("remove_impossible_int_cases", []);
+ ("pattern_literals", [Literal_arg "lem"]);
+ ("vector_concat_assignments", []);
+ ("tuple_assignments", []);
+ ("simple_assignments", []);
+ ("remove_vector_concat", []);
+ ("remove_bitvector_pats", []);
+ ("remove_numeral_pats", []);
+ ("guarded_pats", []);
+ ("bitvector_exps", []);
+ (* ("register_ref_writes", rewrite_register_ref_writes); *)
+ ("nexp_ids", []);
+ ("fix_val_specs", []);
+ ("split", [String_arg "execute"]);
+ ("recheck_defs", []);
+ ("top_sort_defs", []);
+ ("const_prop_mutrec", []);
+ ("vector_string_pats_to_bit_list", []);
+ ("exp_lift_assign", []);
+ ("early_return", []);
+ ("fix_val_specs", []);
+ (* early_return currently breaks the types *)
+ ("recheck_defs", []);
+ ("remove_blocks", []);
+ ("letbind_effects", []);
+ ("remove_e_assign", []);
+ ("internal_lets", []);
+ ("remove_superfluous_letbinds", []);
+ ("remove_superfluous_returns", []);
+ ("merge_function_clauses", []);
+ ("recheck_defs", [])
]
-let rewrite_defs_coq = [
- ("realise_mappings", rewrite_defs_realise_mappings);
- ("remove_mapping_valspecs", remove_mapping_valspecs);
- ("toplevel_string_append", rewrite_defs_toplevel_string_append);
- ("pat_string_append", rewrite_defs_pat_string_append);
- ("mapping_builtins", rewrite_defs_mapping_patterns);
- ("rewrite_undefined", rewrite_undefined_if_gen true);
- ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
- ("remove_not_pats", rewrite_defs_not_pats);
- ("remove_impossible_int_cases", Constant_propagation.remove_impossible_int_cases);
- ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
- ("vector_concat_assignments", rewrite_vector_concat_assignments);
- ("tuple_assignments", rewrite_tuple_assignments);
- ("simple_assignments", rewrite_simple_assignments);
- ("remove_vector_concat", rewrite_defs_remove_vector_concat);
- ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
- ("remove_numeral_pats", rewrite_defs_remove_numeral_pats);
- ("guarded_pats", rewrite_defs_guarded_pats);
- ("bitvector_exps", rewrite_bitvector_exps);
- (* ("register_ref_writes", rewrite_register_ref_writes); *)
- ("nexp_ids", rewrite_defs_nexp_ids);
- ("fix_val_specs", rewrite_fix_val_specs);
- ("split_execute", rewrite_split_fun_ctor_pats "execute");
- ("minimise_recursive_functions", minimise_recursive_functions);
- ("recheck_defs", recheck_defs);
- ("exp_lift_assign", rewrite_defs_exp_lift_assign);
- (* ("remove_assert", rewrite_defs_remove_assert); *)
- ("move_termination_measures", move_termination_measures);
- ("top_sort_defs", fun _ -> top_sort_defs);
- ("early_return", rewrite_defs_early_return);
- (* merge funcls before adding the measure argument so that it doesn't
+let rewrites_coq = [
+ ("realise_mappings", []);
+ ("remove_mapping_valspecs", []);
+ ("toplevel_string_append", []);
+ ("pat_string_append", []);
+ ("mapping_builtins", []);
+ ("undefined", [Bool_arg true]);
+ ("vector_string_pats_to_bit_list", []);
+ ("remove_not_pats", []);
+ ("remove_impossible_int_cases", []);
+ ("pattern_literals", [Literal_arg "lem"]);
+ ("vector_concat_assignments", []);
+ ("tuple_assignments", []);
+ ("simple_assignments", []);
+ ("remove_vector_concat", []);
+ ("remove_bitvector_pats", []);
+ ("remove_numeral_pats", []);
+ ("guarded_pats", []);
+ ("bitvector_exps", []);
+ (* ("register_ref_writes", rewrite_register_ref_writes); *)
+ ("nexp_ids", []);
+ ("fix_val_specs", []);
+ ("split", [String_arg "execute"]);
+ ("minimise_recursive_functions", []);
+ ("recheck_defs", []);
+ ("exp_lift_assign", []);
+ (* ("remove_assert", rewrite_defs_remove_assert); *)
+ ("move_termination_measures", []);
+ ("top_sort_defs", []);
+ ("early_return", []);
+ (* merge funcls before adding the measure argument so that it doesn't
disappear into an internal pattern match *)
- ("merge function clauses", merge_funcls);
- ("recheck_defs_without_effects", recheck_defs_without_effects);
- ("make_cases_exhaustive", MakeExhaustive.rewrite);
- ("rewrite_explicit_measure", rewrite_explicit_measure);
- ("recheck_defs_without_effects", recheck_defs_without_effects);
- ("fix_val_specs", rewrite_fix_val_specs);
- ("remove_blocks", rewrite_defs_remove_blocks);
- ("letbind_effects", rewrite_defs_letbind_effects);
- ("remove_e_assign", rewrite_defs_remove_e_assign);
- ("internal_lets", rewrite_defs_internal_lets);
- ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds);
- ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns);
- ("recheck_defs", recheck_defs)
+ ("merge_function_clauses", []);
+ ("recheck_defs_without_effects", []);
+ ("make_cases_exhaustive", []);
+ ("rewrite_explicit_measure", []);
+ ("recheck_defs_without_effects", []);
+ ("fix_val_specs", []);
+ ("remove_blocks", []);
+ ("letbind_effects", []);
+ ("remove_e_assign", []);
+ ("internal_lets", []);
+ ("remove_superfluous_letbinds", []);
+ ("remove_superfluous_returns", []);
+ ("recheck_defs", [])
]
-let rewrite_defs_ocaml = [
- (* ("undefined", rewrite_undefined); *)
- ("no_effect_check", (fun _ defs -> opt_no_effects := true; defs));
- ("realise_mappings", rewrite_defs_realise_mappings);
- ("toplevel_string_append", rewrite_defs_toplevel_string_append);
- ("pat_string_append", rewrite_defs_pat_string_append);
- ("mapping_builtins", rewrite_defs_mapping_patterns);
- ("rewrite_undefined", rewrite_undefined_if_gen false);
- ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
- ("pat_lits", rewrite_defs_pat_lits rewrite_lit_ocaml);
- ("vector_concat_assignments", rewrite_vector_concat_assignments);
- ("tuple_assignments", rewrite_tuple_assignments);
- ("simple_assignments", rewrite_simple_assignments);
- ("remove_not_pats", rewrite_defs_not_pats);
- ("remove_vector_concat", rewrite_defs_remove_vector_concat);
- ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
- ("remove_numeral_pats", rewrite_defs_remove_numeral_pats);
- ("exp_lift_assign", rewrite_defs_exp_lift_assign);
- ("top_sort_defs", fun _ -> top_sort_defs);
- ("simple_types", rewrite_simple_types);
- ("overload_cast", rewrite_overload_cast);
- (* ("separate_numbs", rewrite_defs_separate_numbs) *)
+let rewrites_ocaml = [
+ ("no_effect_check", []);
+ ("realise_mappings", []);
+ ("toplevel_string_append", []);
+ ("pat_string_append", []);
+ ("mapping_builtins", []);
+ ("undefined", [Bool_arg false]);
+ ("vector_string_pats_to_bit_list", []);
+ ("pattern_literals", [Literal_arg "ocaml"]);
+ ("vector_concat_assignments", []);
+ ("tuple_assignments", []);
+ ("simple_assignments", []);
+ ("remove_not_pats", []);
+ ("remove_vector_concat", []);
+ ("remove_bitvector_pats", []);
+ ("remove_numeral_pats", []);
+ ("exp_lift_assign", []);
+ ("top_sort_defs", []);
+ ("simple_types", []);
+ ("overload_cast", [])
]
-let opt_separate_execute = ref false
-
-let if_separate f env defs =
- if !opt_separate_execute then f env defs else defs
-
-let rewrite_defs_c = [
- ("no_effect_check", (fun _ defs -> opt_no_effects := true; defs));
-
- (* Remove bidirectional mappings *)
- ("realise_mappings", rewrite_defs_realise_mappings);
- ("toplevel_string_append", rewrite_defs_toplevel_string_append);
- ("pat_string_append", rewrite_defs_pat_string_append);
- ("mapping_builtins", rewrite_defs_mapping_patterns);
-
- (* Monomorphisation *)
- ("mono_rewrites", if_mono mono_rewrites);
- ("recheck_defs", if_mono recheck_defs);
- ("rewrite_toplevel_nexps", if_mono rewrite_toplevel_nexps);
- ("monomorphise", if_mono monomorphise);
- ("rewrite_atoms_to_singletons", if_mono (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
- ("recheck_defs", if_mono recheck_defs);
-
- ("rewrite_undefined", rewrite_undefined_if_gen false);
- ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
- ("remove_not_pats", rewrite_defs_not_pats);
- ("pat_lits", rewrite_defs_pat_lits (fun _ -> true));
- ("vector_concat_assignments", rewrite_vector_concat_assignments);
- ("tuple_assignments", rewrite_tuple_assignments);
- ("simple_assignments", rewrite_simple_assignments);
- ("remove_vector_concat", rewrite_defs_remove_vector_concat);
- ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
- ("split_execute", if_separate (rewrite_split_fun_ctor_pats "execute"));
- ("exp_lift_assign", rewrite_defs_exp_lift_assign);
- ("merge_function_clauses", merge_funcls);
- ("recheck_defs", fun _ -> Optimize.recheck)
+let rewrites_c = [
+ ("no_effect_check", []);
+ ("realise_mappings", []);
+ ("toplevel_string_append", []);
+ ("pat_string_append", []);
+ ("mapping_builtins", []);
+ ("mono_rewrites", [If_mono_arg]);
+ ("recheck_defs", [If_mono_arg]);
+ ("toplevel_nexps", [If_mono_arg]);
+ ("monomorphise", [If_mono_arg]);
+ ("atoms_to_singletons", [If_mono_arg]);
+ ("recheck_defs", [If_mono_arg]);
+ ("undefined", [Bool_arg false]);
+ ("vector_string_pats_to_bit_list", []);
+ ("remove_not_pats", []);
+ ("pattern_literals", [Literal_arg "all"]);
+ ("vector_concat_assignments", []);
+ ("tuple_assignments", []);
+ ("simple_assignments", []);
+ ("remove_vector_concat", []);
+ ("remove_bitvector_pats", []);
+ ("exp_lift_assign", []);
+ ("merge_function_clauses", []);
+ ("optimize_recheck_defs", []);
+ ("constant_fold", [])
]
-let rewrite_defs_interpreter = [
- ("no_effect_check", (fun _ defs -> opt_no_effects := true; defs));
- ("realise_mappings", rewrite_defs_realise_mappings);
- ("toplevel_string_append", rewrite_defs_toplevel_string_append);
- ("pat_string_append", rewrite_defs_pat_string_append);
- ("mapping_builtins", rewrite_defs_mapping_patterns);
- ("rewrite_undefined", rewrite_undefined_if_gen false);
- ("vector_concat_assignments", rewrite_vector_concat_assignments);
- ("tuple_assignments", rewrite_tuple_assignments);
- ("simple_assignments", rewrite_simple_assignments)
+let rewrites_interpreter = [
+ ("no_effect_check", []);
+ ("realise_mappings", []);
+ ("toplevel_string_append", []);
+ ("pat_string_append", []);
+ ("mapping_builtins", []);
+ ("undefined", [Bool_arg false]);
+ ("vector_concat_assignments", []);
+ ("tuple_assignments", []);
+ ("simple_assignments", [])
]
+let rewrites_target tgt =
+ match tgt with
+ | "coq" -> rewrites_coq
+ | "lem" -> rewrites_lem
+ | "ocaml" -> rewrites_ocaml
+ | "c" -> rewrites_c
+ | "ir" -> rewrites_c
+ | "sail" -> []
+ | "latex" -> []
+ | "interpreter" -> rewrites_interpreter
+ | "tofrominterp" -> rewrites_interpreter
+ | "marshal" -> rewrites_interpreter
+ | _ ->
+ raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ ("Invalid target for rewriting: " ^ tgt))
+
+let rewrite_defs_target tgt =
+ List.map (fun (name, args) -> (name, instantiate_rewrite (List.assoc name all_rewrites) args)) (rewrites_target tgt)
+
let rewrite_check_annot =
let check_annot exp =
try
diff --git a/src/rewrites.mli b/src/rewrites.mli
index 811d52e8..330f10b4 100644
--- a/src/rewrites.mli
+++ b/src/rewrites.mli
@@ -59,7 +59,6 @@ val opt_dmono_analysis : int ref
val opt_auto_mono : bool ref
val opt_dall_split_errors : bool ref
val opt_dmono_continue : bool ref
-val opt_separate_execute : bool ref
(* Generate a fresh id with the given prefix *)
val fresh_id : string -> l -> id
@@ -67,25 +66,31 @@ val fresh_id : string -> l -> id
(* Re-write undefined to functions created by -undefined_gen flag *)
val rewrite_undefined : bool -> Env.t -> tannot defs -> tannot defs
-(* Perform rewrites to exclude AST nodes not supported for ocaml out*)
-val rewrite_defs_ocaml : (string * (Env.t -> tannot defs -> tannot defs)) list
+(* Perform rewrites to create an AST supported for a specific target *)
+val rewrite_defs_target : string -> (string * (Env.t -> tannot defs -> tannot defs)) list
-(* Perform rewrites to exclude AST nodes not supported for interpreter *)
-val rewrite_defs_interpreter : (string * (Env.t -> tannot defs -> tannot defs)) list
+type rewriter =
+ | Basic_rewriter of (Env.t -> tannot defs -> tannot defs)
+ | Bool_rewriter of (bool -> rewriter)
+ | String_rewriter of (string -> rewriter)
+ | Literal_rewriter of ((lit -> bool) -> rewriter)
-(* Perform rewrites to exclude AST nodes not supported for lem out*)
-val rewrite_defs_lem : (string * (Env.t -> tannot defs -> tannot defs)) list
+val rewrite_lit_ocaml : lit -> bool
+val rewrite_lit_lem : lit -> bool
-(* Perform rewrites to exclude AST nodes not supported for coq out*)
-val rewrite_defs_coq : (string * (Env.t -> tannot defs -> tannot defs)) list
+type rewriter_arg =
+ | If_mono_arg
+ | If_mwords_arg
+ | Bool_arg of bool
+ | String_arg of string
+ | Literal_arg of string
+
+val all_rewrites : (string * rewriter) list
(* Warn about matches where we add a default case for Coq because they're not
exhaustive *)
val opt_coq_warn_nonexhaustive : bool ref
-(* Perform rewrites to exclude AST nodes not supported for C compilation *)
-val rewrite_defs_c : (string * (Env.t -> tannot defs -> tannot defs)) list
-
(* This is a special rewriter pass that checks AST invariants without
actually doing any re-writing *)
val rewrite_defs_check : (string * (Env.t -> tannot defs -> tannot defs)) list
diff --git a/src/sail.ml b/src/sail.ml
index d71e23c7..a0fc2e75 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -56,17 +56,8 @@ let lib = ref ([] : string list)
let opt_file_out : string option ref = ref None
let opt_interactive_script : string option ref = ref None
let opt_print_version = ref false
-let opt_print_initial_env = ref false
-let opt_print_verbose = ref false
-let opt_print_lem = ref false
-let opt_print_tofrominterp = ref false
+let opt_target = ref None
let opt_tofrominterp_output_dir : string option ref = ref None
-let opt_print_ocaml = ref false
-let opt_print_c = ref false
-let opt_print_ir = ref false
-let opt_print_latex = ref false
-let opt_print_coq = ref false
-let opt_print_cgen = ref false
let opt_memo_z3 = ref false
let opt_sanity = ref false
let opt_includes_c = ref ([]:string list)
@@ -76,19 +67,18 @@ let opt_libs_coq = ref ([]:string list)
let opt_file_arguments = ref ([]:string list)
let opt_process_elf : string option ref = ref None
let opt_ocaml_generators = ref ([]:string list)
-let opt_marshal_defs = ref false
-let opt_slice = ref ([]:string list)
+
+let set_target name = Arg.Unit (fun _ -> opt_target := Some name)
let options = Arg.align ([
( "-o",
Arg.String (fun f -> opt_file_out := Some f),
"<prefix> select output filename prefix");
( "-i",
- Arg.Tuple [Arg.Set Interactive.opt_interactive; Arg.Set Initial_check.opt_undefined_gen],
+ Arg.Tuple [Arg.Set Interactive.opt_interactive],
" start interactive interpreter");
( "-is",
- Arg.Tuple [Arg.Set Interactive.opt_interactive; Arg.Set Initial_check.opt_undefined_gen;
- Arg.String (fun s -> opt_interactive_script := Some s)],
+ Arg.Tuple [Arg.Set Interactive.opt_interactive; Arg.String (fun s -> opt_interactive_script := Some s)],
"<filename> start interactive interpreter and execute commands in script");
( "-iout",
Arg.String (fun file -> Value.output_redirect (open_out file)),
@@ -100,22 +90,22 @@ let options = Arg.align ([
Arg.Clear Util.opt_warnings,
" do not print warnings");
( "-tofrominterp",
- Arg.Set opt_print_tofrominterp,
+ set_target "tofrominterp",
" output OCaml functions to translate between shallow embedding and interpreter");
( "-tofrominterp_lem",
- Arg.Set ToFromInterp_backend.lem_mode,
- " output embedding translation for the Lem backend rather than the OCaml backend");
+ Arg.Tuple [set_target "tofrominterp"; Arg.Set ToFromInterp_backend.lem_mode],
+ " output embedding translation for the Lem backend rather than the OCaml backend, implies -tofrominterp");
( "-tofrominterp_output_dir",
Arg.String (fun dir -> opt_tofrominterp_output_dir := Some dir),
" set a custom directory to output embedding translation OCaml");
( "-ocaml",
- Arg.Tuple [Arg.Set opt_print_ocaml; Arg.Set Initial_check.opt_undefined_gen],
+ Arg.Tuple [set_target "ocaml"; Arg.Set Initial_check.opt_undefined_gen],
" output an OCaml translated version of the input");
( "-ocaml-nobuild",
Arg.Set Ocaml_backend.opt_ocaml_nobuild,
" do not build generated OCaml");
( "-ocaml_trace",
- Arg.Tuple [Arg.Set opt_print_ocaml; Arg.Set Initial_check.opt_undefined_gen; Arg.Set Ocaml_backend.opt_trace_ocaml],
+ Arg.Tuple [set_target "ocaml"; Arg.Set Initial_check.opt_undefined_gen; Arg.Set Ocaml_backend.opt_trace_ocaml],
" output an OCaml translated version of the input with tracing instrumentation, implies -ocaml");
( "-ocaml_build_dir",
Arg.String (fun dir -> Ocaml_backend.opt_ocaml_build_dir := dir),
@@ -133,7 +123,7 @@ let options = Arg.align ([
Arg.Set Type_check.opt_smt_linearize,
"(experimental) force linearization for constraints involving exponentials");
( "-latex",
- Arg.Tuple [Arg.Set opt_print_latex; Arg.Clear Type_check.opt_expand_valspec],
+ Arg.Tuple [set_target "latex"; Arg.Clear Type_check.opt_expand_valspec],
" pretty print the input to LaTeX");
( "-latex_prefix",
Arg.String (fun prefix -> Latex.opt_prefix := prefix),
@@ -142,13 +132,13 @@ let options = Arg.align ([
Arg.Clear Latex.opt_simple_val,
" print full valspecs in LaTeX output");
( "-marshal",
- Arg.Tuple [Arg.Set opt_marshal_defs; Arg.Set Initial_check.opt_undefined_gen],
+ Arg.Tuple [set_target "marshal"; Arg.Set Initial_check.opt_undefined_gen],
" OCaml-marshal out the rewritten AST to a file");
( "-ir",
- Arg.Set opt_print_ir,
+ set_target "ir",
" print intermediate representation");
( "-c",
- Arg.Tuple [Arg.Set opt_print_c; Arg.Set Initial_check.opt_undefined_gen],
+ Arg.Tuple [set_target "c"; Arg.Set Initial_check.opt_undefined_gen],
" output a C translated version of the input");
( "-c_include",
Arg.String (fun i -> opt_includes_c := i::!opt_includes_c),
@@ -159,9 +149,6 @@ let options = Arg.align ([
( "-c_no_rts",
Arg.Set C_backend.opt_no_rts,
" do not include the Sail runtime" );
- ( "-c_separate_execute",
- Arg.Set Rewrites.opt_separate_execute,
- " separate execute scattered function into multiple functions");
( "-c_prefix",
Arg.String (fun prefix -> C_backend.opt_prefix := prefix),
"<prefix> prefix generated C functions" );
@@ -191,20 +178,14 @@ let options = Arg.align ([
( "-Oconstant_fold",
Arg.Set Constant_fold.optimize_constant_fold,
" apply constant folding optimizations");
- ( "-Oexperimental",
- Arg.Set C_backend.optimize_experimental,
- " turn on additional, experimental optimisations");
( "-static",
Arg.Set C_backend.opt_static,
" make generated C functions static");
( "-trace",
Arg.Tuple [Arg.Set Ocaml_backend.opt_trace_ocaml],
" instrument output with tracing");
- ( "-cgen",
- Arg.Set opt_print_cgen,
- " generate CGEN source");
( "-lem",
- Arg.Set opt_print_lem,
+ set_target "lem",
" output a Lem translated version of the input");
( "-lem_output_dir",
Arg.String (fun dir -> Process_file.opt_lem_output_dir := Some dir),
@@ -222,7 +203,7 @@ let options = Arg.align ([
Arg.Set Pretty_print_lem.opt_mwords,
" use native machine word library for Lem output");
( "-coq",
- Arg.Set opt_print_coq,
+ set_target "coq",
" output a Coq translated version of the input");
( "-coq_output_dir",
Arg.String (fun dir -> Process_file.opt_coq_output_dir := Some dir),
@@ -292,17 +273,20 @@ let options = Arg.align ([
( "-dmono_continue",
Arg.Set Rewrites.opt_dmono_continue,
" continue despite monomorphisation errors");
+ ( "-const_prop_mutrec",
+ Arg.String (fun name -> Constant_propagation_mutrec.targets := Ast_util.mk_id name :: !Constant_propagation_mutrec.targets),
+ " unroll function in a set of mutually recursive functions");
( "-verbose",
Arg.Int (fun verbosity -> Util.opt_verbosity := verbosity),
" produce verbose output");
( "-output_sail",
- Arg.Set opt_print_verbose,
+ set_target "sail",
" print Sail code after type checking and initial rewriting");
( "-ddump_tc_ast",
Arg.Set opt_ddump_tc_ast,
" (debug) dump the typechecked ast to stdout");
( "-ddump_rewrite_ast",
- Arg.String (fun l -> opt_ddump_rewrite_ast := Some (l, 0)),
+ Arg.String (fun l -> opt_ddump_rewrite_ast := Some (l, 0); Specialize.opt_ddump_spec_ast := Some (l, 0)),
"<prefix> (debug) dump the ast after each rewriting step to <prefix>_<i>.lem");
( "-ddump_flow_graphs",
Arg.Set Jib_compile.opt_debug_flow_graphs,
@@ -328,9 +312,6 @@ let options = Arg.align ([
( "-dprofile",
Arg.Set Profile.opt_profile,
" (debug) provide basic profiling information for rewriting passes within Sail");
- ( "-slice",
- Arg.String (fun s -> opt_slice := s::!opt_slice),
- "<id> produce version of input restricted to the given function");
( "-v",
Arg.Set opt_print_version,
" print version");
@@ -377,7 +358,7 @@ let load_files ?check:(check=false) type_envs files =
("out.sail", ast, type_envs)
else
let ast = Scattered.descatter ast in
- let ast = rewrite_ast type_envs ast in
+ let ast = rewrite_ast_initial type_envs ast in
let out_name = match !opt_file_out with
| None when parsed = [] -> "out.sail"
@@ -386,151 +367,139 @@ let load_files ?check:(check=false) type_envs files =
(out_name, ast, type_envs)
-let main() =
+let prover_regstate tgt ast type_envs =
+ match tgt with
+ | Some "coq" ->
+ State.add_regstate_defs true type_envs ast
+ | Some "lem" ->
+ State.add_regstate_defs !Pretty_print_lem.opt_mwords type_envs ast
+ | _ ->
+ type_envs, ast
+
+let target name out_name ast type_envs =
+ match name with
+ | None -> ()
+
+ | Some "sail" ->
+ Pretty_print_sail.pp_defs stdout ast
+
+ | Some "ocaml" ->
+ let ocaml_generator_info =
+ match !opt_ocaml_generators with
+ | [] -> None
+ | _ -> Some (Ocaml_backend.orig_types_for_ocaml_generator ast, !opt_ocaml_generators)
+ in
+ let out = match !opt_file_out with None -> "out" | Some s -> s in
+ Ocaml_backend.ocaml_compile out ast ocaml_generator_info
+
+ | Some "tofrominterp" ->
+ let out = match !opt_file_out with None -> "out" | Some s -> s in
+ ToFromInterp_backend.tofrominterp_output !opt_tofrominterp_output_dir out ast
+
+ | Some "marshal" ->
+ let out_filename = match !opt_file_out with None -> "out" | Some s -> s in
+ let f = open_out_bin (out_filename ^ ".defs") in
+ let remove_prover (l, tannot) =
+ if Type_check.is_empty_tannot tannot then
+ (l, Type_check.empty_tannot)
+ else
+ (l, Type_check.replace_env (Type_check.Env.set_prover None (Type_check.env_of_tannot tannot)) tannot)
+ in
+ Marshal.to_string (Ast_util.map_defs_annot remove_prover ast, Type_check.Env.set_prover None type_envs) [Marshal.Compat_32]
+ |> B64.encode
+ |> output_string f;
+ close_out f
+
+ | Some "c" ->
+ let ast_c, type_envs = Specialize.(specialize typ_ord_specialization ast type_envs) in
+ let ast_c, type_envs =
+ if !opt_specialize_c then
+ Specialize.(specialize' 2 int_specialization ast_c type_envs)
+ else
+ ast_c, type_envs
+ in
+ let close, output_chan = match !opt_file_out with Some f -> true, open_out (f ^ ".c") | None -> false, stdout in
+ Util.opt_warnings := true;
+ C_backend.compile_ast type_envs output_chan (!opt_includes_c) ast_c;
+ flush output_chan;
+ if close then close_out output_chan else ()
+
+ | Some "ir" ->
+ let ast_c, type_envs = Specialize.(specialize typ_ord_specialization ast type_envs) in
+ let ast_c, type_envs = Specialize.(specialize' 2 int_specialization_with_externs ast_c type_envs) in
+ let close, output_chan =
+ match !opt_file_out with
+ | Some f -> Util.opt_colors := false; (true, open_out (f ^ ".ir.sail"))
+ | None -> false, stdout
+ in
+ Util.opt_warnings := true;
+ let cdefs, _ = C_backend.jib_of_ast type_envs ast_c in
+ let str = Pretty_print_sail.to_string PPrint.(separate_map hardline Jib_util.pp_cdef cdefs) in
+ output_string output_chan (str ^ "\n");
+ flush output_chan;
+ if close then close_out output_chan else ()
+
+ | Some "lem" ->
+ output "" (Lem_out (!opt_libs_lem)) [(out_name, type_envs, ast)]
+
+ | Some "coq" ->
+ output "" (Coq_out (!opt_libs_coq)) [(out_name, type_envs, ast)]
+
+ | Some "latex" ->
+ Util.opt_warnings := true;
+ let latex_dir = match !opt_file_out with None -> "sail_latex" | Some s -> s in
+ begin
+ try
+ if not (Sys.is_directory latex_dir) then begin
+ prerr_endline ("Failure: latex output location exists and is not a directory: " ^ latex_dir);
+ exit 1
+ end
+ with Sys_error(_) -> Unix.mkdir latex_dir 0o755
+ end;
+ Latex.opt_directory := latex_dir;
+ let chan = open_out (Filename.concat latex_dir "commands.tex") in
+ output_string chan (Pretty_print_sail.to_string (Latex.defs ast));
+ close_out chan
+
+ | Some t ->
+ raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ ("Undefined target: " ^ t))
+
+let main () =
if !opt_print_version then
print_endline version
else
- let out_name, ast, type_envs = load_files Type_check.initial_env !opt_file_arguments in
- Util.opt_warnings := false; (* Don't show warnings during re-writing for now *)
-
- begin match !opt_process_elf, !opt_file_out with
- | Some elf, Some out ->
- begin
- let open Elf_loader in
- let chan = open_out out in
- load_elf ~writer:(write_file chan) elf;
- output_string chan "elf_entry\n";
- output_string chan (Big_int.to_string !opt_elf_entry ^ "\n");
- close_out chan;
- exit 0
- end
- | Some _, None ->
- prerr_endline "Failure: No output file given for processed ELF (option -o).";
- exit 1
- | None, _ -> ()
- end;
-
- let ocaml_generator_info =
- match !opt_ocaml_generators with
- | [] -> None
- | _ -> Some (Ocaml_backend.orig_types_for_ocaml_generator ast, !opt_ocaml_generators)
- in
-
begin
- (if !(Interactive.opt_interactive)
- then
- (Interactive.ast := Process_file.rewrite_ast_interpreter type_envs ast; Interactive.env := type_envs)
- else ());
- (if !(opt_sanity)
- then
- let _ = rewrite_ast_check type_envs ast in
- ()
- else ());
- (if !(opt_print_verbose)
- then ((Pretty_print_sail.pp_defs stdout) ast)
- else ());
- (match !opt_slice with
- | [] -> ()
- | ids ->
- let ids = List.map Ast_util.mk_id ids in
- let ids = Ast_util.IdSet.of_list ids in
- Pretty_print_sail.pp_defs stdout (Specialize.slice_defs type_envs ast ids));
- (if !(opt_print_ocaml)
- then
- let ast_ocaml = rewrite_ast_ocaml type_envs ast in
- let out = match !opt_file_out with None -> "out" | Some s -> s in
- Ocaml_backend.ocaml_compile out ast_ocaml ocaml_generator_info
- else ());
- (if !opt_print_tofrominterp
- then
- let ast = rewrite_ast_interpreter type_envs ast in
- let out = match !opt_file_out with None -> "out" | Some s -> s in
- ToFromInterp_backend.tofrominterp_output !opt_tofrominterp_output_dir out ast
- else ());
- (if !(opt_print_c)
- then
- let ast_c = rewrite_ast_c type_envs ast in
- let ast_c, type_envs = Specialize.(specialize typ_ord_specialization ast_c type_envs) in
- let ast_c, type_envs =
- if !opt_specialize_c then
- Specialize.(specialize' 2 int_specialization ast_c type_envs)
- else
- ast_c, type_envs
- in
- let output_chan = match !opt_file_out with Some f -> open_out (f ^ ".c") | None -> stdout in
- Util.opt_warnings := true;
- C_backend.compile_ast type_envs output_chan (!opt_includes_c) ast_c;
- close_out output_chan
- else ());
- (if !(opt_print_ir)
- then
- let ast_c = rewrite_ast_c type_envs ast in
- let ast_c, type_envs = Specialize.(specialize typ_ord_specialization ast_c type_envs) in
- let ast_c, type_envs = Specialize.(specialize' 2 int_specialization ast_c type_envs) in
- let output_chan =
- match !opt_file_out with
- | Some f -> Util.opt_colors := false; open_out (f ^ ".ir.sail")
- | None -> stdout
- in
- Util.opt_warnings := true;
- let cdefs, _ = C_backend.jib_of_ast type_envs ast_c in
- let str = Pretty_print_sail.to_string PPrint.(separate_map hardline Jib_util.pp_cdef cdefs) in
- output_string output_chan (str ^ "\n");
- close_out output_chan
- else ());
- (if !(opt_print_cgen)
- then Cgen_backend.output type_envs ast
- else ());
- (if !(opt_print_lem)
- then
- let mwords = !Pretty_print_lem.opt_mwords in
- let type_envs, ast_lem = State.add_regstate_defs mwords type_envs ast in
- let ast_lem = rewrite_ast_lem type_envs ast_lem in
- output "" (Lem_out (!opt_libs_lem)) [out_name,type_envs,ast_lem]
- else ());
- (if !(opt_print_coq)
- then
- let type_envs, ast_coq = State.add_regstate_defs true type_envs ast in
- let ast_coq = rewrite_ast_coq type_envs ast_coq in
- output "" (Coq_out (!opt_libs_coq)) [out_name,type_envs,ast_coq]
- else ());
- (if !(opt_print_latex)
- then
- begin
- Util.opt_warnings := true;
- let latex_dir = match !opt_file_out with None -> "sail_latex" | Some s -> s in
- begin
- try
- if not (Sys.is_directory latex_dir) then begin
- prerr_endline ("Failure: latex output location exists and is not a directory: " ^ latex_dir);
- exit 1
- end
- with Sys_error(_) -> Unix.mkdir latex_dir 0o755
- end;
- Latex.opt_directory := latex_dir;
- let chan = open_out (Filename.concat latex_dir "commands.tex") in
- output_string chan (Pretty_print_sail.to_string (Latex.defs ast));
- close_out chan
- end
- else ());
- (if !(opt_marshal_defs)
- then
+ let out_name, ast, type_envs = load_files Type_check.initial_env !opt_file_arguments in
+ Util.opt_warnings := false; (* Don't show warnings during re-writing for now *)
+
+ begin match !opt_process_elf, !opt_file_out with
+ | Some elf, Some out ->
begin
- let ast_marshal = rewrite_ast_interpreter type_envs ast in
- let out_filename = match !opt_file_out with None -> "out" | Some s -> s in
- let f = open_out_bin (out_filename ^ ".defs") in
- let remove_prover (l, tannot) =
- if Type_check.is_empty_tannot tannot then
- (l, Type_check.empty_tannot)
- else
- (l, Type_check.replace_env (Type_check.Env.set_prover None (Type_check.env_of_tannot tannot)) tannot)
- in
- Marshal.to_string (Ast_util.map_defs_annot remove_prover ast_marshal, Type_check.Env.set_prover None type_envs) [Marshal.Compat_32]
- |> B64.encode
- |> output_string f;
- close_out f
+ let open Elf_loader in
+ let chan = open_out out in
+ load_elf ~writer:(write_file chan) elf;
+ output_string chan "elf_entry\n";
+ output_string chan (Big_int.to_string !opt_elf_entry ^ "\n");
+ close_out chan;
+ exit 0
end
- else ());
+ | Some _, None ->
+ prerr_endline "Failure: No output file given for processed ELF (option -o).";
+ exit 1
+ | None, _ -> ()
+ end;
+
+ if !opt_sanity then
+ ignore (rewrite_ast_check type_envs ast)
+ else ();
+
+ let type_envs, ast = prover_regstate !opt_target ast type_envs in
+ let ast = match !opt_target with Some tgt -> rewrite_ast_target tgt type_envs ast | None -> ast in
+ target !opt_target out_name ast type_envs;
+
+ if !Interactive.opt_interactive then
+ (Interactive.ast := ast; Interactive.env := type_envs)
+ else ();
if !opt_memo_z3 then Constraint.save_digests () else ()
end
diff --git a/src/sail_lib.ml b/src/sail_lib.ml
index d1a21b73..39485769 100644
--- a/src/sail_lib.ml
+++ b/src/sail_lib.ml
@@ -187,36 +187,27 @@ let sint = function
let add_int (x, y) = Big_int.add x y
let sub_int (x, y) = Big_int.sub x y
let mult (x, y) = Big_int.mul x y
+
+(* This is euclidian division from lem *)
let quotient (x, y) = Big_int.div x y
-(* Big_int does not provide divide with rounding towards zero so roll
- our own, assuming that division of positive integers rounds down *)
-let quot_round_zero (x, y) =
- let posX = Big_int.greater_equal x Big_int.zero in
- let posY = Big_int.greater_equal y Big_int.zero in
- let absX = Big_int.abs x in
- let absY = Big_int.abs y in
- let q = Big_int.div absX absY in
- if posX != posY then
- Big_int.negate q
- else
- q
+(* This is the same as tdiv_int, kept for compatibility with old preludes *)
+let quot_round_zero (x, y) =
+ Big_int.integerDiv_t x y
(* The corresponding remainder function for above just respects the sign of x *)
-let rem_round_zero (x, y) =
- let posX = Big_int.greater_equal x Big_int.zero in
- let absX = Big_int.abs x in
- let absY = Big_int.abs y in
- let r = Big_int.modulus absX absY in
- if posX then
- r
- else
- Big_int.negate r
+let rem_round_zero (x, y) =
+ Big_int.integerRem_t x y
+(* Lem provides euclidian modulo by default *)
let modulus (x, y) = Big_int.modulus x y
let negate x = Big_int.negate x
+let tdiv_int (x, y) = Big_int.integerDiv_t x y
+
+let tmod_int (x, y) = Big_int.integerRem_t x y
+
let add_bit_with_carry (x, y, carry) =
match x, y, carry with
| B0, B0, B0 -> B0, B0
@@ -695,6 +686,7 @@ let string_of_zbit = function
| B1 -> "1"
let string_of_znat n = Big_int.to_string n
let string_of_zint n = Big_int.to_string n
+let string_of_zimplicit n = Big_int.to_string n
let string_of_zunit () = "()"
let string_of_zbool = function
| true -> "true"
diff --git a/src/slice.ml b/src/slice.ml
index f50104c4..fa574b7f 100644
--- a/src/slice.ml
+++ b/src/slice.ml
@@ -184,6 +184,7 @@ let add_def_to_graph graph def =
| E_id id ->
begin match Env.lookup_id id env with
| Register _ -> graph := G.add_edge' self (Register id) !graph
+ | Enum _ -> graph := G.add_edge' self (Constructor id) !graph
| _ ->
if IdSet.mem id (Env.get_toplevel_lets env) then
graph := G.add_edge' self (Letbind id) !graph
@@ -248,8 +249,8 @@ let add_def_to_graph graph def =
IdSet.iter (fun ctor_id -> graph := G.add_edge' (Constructor ctor_id) (Type id) !graph) (snd ctor_nodes);
IdSet.iter (fun typ_id -> graph := G.add_edge' (Type id) (Type typ_id) !graph) (fst ctor_nodes);
scan_typquant (Type id) typq
- | TD_enum (id, _, _) ->
- graph := G.add_edges' (Type id) [] !graph
+ | TD_enum (id, ctors, _) ->
+ List.iter (fun ctor_id -> graph := G.add_edge' (Constructor ctor_id) (Type id) !graph) ctors
| TD_bitfield _ ->
Reporting.unreachable l __POS__ "Bitfield should be re-written"
in
@@ -268,7 +269,11 @@ let add_def_to_graph graph def =
IdSet.iter (fun id -> ignore (rewrite_let (rewriters (Letbind id)) lb)) ids
| DEF_type tdef ->
add_type_def_to_graph tdef
- | DEF_pragma _ -> ()
+ | DEF_reg_dec (DEC_aux (DEC_reg (_, _, typ, id), _)) ->
+ IdSet.iter (fun typ_id -> graph := G.add_edge' (Register id) (Type typ_id) !graph) (typ_ids typ)
+ | DEF_reg_dec (DEC_aux (DEC_config (id, typ, exp), _)) ->
+ ignore (rewrite_exp (rewriters (Register id)) exp);
+ IdSet.iter (fun typ_id -> graph := G.add_edge' (Register id) (Type typ_id) !graph) (typ_ids typ)
| _ -> ()
end;
G.fix_leaves !graph
@@ -283,6 +288,51 @@ let rec graph_of_ast (Defs defs) =
| [] -> G.empty
+let id_of_typedef (TD_aux (aux, _)) =
+ match aux with
+ | TD_abbrev (id, _, _) -> id
+ | TD_record (id, _, _, _) -> id
+ | TD_variant (id, _, _, _) -> id
+ | TD_enum (id, _, _) -> id
+ | TD_bitfield (id, _, _) -> id
+
+let id_of_reg_dec (DEC_aux (aux, _)) =
+ match aux with
+ | DEC_reg (_, _, _, id) -> id
+ | DEC_config (id, _, _) -> id
+ | _ -> assert false
+
+
+let filter_ast cuts g (Defs defs) =
+ let rec filter_ast' g =
+ let module NM = Map.Make(Node) in
+ function
+ | DEF_fundef fdef :: defs when IdSet.mem (id_of_fundef fdef) cuts -> filter_ast' g defs
+ | DEF_fundef fdef :: defs when NM.mem (Function (id_of_fundef fdef)) g -> DEF_fundef fdef :: filter_ast' g defs
+ | DEF_fundef _ :: defs -> filter_ast' g defs
+
+ | DEF_reg_dec rdec :: defs when NM.mem (Register (id_of_reg_dec rdec)) g -> DEF_reg_dec rdec :: filter_ast' g defs
+ | DEF_reg_dec _ :: defs -> filter_ast' g defs
+
+ | DEF_spec vs :: defs when NM.mem (Function (id_of_val_spec vs)) g -> DEF_spec vs :: filter_ast' g defs
+ | DEF_spec _ :: defs -> filter_ast' g defs
+
+ | DEF_val (LB_aux (LB_val (pat, exp), _) as lb) :: defs ->
+ let ids = pat_ids pat |> IdSet.elements in
+ if List.exists (fun id -> NM.mem (Letbind id) g) ids then
+ DEF_val lb :: filter_ast' g defs
+ else
+ filter_ast' g defs
+
+ | DEF_type tdef :: defs when NM.mem (Type (id_of_typedef tdef)) g -> DEF_type tdef :: filter_ast' g defs
+ | DEF_type _ :: defs -> filter_ast' g defs
+
+ | def :: defs -> def :: filter_ast' g defs
+
+ | [] -> []
+ in
+ Defs (filter_ast' g defs)
+
let dot_of_ast out_chan ast =
let module G = Graph.Make(Node) in
let module NodeSet = Set.Make(Node) in
diff --git a/src/slice.mli b/src/slice.mli
index 09558ebf..04f140fe 100644
--- a/src/slice.mli
+++ b/src/slice.mli
@@ -49,6 +49,7 @@
(**************************************************************************)
open Ast
+open Ast_util
type node =
| Register of id
@@ -66,3 +67,5 @@ end
val graph_of_ast : Type_check.tannot defs -> Graph.Make(Node).graph
val dot_of_ast : out_channel -> Type_check.tannot defs -> unit
+
+val filter_ast : IdSet.t -> Graph.Make(Node).graph -> Type_check.tannot defs -> Type_check.tannot defs
diff --git a/src/specialize.ml b/src/specialize.ml
index 6b5b108a..607084c8 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -53,7 +53,7 @@ open Ast_util
open Rewriter
let opt_ddump_spec_ast = ref None
-
+
let is_typ_ord_arg = function
| A_aux (A_typ _, _) -> true
| A_aux (A_order _, _) -> true
@@ -483,7 +483,7 @@ let specialize_id_overloads instantiations id (Defs defs) =
valspecs are then re-specialized. This process is iterated until
the whole spec is specialized. *)
-let initial_calls = IdSet.of_list
+let initial_calls = ref (IdSet.of_list
[ mk_id "main";
mk_id "__SetConfig";
mk_id "__ListConfig";
@@ -491,10 +491,12 @@ let initial_calls = IdSet.of_list
mk_id "decode";
mk_id "initialize_registers";
mk_id "append_64" (* used to construct bitvector literals in C backend *)
- ]
+ ])
+
+let add_initial_calls ids = initial_calls := IdSet.union ids !initial_calls
-let remove_unused_valspecs ?(initial_calls=initial_calls) env ast =
- let calls = ref initial_calls in
+let remove_unused_valspecs env ast =
+ let calls = ref !initial_calls in
let vs_ids = val_spec_ids ast in
let inspect_exp = function
@@ -527,14 +529,6 @@ let remove_unused_valspecs ?(initial_calls=initial_calls) env ast =
List.fold_left (fun ast id -> Defs (remove_unused ast id)) ast (IdSet.elements unused)
-let slice_defs env (Defs defs) keep_ids =
- let keep = function
- | DEF_fundef fd -> IdSet.mem (id_of_fundef fd) keep_ids
- | _ -> true
- in
- let defs = List.filter keep defs in
- remove_unused_valspecs env (Defs defs) ~initial_calls:keep_ids
-
let specialize_id spec id ast =
let instantiations = instantiations_of spec id ast in
let ast = specialize_id_valspec spec instantiations id ast in
diff --git a/src/specialize.mli b/src/specialize.mli
index 6ec8c2aa..0a64112c 100644
--- a/src/specialize.mli
+++ b/src/specialize.mli
@@ -54,6 +54,8 @@ open Ast
open Ast_util
open Type_check
+val opt_ddump_spec_ast : (string * int) option ref
+
type specialization
(** Only specialize Type- and Ord- kinded polymorphism. *)
@@ -72,6 +74,8 @@ val int_specialization_with_externs : specialization
or some combination of those. *)
val polymorphic_functions : specialization -> 'a defs -> IdSet.t
+val add_initial_calls : IdSet.t -> unit
+
(** specialize returns an AST with all the Order and Type polymorphism
removed, as well as the environment produced by type checking that
AST with [Type_check.initial_env]. The env parameter is the
@@ -88,6 +92,3 @@ val specialize' : int -> specialization -> tannot defs -> Env.t -> tannot defs *
val instantiations_of : specialization -> id -> tannot defs -> typ_arg KBindings.t list
val string_of_instantiation : typ_arg KBindings.t -> string
-
-(** Remove all function definitions except for the given set *)
-val slice_defs : Env.t -> tannot defs -> IdSet.t -> tannot defs
diff --git a/src/toFromInterp_backend.ml b/src/toFromInterp_backend.ml
index b253791d..d65aaf3b 100644
--- a/src/toFromInterp_backend.ml
+++ b/src/toFromInterp_backend.ml
@@ -123,7 +123,7 @@ let frominterp_typedef (TD_aux (td_aux, (l, _))) =
| Id_aux ((Id "diafps"),_) -> empty
(* | Id_aux ((Id "option"),_) -> empty *)
| Id_aux ((Id id_string), _)
- | Id_aux ((DeIid id_string), _) ->
+ | Id_aux ((Operator id_string), _) ->
if !lem_mode && id_string = "option" then empty else
let fromInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "FromInterpValue"] in
let fromFallback = separate space [pipe; underscore; arrow; string "failwith";
@@ -265,7 +265,7 @@ let tointerp_typedef (TD_aux (td_aux, (l, _))) =
| Id_aux ((Id "diafps"),_) -> empty
(* | Id_aux ((Id "option"),_) -> empty *)
| Id_aux ((Id id_string), _)
- | Id_aux ((DeIid id_string), _) ->
+ | Id_aux ((Operator id_string), _) ->
if !lem_mode && id_string = "option" then empty else
let toInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"] in
let toInterpValue =
diff --git a/src/type_check.ml b/src/type_check.ml
index e32ac45c..b6aabdc7 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -76,6 +76,9 @@ let opt_expand_valspec = ref true
the SMT solver to use non-linear arithmetic. *)
let opt_smt_linearize = ref false
+(* Allow use of div and mod when rewriting nexps *)
+let opt_smt_div = ref false
+
let depth = ref 0
let rec indent n = match n with
@@ -136,8 +139,8 @@ let typ_error env l m = raise (Type_error (env, l, Err_other m))
let typ_raise env l err = raise (Type_error (env, l, err))
let deinfix = function
- | Id_aux (Id v, l) -> Id_aux (DeIid v, l)
- | Id_aux (DeIid v, l) -> Id_aux (DeIid v, l)
+ | Id_aux (Id v, l) -> Id_aux (Operator v, l)
+ | Id_aux (Operator v, l) -> Id_aux (Operator v, l)
let field_name rec_id id =
match rec_id, id with
@@ -177,7 +180,7 @@ let is_atom_bool (Typ_aux (typ_aux, _)) =
let rec strip_id = function
| Id_aux (Id x, _) -> Id_aux (Id x, Parse_ast.Unknown)
- | Id_aux (DeIid x, _) -> Id_aux (DeIid x, Parse_ast.Unknown)
+ | Id_aux (Operator x, _) -> Id_aux (Operator x, Parse_ast.Unknown)
and strip_kid = function
| Kid_aux (Var x, _) -> Kid_aux (Var x, Parse_ast.Unknown)
and strip_base_effect = function
@@ -1200,6 +1203,7 @@ end = struct
{ env with
constraints = List.map (constraint_subst v (arg_kopt (mk_kopt s_k s_v))) env.constraints;
typ_vars = KBindings.add v (l, k) (KBindings.add s_v (s_l, s_k) env.typ_vars);
+ locals = Bindings.map (fun (mut, typ) -> mut, typ_subst v (arg_kopt (mk_kopt s_k s_v)) typ) env.locals;
shadow_vars = KBindings.add v (n + 1) env.shadow_vars
}, Some s_v
end
@@ -1222,7 +1226,7 @@ end = struct
else if KidSet.cardinal power_vars = 1 && !opt_smt_linearize then
let v = KidSet.choose power_vars in
let constrs = List.fold_left nc_and nc_true (get_constraints env) in
- begin match Constraint.solve_all_smt l (get_typ_vars env) constrs v with
+ begin match Constraint.solve_all_smt l constrs v with
| Some solutions ->
typ_print (lazy (Util.("Linearizing " |> red |> clear) ^ string_of_n_constraint constr
^ " for " ^ string_of_kid v ^ " in " ^ Util.string_of_list ", " Big_int.to_string solutions));
@@ -1476,10 +1480,8 @@ which is then a problem we can feed to the constraint solver expecting unsat.
*)
let prove_smt env (NC_aux (_, l) as nc) =
- let vars = Env.get_typ_vars env in
- let vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in
let ncs = Env.get_constraints env in
- match Constraint.call_smt l vars (List.fold_left nc_and (nc_not nc) ncs) with
+ match Constraint.call_smt l (List.fold_left nc_and (nc_not nc) ncs) with
| Constraint.Unsat -> typ_debug (lazy "unsat"); true
| Constraint.Sat -> typ_debug (lazy "sat"); false
| Constraint.Unknown ->
@@ -1487,7 +1489,7 @@ let prove_smt env (NC_aux (_, l) as nc) =
constraints, even when such constraints are irrelevant *)
let ncs' = List.concat (List.map constraint_conj ncs) in
let ncs' = List.filter (fun nc -> KidSet.is_empty (constraint_power_variables nc)) ncs' in
- match Constraint.call_smt l vars (List.fold_left nc_and (nc_not nc) ncs') with
+ match Constraint.call_smt l (List.fold_left nc_and (nc_not nc) ncs') with
| Constraint.Unsat -> typ_debug (lazy "unsat"); true
| Constraint.Sat | Constraint.Unknown -> typ_debug (lazy "sat/unknown"); false
@@ -1501,7 +1503,7 @@ let solve_unique env (Nexp_aux (_, l) as nexp) =
let vars = Env.get_typ_vars env in
let vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in
let constr = List.fold_left nc_and (nc_eq (nvar (mk_kid "solve#")) nexp) (Env.get_constraints env) in
- Constraint.solve_unique_smt l vars constr (mk_kid "solve#")
+ Constraint.solve_unique_smt l constr (mk_kid "solve#")
let debug_pos (file, line, _, _) =
"(" ^ file ^ "/" ^ string_of_int line ^ ") "
@@ -1662,7 +1664,7 @@ let rec unify_typ l env goals (Typ_aux (aux1, _) as typ1) (Typ_aux (aux2, _) as
| Typ_tup typs1, Typ_tup typs2 when List.length typs1 = List.length typs2 ->
List.fold_left (merge_uvars l) KBindings.empty (List.map2 (unify_typ l env goals) typs1 typs2)
- | _, _ -> unify_error l ("Cound not unify " ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)
+ | _, _ -> unify_error l ("Could not unify " ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)
and unify_typ_arg l env goals (A_aux (aux1, _) as typ_arg1) (A_aux (aux2, _) as typ_arg2) =
match aux1, aux2 with
@@ -1709,7 +1711,7 @@ and unify_order l goals (Ord_aux (aux1, _) as ord1) (Ord_aux (aux2, _) as ord2)
| Ord_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_order ord2)
| Ord_inc, Ord_inc -> KBindings.empty
| Ord_dec, Ord_dec -> KBindings.empty
- | _, _ -> unify_error l ("Cound not unify " ^ string_of_order ord1 ^ " and " ^ string_of_order ord2)
+ | _, _ -> unify_error l ("Could not unify " ^ string_of_order ord1 ^ " and " ^ string_of_order ord2)
and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) =
typ_debug (lazy (Util.("Unify nexp " |> magenta |> clear) ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2
@@ -1759,9 +1761,9 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au
mod(m, C) = 0 && C != 0 --> (C * n = m <--> n = m / C)
- to help us unify multiplications and divisions.
+ to help us unify multiplications and divisions. *)
let valid n c = prove __POS__ env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove __POS__ env (nc_neq c (nint 0)) in
- if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then
+ (*if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then
unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b])
else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then
unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) *)
@@ -1777,6 +1779,8 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au
unify_nexp l env goals n1b (nconstant (Big_int.div c2 c1))
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
end
+ | Nexp_var kid when (not (KidSet.mem kid goals)) && valid nexp2 n1a && !opt_smt_div ->
+ unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a])
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
end
else if KidSet.is_empty (nexp_frees n1b) then
@@ -1784,6 +1788,8 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au
match nexp_aux2 with
| Nexp_times (n2a, n2b) when prove __POS__ env (NC_aux (NC_equal (n1b, n2b), Parse_ast.Unknown)) ->
unify_nexp l env goals n1a n2a
+ | Nexp_var kid when (not (KidSet.mem kid goals)) && valid nexp2 n1b && !opt_smt_div ->
+ unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b])
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
end
else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
@@ -2054,7 +2060,7 @@ let rec subtyp l env typ1 typ2 =
let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in
if not (kids2 = []) then typ_error env l ("Universally quantified constraint generated: " ^ Util.string_of_list ", " string_of_kid kids2) else ();
let vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) (Env.get_typ_vars env) in
- begin match Constraint.call_smt l vars (nc_eq nexp1 nexp2) with
+ begin match Constraint.call_smt l (nc_eq nexp1 nexp2) with
| Constraint.Sat ->
let env = Env.add_constraint (nc_eq nexp1 nexp2) env in
if prove __POS__ env nc2 then
@@ -2187,15 +2193,17 @@ let rec rewrite_sizeof' env (Nexp_aux (aux, l) as nexp) =
let exp = rewrite_sizeof' env nexp in
mk_exp (E_app (mk_id "pow2", [exp]))
+ (* SMT solver div/mod is euclidian, so we must use those versions of
+ div and mod in lib/smt.sail *)
| Nexp_app (id, [nexp1; nexp2]) when string_of_id id = "div" ->
let exp1 = rewrite_sizeof' env nexp1 in
let exp2 = rewrite_sizeof' env nexp2 in
- mk_exp (E_app (mk_id "div", [exp1; exp2]))
+ mk_exp (E_app (mk_id "ediv_int", [exp1; exp2]))
| Nexp_app (id, [nexp1; nexp2]) when string_of_id id = "mod" ->
let exp1 = rewrite_sizeof' env nexp1 in
let exp2 = rewrite_sizeof' env nexp2 in
- mk_exp (E_app (mk_id "mod", [exp1; exp2]))
+ mk_exp (E_app (mk_id "emod_int", [exp1; exp2]))
| Nexp_app _ | Nexp_id _ ->
typ_error env l ("Cannot re-write sizeof(" ^ string_of_nexp nexp ^ ")")
@@ -2626,10 +2634,10 @@ let irule r env exp =
(* This function adds useful assertion messages to asserts missing them *)
-let assert_msg test = function
+let assert_msg = function
| E_aux (E_lit (L_aux (L_string "", _)), (l, _)) ->
let open Reporting in
- locate (fun _ -> l) (mk_lit_exp (L_string (loc_to_string ~code:false l ^ ": " ^ string_of_exp test)))
+ locate (fun _ -> l) (mk_lit_exp (L_string (short_loc_to_string l)))
| msg -> msg
let strip_exp : 'a exp -> unit exp = function exp -> map_exp_annot (fun (l, _) -> (l, ())) exp
@@ -2877,7 +2885,7 @@ and check_block l env exps ret_typ =
let texp, env = bind_assignment env lexp bind in
texp :: check_block l env exps ret_typ
| ((E_aux (E_assert (constr_exp, msg), _) as exp) :: exps) ->
- let msg = assert_msg constr_exp msg in
+ let msg = assert_msg msg in
let constr_exp = crule check_exp env constr_exp bool_typ in
let checked_msg = crule check_exp env msg string_typ in
let env = match assert_constraint env true constr_exp with
@@ -3128,40 +3136,34 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
typ_error env l (Printf.sprintf "Cannot bind tuple pattern %s against non tuple type %s"
(string_of_pat pat) (string_of_typ typ))
end
- | P_app (f, pats) when Env.is_union_constructor f env ->
- begin
- (* Treat Ctor((p, x)) the same as Ctor(p, x) *)
- let pats = match pats with [P_aux (P_tup pats, _)] -> pats | _ -> pats in
- let (typq, ctor_typ) = Env.get_union_id f env in
- let quants = quant_items typq in
- let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with
- | Typ_tup typs -> typs
- | _ -> [typ]
- in
- match Env.expand_synonyms env ctor_typ with
- | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) ->
- begin
- try
- let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in
- typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ));
- let unifiers = unify l env goals ret_typ typ in
- let arg_typ' = subst_unifiers unifiers arg_typ in
- let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
- if not (List.for_all (solve_quant env) quants') then
- typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env))
- else ();
- let ret_typ' = subst_unifiers unifiers ret_typ in
- let arg_typ', env = bind_existential l None arg_typ' env in
- let tpats, env, guards =
- try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with
- | Invalid_argument _ -> typ_error env l "Union constructor pattern arguments have incorrect length"
- in
- annot_pat (P_app (f, List.rev tpats)) typ, env, guards
- with
- | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against union constructor: " ^ m)
- end
- | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ)
+ | P_app (f, [pat]) when Env.is_union_constructor f env ->
+ let (typq, ctor_typ) = Env.get_union_id f env in
+ let quants = quant_items typq in
+ begin match Env.expand_synonyms env ctor_typ with
+ | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) ->
+ begin
+ try
+ let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in
+ typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ));
+ let unifiers = unify l env goals ret_typ typ in
+ let arg_typ' = subst_unifiers unifiers arg_typ in
+ let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
+ if not (List.for_all (solve_quant env) quants') then
+ typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env))
+ else ();
+ let ret_typ' = subst_unifiers unifiers ret_typ in
+ let arg_typ', env = bind_existential l None arg_typ' env in
+ let tpat, env, guards = bind_pat env pat arg_typ' in
+ annot_pat (P_app (f, [tpat])) typ, env, guards
+ with
+ | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against union constructor: " ^ m)
+ end
+ | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ)
end
+
+ | P_app (f, pats) when Env.is_union_constructor f env ->
+ (* Treat Ctor(x, y) as Ctor((x, y)) *)
+ bind_pat env (mk_pat (P_app (f, [mk_pat (P_tup pats)]))) typ
| P_app (f, pats) when Env.is_mapping f env ->
begin
@@ -3750,7 +3752,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
let vec_typ = dvector_typ env (nint (List.length vec)) (typ_of inferred_item) in
annot_exp (E_vector (inferred_item :: checked_items)) vec_typ
| E_assert (test, msg) ->
- let msg = assert_msg test msg in
+ let msg = assert_msg msg in
let checked_test = crule check_exp env test bool_typ in
let checked_msg = crule check_exp env msg string_typ in
annot_exp_effect (E_assert (checked_test, checked_msg)) unit_typ (mk_effect [BE_escape])
diff --git a/src/type_check.mli b/src/type_check.mli
index 1712be58..2a413238 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -77,6 +77,9 @@ val opt_expand_valspec : bool ref
the SMT solver to use non-linear arithmetic. *)
val opt_smt_linearize : bool ref
+(** Allow use of div and mod when rewriting nexps *)
+val opt_smt_div : bool ref
+
(** {2 Type errors} *)
type type_error =
diff --git a/src/value.ml b/src/value.ml
index 843a943b..279d3aba 100644
--- a/src/value.ml
+++ b/src/value.ml
@@ -93,6 +93,33 @@ type value =
with a direct register read. *)
| V_attempted_read of string
+let coerce_bit = function
+ | V_bit b -> b
+ | _ -> assert false
+
+let is_bit = function
+ | V_bit _ -> true
+ | _ -> false
+
+let rec string_of_value = function
+ | V_vector vs when List.for_all is_bit vs -> Sail_lib.string_of_bits (List.map coerce_bit vs)
+ | V_vector vs -> "[" ^ Util.string_of_list ", " string_of_value vs ^ "]"
+ | V_bool true -> "true"
+ | V_bool false -> "false"
+ | V_bit Sail_lib.B0 -> "bitzero"
+ | V_bit Sail_lib.B1 -> "bitone"
+ | V_int n -> Big_int.to_string n
+ | V_tuple vals -> "(" ^ Util.string_of_list ", " string_of_value vals ^ ")"
+ | V_list vals -> "[|" ^ Util.string_of_list ", " string_of_value vals ^ "|]"
+ | V_unit -> "()"
+ | V_string str -> "\"" ^ str ^ "\""
+ | V_ref str -> "ref " ^ str
+ | V_real r -> Sail_lib.string_of_real r
+ | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")"
+ | V_record record ->
+ "{" ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) ^ "}"
+ | V_attempted_read _ -> assert false
+
let rec eq_value v1 v2 =
match v1, v2 with
| V_vector v1s, V_vector v2s when List.length v1s = List.length v2s -> List.for_all2 eq_value v1s v2s
@@ -111,12 +138,7 @@ let rec eq_value v1 v2 =
StringMap.equal eq_value fields1 fields2
| _, _ -> false
-let coerce_bit = function
- | V_bit b -> b
- | _ -> assert false
-
let coerce_ctor = function
- | V_ctor (str, [V_tuple vals]) -> (str, vals)
| V_ctor (str, vals) -> (str, vals)
| _ -> assert false
@@ -383,33 +405,10 @@ let value_replicate_bits = function
| [v1; v2] -> mk_vector (Sail_lib.replicate_bits (coerce_bv v1, coerce_int v2))
| _ -> failwith "value replicate_bits"
-let is_bit = function
- | V_bit _ -> true
- | _ -> false
-
let is_ctor = function
| V_ctor _ -> true
| _ -> false
-let rec string_of_value = function
- | V_vector vs when List.for_all is_bit vs -> Sail_lib.string_of_bits (List.map coerce_bit vs)
- | V_vector vs -> "[" ^ Util.string_of_list ", " string_of_value vs ^ "]"
- | V_bool true -> "true"
- | V_bool false -> "false"
- | V_bit Sail_lib.B0 -> "bitzero"
- | V_bit Sail_lib.B1 -> "bitone"
- | V_int n -> Big_int.to_string n
- | V_tuple vals -> "(" ^ Util.string_of_list ", " string_of_value vals ^ ")"
- | V_list vals -> "[|" ^ Util.string_of_list ", " string_of_value vals ^ "|]"
- | V_unit -> "()"
- | V_string str -> "\"" ^ str ^ "\""
- | V_ref str -> "ref " ^ str
- | V_real r -> Sail_lib.string_of_real r
- | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")"
- | V_record record ->
- "{" ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) ^ "}"
- | V_attempted_read _ -> assert false
-
let value_sign_extend = function
| [v1; v2] -> mk_vector (Sail_lib.sign_extend (coerce_bv v1, coerce_int v2))
| _ -> failwith "value sign_extend"
diff --git a/src/value2.lem b/src/value2.lem
index e8a8262a..caf355b7 100644
--- a/src/value2.lem
+++ b/src/value2.lem
@@ -49,34 +49,14 @@
(*========================================================================*)
open import Pervasives
-open import Assert_extra
-open Map
open import Sail2_values
type vl =
- | V_vector of list vl
- | V_list of list vl
| V_bits of list bitU
| V_bit of bitU
- | V_tuple of list vl
| V_bool of bool
- | V_nondet (* Special nondeterministic boolean *)
| V_unit
| V_int of integer
| V_string of string
- | V_ctor of string * list vl
- | V_ctor_kind of string
- | V_record of list (string * vl)
| V_null (* Used for unitialized values and null pointers in C compilation *)
-
-
-let value_int_op_int op = function
- | [V_int v1; V_int v2] -> V_int (op v1 v2)
- | _ -> V_null
-end
-
-let value_bool_op_int op = function
- | [V_int v1; V_int v2] -> V_bool (op v1 v2)
- | _ -> V_null
-end
diff --git a/test/arm/run_tests.sh b/test/arm/run_tests.sh
index b24cc584..9d7af14f 100755
--- a/test/arm/run_tests.sh
+++ b/test/arm/run_tests.sh
@@ -83,7 +83,7 @@ printf "\nLoading specification into interpreter...\n"
cd $SAILDIR/aarch64
-if $SAILDIR/sail -no_lexp_bounds_check -is $DIR/test.isail no_vector.sail 1> /dev/null 2> /dev/null;
+if $SAILDIR/sail -undefined_gen -no_lexp_bounds_check -is $DIR/test.isail no_vector.sail 1> /dev/null 2> /dev/null;
then
green "loaded no_vector specification" "ok";
diff --git a/test/arm/test.isail b/test/arm/test.isail
index 8775ed8f..f3f4dfa1 100644
--- a/test/arm/test.isail
+++ b/test/arm/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
:elf ../test/arm/test_O2.elf
:output ../test/arm/iresult
initialize_registers()
diff --git a/test/builtins/div_int.sail b/test/builtins/div_int.sail
index fed6de6e..e8da4f4b 100644
--- a/test/builtins/div_int.sail
+++ b/test/builtins/div_int.sail
@@ -5,6 +5,8 @@ $include <flow.sail>
$include <vector_dec.sail>
$include <arith.sail>
+overload div_int = {tdiv_int}
+
function main (() : unit) -> unit = {
assert(div_int(48240160, 8) == 6030020);
assert(div_int(48240168, 8) == 6030021);
diff --git a/test/builtins/div_int2.sail b/test/builtins/div_int2.sail
index d3df278d..8ce97cc0 100644
--- a/test/builtins/div_int2.sail
+++ b/test/builtins/div_int2.sail
@@ -5,6 +5,8 @@ $include <flow.sail>
$include <vector_dec.sail>
$include <arith.sail>
+overload div_int = {tdiv_int}
+
function main (() : unit) -> unit = {
assert(div_int(0, 8) == 0);
assert(div_int(1000, 12) == 83);
diff --git a/test/builtins/divmod.sail b/test/builtins/divmod.sail
new file mode 100644
index 00000000..f9d7e7c5
--- /dev/null
+++ b/test/builtins/divmod.sail
@@ -0,0 +1,43 @@
+default Order dec
+
+$include <exception_basic.sail>
+$include <arith.sail>
+$include <smt.sail>
+
+function main (() : unit) -> unit = {
+ assert(ediv_int( 7 , 5) == 1);
+ assert(ediv_int( 7 , -5) == -1);
+ assert(ediv_int(-7 , 5) == -2);
+ assert(ediv_int(-7 , -5) == 2);
+ assert(ediv_int( 12 , 3) == 4);
+ assert(ediv_int( 12 , -3) == -4);
+ assert(ediv_int(-12 , 3) == -4);
+ assert(ediv_int(-12 , -3) == 4);
+
+ assert(emod_int( 7 , 5) == 2);
+ assert(emod_int( 7 , -5) == 2);
+ assert(emod_int(-7 , 5) == 3);
+ assert(emod_int(-7 , -5) == 3);
+ assert(emod_int( 12 , 3) == 0);
+ assert(emod_int( 12 , -3) == 0);
+ assert(emod_int(-12 , 3) == 0);
+ assert(emod_int(-12 , -3) == 0);
+
+ assert(tdiv_int( 7 , 5) == 1);
+ assert(tdiv_int( 7 , -5) == -1);
+ assert(tdiv_int(-7 , 5) == -1);
+ assert(tdiv_int(-7 , -5) == 1);
+ assert(tdiv_int( 12 , 3) == 4);
+ assert(tdiv_int( 12 , -3) == -4);
+ assert(tdiv_int(-12 , 3) == -4);
+ assert(tdiv_int(-12 , -3) == 4);
+
+ assert(tmod_int( 7 , 5) == 2);
+ assert(tmod_int( 7 , -5) == 2);
+ assert(tmod_int(-7 , 5) == -2);
+ assert(tmod_int(-7 , -5) == -2);
+ assert(tmod_int( 12 , 3) == 0);
+ assert(tmod_int( 12 , -3) == 0);
+ assert(tmod_int(-12 , 3) == 0);
+ assert(tmod_int(-12 , -3) == 0);
+} \ No newline at end of file
diff --git a/test/c/anf_as_pattern.expect b/test/c/anf_as_pattern.expect
new file mode 100644
index 00000000..9766475a
--- /dev/null
+++ b/test/c/anf_as_pattern.expect
@@ -0,0 +1 @@
+ok
diff --git a/test/c/anf_as_pattern.sail b/test/c/anf_as_pattern.sail
new file mode 100644
index 00000000..9b9196b1
--- /dev/null
+++ b/test/c/anf_as_pattern.sail
@@ -0,0 +1,19 @@
+default Order dec
+
+$include <prelude.sail>
+
+val "print_endline" : string -> unit
+
+function test () : unit -> option(int) = {
+ match Some(3) {
+ Some(_) as x => x,
+ _ => None()
+ }
+}
+
+function main() : unit -> unit = {
+ match test() {
+ Some(3) => print_endline("ok"),
+ _ => print_endline("fail")
+ }
+} \ No newline at end of file
diff --git a/test/c/anon_rec.expect b/test/c/anon_rec.expect
new file mode 100644
index 00000000..9766475a
--- /dev/null
+++ b/test/c/anon_rec.expect
@@ -0,0 +1 @@
+ok
diff --git a/test/c/anon_rec.sail b/test/c/anon_rec.sail
new file mode 100644
index 00000000..17dd1e07
--- /dev/null
+++ b/test/c/anon_rec.sail
@@ -0,0 +1,12 @@
+default Order dec
+
+union Foo ('a : Type) = {
+ MkFoo : { field1 : 'a, field2 : int }
+}
+
+val "print_endline" : string -> unit
+
+function main((): unit) -> unit = {
+ let _: Foo(unit) = MkFoo(struct { field1 = (), field2 = 22 });
+ print_endline("ok")
+}
diff --git a/test/c/execute.isail b/test/c/execute.isail
index f4b5ea0f..018dd92c 100644
--- a/test/c/execute.isail
+++ b/test/c/execute.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
main()
diff --git a/test/c/flow_restrict.expect b/test/c/flow_restrict.expect
new file mode 100644
index 00000000..9766475a
--- /dev/null
+++ b/test/c/flow_restrict.expect
@@ -0,0 +1 @@
+ok
diff --git a/test/c/flow_restrict.sail b/test/c/flow_restrict.sail
new file mode 100644
index 00000000..ef2ec412
--- /dev/null
+++ b/test/c/flow_restrict.sail
@@ -0,0 +1,23 @@
+default Order dec
+
+$include <flow.sail>
+$include <exception_basic.sail>
+
+val "print_endline" : string -> unit
+
+register R : bool
+
+function main((): unit) -> unit = {
+ R = false;
+ let 'x = 3180327502475943573495720457203572045720485720458724;
+ y : range(0, 'x) = 1;
+ if R then {
+ assert(constraint('x <= 2));
+ y = 2;
+ let z = y;
+ let x = 2;
+ ()
+ } else {
+ print_endline("ok")
+ }
+}
diff --git a/test/c/poly_int_record.expect b/test/c/poly_int_record.expect
new file mode 100644
index 00000000..a8a10253
--- /dev/null
+++ b/test/c/poly_int_record.expect
@@ -0,0 +1,3 @@
+x = 1
+y = 2
+ok
diff --git a/test/c/poly_int_record.sail b/test/c/poly_int_record.sail
new file mode 100644
index 00000000..ebb18713
--- /dev/null
+++ b/test/c/poly_int_record.sail
@@ -0,0 +1,21 @@
+default Order dec
+
+val "print_endline" : string -> unit
+val "print_int" : (string, int) -> unit
+
+struct S('a: Type) = {
+ field1 : ('a, 'a),
+ field2 : unit
+}
+
+function main((): unit) -> unit = {
+ var s : S(range(0, 3)) = struct { field1 = (0, 3), field2 = () };
+ s.field1 = (1, 2);
+ match s.field1 {
+ (x, y) => {
+ print_int("x = ", x);
+ print_int("y = ", y);
+ }
+ };
+ print_endline("ok");
+}
diff --git a/test/c/poly_record.expect b/test/c/poly_record.expect
new file mode 100644
index 00000000..9766475a
--- /dev/null
+++ b/test/c/poly_record.expect
@@ -0,0 +1 @@
+ok
diff --git a/test/c/poly_record.sail b/test/c/poly_record.sail
new file mode 100644
index 00000000..afe1f144
--- /dev/null
+++ b/test/c/poly_record.sail
@@ -0,0 +1,18 @@
+default Order dec
+
+val "print_endline" : string -> unit
+
+struct S('a: Type) = {
+ field1 : 'a,
+ field2 : unit
+}
+
+function f forall ('a :Type). (s: S('a)) -> unit = {
+ s.field2
+}
+
+function main((): unit) -> unit = {
+ let s : S(unit) = struct { field1 = (), field2 = () };
+ f(s);
+ print_endline("ok");
+}
diff --git a/test/c/run_tests.py b/test/c/run_tests.py
index 2ee44fca..be953749 100755
--- a/test/c/run_tests.py
+++ b/test/c/run_tests.py
@@ -40,7 +40,7 @@ def test_interpreter(name):
basename = os.path.splitext(os.path.basename(filename))[0]
tests[filename] = os.fork()
if tests[filename] == 0:
- step('sail -is execute.isail -iout {}.iresult {}'.format(basename, filename))
+ step('sail -undefined_gen -is execute.isail -iout {}.iresult {}'.format(basename, filename))
step('diff {}.iresult {}.expect'.format(basename, basename))
print '{} {}{}{}'.format(filename, color.PASS, 'ok', color.END)
sys.exit()
diff --git a/test/c/tuple_union.expect b/test/c/tuple_union.expect
new file mode 100644
index 00000000..d8ea9f4f
--- /dev/null
+++ b/test/c/tuple_union.expect
@@ -0,0 +1,42 @@
+y = 1
+z = 2
+y = 1
+z = 2
+y = 1
+z = 2
+
+y = 3
+z = 4
+y = 3
+z = 4
+y = 3
+z = 4
+
+y = 5
+z = 6
+y = 5
+z = 6
+y = 5
+z = 6
+
+y = 7
+z = 8
+y = 7
+z = 8
+y = 7
+z = 8
+
+y = 9
+z = 10
+y = 9
+z = 10
+y = 9
+z = 10
+
+y = 11
+z = 12
+y = 11
+z = 12
+y = 11
+z = 12
+
diff --git a/test/c/tuple_union.sail b/test/c/tuple_union.sail
new file mode 100644
index 00000000..1914038f
--- /dev/null
+++ b/test/c/tuple_union.sail
@@ -0,0 +1,48 @@
+default Order dec
+
+$include <prelude.sail>
+
+val "print_endline" : string -> unit
+
+union U('a: Type) = {
+ Ctor : 'a
+}
+
+type pair = (int, int)
+
+function foo(x: U(pair)) -> unit = {
+ match x {
+ Ctor(y, z) => {
+ print_int("y = ", y);
+ print_int("z = ", z)
+ }
+ };
+ match x {
+ Ctor((y, z)) => {
+ print_int("y = ", y);
+ print_int("z = ", z)
+ }
+ };
+ match x {
+ Ctor(x) => match x {
+ (y, z) => {
+ print_int("y = ", y);
+ print_int("z = ", z)
+ }
+ }
+ };
+ print_endline("")
+}
+
+function main((): unit) -> unit = {
+ foo(Ctor(1, 2));
+ foo(Ctor((3, 4)));
+ let x = (5, 6);
+ foo(Ctor(x));
+ let x = Ctor(7, 8);
+ foo(x);
+ let x = Ctor(((9, 10)));
+ foo(x);
+ let x = (11, 12);
+ foo(Ctor(x));
+}
diff --git a/test/c/unused_poly_ctor.expect b/test/c/unused_poly_ctor.expect
new file mode 100644
index 00000000..e55551e8
--- /dev/null
+++ b/test/c/unused_poly_ctor.expect
@@ -0,0 +1 @@
+y = 0xFFFF
diff --git a/test/c/unused_poly_ctor.sail b/test/c/unused_poly_ctor.sail
new file mode 100644
index 00000000..c752cb33
--- /dev/null
+++ b/test/c/unused_poly_ctor.sail
@@ -0,0 +1,18 @@
+default Order dec
+
+$include <prelude.sail>
+
+val "print_endline" : string -> unit
+
+union U('a: Type) = {
+ Err : 'a,
+ Ok : bits(16)
+}
+
+function main((): unit) -> unit = {
+ let x : U(unit) = Ok(0xFFFF);
+ match x {
+ Err() => print_endline("error"),
+ Ok(y) => print_bits("y = ", y)
+ }
+}
diff --git a/test/coq/_CoqProject b/test/coq/_CoqProject
new file mode 100644
index 00000000..a694372c
--- /dev/null
+++ b/test/coq/_CoqProject
@@ -0,0 +1,2 @@
+-R ../../../bbv/theories bbv
+-R ../../lib/coq/ Sail
diff --git a/test/coq/pass/foreach_using_tyvar.sail b/test/coq/pass/foreach_using_tyvar.sail
new file mode 100644
index 00000000..8aabe00c
--- /dev/null
+++ b/test/coq/pass/foreach_using_tyvar.sail
@@ -0,0 +1,11 @@
+$include <arith.sail>
+
+val f : forall 'n, 'n != 5. int('n) -> unit
+
+val magic : forall 'n. int('n) -> bool effect {rreg}
+
+val g : int -> unit effect {rreg}
+
+function g(x) =
+ foreach (n from 0 to x)
+ if n != 5 & magic(n) then f(n)
diff --git a/test/coq/pass/rebind.sail b/test/coq/pass/rebind.sail
new file mode 100644
index 00000000..247c1d6d
--- /dev/null
+++ b/test/coq/pass/rebind.sail
@@ -0,0 +1,10 @@
+default Order dec
+
+$include <prelude.sail>
+
+val foo : forall 'n, 'n >= 0. (int('n),bits('n)) -> bits(5 + 'n)
+
+function foo(n,x) = {
+ let (n as 'm) = 5 in
+ (append((x : bits('n)),sail_ones(n)) : bits('m + 'n))
+}
diff --git a/test/coq/pass/unbound_ex_tyvars.sail b/test/coq/pass/unbound_ex_tyvars.sail
new file mode 100644
index 00000000..f99b1bd1
--- /dev/null
+++ b/test/coq/pass/unbound_ex_tyvars.sail
@@ -0,0 +1,16 @@
+$include <prelude.sail>
+
+/* We currently produce a rich type for the guard of the if that's
+ visible in the Coq output. The raw Sail type involves unbound type
+ variables that were existentially bound in x, so in order to print
+ out a useful Coq type we now rewrite it in terms of x. */
+
+val isA : unit -> bool effect {rreg}
+val isB : unit -> bool effect {rreg}
+val isC : unit -> bool effect {rreg}
+val foo : bool -> bool effect {rreg}
+
+function foo(b) = {
+ let x = (b | isA()) & isB();
+ if x | isC() then true else false
+}
diff --git a/test/coq/pass/unpacking.sail b/test/coq/pass/unpacking.sail
new file mode 100644
index 00000000..d0143f40
--- /dev/null
+++ b/test/coq/pass/unpacking.sail
@@ -0,0 +1,16 @@
+default Order dec
+
+$include <prelude.sail>
+
+val f : int -> {'n, 'n >= 0. int('n)} effect {rreg}
+val g : int -> {'n, 'n >= 0. int('n)}
+
+val test : unit -> int effect {rreg}
+
+function test() = {
+ let x1 : {'n, 'n >= 0. int('n)} = f(5);
+ let x2 : int = f(6);
+ let y1 : {'n, 'n >= 0. int('n)} = g(7);
+ let y2 : int = g(8);
+ x1 + x2 + y1 + y2
+}
diff --git a/test/coq/skip b/test/coq/skip
index e0096643..49744fce 100644
--- a/test/coq/skip
+++ b/test/coq/skip
@@ -5,11 +5,40 @@ option_tuple.sail
pat_completeness.sail
XXXXX tests which need inline extern definitions adjusted
patternrefinement.sail
-procstate1.sail
vector_subrange_gen.sail
XXXXX currently unsupported use of a bitvector in a parametric vector type
pure_record.sail
pure_record2.sail
pure_record3.sail
vector_access_dec.sail
-vector_access.sail \ No newline at end of file
+vector_access.sail
+XXXXX unsupported existential quantification of a vector length
+bind_typ_var.sail
+XXXXX needs impliciation in constraints fixed
+bool_constraint.sail
+XXXXX needs some smart existential instantiation
+complex_exist_sat.sail
+XXXXX needs name collision avoidance due to type/constructor punning
+constraint_ctor.sail
+XXXXX Complex existential type - probably going to need this for ARM instruction ASTs
+execute_decode_hard.sail
+existential_ast.sail
+existential_ast2.sail
+existential_ast3.sail
+XXXXX Needs an existential witness
+exist1.sail
+exist2.sail
+XXXXX Needs a type synonym expanded - awkward because we don't attach environments everywhere
+exist_synonym.sail
+reg_32_64.sail
+XXXXX Examples where int(...) should be expanded internally, but not yet supported
+exit1.sail
+exit2.sail
+inline_typ.sail
+XXXXX Examples with exponentials that the solver can't handle
+pow_32_64.sail
+XXXXX Register constructor doesn't use expanded type from type checker - need environment for type definition to fix this easily
+reg_mod.sail
+reg_ref.sail
+XXXXX Dodgy division/modulo stuff
+Replicate.sail
diff --git a/test/mono/exint.sail b/test/mono/exint.sail
index 639e7d45..855b689c 100644
--- a/test/mono/exint.sail
+++ b/test/mono/exint.sail
@@ -39,7 +39,7 @@ function test(x) = {
0b00 => n = 1,
0b01 => n = 2,
0b10 => n = 4,
- 0b11 => ()
+ 0b11 => n = 8
};
let 'n2 = ex_int(n) in {
assert(constraint('n2 >= 0));
@@ -54,4 +54,4 @@ function run () = {
test(0b01);
test(0b10);
test(0b11);
-} \ No newline at end of file
+}
diff --git a/test/ocaml/bitfield/test.isail b/test/ocaml/bitfield/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/bitfield/test.isail
+++ b/test/ocaml/bitfield/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/hello_world/test.isail b/test/ocaml/hello_world/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/hello_world/test.isail
+++ b/test/ocaml/hello_world/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/loop/test.isail b/test/ocaml/loop/test.isail
index 6a9595e3..009d3eab 100644
--- a/test/ocaml/loop/test.isail
+++ b/test/ocaml/loop/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
:output result
main()
:run
diff --git a/test/ocaml/lsl/test.isail b/test/ocaml/lsl/test.isail
index 6a9595e3..009d3eab 100644
--- a/test/ocaml/lsl/test.isail
+++ b/test/ocaml/lsl/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
:output result
main()
:run
diff --git a/test/ocaml/pattern1/test.isail b/test/ocaml/pattern1/test.isail
index 6a9595e3..009d3eab 100644
--- a/test/ocaml/pattern1/test.isail
+++ b/test/ocaml/pattern1/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
:output result
main()
:run
diff --git a/test/ocaml/reg_alias/test.isail b/test/ocaml/reg_alias/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/reg_alias/test.isail
+++ b/test/ocaml/reg_alias/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/reg_passing/test.isail b/test/ocaml/reg_passing/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/reg_passing/test.isail
+++ b/test/ocaml/reg_passing/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/reg_ref/test.isail b/test/ocaml/reg_ref/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/reg_ref/test.isail
+++ b/test/ocaml/reg_ref/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/run_tests.sh b/test/ocaml/run_tests.sh
index c160ef9f..d077cd80 100755
--- a/test/ocaml/run_tests.sh
+++ b/test/ocaml/run_tests.sh
@@ -96,7 +96,7 @@ cd $DIR
for i in `ls -d */`;
do
cd $DIR/$i;
- if $SAILDIR/sail -no_warn -is test.isail ../prelude.sail `ls *.sail` 1> /dev/null;
+ if $SAILDIR/sail -no_warn -undefined_gen -is test.isail ../prelude.sail `ls *.sail` 1> /dev/null;
then
if diff expect result;
then
diff --git a/test/ocaml/short_circuit/test.isail b/test/ocaml/short_circuit/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/short_circuit/test.isail
+++ b/test/ocaml/short_circuit/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/string_equality/test.isail b/test/ocaml/string_equality/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/string_equality/test.isail
+++ b/test/ocaml/string_equality/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/string_of_struct/test.isail b/test/ocaml/string_of_struct/test.isail
index 6a9595e3..009d3eab 100644
--- a/test/ocaml/string_of_struct/test.isail
+++ b/test/ocaml/string_of_struct/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
:output result
main()
:run
diff --git a/test/ocaml/trycatch/test.isail b/test/ocaml/trycatch/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/trycatch/test.isail
+++ b/test/ocaml/trycatch/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/types/test.isail b/test/ocaml/types/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/types/test.isail
+++ b/test/ocaml/types/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/vec_32_64/test.isail b/test/ocaml/vec_32_64/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/vec_32_64/test.isail
+++ b/test/ocaml/vec_32_64/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/ocaml/void/test.isail b/test/ocaml/void/test.isail
index b3eb5d41..e5926ff5 100644
--- a/test/ocaml/void/test.isail
+++ b/test/ocaml/void/test.isail
@@ -1,3 +1,4 @@
+:rewrites interpreter
initialize_registers()
:run
:output result
diff --git a/test/typecheck/pass/Replicate.sail b/test/typecheck/pass/Replicate.sail
index 03954a9f..291b7e16 100644
--- a/test/typecheck/pass/Replicate.sail
+++ b/test/typecheck/pass/Replicate.sail
@@ -3,6 +3,9 @@ default Order dec
$include <smt.sail>
$include <prelude.sail>
+overload operator / = {ediv_int}
+overload operator % = {emod_int}
+
val Replicate : forall ('M : Int) ('N : Int), 'M >= 1.
(implicit('N), bits('M)) -> bits('N) effect {escape}
diff --git a/test/typecheck/pass/Replicate/v1.expect b/test/typecheck/pass/Replicate/v1.expect
index 92c6d7cd..c40aa5ec 100644
--- a/test/typecheck/pass/Replicate/v1.expect
+++ b/test/typecheck/pass/Replicate/v1.expect
@@ -1,8 +1,8 @@
Type error:
-[Replicate/v1.sail]:11:4-30
-11 | replicate_bits(x, 'N / 'M)
+[Replicate/v1.sail]:14:4-30
+14 | replicate_bits(x, 'N / 'M)
 | ^------------------------^
-  | Tried performing type coercion from vector(('M * div('N, 'M)), dec, bit) to vector('N, dec, bit) on replicate_bits(x, div(__id(N), bitvector_length(x)))
+  | Tried performing type coercion from vector(('M * div('N, 'M)), dec, bit) to vector('N, dec, bit) on replicate_bits(x, ediv_int(__id(N), bitvector_length(x)))
 | Coercion failed because:
 | Mismatched argument types in subtype check
 |
diff --git a/test/typecheck/pass/Replicate/v1.sail b/test/typecheck/pass/Replicate/v1.sail
index 69f2bb6f..55627db5 100644
--- a/test/typecheck/pass/Replicate/v1.sail
+++ b/test/typecheck/pass/Replicate/v1.sail
@@ -3,6 +3,9 @@ default Order dec
$include <smt.sail>
$include <prelude.sail>
+overload operator / = {ediv_int}
+overload operator % = {emod_int}
+
val Replicate : forall ('M : Int) ('N : Int), 'M >= 0.
(implicit('N), bits('M)) -> bits('N) effect {escape}
diff --git a/test/typecheck/pass/Replicate/v2.expect b/test/typecheck/pass/Replicate/v2.expect
index 62992f2c..c2c15c12 100644
--- a/test/typecheck/pass/Replicate/v2.expect
+++ b/test/typecheck/pass/Replicate/v2.expect
@@ -1,8 +1,8 @@
Type error:
-[Replicate/v2.sail]:10:4-30
-10 | replicate_bits(x, 'N / 'M)
+[Replicate/v2.sail]:13:4-30
+13 | replicate_bits(x, 'N / 'M)
 | ^------------------------^
-  | Tried performing type coercion from {('ex80# : Int), true. vector(('M * 'ex80#), dec, bit)} to vector('N, dec, bit) on replicate_bits(x, div_int(__id(N), bitvector_length(x)))
+  | Tried performing type coercion from {('ex118# : Int), true. vector(('M * 'ex118#), dec, bit)} to vector('N, dec, bit) on replicate_bits(x, tdiv_int(__id(N), bitvector_length(x)))
 | Coercion failed because:
 | Mismatched argument types in subtype check
 |
diff --git a/test/typecheck/pass/Replicate/v2.sail b/test/typecheck/pass/Replicate/v2.sail
index e54b0af4..436ef24b 100644
--- a/test/typecheck/pass/Replicate/v2.sail
+++ b/test/typecheck/pass/Replicate/v2.sail
@@ -2,6 +2,9 @@ default Order dec
$include <prelude.sail>
+overload operator / = {tdiv_int}
+overload operator % = {tmod_int}
+
val Replicate : forall ('M : Int) ('N : Int), 'M >= 1.
(implicit('N), bits('M)) -> bits('N) effect {escape}
diff --git a/test/typecheck/pass/anon_rec.sail b/test/typecheck/pass/anon_rec.sail
new file mode 100644
index 00000000..17dd1e07
--- /dev/null
+++ b/test/typecheck/pass/anon_rec.sail
@@ -0,0 +1,12 @@
+default Order dec
+
+union Foo ('a : Type) = {
+ MkFoo : { field1 : 'a, field2 : int }
+}
+
+val "print_endline" : string -> unit
+
+function main((): unit) -> unit = {
+ let _: Foo(unit) = MkFoo(struct { field1 = (), field2 = 22 });
+ print_endline("ok")
+}
diff --git a/test/typecheck/pass/existential_ast/v3.expect b/test/typecheck/pass/existential_ast/v3.expect
index af2cf65f..7bbd59ad 100644
--- a/test/typecheck/pass/existential_ast/v3.expect
+++ b/test/typecheck/pass/existential_ast/v3.expect
@@ -3,5 +3,5 @@ Type error:
26 | Some(Ctor1(a, x, c))
 | ^------------^
 | Could not resolve quantifiers for Ctor1
-  | * datasize('ex157#)
+  | * datasize('ex195#)
 |
diff --git a/test/typecheck/pass/existential_ast3/v1.expect b/test/typecheck/pass/existential_ast3/v1.expect
index e904aa61..24b927a5 100644
--- a/test/typecheck/pass/existential_ast3/v1.expect
+++ b/test/typecheck/pass/existential_ast3/v1.expect
@@ -4,17 +4,17 @@ Type error:
 | ^---------------^
 | Tried performing type coercion from (int(33), range(0, (2 ^ 5 - 1))) to {('d : Int) ('n : Int), (datasize('d) & (0 <= 'n & ('n + 1) <= 'd)). (int('d), int('n))} on (33, unsigned(a))
 | Coercion failed because:
-  | (int(33), int('ex119#)) is not a subtype of (int('ex114#), int('ex115#))
+  | (int(33), int('ex157#)) is not a subtype of (int('ex152#), int('ex153#))
 | [existential_ast3/v1.sail]:17:48-65
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^
-  |  | 'ex114# bound here
+  |  | 'ex152# bound here
 | [existential_ast3/v1.sail]:17:48-65
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^
-  |  | 'ex115# bound here
+  |  | 'ex153# bound here
 | [existential_ast3/v1.sail]:17:48-65
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^
-  |  | 'ex119# bound here
+  |  | 'ex157# bound here
 |
diff --git a/test/typecheck/pass/existential_ast3/v2.expect b/test/typecheck/pass/existential_ast3/v2.expect
index fdd13607..a2c08583 100644
--- a/test/typecheck/pass/existential_ast3/v2.expect
+++ b/test/typecheck/pass/existential_ast3/v2.expect
@@ -4,17 +4,17 @@ Type error:
 | ^---------------^
 | Tried performing type coercion from (int(31), range(0, (2 ^ 5 - 1))) to {('d : Int) ('n : Int), (datasize('d) & (0 <= 'n & ('n + 1) <= 'd)). (int('d), int('n))} on (31, unsigned(a))
 | Coercion failed because:
-  | (int(31), int('ex119#)) is not a subtype of (int('ex114#), int('ex115#))
+  | (int(31), int('ex157#)) is not a subtype of (int('ex152#), int('ex153#))
 | [existential_ast3/v2.sail]:17:48-65
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^
-  |  | 'ex114# bound here
+  |  | 'ex152# bound here
 | [existential_ast3/v2.sail]:17:48-65
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^
-  |  | 'ex115# bound here
+  |  | 'ex153# bound here
 | [existential_ast3/v2.sail]:17:48-65
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^
-  |  | 'ex119# bound here
+  |  | 'ex157# bound here
 |
diff --git a/test/typecheck/pass/existential_ast3/v3.expect b/test/typecheck/pass/existential_ast3/v3.expect
index 2432e632..cf86b765 100644
--- a/test/typecheck/pass/existential_ast3/v3.expect
+++ b/test/typecheck/pass/existential_ast3/v3.expect
@@ -3,5 +3,5 @@ Type error:
25 | Some(Ctor(64, unsigned(0b0 @ b @ a)))
 | ^-----------------------------^
 | Could not resolve quantifiers for Ctor
-  | * (datasize(64) & (0 <= 'ex158# & ('ex158# + 1) <= 64))
+  | * (datasize(64) & (0 <= 'ex196# & ('ex196# + 1) <= 64))
 |
diff --git a/test/typecheck/pass/guards.sail b/test/typecheck/pass/guards.sail
index 4aac2bed..594130a8 100644
--- a/test/typecheck/pass/guards.sail
+++ b/test/typecheck/pass/guards.sail
@@ -1,8 +1,9 @@
default Order dec
$include <prelude.sail>
+$include <smt.sail>
-overload operator / = {quotient}
+overload operator / = {ediv_int}
union T = {C1 : int, C2 : int}
diff --git a/test/typecheck/pass/if_infer/v1.expect b/test/typecheck/pass/if_infer/v1.expect
index a63f28f1..80526204 100644
--- a/test/typecheck/pass/if_infer/v1.expect
+++ b/test/typecheck/pass/if_infer/v1.expect
@@ -5,8 +5,8 @@ Type error:
 | No overloading for vector_access, tried:
 | * bitvector_access
 | Could not resolve quantifiers for bitvector_access
-  | * (0 <= 'ex76# & ('ex76# + 1) <= 3)
+  | * (0 <= 'ex114# & ('ex114# + 1) <= 3)
 | * plain_vector_access
 | Could not resolve quantifiers for plain_vector_access
-  | * (0 <= 'ex79# & ('ex79# + 1) <= 3)
+  | * (0 <= 'ex117# & ('ex117# + 1) <= 3)
 |
diff --git a/test/typecheck/pass/if_infer/v2.expect b/test/typecheck/pass/if_infer/v2.expect
index f37d215f..0b705b50 100644
--- a/test/typecheck/pass/if_infer/v2.expect
+++ b/test/typecheck/pass/if_infer/v2.expect
@@ -5,8 +5,8 @@ Type error:
 | No overloading for vector_access, tried:
 | * bitvector_access
 | Could not resolve quantifiers for bitvector_access
-  | * (0 <= 'ex76# & ('ex76# + 1) <= 4)
+  | * (0 <= 'ex114# & ('ex114# + 1) <= 4)
 | * plain_vector_access
 | Could not resolve quantifiers for plain_vector_access
-  | * (0 <= 'ex79# & ('ex79# + 1) <= 4)
+  | * (0 <= 'ex117# & ('ex117# + 1) <= 4)
 |
diff --git a/test/typecheck/pass/recursion.sail b/test/typecheck/pass/recursion.sail
index 5ca85f53..cd3ca46c 100644
--- a/test/typecheck/pass/recursion.sail
+++ b/test/typecheck/pass/recursion.sail
@@ -2,6 +2,8 @@ default Order dec
$include <prelude.sail>
+overload operator / = {tdiv_int}
+
val log2 : int -> int
function log2(n) =
diff --git a/test/typecheck/pass/shadow_let.sail b/test/typecheck/pass/shadow_let.sail
new file mode 100644
index 00000000..8a30744c
--- /dev/null
+++ b/test/typecheck/pass/shadow_let.sail
@@ -0,0 +1,14 @@
+default Order dec
+
+register R : int
+
+val foo : int(1) -> unit
+val bar : int(2) -> unit
+
+function main((): unit) -> unit = {
+ let 'x : {'z, 'z == 1. int('z)} = 1;
+ let 'y = x;
+ foo(x);
+ let 'x : {'z, 'z == 2. int('z)} = 2;
+ foo(y);
+} \ No newline at end of file
diff --git a/test/typecheck/pass/shadow_let/v1.expect b/test/typecheck/pass/shadow_let/v1.expect
new file mode 100644
index 00000000..3cd21dc0
--- /dev/null
+++ b/test/typecheck/pass/shadow_let/v1.expect
@@ -0,0 +1,12 @@
+Type error:
+[shadow_let/v1.sail]:13:6-7
+13 | bar(y);
+  | ^
+  | Tried performing type coercion from int('_x#1) to int(2) on y
+  | Coercion failed because:
+  | int('_x#1) is not a subtype of int(2)
+  | [shadow_let/v1.sail]:9:6-8
+  | 9 | let 'x : {'z, 'z == 1. int('z)} = 1;
+  |  | ^^
+  |  | '_x#1 bound here
+  |
diff --git a/test/typecheck/pass/shadow_let/v1.sail b/test/typecheck/pass/shadow_let/v1.sail
new file mode 100644
index 00000000..d7dc20a5
--- /dev/null
+++ b/test/typecheck/pass/shadow_let/v1.sail
@@ -0,0 +1,14 @@
+default Order dec
+
+register R : int
+
+val foo : int(1) -> unit
+val bar : int(2) -> unit
+
+function main((): unit) -> unit = {
+ let 'x : {'z, 'z == 1. int('z)} = 1;
+ let 'y = x;
+ foo(x);
+ let 'x : {'z, 'z == 2. int('z)} = 2;
+ bar(y);
+} \ No newline at end of file