summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/prelude.sail16
-rw-r--r--mips/hgen/regs_out_in.hgen5
-rw-r--r--risc-v/hgen/ast.hgen2
-rw-r--r--risc-v/hgen/herdtools_ast_to_shallow_ast.hgen11
-rw-r--r--risc-v/hgen/herdtools_types_to_shallow_types.hgen7
-rw-r--r--risc-v/hgen/lexer.hgen5
-rw-r--r--risc-v/hgen/parser.hgen16
-rw-r--r--risc-v/hgen/pretty.hgen4
-rw-r--r--risc-v/hgen/regs_out_in.hgen5
-rw-r--r--risc-v/hgen/sail_trans_out.hgen6
-rw-r--r--risc-v/hgen/shallow_ast_to_herdtools_ast.hgen6
-rw-r--r--risc-v/hgen/shallow_types_to_herdtools_types.hgen3
-rw-r--r--risc-v/hgen/token_types.hgen2
-rw-r--r--risc-v/hgen/tokens.hgen3
-rw-r--r--risc-v/hgen/trans_sail.hgen18
-rw-r--r--risc-v/hgen/types.hgen7
-rw-r--r--risc-v/hgen/types_sail_trans_out.hgen5
-rw-r--r--risc-v/hgen/types_trans_sail.hgen8
-rw-r--r--risc-v/riscv.sail56
-rw-r--r--risc-v/riscv_extras.lem6
-rw-r--r--risc-v/riscv_extras_embed.lem8
-rw-r--r--risc-v/riscv_extras_embed_sequential.lem9
-rw-r--r--src/gen_lib/sail_operators.lem5
-rw-r--r--src/gen_lib/sail_operators_mwords.lem79
-rw-r--r--src/gen_lib/sail_values.lem3
-rw-r--r--src/gen_lib/state.lem2
-rw-r--r--src/lem_interp/pretty_interp.ml32
-rw-r--r--src/lem_interp/printing_functions.ml3
-rw-r--r--src/lem_interp/printing_functions.mli1
-rw-r--r--src/lem_interp/run_interp.ml45
-rw-r--r--src/lem_interp/sail_impl_base.lem5
-rw-r--r--src/monomorphise.ml1191
-rw-r--r--src/pretty_print_lem.ml41
-rw-r--r--src/process_file.ml3
-rw-r--r--src/process_file.mli1
-rw-r--r--src/sail.ml3
-rw-r--r--src/type_check.mli2
-rw-r--r--test/mono/.gitignore1
-rw-r--r--test/mono/addsubexist.sail75
-rw-r--r--test/mono/fnreduce.sail69
-rw-r--r--test/mono/test.ml1
-rwxr-xr-xtest/mono/test.sh44
-rw-r--r--test/mono/tests4
-rw-r--r--test/mono/union-exist.sail33
-rw-r--r--test/mono/varmatch.sail19
-rw-r--r--test/mono/vector.sail21
46 files changed, 1252 insertions, 639 deletions
diff --git a/lib/prelude.sail b/lib/prelude.sail
index e5b5004f..9a79f81b 100644
--- a/lib/prelude.sail
+++ b/lib/prelude.sail
@@ -125,11 +125,11 @@ val forall Num 'n, Num 'm, Order 'ord. vector<'n, 'm, 'ord, bit> -> bit effect p
(* Arithmetic *)
val extern forall Num 'n, Num 'm, Num 'o, Num 'p.
- ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add_range = "add_int"
+ ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add_range = "add"
-val extern (nat, nat) -> nat effect pure add_nat = "add_int"
+val extern (nat, nat) -> nat effect pure add_nat = "add"
-val extern (int, int) -> int effect pure add_int
+val extern (int, int) -> int effect pure add_int = "add"
val forall Num 'n, Num 'o, Order 'ord.
(vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_vec
@@ -144,9 +144,9 @@ val forall Num 'n, Num 'o, Order 'ord.
(vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> (vector<'o, 'n, 'ord, bit>, bit, bit) effect pure add_overflow_vec
val extern forall Num 'n, Num 'm, Num 'o, Num 'p.
- ([|'n:'m|], [|'o:'p|]) -> [|'n - 'p:'m - 'o|] effect pure sub_range = "sub_int"
+ ([|'n:'m|], [|'o:'p|]) -> [|'n - 'p:'m - 'o|] effect pure sub_range = "sub"
-val extern (int, int) -> int effect pure sub_int = "sub_int"
+val extern (int, int) -> int effect pure sub_int = "sub"
val forall Num 'n, Num 'm, Order 'ord.
(vector<'n,'m,'ord,bit>, int) -> vector<'n,'m,'ord,bit> effect pure sub_vec_int
@@ -177,7 +177,7 @@ overload (deinfix -) [
val extern bool -> bit effect pure bool_to_bit = "bool_to_bitU"
-val extern (int, int) -> int effect pure mult_int
+val (int, int) -> int effect pure mult_int
val extern forall Num 'n, Num 'm. ([:'n:], [:'m:]) -> [:'n * 'm:] effect pure mult_range = "mult_int"
val forall Num 'n, Num 'o, Order 'ord.
(vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<2 * 'n - 1, 2 * 'n, 'ord, bit> effect pure mult_vec
@@ -296,7 +296,7 @@ val extern (int, int) -> int effect pure min_int = "min"
overload min [min_range_atom; min_int]
-val extern (int, int) -> int effect pure quotient = "quotient_int"
+val extern (int, int) -> int effect pure quotient
overload (deinfix quot) [quotient]
@@ -310,7 +310,7 @@ val extern forall Num 'n. [:'n:] -> [:2** 'n:] effect pure pow2
val extern forall Num 'n, Num 'm, Order 'ord, Type 'a. vector<'n,'m,'ord,'a> -> [:'m:] effect pure vector_length = "length"
val extern forall Type 'a. list<'a> -> nat effect pure list_length
-val extern forall Num 'n, Num 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [:'m:] effect pure bitvector_length = "bitvector_length"
+val extern forall Num 'n, Num 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [:'m:] effect pure bitvector_length = "bvlength"
overload length [bitvector_length; vector_length; list_length]
diff --git a/mips/hgen/regs_out_in.hgen b/mips/hgen/regs_out_in.hgen
deleted file mode 100644
index 8e1fd093..00000000
--- a/mips/hgen/regs_out_in.hgen
+++ /dev/null
@@ -1,5 +0,0 @@
-(* for each instruction instance, identify the role of the registers
- and possible branching: (outputs, inputs, voidstars, branch) *)
-
-| `MIPSAdd ->
- ([], [], [], [Next])
diff --git a/risc-v/hgen/ast.hgen b/risc-v/hgen/ast.hgen
index 2860484e..8983b5ae 100644
--- a/risc-v/hgen/ast.hgen
+++ b/risc-v/hgen/ast.hgen
@@ -10,4 +10,4 @@
| `RISCVADDIW of bit12 * reg * reg
| `RISCVSHIFTW of bit5 * reg * reg * riscvSop
| `RISCVRTYPEW of reg * reg * reg * riscvRopw
-| `RISCVFENCE
+| `RISCVFENCE of bit4 * bit4
diff --git a/risc-v/hgen/herdtools_ast_to_shallow_ast.hgen b/risc-v/hgen/herdtools_ast_to_shallow_ast.hgen
index 46b11310..50026612 100644
--- a/risc-v/hgen/herdtools_ast_to_shallow_ast.hgen
+++ b/risc-v/hgen/herdtools_ast_to_shallow_ast.hgen
@@ -4,14 +4,14 @@
translate_reg "rd" rd,
translate_uop op)
| `RISCVJAL(imm, rd) -> JAL0(
- translate_imm20 "imm" imm,
+ translate_imm21 "imm" imm,
translate_reg "rd" rd)
| `RISCVJALR(imm, rs, rd) -> JALR0(
translate_imm12 "imm" imm,
translate_reg "rs" rd,
translate_reg "rd" rd)
| `RISCVBType(imm, rs2, rs1, op) -> BTYPE(
- translate_imm12 "imm" imm,
+ translate_imm13 "imm" imm,
translate_reg "rs2" rs2,
translate_reg "rs1" rs1,
translate_bop op)
@@ -55,7 +55,6 @@
translate_reg "rs1" rs1,
translate_reg "rd" rd,
translate_ropw op)
-| `RISCVFENCE -> FENCE (
- translate_imm4 "pred" 0,
- translate_imm4 "succ" 0
-)
+| `RISCVFENCE(pred, succ) -> FENCE(
+ translate_imm4 "pred" pred,
+ translate_imm4 "succ" succ)
diff --git a/risc-v/hgen/herdtools_types_to_shallow_types.hgen b/risc-v/hgen/herdtools_types_to_shallow_types.hgen
index c15b3f94..4d8bd87a 100644
--- a/risc-v/hgen/herdtools_types_to_shallow_types.hgen
+++ b/risc-v/hgen/herdtools_types_to_shallow_types.hgen
@@ -57,10 +57,15 @@ let translate_bool name = function
| true -> Sail_values.B1
| false -> Sail_values.B0
+let translate_imm21 name value =
+ Sail_values.to_vec0 is_inc (Nat_big_num.of_int 21,Nat_big_num.of_int value)
let translate_imm20 name value =
Sail_values.to_vec0 is_inc (Nat_big_num.of_int 20,Nat_big_num.of_int value)
+let translate_imm13 name value =
+ Sail_values.to_vec0 is_inc (Nat_big_num.of_int 13,Nat_big_num.of_int value)
+
let translate_imm12 name value =
Sail_values.to_vec0 is_inc (Nat_big_num.of_int 12,Nat_big_num.of_int value)
@@ -70,5 +75,5 @@ let translate_imm6 name value =
let translate_imm5 name value =
Sail_values.to_vec0 is_inc (Nat_big_num.of_int 5,Nat_big_num.of_int value)
-let translate_imm4 name value =
+let translate_imm4 name value =
Sail_values.to_vec0 is_inc (Nat_big_num.of_int 4,Nat_big_num.of_int value)
diff --git a/risc-v/hgen/lexer.hgen b/risc-v/hgen/lexer.hgen
index 1888d42b..5f2c8326 100644
--- a/risc-v/hgen/lexer.hgen
+++ b/risc-v/hgen/lexer.hgen
@@ -1,4 +1,4 @@
-"lwu" , UTYPE { op=RISCVLUI };
+"lui" , UTYPE { op=RISCVLUI };
"auipc" , UTYPE { op=RISCVAUIPC };
"jal", JAL ();
@@ -59,3 +59,6 @@
"sraw", RTYPEW{op=RISCVSRAW};
"fence", FENCE ();
+"r", FENCEOPTION Fence_R;
+"w", FENCEOPTION Fence_W;
+"rw", FENCEOPTION Fence_RW;
diff --git a/risc-v/hgen/parser.hgen b/risc-v/hgen/parser.hgen
index ba780857..37fd8d8d 100644
--- a/risc-v/hgen/parser.hgen
+++ b/risc-v/hgen/parser.hgen
@@ -15,12 +15,22 @@
| LOAD reg COMMA NUM LPAR reg RPAR
{ `RISCVLoad($4, $6, $2, $1.unsigned, $1.width) }
| STORE reg COMMA NUM LPAR reg RPAR
- { `RISCVStore($4, $6, $2, $1.width) }
+ { `RISCVStore($4, $2, $6, $1.width) }
| ADDIW reg COMMA reg COMMA NUM
{ `RISCVADDIW ($6, $4, $2) }
| SHIFTW reg COMMA reg COMMA NUM
{ `RISCVSHIFTW ($6, $4, $2, $1.op) }
| RTYPEW reg COMMA reg COMMA reg
{ `RISCVRTYPEW ($6, $4, $2, $1.op) }
-| FENCE
- { `RISCVFENCE }
+| FENCE FENCEOPTION COMMA FENCEOPTION
+ { match ($2, $4) with
+ | (Fence_RW, Fence_RW) -> `RISCVFENCE (0b0011, 0b0011)
+ | (Fence_R, Fence_RW) -> `RISCVFENCE (0b0010, 0b0011)
+ | (Fence_RW, Fence_W) -> `RISCVFENCE (0b0011, 0b0001)
+ | (Fence_RW, Fence_R) -> failwith "'fence rw,r' is not supported"
+ | (Fence_R, Fence_R) -> failwith "'fence r,r' is not supported"
+ | (Fence_R, Fence_W) -> failwith "'fence r,w' is not supported"
+ | (Fence_W, Fence_RW) -> failwith "'fence w,rw' is not supported"
+ | (Fence_W, Fence_R) -> failwith "'fence w,r' is not supported"
+ | (Fence_W, Fence_W) -> failwith "'fence w,w' is not supported"
+ }
diff --git a/risc-v/hgen/pretty.hgen b/risc-v/hgen/pretty.hgen
index d8572ae2..1da3ef11 100644
--- a/risc-v/hgen/pretty.hgen
+++ b/risc-v/hgen/pretty.hgen
@@ -8,8 +8,8 @@
| `RISCVShiftIop(imm, rs, rd, op) -> sprintf "%s %s, %s, %d" (pp_riscv_sop op) (pp_reg rd) (pp_reg rs) imm
| `RISCVRType (rs2, rs1, rd, op) -> sprintf "%s %s, %s, %s" (pp_riscv_rop op) (pp_reg rd) (pp_reg rs1) (pp_reg rs2)
| `RISCVLoad(imm, rs, rd, unsigned, width) -> sprintf "%s %s, %d(%s)" (pp_riscv_load_op (unsigned, width)) (pp_reg rd) imm (pp_reg rs)
-| `RISCVStore(imm, rs, rd, width) -> sprintf "%s %s, %d(%s)" (pp_riscv_store_op width) (pp_reg rd) imm (pp_reg rs)
+| `RISCVStore(imm, rs2, rs1, width) -> sprintf "%s %s, %d(%s)" (pp_riscv_store_op width) (pp_reg rs2) imm (pp_reg rs1)
| `RISCVADDIW(imm, rs, rd) -> sprintf "addiw %s, %s, %d" (pp_reg rd) (pp_reg rs) imm
| `RISCVSHIFTW(imm, rs, rd, op) -> sprintf "%s %s, %s, %d" (pp_riscv_sop op) (pp_reg rd) (pp_reg rs) imm
| `RISCVRTYPEW(rs2, rs1, rd, op) -> sprintf "%s %s, %s, %s" (pp_riscv_ropw op) (pp_reg rd) (pp_reg rs1) (pp_reg rs2)
-| `RISCVFENCE -> "fence"
+| `RISCVFENCE(pred, succ) -> sprintf "fence %s, %s" (pp_riscv_fence_option pred) (pp_riscv_fence_option succ)
diff --git a/risc-v/hgen/regs_out_in.hgen b/risc-v/hgen/regs_out_in.hgen
deleted file mode 100644
index 8e1fd093..00000000
--- a/risc-v/hgen/regs_out_in.hgen
+++ /dev/null
@@ -1,5 +0,0 @@
-(* for each instruction instance, identify the role of the registers
- and possible branching: (outputs, inputs, voidstars, branch) *)
-
-| `MIPSAdd ->
- ([], [], [], [Next])
diff --git a/risc-v/hgen/sail_trans_out.hgen b/risc-v/hgen/sail_trans_out.hgen
index a9d3159c..dca5bea1 100644
--- a/risc-v/hgen/sail_trans_out.hgen
+++ b/risc-v/hgen/sail_trans_out.hgen
@@ -1,8 +1,8 @@
| ("EBREAK", []) -> `RISCVStopFetching
| ("UTYPE", [imm; rd; op]) -> `RISCVUTYPE(translate_out_simm20 imm, translate_out_ireg rd, translate_out_uop op)
-| ("JAL", [imm; rd]) -> `RISCVJAL(translate_out_simm20 imm, translate_out_ireg rd)
+| ("JAL", [imm; rd]) -> `RISCVJAL(translate_out_simm21 imm, translate_out_ireg rd)
| ("JALR", [imm; rs; rd]) -> `RISCVJALR(translate_out_simm12 imm, translate_out_ireg rs, translate_out_ireg rd)
-| ("BTYPE", [imm; rs2; rs1; op]) -> `RISCVBType(translate_out_simm12 imm, translate_out_ireg rs2, translate_out_ireg rs1, translate_out_bop op)
+| ("BTYPE", [imm; rs2; rs1; op]) -> `RISCVBType(translate_out_simm13 imm, translate_out_ireg rs2, translate_out_ireg rs1, translate_out_bop op)
| ("ITYPE", [imm; rs1; rd; op]) -> `RISCVIType(translate_out_simm12 imm, translate_out_ireg rs1, translate_out_ireg rd, translate_out_iop op)
| ("SHIFTIOP", [imm; rs; rd; op]) -> `RISCVShiftIop(translate_out_imm6 imm, translate_out_ireg rs, translate_out_ireg rd, translate_out_sop op)
| ("RTYPE", [rs2; rs1; rd; op]) -> `RISCVRType (translate_out_ireg rs2, translate_out_ireg rs1, translate_out_ireg rd, translate_out_rop op)
@@ -11,4 +11,4 @@
| ("ADDIW", [imm; rs; rd]) -> `RISCVADDIW(translate_out_simm12 imm, translate_out_ireg rs, translate_out_ireg rd)
| ("SHIFTW", [imm; rs; rd; op]) -> `RISCVSHIFTW(translate_out_imm5 imm, translate_out_ireg rs, translate_out_ireg rd, translate_out_sop op)
| ("RTYPEW", [rs2; rs1; rd; op]) -> `RISCVRTYPEW(translate_out_ireg rs2, translate_out_ireg rs1, translate_out_ireg rd, translate_out_ropw op)
-| ("FENCE", []) -> `RISCVFENCE
+| ("FENCE", [pred; succ]) -> `RISCVFENCE(translate_out_imm4 pred, translate_out_imm4 succ)
diff --git a/risc-v/hgen/shallow_ast_to_herdtools_ast.hgen b/risc-v/hgen/shallow_ast_to_herdtools_ast.hgen
index 9278b92e..6158ebd7 100644
--- a/risc-v/hgen/shallow_ast_to_herdtools_ast.hgen
+++ b/risc-v/hgen/shallow_ast_to_herdtools_ast.hgen
@@ -1,8 +1,8 @@
| EBREAK -> `RISCVStopFetching
| UTYPE( imm, rd, op) -> `RISCVUTYPE(translate_out_simm20 imm, translate_out_ireg rd, translate_out_uop op)
-| JAL0( imm, rd) -> `RISCVJAL(translate_out_simm20 imm, translate_out_ireg rd)
+| JAL0( imm, rd) -> `RISCVJAL(translate_out_simm21 imm, translate_out_ireg rd)
| JALR0( imm, rs, rd) -> `RISCVJALR(translate_out_simm12 imm, translate_out_ireg rs, translate_out_ireg rd)
-| BTYPE( imm, rs2, rs1, op) -> `RISCVBType(translate_out_simm12 imm, translate_out_ireg rs2, translate_out_ireg rs1, translate_out_bop op)
+| BTYPE( imm, rs2, rs1, op) -> `RISCVBType(translate_out_simm13 imm, translate_out_ireg rs2, translate_out_ireg rs1, translate_out_bop op)
| ITYPE( imm, rs1, rd, op) -> `RISCVIType(translate_out_simm12 imm, translate_out_ireg rs1, translate_out_ireg rd, translate_out_iop op)
| SHIFTIOP( imm, rs, rd, op) -> `RISCVShiftIop(translate_out_imm6 imm, translate_out_ireg rs, translate_out_ireg rd, translate_out_sop op)
| RTYPE( rs2, rs1, rd, op) -> `RISCVRType (translate_out_ireg rs2, translate_out_ireg rs1, translate_out_ireg rd, translate_out_rop op)
@@ -11,4 +11,4 @@
| ADDIW( imm, rs, rd) -> `RISCVADDIW(translate_out_simm12 imm, translate_out_ireg rs, translate_out_ireg rd)
| SHIFTW( imm, rs, rd, op) -> `RISCVSHIFTW(translate_out_imm5 imm, translate_out_ireg rs, translate_out_ireg rd, translate_out_sop op)
| RTYPEW( rs2, rs1, rd, op) -> `RISCVRTYPEW(translate_out_ireg rs2, translate_out_ireg rs1, translate_out_ireg rd, translate_out_ropw op)
-| FENCE(pred, succ) -> `RISCVFENCE
+| FENCE( pred, succ) -> `RISCVFENCE(translate_out_imm4 pred, translate_out_imm4 succ)
diff --git a/risc-v/hgen/shallow_types_to_herdtools_types.hgen b/risc-v/hgen/shallow_types_to_herdtools_types.hgen
index d635efde..a891d7d0 100644
--- a/risc-v/hgen/shallow_types_to_herdtools_types.hgen
+++ b/risc-v/hgen/shallow_types_to_herdtools_types.hgen
@@ -64,7 +64,10 @@ let translate_out_bool = function
| Sail_values.B0 -> false
| _ -> failwith "translate_out_bool Undef"
+let translate_out_simm21 imm = translate_out_signed_int imm 21
let translate_out_simm20 imm = translate_out_signed_int imm 20
+let translate_out_simm13 imm = translate_out_signed_int imm 13
let translate_out_simm12 imm = translate_out_signed_int imm 12
let translate_out_imm6 imm = translate_out_int imm
let translate_out_imm5 imm = translate_out_int imm
+let translate_out_imm4 imm = translate_out_int imm
diff --git a/risc-v/hgen/token_types.hgen b/risc-v/hgen/token_types.hgen
index e778d2a9..2980b985 100644
--- a/risc-v/hgen/token_types.hgen
+++ b/risc-v/hgen/token_types.hgen
@@ -11,3 +11,5 @@ type token_ADDIW = unit
type token_SHIFTW = {op : riscvSop }
type token_RTYPEW = {op : riscvRopw }
type token_FENCE = unit
+
+type token_FENCEOPTION = Fence_R | Fence_W | Fence_RW
diff --git a/risc-v/hgen/tokens.hgen b/risc-v/hgen/tokens.hgen
index abe6a6c3..f952cf77 100644
--- a/risc-v/hgen/tokens.hgen
+++ b/risc-v/hgen/tokens.hgen
@@ -10,4 +10,5 @@
%token <RISCVHGenBase.token_ADDIW> ADDIW
%token <RISCVHGenBase.token_SHIFTW> SHIFTW
%token <RISCVHGenBase.token_RTYPEW> RTYPEW
-%token <RISCVHGenBase.token_FENCE> FENCE
+%token <RISCVHGenBase.token_FENCE> FENCE
+%token <RISCVHGenBase.token_FENCEOPTION> FENCEOPTION
diff --git a/risc-v/hgen/trans_sail.hgen b/risc-v/hgen/trans_sail.hgen
index c2e3138b..df22d9dc 100644
--- a/risc-v/hgen/trans_sail.hgen
+++ b/risc-v/hgen/trans_sail.hgen
@@ -10,7 +10,7 @@
| `RISCVJAL(imm, rd) ->
("JAL",
[
- translate_imm20 "imm" imm;
+ translate_imm21 "imm" imm;
translate_reg "rd" rd;
],
[])
@@ -25,7 +25,7 @@
| `RISCVBType(imm, rs2, rs1, op) ->
("BTYPE",
[
- translate_imm12 "imm" imm;
+ translate_imm13 "imm" imm;
translate_reg "rs2" rs2;
translate_reg "rs1" rs1;
translate_bop "op" op;
@@ -68,12 +68,12 @@
translate_width "width" width;
],
[])
-| `RISCVStore(imm, rs, rd, width) ->
+| `RISCVStore(imm, rs2, rs1, width) ->
("STORE",
[
translate_imm12 "imm" imm;
- translate_reg "rs" rs;
- translate_reg "rd" rd;
+ translate_reg "rs2" rs2;
+ translate_reg "rs1" rs1;
translate_width "width" width;
],
[])
@@ -103,4 +103,10 @@
translate_ropw "op" op;
],
[])
-| `RISCVFENCE -> ("FENCE", [], [])
+| `RISCVFENCE(pred, succ) ->
+ ("FENCE",
+ [
+ translate_imm4 "pred" pred;
+ translate_imm4 "succ" succ;
+ ],
+ [])
diff --git a/risc-v/hgen/types.hgen b/risc-v/hgen/types.hgen
index e31b11f8..87fc9b95 100644
--- a/risc-v/hgen/types.hgen
+++ b/risc-v/hgen/types.hgen
@@ -2,6 +2,7 @@ type bit20 = int
type bit12 = int
type bit6 = int
type bit5 = int
+type bit4 = int
type riscvUop = (* upper immediate ops *)
| RISCVLUI
@@ -114,3 +115,9 @@ let pp_riscv_store_op width = match width with
| RISCVWORD -> "sw"
| RISCVDOUBLE -> "sd"
| _ -> failwith "unexpected store op"
+
+let pp_riscv_fence_option = function
+ | 0b0011 -> "rw"
+ | 0b0010 -> "r"
+ | 0b0001 -> "w"
+ | _ -> failwith "unexpected fence option"
diff --git a/risc-v/hgen/types_sail_trans_out.hgen b/risc-v/hgen/types_sail_trans_out.hgen
index e034cf37..e22110d0 100644
--- a/risc-v/hgen/types_sail_trans_out.hgen
+++ b/risc-v/hgen/types_sail_trans_out.hgen
@@ -11,10 +11,13 @@ let translate_out_signed_int inst bits =
let translate_out_ireg ireg = IReg (int_to_ireg (translate_out_int ireg))
+let translate_out_simm21 imm = translate_out_signed_int imm 21
let translate_out_simm20 imm = translate_out_signed_int imm 20
+let translate_out_simm13 imm = translate_out_signed_int imm 13
let translate_out_simm12 imm = translate_out_signed_int imm 12
let translate_out_imm6 imm = translate_out_int imm
let translate_out_imm5 imm = translate_out_int imm
+let translate_out_imm4 imm = translate_out_int imm
let translate_out_bool = function
| (name, Bit, [Bitc_one]) -> true
@@ -80,4 +83,4 @@ let translate_out_ropw op = match translate_out_enum op with
| 2 -> RISCVSLLW
| 3 -> RISCVSRLW
| 4 -> RISCVSRAW
-| _ -> failwith "Unknown ropw in sail translate out" \ No newline at end of file
+| _ -> failwith "Unknown ropw in sail translate out"
diff --git a/risc-v/hgen/types_trans_sail.hgen b/risc-v/hgen/types_trans_sail.hgen
index 9dd36d5e..1bf174fa 100644
--- a/risc-v/hgen/types_trans_sail.hgen
+++ b/risc-v/hgen/types_trans_sail.hgen
@@ -19,15 +19,21 @@ let translate_ropw = translate_enum [RISCVADDW; RISCVSUBW; RISCVSLLW; RISCVSRLW;
let translate_width = translate_enum [RISCVBYTE; RISCVHALF; RISCVWORD; RISCVDOUBLE]
let translate_reg name value =
(name, Bvector (Some 5), bit_list_of_integer 5 (Nat_big_num.of_int (reg_to_int value)))
+let translate_imm21 name value =
+ (name, Bvector (Some 21), bit_list_of_integer 21 (Nat_big_num.of_int value))
let translate_imm20 name value =
- (name, Bvector (Some 26), bit_list_of_integer 26 (Nat_big_num.of_int value))
+ (name, Bvector (Some 20), bit_list_of_integer 20 (Nat_big_num.of_int value))
let translate_imm16 name value =
(name, Bvector (Some 16), bit_list_of_integer 16 (Nat_big_num.of_int value))
+let translate_imm13 name value =
+ (name, Bvector (Some 13), bit_list_of_integer 13 (Nat_big_num.of_int value))
let translate_imm12 name value =
(name, Bvector (Some 12), bit_list_of_integer 12 (Nat_big_num.of_int value))
let translate_imm6 name value =
(name, Bvector (Some 6), bit_list_of_integer 6 (Nat_big_num.of_int value))
let translate_imm5 name value =
(name, Bvector (Some 5), bit_list_of_integer 5 (Nat_big_num.of_int value))
+let translate_imm4 name value =
+ (name, Bvector (Some 4), bit_list_of_integer 4 (Nat_big_num.of_int value))
let translate_bool name value =
(name, Bit, [if value then Bitc_one else Bitc_zero])
diff --git a/risc-v/riscv.sail b/risc-v/riscv.sail
index e464c5f7..4a80adb0 100644
--- a/risc-v/riscv.sail
+++ b/risc-v/riscv.sail
@@ -52,10 +52,17 @@ function unit wGPR((regno) r, (regval) v) =
if (r != 0) then
GPRs[r] := v
+function forall 'a. 'a effect { escape } not_implemented((string) message) =
+{
+ exit message;
+}
+
val extern forall Nat 'n. ( bit[64] , [|'n|] ) -> (bit[8 * 'n]) effect { rmem } MEMr
val extern forall Nat 'n. ( bit[64] , [|'n|]) -> unit effect { eamem } MEMea
val extern forall Nat 'n. ( bit[64] , [|'n|] , bit[8*'n]) -> unit effect { wmv } MEMval
-val extern unit -> unit effect { barr } MEM_sync
+val extern unit -> unit effect { barr } MEM_fence_rw_rw
+val extern unit -> unit effect { barr } MEM_fence_r_rw
+val extern unit -> unit effect { barr } MEM_fence_rw_w
(* Ideally these would be sail builtin *)
function (bit[64]) shift_right_arith64 ((bit[64]) v, (bit[6]) shift) =
@@ -95,10 +102,10 @@ function clause execute (UTYPE(imm, rd, op)) =
} in
wGPR(rd, ret)
-union ast member ((bit[20]), regno) JAL
-function clause decode ((bit[20]) imm : (regno) rd : 0b1101111) = Some (JAL(imm, rd))
+union ast member ((bit[21]), regno) JAL
+function clause decode ((bit[20]) imm : (regno) rd : 0b1101111) = Some (JAL(imm[19] : imm[7..0] : imm[8] : imm[18..13] : imm[12..9] : 0b0, rd))
function clause execute (JAL(imm, rd)) =
- let (bit[64]) offset = EXTS(imm[19] : imm[7..0] : imm[8] : imm[18..13] : imm[12..9] : 0b0) in {
+ let (bit[64]) offset = EXTS(imm) in {
nextPC := PC + offset;
wGPR(rd, PC + 4);
}
@@ -112,13 +119,13 @@ function clause execute (JALR(imm, rs1, rd)) =
wGPR(rd, PC + 4);
}
-union ast member ((bit[12]), regno, regno, bop) BTYPE
-function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b000 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BEQ))
-function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b001 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BNE))
-function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b100 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BLT))
-function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b101 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BGE))
-function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b110 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BLTU))
-function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b111 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BGEU))
+union ast member ((bit[13]), regno, regno, bop) BTYPE
+function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b000 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1] : 0b0, rs2, rs1, BEQ))
+function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b001 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1] : 0b0, rs2, rs1, BNE))
+function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b100 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1] : 0b0, rs2, rs1, BLT))
+function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b101 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1] : 0b0, rs2, rs1, BGE))
+function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b110 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1] : 0b0, rs2, rs1, BLTU))
+function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b111 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1] : 0b0, rs2, rs1, BGEU))
function clause execute (BTYPE(imm, rs2, rs1, op)) =
let rs1_val = rGPR(rs1) in
@@ -132,7 +139,7 @@ function clause execute (BTYPE(imm, rs2, rs1, op)) =
case BGEU -> unsigned(rs1_val) >= unsigned(rs2_val) (* XXX sail missing >=_u *)
} in
if (taken) then
- nextPC := PC + EXTS(imm : 0b0)
+ nextPC := PC + EXTS(imm)
union ast member ((bit[12]), regno, regno, iop) ITYPE
function clause decode ((bit[12]) imm : (regno) rs1 : 0b000 : (regno) rd : 0b0010011) = Some(ITYPE(imm, rs1, rd, ADDI))
@@ -208,15 +215,15 @@ function clause execute(LOAD(imm, rs1, rd, unsigned, width)) =
let (bit[64]) result = if unsigned then
switch (width) {
case BYTE -> EXTZ(MEMr(addr, 1))
- case WORD -> EXTZ(MEMr(addr, 2))
- case HALF -> EXTZ(MEMr(addr, 4))
+ case HALF -> EXTZ(MEMr(addr, 2))
+ case WORD -> EXTZ(MEMr(addr, 4))
case DOUBLE -> MEMr(addr, 8)
}
else
switch (width) {
case BYTE -> EXTS(MEMr(addr, 1))
- case WORD -> EXTS(MEMr(addr, 2))
- case HALF -> EXTS(MEMr(addr, 4))
+ case HALF -> EXTS(MEMr(addr, 2))
+ case WORD -> EXTS(MEMr(addr, 4))
case DOUBLE -> MEMr(addr, 8)
} in
wGPR(rd, result)
@@ -230,15 +237,15 @@ function clause execute (STORE(imm, rs2, rs1, width)) =
let (bit[64]) addr = rGPR(rs1) + EXTS(imm) in {
switch (width) {
case BYTE -> MEMea(addr, 1)
- case WORD -> MEMea(addr, 2)
- case HALF -> MEMea(addr, 4)
+ case HALF -> MEMea(addr, 2)
+ case WORD -> MEMea(addr, 4)
case DOUBLE -> MEMea(addr, 8)
};
let rs2_val = rGPR(rs2) in
switch (width) {
- case BYTE -> MEMval(addr, 1, rs2_val)
- case WORD -> MEMval(addr, 2, rs2_val)
- case HALF -> MEMval(addr, 4, rs2_val)
+ case BYTE -> MEMval(addr, 1, rs2_val[7..0])
+ case HALF -> MEMval(addr, 2, rs2_val[15..0])
+ case WORD -> MEMval(addr, 4, rs2_val[31..0])
case DOUBLE -> MEMval(addr, 8, rs2_val)
}
}
@@ -285,7 +292,12 @@ function clause execute (RTYPEW(rs2, rs1, rd, op)) =
union ast member (bit[4], bit[4]) FENCE
function clause decode (0b0000 : (bit[4]) pred : (bit[4]) succ : 0b00000 : 0b000 : 0b00000 : 0b0001111) = Some(FENCE (pred, succ))
function clause execute (FENCE(pred, succ)) = {
- MEM_sync(); (* XXX use pred and succ *)
+ switch(pred, succ) {
+ case (0b0011, 0b0011) -> MEM_fence_rw_rw()
+ case (0b0010, 0b0011) -> MEM_fence_r_rw()
+ case (0b0011, 0b0001) -> MEM_fence_rw_w()
+ case _ -> not_implemented("unsupported fence")
+ }
}
union ast member unit FENCEI
diff --git a/risc-v/riscv_extras.lem b/risc-v/riscv_extras.lem
index c09c85c2..aa5d8fb8 100644
--- a/risc-v/riscv_extras.lem
+++ b/risc-v/riscv_extras.lem
@@ -53,6 +53,8 @@ let memory_vals : memory_write_vals =
(IState (Interp.add_answer_to_stack interp bit) context)))));
]
-let barrier_functions = [
- ("MEM_sync", Barrier_MIPS_SYNC);
+let barrier_functions =
+ [ ("MEM_fence_rw_rw", Barrier_RISCV_rw_rw);
+ ("MEM_fence_r_rw", Barrier_RISCV_r_rw);
+ ("MEM_fence_rw_w", Barrier_RISCV_rw_w);
]
diff --git a/risc-v/riscv_extras_embed.lem b/risc-v/riscv_extras_embed.lem
index cbc8bd0d..1146d1cd 100644
--- a/risc-v/riscv_extras_embed.lem
+++ b/risc-v/riscv_extras_embed.lem
@@ -22,9 +22,13 @@ val MEMval_conditional : (vector bitU * integer * vector bitU) -> M bitU
let MEMval (_,_,v) = write_mem_val v >>= fun _ -> return ()
let MEMval_conditional (_,_,v) = write_mem_val v >>= fun b -> return (if b then B1 else B0)
-val MEM_sync : unit -> M unit
+val MEM_fence_rw_rw : unit -> M unit
+val MEM_fence_r_rw : unit -> M unit
+val MEM_fence_rw_w : unit -> M unit
-let MEM_sync () = barrier Barrier_Isync
+let MEM_fence_rw_rw () = barrier Barrier_RISCV_rw_rw
+let MEM_fence_r_rw () = barrier Barrier_RISCV_r_rw
+let MEM_fence_rw_w () = barrier Barrier_RISCV_rw_w
let duplicate (bit,len) =
let bits = repeat [bit] len in
diff --git a/risc-v/riscv_extras_embed_sequential.lem b/risc-v/riscv_extras_embed_sequential.lem
index 7fb62161..f6709ff7 100644
--- a/risc-v/riscv_extras_embed_sequential.lem
+++ b/risc-v/riscv_extras_embed_sequential.lem
@@ -23,10 +23,13 @@ val MEMval_conditional : (vector bitU * integer * vector bitU) -> M bitU
let MEMval (_,_,v) = write_mem_val v >>= fun _ -> return ()
let MEMval_conditional (_,_,v) = write_mem_val v >>= fun b -> return (if b then B1 else B0)
-val MEM_sync : unit -> M unit
-
-let MEM_sync () = barrier Barrier_MIPS_SYNC
+val MEM_fence_rw_rw : unit -> M unit
+val MEM_fence_r_rw : unit -> M unit
+val MEM_fence_rw_w : unit -> M unit
+let MEM_fence_rw_rw () = barrier Barrier_RISCV_rw_rw
+let MEM_fence_r_rw () = barrier Barrier_RISCV_r_rw
+let MEM_fence_rw_w () = barrier Barrier_RISCV_rw_w
let duplicate (bit,len) =
let bits = repeat [bit] len in
diff --git a/src/gen_lib/sail_operators.lem b/src/gen_lib/sail_operators.lem
index b94257f0..30c7325e 100644
--- a/src/gen_lib/sail_operators.lem
+++ b/src/gen_lib/sail_operators.lem
@@ -437,11 +437,6 @@ let arith_op_vec_range_no0 op sign size (Vector _ _ is_inc as l) r =
let mod_VIV = arith_op_vec_range_no0 hardware_mod false 1
-val repeat : forall 'a. list 'a -> integer -> list 'a
-let rec repeat xs n =
- if n = 0 then []
- else xs ++ repeat xs (n-1)
-
(* Assumes decreasing bit vectors *)
let duplicate (bit, length) =
Vector (repeat [bit] length) (length - 1) false
diff --git a/src/gen_lib/sail_operators_mwords.lem b/src/gen_lib/sail_operators_mwords.lem
index a1fc1fd7..8fb158de 100644
--- a/src/gen_lib/sail_operators_mwords.lem
+++ b/src/gen_lib/sail_operators_mwords.lem
@@ -118,9 +118,9 @@ let bitwise_not bs = lNot bs
let bitwise_binop op (bsl, bsr) = (op bsl bsr)
-let bitwise_and = bitwise_binop lAnd
-let bitwise_or = bitwise_binop lOr
-let bitwise_xor = bitwise_binop lXor
+let bitwise_and x = bitwise_binop lAnd x
+let bitwise_or x = bitwise_binop lOr x
+let bitwise_xor x = bitwise_binop lXor x
(*let unsigned bs : integer = unsignedIntegerFromWord bs*)
let unsigned_big = unsigned
@@ -311,9 +311,9 @@ let arith_op_range_vec_range op sign l r = op l (to_num sign r)
* add_range_vec_range_signed
* minus_range_vec_range
*)
-let add_IVI = arith_op_range_vec_range integerAdd false
-let addS_IVI = arith_op_range_vec_range integerAdd true
-let minus_IVI = arith_op_range_vec_range integerMinus false
+let add_IVI x = arith_op_range_vec_range integerAdd false x
+let addS_IVI x = arith_op_range_vec_range integerAdd true x
+let minus_IVI x = arith_op_range_vec_range integerMinus false x
let arith_op_vec_range_range op sign l r = op (to_num sign l) r
@@ -321,9 +321,9 @@ let arith_op_vec_range_range op sign l r = op (to_num sign l) r
* add_vec_range_range_signed
* minus_vec_range_range
*)
-let add_VII = arith_op_vec_range_range integerAdd false
-let addS_VII = arith_op_vec_range_range integerAdd true
-let minus_VII = arith_op_vec_range_range integerMinus false
+let add_VII x = arith_op_vec_range_range integerAdd false x
+let addS_VII x = arith_op_vec_range_range integerAdd true x
+let minus_VII x = arith_op_vec_range_range integerMinus false x
@@ -334,8 +334,8 @@ let arith_op_vec_vec_range op sign l r =
(* add_vec_vec_range
* add_vec_vec_range_signed
*)
-let add_VVI = arith_op_vec_vec_range integerAdd false
-let addS_VVI = arith_op_vec_vec_range integerAdd true
+let add_VVI x = arith_op_vec_vec_range integerAdd false x
+let addS_VVI x = arith_op_vec_vec_range integerAdd true x
let arith_op_vec_bit op sign (size : integer) l r =
let l' = to_num sign l in
@@ -346,9 +346,9 @@ let arith_op_vec_bit op sign (size : integer) l r =
* add_vec_bit_signed
* minus_vec_bit_signed
*)
-let add_VBV = arith_op_vec_bit integerAdd false 1
-let addS_VBV = arith_op_vec_bit integerAdd true 1
-let minus_VBV = arith_op_vec_bit integerMinus true 1
+let add_VBV x = arith_op_vec_bit integerAdd false 1 x
+let addS_VBV x = arith_op_vec_bit integerAdd true 1 x
+let minus_VBV x = arith_op_vec_bit integerMinus true 1 x
(* TODO: these can't be done directly in Lem because of the one_more size calculation
val arith_op_overflow_vec : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> bitvector 'a -> bitvector 'a -> bitvector 'b * bitU * bool
@@ -426,9 +426,9 @@ let shift_op_vec op (bs, (n : integer)) =
rotateLeft n bs
end
-let bitwise_leftshift = shift_op_vec LL_shift (*"<<"*)
-let bitwise_rightshift = shift_op_vec RR_shift (*">>"*)
-let bitwise_rotate = shift_op_vec LLL_shift (*"<<<"*)
+let bitwise_leftshift x = shift_op_vec LL_shift x (*"<<"*)
+let bitwise_rightshift x = shift_op_vec RR_shift x (*">>"*)
+let bitwise_rotate x = shift_op_vec LLL_shift x (*"<<<"*)
let shiftl = bitwise_leftshift
@@ -491,6 +491,11 @@ let mod_VIV = arith_op_vec_range_no0 hardware_mod false 1
let duplicate (bit, length) =
vec_to_bvec (Vector (repeat [bit] length) (length - 1) false)
+(* TODO: replace with better native versions *)
+let replicate_bits (v, count) =
+ let v = bvec_to_vec true 0 v in
+ vec_to_bvec (Vector (repeat (get_elems v) count) ((length v * count) - 1) false)
+
let compare_op op (l,r) = (op l r)
let lt = compare_op (<)
@@ -502,37 +507,37 @@ let compare_op_vec op sign (l,r) =
let (l',r') = (to_num sign l, to_num sign r) in
compare_op op (l',r')
-let lt_vec = compare_op_vec (<) true
-let gt_vec = compare_op_vec (>) true
-let lteq_vec = compare_op_vec (<=) true
-let gteq_vec = compare_op_vec (>=) true
+let lt_vec x = compare_op_vec (<) true x
+let gt_vec x = compare_op_vec (>) true x
+let lteq_vec x = compare_op_vec (<=) true x
+let gteq_vec x = compare_op_vec (>=) true x
-let lt_vec_signed = compare_op_vec (<) true
-let gt_vec_signed = compare_op_vec (>) true
-let lteq_vec_signed = compare_op_vec (<=) true
-let gteq_vec_signed = compare_op_vec (>=) true
-let lt_vec_unsigned = compare_op_vec (<) false
-let gt_vec_unsigned = compare_op_vec (>) false
-let lteq_vec_unsigned = compare_op_vec (<=) false
-let gteq_vec_unsigned = compare_op_vec (>=) false
+let lt_vec_signed x = compare_op_vec (<) true x
+let gt_vec_signed x = compare_op_vec (>) true x
+let lteq_vec_signed x = compare_op_vec (<=) true x
+let gteq_vec_signed x = compare_op_vec (>=) true x
+let lt_vec_unsigned x = compare_op_vec (<) false x
+let gt_vec_unsigned x = compare_op_vec (>) false x
+let lteq_vec_unsigned x = compare_op_vec (<=) false x
+let gteq_vec_unsigned x = compare_op_vec (>=) false x
let lt_svec = lt_vec_signed
let compare_op_vec_range op sign (l,r) =
compare_op op ((to_num sign l),r)
-let lt_vec_range = compare_op_vec_range (<) true
-let gt_vec_range = compare_op_vec_range (>) true
-let lteq_vec_range = compare_op_vec_range (<=) true
-let gteq_vec_range = compare_op_vec_range (>=) true
+let lt_vec_range x = compare_op_vec_range (<) true x
+let gt_vec_range x = compare_op_vec_range (>) true x
+let lteq_vec_range x = compare_op_vec_range (<=) true x
+let gteq_vec_range x = compare_op_vec_range (>=) true x
let compare_op_range_vec op sign (l,r) =
compare_op op (l, (to_num sign r))
-let lt_range_vec = compare_op_range_vec (<) true
-let gt_range_vec = compare_op_range_vec (>) true
-let lteq_range_vec = compare_op_range_vec (<=) true
-let gteq_range_vec = compare_op_range_vec (>=) true
+let lt_range_vec x = compare_op_range_vec (<) true x
+let gt_range_vec x = compare_op_range_vec (>) true x
+let lteq_range_vec x = compare_op_range_vec (<=) true x
+let gteq_range_vec x = compare_op_range_vec (>=) true x
val eq : forall 'a. Eq 'a => 'a * 'a -> bool
let eq (l,r) = (l = r)
diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem
index 48d728bf..906b35a8 100644
--- a/src/gen_lib/sail_values.lem
+++ b/src/gen_lib/sail_values.lem
@@ -41,7 +41,7 @@ let list_append (l, r) = l ++ r
val repeat : forall 'a. list 'a -> integer -> list 'a
let rec repeat xs n =
- if n = 0 then []
+ if n <= 0 then []
else xs ++ repeat xs (n-1)
let duplicate_to_list (bit, length) = repeat [bit] length
@@ -355,7 +355,6 @@ let vec_to_bvec (Vector elems start is_inc) =
(*** Vector operations *)
-
(* Bytes and addresses *)
val byte_chunks : forall 'a. nat -> list 'a -> list (list 'a)
diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem
index dc30a17f..1b03c81e 100644
--- a/src/gen_lib/state.lem
+++ b/src/gen_lib/state.lem
@@ -212,7 +212,7 @@ val barrier : forall 'regs. barrier_kind -> M 'regs unit
let barrier _ = return ()
val footprint : forall 'regs. M 'regs unit
-let footprint = return ()
+let footprint s = return () s
val foreachM_inc : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
diff --git a/src/lem_interp/pretty_interp.ml b/src/lem_interp/pretty_interp.ml
index a51598b3..9f1ea3e3 100644
--- a/src/lem_interp/pretty_interp.ml
+++ b/src/lem_interp/pretty_interp.ml
@@ -127,36 +127,6 @@ let bitvec_to_string l = "0b" ^ collapse_leading (String.concat "" (List.map (fu
;;
-let rec val_to_string_internal ((Interp.LMem (_,_,memory,_)) as mem) = function
- | Interp_ast.V_boxref(n, t) -> val_to_string_internal mem (Pmap.find n memory)
- | Interp_ast.V_lit (L_aux(l,_)) -> sprintf "%s" (lit_to_string l)
- | Interp_ast.V_tuple l ->
- let repr = String.concat ", " (List.map (val_to_string_internal mem) l) in
- sprintf "(%s)" repr
- | Interp_ast.V_list l ->
- let repr = String.concat "; " (List.map (val_to_string_internal mem) l) in
- sprintf "[||%s||]" repr
- | Interp_ast.V_vector (first_index, inc, l) ->
- let last_index = (if (Interp_ast.IInc = inc) then List.length l - 1 else 1 - List.length l) + first_index in
- let repr =
- try bitvec_to_string l
- with Failure _ ->
- sprintf "[%s]" (String.concat "; " (List.map (val_to_string_internal mem) l)) in
- sprintf "%s [%s..%s]" repr (string_of_int first_index) (string_of_int last_index)
- | (Interp_ast.V_vector_sparse(first_index,last_index,inc,l,default) as v) ->
- val_to_string_internal mem (Interp_lib.fill_in_sparse v)
- | Interp_ast.V_record(_, l) ->
- let pp (id, value) = sprintf "%s = %s" (id_to_string id) (val_to_string_internal mem value) in
- let repr = String.concat "; " (List.map pp l) in
- sprintf "{%s}" repr
- | Interp_ast.V_ctor (id,_,_, value) ->
- sprintf "%s %s" (id_to_string id) (val_to_string_internal mem value)
- | Interp_ast.V_register _ | Interp_ast.V_register_alias _ ->
- sprintf "reg-as-value"
- | Interp_ast.V_unknown -> "unknown"
- | Interp_ast.V_track(v,rs) -> (*"tainted by {" ^ (Interp_utilities.list_to_string Interp.string_of_reg_form "," rs) ^ "} --" ^ *) (val_to_string_internal mem v)
-;;
-
(****************************************************************************
* PPrint-based source-to-source pretty printer
****************************************************************************)
@@ -582,7 +552,7 @@ let doc_exp, doc_let =
(* XXX missing case *)
| E_comment _ | E_comment_struc _ -> string ""
| E_internal_value v ->
- string (val_to_string_internal mem v)
+ string (Interp.string_of_value v)
| _-> failwith "internal expression escaped"
and let_exp env mem add_red show_hole_contents (LB_aux(lb,_)) = match lb with
diff --git a/src/lem_interp/printing_functions.ml b/src/lem_interp/printing_functions.ml
index 79a86113..a19256a2 100644
--- a/src/lem_interp/printing_functions.ml
+++ b/src/lem_interp/printing_functions.ml
@@ -49,7 +49,6 @@ open Interp_interface ;;
open Nat_big_num ;;
-let val_to_string_internal = Pretty_interp.val_to_string_internal ;;
let lit_to_string = Pretty_interp.lit_to_string ;;
let id_to_string = Pretty_interp.id_to_string ;;
let loc_to_string = Pretty_interp.loc_to_string ;;
@@ -451,7 +450,7 @@ let local_variables_to_string (IState(stack,_)) =
String.concat ", " (option_map (fun (id,value)->
match id with
| "0" -> None (*Let's not print out the context hole again*)
- | _ -> Some (id ^ "=" ^ val_to_string_internal mem value)) (Pmap.bindings_list env))
+ | _ -> Some (id ^ "=" ^ Interp.string_of_value value)) (Pmap.bindings_list env))
let instr_parm_to_string (name, typ, value) =
name ^"="^
diff --git a/src/lem_interp/printing_functions.mli b/src/lem_interp/printing_functions.mli
index 85744d61..f1a0cd4a 100644
--- a/src/lem_interp/printing_functions.mli
+++ b/src/lem_interp/printing_functions.mli
@@ -10,7 +10,6 @@ val loc_to_string : l -> string
val get_loc : tannot exp -> string
(*interp_interface.value to string*)
val reg_value_to_string : register_value -> string
-val val_to_string_internal : Interp.lmem -> Interp_ast.value -> string
(*(*Force all representations to hex strings instead of a mixture of hex and binary strings*)
val val_to_hex_string : value0 -> string*)
diff --git a/src/lem_interp/run_interp.ml b/src/lem_interp/run_interp.ml
index 6f5ca07a..f61d9aaf 100644
--- a/src/lem_interp/run_interp.ml
+++ b/src/lem_interp/run_interp.ml
@@ -114,33 +114,6 @@ let rec reg_to_string = function
| SubReg (id,r,_) -> sprintf "%s.%s" (reg_to_string r) (id_to_string id)
;;
-let rec val_to_string_internal = function
- | V_boxref(n, t) -> sprintf "boxref %d" n
- | V_lit (L_aux(l,_)) -> sprintf "%s" (lit_to_string l)
- | V_tuple l ->
- let repr = String.concat ", " (List.map val_to_string_internal l) in
- sprintf "(%s)" repr
- | V_list l ->
- let repr = String.concat "; " (List.map val_to_string_internal l) in
- sprintf "[||%s||]" repr
- | V_vector (first_index, inc, l) ->
- let last_index = add_int_big_int (if inc then List.length l - 1 else 1 - List.length l) first_index in
- let repr =
- try bitvec_to_string (* (if inc then l else List.rev l)*) l
- with Failure _ ->
- sprintf "[%s]" (String.concat "; " (List.map val_to_string_internal l)) in
- sprintf "%s [%s..%s]" repr (string_of_big_int first_index) (string_of_big_int last_index)
- | V_record(_, l) ->
- let pp (id, value) = sprintf "%s = %s" (id_to_string id) (val_to_string_internal value) in
- let repr = String.concat "; " (List.map pp l) in
- sprintf "{%s}" repr
- | V_ctor (id,_, value) ->
- sprintf "%s %s" (id_to_string id) (val_to_string_internal value)
- | V_register r ->
- sprintf "reg-as-value %s" (reg_to_string r)
- | V_unknown -> "unknown"
-;;
-
let rec top_frame_exp_state = function
| Top -> raise (Invalid_argument "top_frame_exp")
| Hole_frame(_, e, _, env, mem, Top)
@@ -210,7 +183,7 @@ let id_compare i1 i2 =
module Reg = struct
include Map.Make(struct type t = id let compare = id_compare end)
let to_string id v =
- sprintf "%s -> %s" (id_to_string id) (val_to_string_internal v)
+ sprintf "%s -> %s" (id_to_string id) (string_of_value v)
let find id m =
(* eprintf "reg_find called with %s\n" (id_to_string id);*)
let v = find id m in
@@ -255,7 +228,7 @@ module Mem = struct
v
*)
let to_string idx v =
- sprintf "[%s] -> %s" (string_of_big_int idx) (val_to_string_internal v)
+ sprintf "[%s] -> %s" (string_of_big_int idx) (string_of_value v)
end ;;
@@ -412,7 +385,7 @@ let run
in
let rec loop mode env = function
| Value v ->
- debugf "%s: %s %s\n" (grey name) (blue "return") (val_to_string_internal v);
+ debugf "%s: %s %s\n" (grey name) (blue "return") (string_of_value v);
true, mode, env
| Action (a, s) ->
let (top_exp,(top_env,top_mem)) = top_frame_exp_state s in
@@ -429,25 +402,25 @@ let run
let left = "<-" and right = "->" in
let (mode',env',s) = begin match a with
| Read_reg (reg, sub) ->
- show "read_reg" (reg_to_string reg ^ sub_to_string sub) right (val_to_string_internal return);
+ show "read_reg" (reg_to_string reg ^ sub_to_string sub) right (string_of_value return);
step (),env',s
| Write_reg (reg, sub, value) ->
assert (return = unit_lit);
- show "write_reg" (reg_to_string reg ^ sub_to_string sub) left (val_to_string_internal value);
+ show "write_reg" (reg_to_string reg ^ sub_to_string sub) left (string_of_value value);
step (),env',s
| Read_mem (id, args, sub) ->
- show "read_mem" (id_to_string id ^ val_to_string_internal args ^ sub_to_string sub) right (val_to_string_internal return);
+ show "read_mem" (id_to_string id ^ string_of_value args ^ sub_to_string sub) right (string_of_value return);
step (),env',s
| Write_mem (id, args, sub, value) ->
assert (return = unit_lit);
- show "write_mem" (id_to_string id ^ val_to_string_internal args ^ sub_to_string sub) left (val_to_string_internal value);
+ show "write_mem" (id_to_string id ^ string_of_value args ^ sub_to_string sub) left (string_of_value value);
step (),env',s
(* distinguish single argument for pretty-printing *)
| Call_extern (f, (V_tuple _ as args)) ->
- show "call_lib" (f ^ val_to_string_internal args) right (val_to_string_internal return);
+ show "call_lib" (f ^ string_of_value args) right (string_of_value return);
step (),env',s
| Call_extern (f, arg) ->
- show "call_lib" (sprintf "%s(%s)" f (val_to_string_internal arg)) right (val_to_string_internal return);
+ show "call_lib" (sprintf "%s(%s)" f (string_of_value arg)) right (string_of_value return);
step (),env',s
| Interp.Step _ ->
assert (return = unit_lit);
diff --git a/src/lem_interp/sail_impl_base.lem b/src/lem_interp/sail_impl_base.lem
index 167e7de9..ba939108 100644
--- a/src/lem_interp/sail_impl_base.lem
+++ b/src/lem_interp/sail_impl_base.lem
@@ -465,6 +465,11 @@ type barrier_kind =
| Barrier_TM_COMMIT
(* MIPS barriers *)
| Barrier_MIPS_SYNC
+ (* RISC-V barriers *)
+ | Barrier_RISCV_rw_rw
+ | Barrier_RISCV_r_rw
+ | Barrier_RISCV_rw_w
+
instance (Show barrier_kind)
let show = function
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 62b18042..42546ae0 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3,7 +3,6 @@ open Ast
open Ast_util
open Type_check
-let disable_const_propagation = false
let size_set_limit = 8
let vector_split_limit = 4
@@ -16,23 +15,22 @@ let env_typ_expected l : tannot -> Env.t * typ = function
| None -> raise (Reporting_basic.err_unreachable l "Missing type environment")
| Some (env,ty,_) -> env,ty
-module KSubst = Map.Make(Kid)
-module ISubst = Map.Make(Id)
-let ksubst_from_list = List.fold_left (fun s (v,i) -> KSubst.add v i s) KSubst.empty
-let isubst_from_list = List.fold_left (fun s (v,i) -> ISubst.add v i s) ISubst.empty
+let kbindings_from_list = List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
+let bindings_from_list = List.fold_left (fun s (v,i) -> Bindings.add v i s) Bindings.empty
(* union was introduced in 4.03.0, a bit too recently *)
-let isubst_union s1 s2 =
- ISubst.merge (fun _ x y -> match x,y with
+let bindings_union s1 s2 =
+ Bindings.merge (fun _ x y -> match x,y with
| _, (Some x) -> Some x
| (Some x), _ -> Some x
| _, _ -> None) s1 s2
-let subst_src_typ substs t =
- let rec s_snexp (Nexp_aux (ne,l) as nexp) =
+let subst_nexp substs nexp =
+ let rec s_snexp substs (Nexp_aux (ne,l) as nexp) =
let re ne = Nexp_aux (ne,l) in
+ let s_snexp = s_snexp substs in
match ne with
| Nexp_var (Kid_aux (_,l) as kid) ->
- (try KSubst.find kid substs
+ (try KBindings.find kid substs
with Not_found -> nexp)
| Nexp_id _
| Nexp_constant _ -> nexp
@@ -41,23 +39,56 @@ let subst_src_typ substs t =
| Nexp_minus (n1,n2) -> re (Nexp_minus (s_snexp n1, s_snexp n2))
| Nexp_exp ne -> re (Nexp_exp (s_snexp ne))
| Nexp_neg ne -> re (Nexp_neg (s_snexp ne))
- in
- let rec s_styp ((Typ_aux (t,l)) as ty) =
+ in s_snexp substs nexp
+
+let rec subst_nc substs (NC_aux (nc,l) as n_constraint) =
+ let snexp nexp = subst_nexp substs nexp in
+ let snc nc = subst_nc substs nc in
+ let re nc = NC_aux (nc,l) in
+ match nc with
+ | NC_equal (n1,n2) -> re (NC_equal (snexp n1, snexp n2))
+ | NC_bounded_ge (n1,n2) -> re (NC_bounded_ge (snexp n1, snexp n2))
+ | NC_bounded_le (n1,n2) -> re (NC_bounded_le (snexp n1, snexp n2))
+ | NC_not_equal (n1,n2) -> re (NC_not_equal (snexp n1, snexp n2))
+ | NC_set (kid,is) ->
+ begin
+ match KBindings.find kid substs with
+ | Nexp_aux (Nexp_constant i,_) ->
+ if List.mem i is then re NC_true else re NC_false
+ | nexp ->
+ raise (Reporting_basic.err_general l
+ ("Unable to substitute " ^ string_of_nexp nexp ^
+ " into set constraint " ^ string_of_n_constraint n_constraint))
+ | exception Not_found -> n_constraint
+ end
+ | NC_or (nc1,nc2) -> re (NC_or (snc nc1, snc nc2))
+ | NC_and (nc1,nc2) -> re (NC_and (snc nc1, snc nc2))
+ | NC_true
+ | NC_false
+ -> n_constraint
+
+
+
+let subst_src_typ substs t =
+ let rec s_styp substs ((Typ_aux (t,l)) as ty) =
let re t = Typ_aux (t,l) in
match t with
| Typ_wild
| Typ_id _
| Typ_var _
-> ty
- | Typ_fn (t1,t2,e) -> re (Typ_fn (s_styp t1, s_styp t2,e))
- | Typ_tup ts -> re (Typ_tup (List.map s_styp ts))
- | Typ_app (id,tas) -> re (Typ_app (id,List.map s_starg tas))
- and s_starg (Typ_arg_aux (ta,l) as targ) =
+ | Typ_fn (t1,t2,e) -> re (Typ_fn (s_styp substs t1, s_styp substs t2,e))
+ | Typ_tup ts -> re (Typ_tup (List.map (s_styp substs) ts))
+ | Typ_app (id,tas) -> re (Typ_app (id,List.map (s_starg substs) tas))
+ | Typ_exist (kids,nc,t) ->
+ let substs = List.fold_left (fun sub v -> KBindings.remove v sub) substs kids in
+ re (Typ_exist (kids,nc,s_styp substs t))
+ and s_starg substs (Typ_arg_aux (ta,l) as targ) =
match ta with
- | Typ_arg_nexp ne -> Typ_arg_aux (Typ_arg_nexp (s_snexp ne),l)
- | Typ_arg_typ t -> Typ_arg_aux (Typ_arg_typ (s_styp t),l)
+ | Typ_arg_nexp ne -> Typ_arg_aux (Typ_arg_nexp (subst_nexp substs ne),l)
+ | Typ_arg_typ t -> Typ_arg_aux (Typ_arg_typ (s_styp substs t),l)
| Typ_arg_order _ -> targ
- in s_styp t
+ in s_styp substs t
let make_vector_lit sz i =
let f j = if (i lsr (sz-j-1)) mod 2 = 0 then '0' else '1' in
@@ -128,6 +159,111 @@ let rec cross = function
let t' = cross t in
List.concat (List.map (fun y -> List.map (fun l' -> (x,y)::l') t') l)
+let rec cross' = function
+ | [] -> [[]]
+ | (h::t) ->
+ let t' = cross' t in
+ List.concat (List.map (fun x -> List.map (List.cons x) t') h)
+
+let rec cross'' = function
+ | [] -> [[]]
+ | (k,None)::t -> List.map (List.cons (k,None)) (cross'' t)
+ | (k,Some h)::t ->
+ let t' = cross'' t in
+ List.concat (List.map (fun x -> List.map (List.cons (k,Some x)) t') h)
+
+let kidset_bigunion = function
+ | [] -> KidSet.empty
+ | h::t -> List.fold_left KidSet.union h t
+
+(* TODO: deal with non-set constraints, intersections, etc somehow *)
+let extract_set_nc var (NC_aux (_,l) as nc) =
+ let rec aux (NC_aux (nc,l)) =
+ let re nc = NC_aux (nc,l) in
+ match nc with
+ | NC_set (id,is) when Kid.compare id var = 0 -> Some (is,re NC_true)
+ | NC_and (nc1,nc2) ->
+ (match aux nc1, aux nc2 with
+ | None, None -> None
+ | None, Some (is,nc2') -> Some (is, re (NC_and (nc1,nc2')))
+ | Some (is,nc1'), None -> Some (is, re (NC_and (nc1',nc2)))
+ | Some _, Some _ ->
+ raise (Reporting_basic.err_general l ("Multiple set constraints for " ^ string_of_kid var)))
+ | _ -> None
+ in match aux nc with
+ | Some is -> is
+ | None ->
+ raise (Reporting_basic.err_general l ("No set constraint for " ^ string_of_kid var))
+
+let rec peel = function
+ | [], l -> ([], l)
+ | h1::t1, h2::t2 -> let (l1,l2) = peel (t1, t2) in ((h1,h2)::l1,l2)
+ | _,_ -> assert false
+
+let rec split_insts = function
+ | [] -> [],[]
+ | (k,None)::t -> let l1,l2 = split_insts t in l1,k::l2
+ | (k,Some v)::t -> let l1,l2 = split_insts t in (k,v)::l1,l2
+
+let apply_kid_insts kid_insts t =
+ let kid_insts, kids' = split_insts kid_insts in
+ let kid_insts = List.map (fun (v,i) -> (v,Nexp_aux (Nexp_constant i,Generated Unknown))) kid_insts in
+ let subst = kbindings_from_list kid_insts in
+ kids', subst_src_typ subst t
+
+let rec inst_src_type insts (Typ_aux (ty,l) as typ) =
+ match ty with
+ | Typ_wild
+ | Typ_id _
+ | Typ_var _
+ -> insts,typ
+ | Typ_fn _ ->
+ raise (Reporting_basic.err_general l "Function type in constructor")
+ | Typ_tup ts ->
+ let insts,ts =
+ List.fold_right
+ (fun typ (insts,ts) -> let insts,typ = inst_src_type insts typ in insts,typ::ts)
+ ts (insts,[])
+ in insts, Typ_aux (Typ_tup ts,l)
+ | Typ_app (id,args) ->
+ let insts,ts =
+ List.fold_right
+ (fun arg (insts,args) -> let insts,arg = inst_src_typ_arg insts arg in insts,arg::args)
+ args (insts,[])
+ in insts, Typ_aux (Typ_app (id,ts),l)
+ | Typ_exist (kids, nc, t) ->
+ let kid_insts, insts' = peel (kids,insts) in
+ let kids', t' = apply_kid_insts kid_insts t in
+ (* TODO: subst in nc *)
+ match kids' with
+ | [] -> insts', t'
+ | _ -> insts', Typ_aux (Typ_exist (kids', nc, t'), l)
+and inst_src_typ_arg insts (Typ_arg_aux (ta,l) as tyarg) =
+ match ta with
+ | Typ_arg_nexp _
+ | Typ_arg_order _
+ -> insts, tyarg
+ | Typ_arg_typ typ ->
+ let insts', typ' = inst_src_type insts typ in
+ insts', Typ_arg_aux (Typ_arg_typ typ',l)
+
+let rec contains_exist (Typ_aux (ty,_)) =
+ match ty with
+ | Typ_wild
+ | Typ_id _
+ | Typ_var _
+ -> false
+ | Typ_fn (t1,t2,_) -> contains_exist t1 || contains_exist t2
+ | Typ_tup ts -> List.exists contains_exist ts
+ | Typ_app (_,args) -> List.exists contains_exist_arg args
+ | Typ_exist _ -> true
+and contains_exist_arg (Typ_arg_aux (arg,_)) =
+ match arg with
+ | Typ_arg_nexp _
+ | Typ_arg_order _
+ -> false
+ | Typ_arg_typ typ -> contains_exist typ
+
(* Given a type for a constructor, work out which refinements we ought to produce *)
(* TODO collision avoidance *)
let split_src_type id ty (TypQ_aux (q,ql)) =
@@ -146,65 +282,98 @@ let split_src_type id ty (TypQ_aux (q,ql)) =
| Nexp_neg n
-> size_nvars_nexp n
in
- let rec size_nvars_ty (Typ_aux (ty,l)) =
+ (* This was originally written for the general case, but I cut it down to the
+ more manageable prenex-form below *)
+ let rec size_nvars_ty (Typ_aux (ty,l) as typ) =
match ty with
| Typ_wild
| Typ_id _
| Typ_var _
- -> []
+ -> (KidSet.empty,[[],typ])
| Typ_fn _ ->
raise (Reporting_basic.err_general l ("Function type in constructor " ^ i))
- | Typ_tup ts -> List.concat (List.map size_nvars_ty ts)
+ | Typ_tup ts ->
+ let (vars,tys) = List.split (List.map size_nvars_ty ts) in
+ let insttys = List.map (fun x -> let (insts,tys) = List.split x in
+ List.concat insts, Typ_aux (Typ_tup tys,l)) (cross' tys) in
+ (kidset_bigunion vars, insttys)
| Typ_app (Id_aux (Id "vector",_),
[_;Typ_arg_aux (Typ_arg_nexp sz,_);
_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
- size_nvars_nexp sz
+ (KidSet.of_list (size_nvars_nexp sz), [[],typ])
| Typ_app (_, tas) ->
- [] (* We only support sizes for bitvectors mentioned explicitly, not any buried
- inside another type *)
+ (KidSet.empty,[[],typ]) (* We only support sizes for bitvectors mentioned explicitly, not any buried
+ inside another type *)
+ | Typ_exist (kids, nc, t) ->
+ let (vars,tys) = size_nvars_ty t in
+ let find_insts k (insts,nc) =
+ let inst,nc' =
+ if KidSet.mem k vars then
+ let is,nc' = extract_set_nc k nc in
+ Some is,nc'
+ else None,nc
+ in (k,inst)::insts,nc'
+ in
+ let (insts,nc') = List.fold_right find_insts kids ([],nc) in
+ let insts = cross'' insts in
+ let ty_and_inst (inst0,ty) inst =
+ let kids, ty = apply_kid_insts inst ty in
+ let ty =
+ (* Typ_exist is not allowed an empty list of kids *)
+ match kids with
+ | [] -> ty
+ | _ -> Typ_aux (Typ_exist (kids, nc', ty),l)
+ in inst@inst0, ty
+ in
+ let tys = List.concat (List.map (fun instty -> List.map (ty_and_inst instty) insts) tys) in
+ let free = List.fold_left (fun vars k -> KidSet.remove k vars) vars kids in
+ (free,tys)
+ in
+ (* Only single-variable prenex-form for now *)
+ let size_nvars_ty (Typ_aux (ty,l) as typ) =
+ match ty with
+ | Typ_exist (kids,_,t) ->
+ begin
+ match snd (size_nvars_ty typ) with
+ | [] -> []
+ | tys ->
+ (* One level of tuple type is stripped off by the type checker, so
+ add another here *)
+ let tys =
+ List.map (fun (x,ty) ->
+ x, match ty with
+ | Typ_aux (Typ_tup _,_) -> Typ_aux (Typ_tup [ty],Unknown)
+ | _ -> ty) tys in
+ if contains_exist t then
+ raise (Reporting_basic.err_general l
+ "Only prenex types in unions are supported by monomorphisation")
+ else if List.length kids > 1 then
+ raise (Reporting_basic.err_general l
+ "Only single-variable existential types in unions are currently supported by monomorphisation")
+ else tys
+ end
+ | _ -> []
in
- let nvars = List.sort_uniq Kid.compare (size_nvars_ty ty) in
- match nvars with
+ (* TODO: reject universally quantification or monomorphise it *)
+ let variants = size_nvars_ty ty in
+ match variants with
| [] -> None
| sample::__ ->
- (* Only check for constraints if we found a size to constrain *)
- let qs =
- match q with
- | TypQ_no_forall ->
- raise (Reporting_basic.err_general ql
- ("No set constraint for variable " ^ string_of_kid sample ^ " in constructor " ^ i))
- | TypQ_tq qs -> qs
- in
- let find_set (Kid_aux (Var nvar,_) as kid) =
- match list_extract (function
- | QI_aux (QI_const (NC_aux (NC_set (Kid_aux (Var nvar',_),vals),_)),_)
- -> if nvar = nvar' then Some vals else None
- | _ -> None) qs with
- | None ->
- raise (Reporting_basic.err_general ql
- ("No set constraint for variable " ^ nvar ^ " in constructor " ^ i))
- | Some vals -> (kid,vals)
- in
- let nvar_sets = List.map find_set nvars in
- let total_variants = List.fold_left ( * ) 1 (List.map (fun (_,l) -> List.length l) nvar_sets) in
- let () = if total_variants > size_set_limit then
+ let () = if List.length variants > size_set_limit then
raise (Reporting_basic.err_general ql
- (string_of_int total_variants ^ "variants for constructor " ^ i ^
+ (string_of_int (List.length variants) ^ "variants for constructor " ^ i ^
"bigger than limit " ^ string_of_int size_set_limit)) else ()
in
- let variants = cross nvar_sets in
let wrap = match id with
| Id_aux (Id i,l) -> (fun f -> Id_aux (Id (f i),Generated l))
| Id_aux (DeIid i,l) -> (fun f -> Id_aux (DeIid (f i),l))
in
- let name l i = String.concat "_" (i::(List.map (fun (v,i) -> string_of_kid v ^ string_of_int i) l)) in
- Some (List.map (fun l -> (l, wrap (name l))) variants)
-
-(* TODO: maybe fold this into subst_src_typ? *)
-let inst_src_type insts ty =
- let insts = List.map (fun (v,i) -> (v,Nexp_aux (Nexp_constant i,Generated Unknown))) insts in
- let subst = ksubst_from_list insts in
- subst_src_typ subst ty
+ let name_seg = function
+ | (_,None) -> ""
+ | (k,Some i) -> string_of_kid k ^ string_of_int i
+ in
+ let name l i = String.concat "_" (i::(List.map name_seg l)) in
+ Some (List.map (fun (l,ty) -> (l, wrap (name l),ty)) variants)
let reduce_nexp subst ne =
let rec eval (Nexp_aux (ne,_) as nexp) =
@@ -220,62 +389,71 @@ let reduce_nexp subst ne =
string_of_nexp nexp ^ " into concrete value"))
in eval ne
+
+let typ_of_args args =
+ match args with
+ | [E_aux (_,(l,annot))] ->
+ snd (env_typ_expected l annot)
+ | _ ->
+ let tys = List.map (fun (E_aux (_,(l,annot))) -> snd (env_typ_expected l annot)) args in
+ Typ_aux (Typ_tup tys,Unknown)
+
(* Check to see if we need to monomorphise a use of a constructor. Currently
assumes that bitvector sizes are always given as a variable; don't yet handle
more general cases (e.g., 8 * var) *)
-(* TODO: use type checker's instantiation instead *)
-let refine_constructor refinements id substs (E_aux (_,(l,_)) as arg) t =
- let rec derive_vars (Typ_aux (t,_)) (E_aux (e,(l,tannot))) =
- match t with
- | Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var v,_)),_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
- (match tannot with
- | Some (_,Typ_aux (Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp ne,_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]),_),_) ->
- [(v,reduce_nexp substs ne)]
- | _ -> [])
- | Typ_wild
- | Typ_var _
- | Typ_id _
- | Typ_fn _
- | Typ_app _
- -> []
- | Typ_tup ts ->
- match e with
- | E_tuple es -> List.concat (List.map2 derive_vars ts es)
- | _ -> [] (* TODO? *)
- in
- try
- let (_,irefinements) = List.find (fun (id',_) -> Id.compare id id' = 0) refinements in
- let vars = List.sort_uniq (fun x y -> Kid.compare (fst x) (fst y)) (derive_vars t arg) in
- try
- Some (List.assoc vars irefinements)
- with Not_found ->
- (Reporting_basic.print_err false true l "Monomorphisation"
- ("Failed to find a monomorphic constructor for " ^ string_of_id id ^ " instance " ^
- match vars with [] -> "<empty>"
- | _ -> String.concat "," (List.map (fun (x,y) -> string_of_kid x ^ "=" ^ string_of_int y) vars)); None)
- with Not_found -> None
+let refine_constructor refinements l env id args =
+ match List.find (fun (id',_) -> Id.compare id id' = 0) refinements with
+ | (_,irefinements) -> begin
+ let (_,constr_ty) = Env.get_val_spec id env in
+ match constr_ty with
+ | Typ_aux (Typ_fn (constr_ty,_,_),_) -> begin
+ let arg_ty = typ_of_args args in
+ match Type_check.destruct_exist env constr_ty with
+ | None -> None
+ | Some (kids,nc,constr_ty) ->
+ let (bindings,_,_) = Type_check.unify l env constr_ty arg_ty in
+ let find_kid kid = try Some (KBindings.find kid bindings) with Not_found -> None in
+ let bindings = List.map find_kid kids in
+ let matches_refinement (mapping,_,_) =
+ List.for_all2
+ (fun v (_,w) ->
+ match v,w with
+ | _,None -> true
+ | Some (U_nexp (Nexp_aux (Nexp_constant n, _))),Some m -> n = m
+ | _,_ -> false) bindings mapping
+ in
+ match List.find matches_refinement irefinements with
+ | (_,new_id,_) -> Some (E_app (new_id,args))
+ | exception Not_found ->
+ (Reporting_basic.print_err false true l "Monomorphisation"
+ ("Unable to refine constructor " ^ string_of_id id);
+ None)
+ end
+ | _ -> None
+ end
+ | exception Not_found -> None
(* Substitute found nexps for variables in an expression, and rename constructors to reflect
specialisation *)
-let nexp_subst_fns substs refinements =
-(*
- let s_t t = typ_subst substs true t in
+(* TODO: kid shadowing *)
+let nexp_subst_fns substs =
+
+ let s_t t = subst_src_typ substs t in
(* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in
- hopefully don't need this anyway *)
- let s_typschm tsh = tsh in
+ hopefully don't need this anyway *)(*
+ let s_typschm tsh = tsh in*)
let s_tannot = function
- | Base ((params,t),tag,ranges,effl,effc,bounds) ->
- (* TODO: do other fields need mapped? *)
- Base ((params,s_t t),tag,ranges,effl,effc,bounds)
- | tannot -> tannot
+ | None -> None
+ | Some (env,t,eff) -> Some (env,s_t t,eff) (* TODO: what about env? *)
in
- let rec s_pat (P_aux (p,(l,annot))) =
- let re p = P_aux (p,(l,s_tannot annot)) in
+(* let rec s_pat (P_aux (p,(l,annot))) =
+ let re p = P_aux (p,(l,(*s_tannot*) annot)) in
match p with
| P_lit _ | P_wild | P_id _ -> re p
+ | P_var kid -> re p
| P_as (p',id) -> re (P_as (s_pat p', id))
| P_typ (ty,p') -> re (P_typ (ty,s_pat p'))
| P_app (id,ps) -> re (P_app (id, List.map s_pat ps))
@@ -285,11 +463,12 @@ let nexp_subst_fns substs refinements =
| P_vector_concat ps -> re (P_vector_concat (List.map s_pat ps))
| P_tup ps -> re (P_tup (List.map s_pat ps))
| P_list ps -> re (P_list (List.map s_pat ps))
+ | P_cons (p1,p2) -> re (P_cons (s_pat p1, s_pat p2))
and s_fpat (FP_aux (FP_Fpat (id, p), (l,annot))) =
FP_aux (FP_Fpat (id, s_pat p), (l,s_tannot annot))
in*)
let rec s_exp (E_aux (e,(l,annot))) =
- let re e = E_aux (e,(l,(*s_tannot*) annot)) in
+ let re e = E_aux (e,(l,s_tannot annot)) in
match e with
| E_block es -> re (E_block (List.map s_exp es))
| E_nondet es -> re (E_nondet (List.map s_exp es))
@@ -297,35 +476,18 @@ let nexp_subst_fns substs refinements =
| E_lit _
| E_comment _ -> re e
| E_sizeof ne -> re (E_sizeof ne) (* TODO: does this need done? does it appear in type checked code? *)
- | E_constraint _ -> re e (* TODO: actual substitution if necessary *)
- | E_internal_exp (l,annot) -> re (E_internal_exp (l, (*s_tannot*) annot))
- | E_sizeof_internal (l,annot) -> re (E_sizeof_internal (l, (*s_tannot*) annot))
+ | E_constraint nc -> re (E_constraint (subst_nc substs nc))
+ | E_internal_exp (l,annot) -> re (E_internal_exp (l, s_tannot annot))
+ | E_sizeof_internal (l,annot) -> re (E_sizeof_internal (l, s_tannot annot))
| E_internal_exp_user ((l1,annot1),(l2,annot2)) ->
- re (E_internal_exp_user ((l1, (*s_tannot*) annot1),(l2, (*s_tannot*) annot2)))
+ re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2)))
| E_cast (t,e') -> re (E_cast (t, s_exp e'))
- | E_app (id,es) ->
- let es' = List.map s_exp es in
- let arg =
- match es' with
- | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(l,None))
- | [e] -> e
- | _ -> E_aux (E_tuple es',(l,None))
- in
- let id' =
- let env,_ = env_typ_expected l annot in
- if Env.is_union_constructor id env then
- let (qs,ty) = Env.get_val_spec id env in
- match ty with (Typ_aux (Typ_fn(inty,outty,_),_)) ->
- (match refine_constructor refinements id substs arg inty with
- | None -> id
- | Some id' -> id')
- | _ -> id
- else id
- in re (E_app (id',es'))
+ | E_app (id,es) -> re (E_app (id, List.map s_exp es))
| E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2))
| E_tuple es -> re (E_tuple (List.map s_exp es))
| E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3))
| E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,s_exp e1,s_exp e2,s_exp e3,ord,s_exp e4))
+ | E_loop (loop,e1,e2) -> re (E_loop (loop,s_exp e1,s_exp e2))
| E_vector es -> re (E_vector (List.map s_exp es))
| E_vector_access (e1,e2) -> re (E_vector_access (s_exp e1,s_exp e2))
| E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (s_exp e1,s_exp e2,s_exp e3))
@@ -343,41 +505,42 @@ let nexp_subst_fns substs refinements =
| E_exit e -> re (E_exit (s_exp e))
| E_return e -> re (E_return (s_exp e))
| E_assert (e1,e2) -> re (E_assert (s_exp e1,s_exp e2))
- | E_internal_cast ((l,ann),e) -> re (E_internal_cast ((l,(*s_tannot*) ann),s_exp e))
+ | E_internal_cast ((l,ann),e) -> re (E_internal_cast ((l,s_tannot ann),s_exp e))
| E_comment_struc e -> re (E_comment_struc e)
| E_internal_let (le,e1,e2) -> re (E_internal_let (s_lexp le, s_exp e1, s_exp e2))
| E_internal_plet (p,e1,e2) -> re (E_internal_plet ((*s_pat*) p, s_exp e1, s_exp e2))
| E_internal_return e -> re (E_internal_return (s_exp e))
+ | E_throw e -> re (E_throw (s_exp e))
+ | E_try (e,cases) -> re (E_try (s_exp e, List.map s_pexp cases))
and s_opt_default (Def_val_aux (ed,(l,annot))) =
match ed with
- | Def_val_empty -> Def_val_aux (Def_val_empty,(l,(*s_tannot*) annot))
- | Def_val_dec e -> Def_val_aux (Def_val_dec (s_exp e),(l,(*s_tannot*) annot))
+ | Def_val_empty -> Def_val_aux (Def_val_empty,(l,s_tannot annot))
+ | Def_val_dec e -> Def_val_aux (Def_val_dec (s_exp e),(l,s_tannot annot))
and s_fexps (FES_aux (FES_Fexps (fes,flag), (l,annot))) =
- FES_aux (FES_Fexps (List.map s_fexp fes, flag), (l,(*s_tannot*) annot))
+ FES_aux (FES_Fexps (List.map s_fexp fes, flag), (l,s_tannot annot))
and s_fexp (FE_aux (FE_Fexp (id,e), (l,annot))) =
- FE_aux (FE_Fexp (id,s_exp e),(l,(*s_tannot*) annot))
+ FE_aux (FE_Fexp (id,s_exp e),(l,s_tannot annot))
and s_pexp = function
| (Pat_aux (Pat_exp (p,e),(l,annot))) ->
- Pat_aux (Pat_exp ((*s_pat*) p, s_exp e),(l,(*s_tannot*) annot))
+ Pat_aux (Pat_exp ((*s_pat*) p, s_exp e),(l,s_tannot annot))
| (Pat_aux (Pat_when (p,e1,e2),(l,annot))) ->
- Pat_aux (Pat_when ((*s_pat*) p, s_exp e1, s_exp e2),(l,(*s_tannot*) annot))
+ Pat_aux (Pat_when ((*s_pat*) p, s_exp e1, s_exp e2),(l,s_tannot annot))
and s_letbind (LB_aux (lb,(l,annot))) =
match lb with
- | LB_val (p,e) -> LB_aux (LB_val ((*s_pat*) p,s_exp e), (l,(*s_tannot*) annot))
+ | LB_val (p,e) -> LB_aux (LB_val ((*s_pat*) p,s_exp e), (l,s_tannot annot))
and s_lexp (LEXP_aux (e,(l,annot))) =
- let re e = LEXP_aux (e,(l,(*s_tannot*) annot)) in
+ let re e = LEXP_aux (e,(l,s_tannot annot)) in
match e with
- | LEXP_id _
- | LEXP_cast _
- -> re e
+ | LEXP_id _ -> re e
+ | LEXP_cast (typ,id) -> re (LEXP_cast (s_t typ, id))
| LEXP_memory (id,es) -> re (LEXP_memory (id,List.map s_exp es))
| LEXP_tup les -> re (LEXP_tup (List.map s_lexp les))
| LEXP_vector (le,e) -> re (LEXP_vector (s_lexp le, s_exp e))
| LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2))
| LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id))
in ((fun x -> x (*s_pat*)),s_exp)
-let nexp_subst_pat substs refinements = fst (nexp_subst_fns substs refinements)
-let nexp_subst_exp substs refinements = snd (nexp_subst_fns substs refinements)
+let nexp_subst_pat substs = fst (nexp_subst_fns substs)
+let nexp_subst_exp substs = snd (nexp_subst_fns substs)
let bindings_from_pat p =
let rec aux_pat (P_aux (p,(l,annot))) =
@@ -390,6 +553,7 @@ let bindings_from_pat p =
| P_typ (_,p) -> aux_pat p
| P_id id ->
if pat_id_is_variable env id then [id] else []
+ | P_var (p,kid) -> aux_pat p
| P_vector ps
| P_vector_concat ps
| P_app (_,ps)
@@ -403,86 +567,7 @@ let bindings_from_pat p =
let remove_bound env pat =
let bound = bindings_from_pat pat in
- List.fold_left (fun sub v -> ISubst.remove v env) env bound
-
-(* Remove explicit existential types from the AST, so that the sizes of
- bitvectors will be filled in throughout.
-
- Problems: there might be other existential types that we want to keep (e.g.
- because they describe conditions needed for a vector index to be in range),
- and inference might not be able to find a sufficiently precise type. *)
-let rec deexist_exp (E_aux (e,(l,(annot : Type_check.tannot))) as exp) =
- let re e = E_aux (e,(l,annot)) in
- match e with
- | E_block es -> re (E_block (List.map deexist_exp es))
- | E_nondet es -> re (E_nondet (List.map deexist_exp es))
- | E_id _
- | E_lit _
- | E_sizeof _
- | E_constraint _
- -> (*Type_check.strip_exp*) exp
- | E_cast (Typ_aux (Typ_exist (kids, nc, ty),l),(E_aux (_,(l',annot')) as e)) ->
-(* let env,_ = env_typ_expected l' annot' in
- let plain_e = deexist_exp e in
- let E_aux (_,(_,annot'')) = Type_check.infer_exp env plain_e in
-*)
- deexist_exp e
- | E_cast (ty,e) -> re (E_cast (ty,deexist_exp e))
- | E_app (id,args) -> re (E_app (id,List.map deexist_exp args))
- | E_app_infix (e1,id,e2) -> re (E_app_infix (deexist_exp e1,id,deexist_exp e2))
- | E_tuple es -> re (E_tuple (List.map deexist_exp es))
- | E_if (e1,e2,e3) -> re (E_if (deexist_exp e1,deexist_exp e2,deexist_exp e3))
- | E_for (id,e1,e2,e3,ord,e4) ->
- re (E_for (id,deexist_exp e1,deexist_exp e2,deexist_exp e3,ord,deexist_exp e4))
- | E_vector es -> re (E_vector (List.map deexist_exp es))
- | E_vector_access (e1,e2) -> re (E_vector_access (deexist_exp e1,deexist_exp e2))
- | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (deexist_exp e1,deexist_exp e2,deexist_exp e3))
- | E_vector_update (e1,e2,e3) -> re (E_vector_update (deexist_exp e1,deexist_exp e2,deexist_exp e3))
- | E_vector_update_subrange (e1,e2,e3,e4) ->
- re (E_vector_update_subrange (deexist_exp e1,deexist_exp e2,deexist_exp e3,deexist_exp e4))
- | E_vector_append (e1,e2) -> re (E_vector_append (deexist_exp e1,deexist_exp e2))
- | E_list es -> re (E_list (List.map deexist_exp es))
- | E_cons (e1,e2) -> re (E_cons (deexist_exp e1,deexist_exp e2))
- | E_record _ -> (*Type_check.strip_exp*) exp (* TODO *)
- | E_record_update _ -> (*Type_check.strip_exp*) exp (* TODO *)
- | E_field (e1,fld) -> re (E_field (deexist_exp e1,fld))
- | E_case (e1,cases) -> re (E_case (deexist_exp e1, List.map deexist_pexp cases))
- | E_let (lb,e1) -> re (E_let (deexist_letbind lb, deexist_exp e1))
- | E_assign (le,e1) -> re (E_assign (deexist_lexp le, deexist_exp e1))
- | E_exit e1 -> re (E_exit (deexist_exp e1))
- | E_return e1 -> re (E_return (deexist_exp e1))
- | E_assert (e1,e2) -> re (E_assert (deexist_exp e1,deexist_exp e2))
-and deexist_pexp (Pat_aux (pe,(l,annot))) =
- match pe with
- | Pat_exp (p,e) -> Pat_aux (Pat_exp ((*Type_check.strip_pat*) p,deexist_exp e),(l,annot))
- | Pat_when (p,e1,e2) -> Pat_aux (Pat_when ((*Type_check.strip_pat*) p,deexist_exp e1,deexist_exp e2),(l,annot))
-and deexist_letbind (LB_aux (lb,(l,annot))) =
- match lb with (* TODO, drop tysc if there's an exist? Do they even appear here? *)
- | LB_val (p,e) -> LB_aux (LB_val ((*Type_check.strip_pat*) p,deexist_exp e),(l,annot))
-and deexist_lexp (LEXP_aux (le,(l,annot))) =
- let re le = LEXP_aux (le,(l,annot)) in
- match le with
- | LEXP_id id -> re (LEXP_id id)
- | LEXP_memory (id,es) -> re (LEXP_memory (id,List.map deexist_exp es))
- | LEXP_cast (Typ_aux (Typ_exist _,_),id) -> re (LEXP_id id)
- | LEXP_cast (ty,id) -> re (LEXP_cast (ty,id))
- | LEXP_tup lexps -> re (LEXP_tup (List.map deexist_lexp lexps))
- | LEXP_vector (le,e) -> re (LEXP_vector (deexist_lexp le, deexist_exp e))
- | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (deexist_lexp le, deexist_exp e1, deexist_exp e2))
- | LEXP_field (le,id) -> re (LEXP_field (deexist_lexp le, id))
-
-let deexist_funcl (FCL_aux (FCL_Funcl (id,p,e),(l,annot))) =
- FCL_aux (FCL_Funcl (id, (*Type_check.strip_pat*) p, deexist_exp e),(l,annot))
-
-let deexist_def = function
- | DEF_kind kd -> DEF_kind kd
- | DEF_type td -> DEF_type td
- | DEF_fundef (FD_aux (FD_function (recopt,topt,effopt,fcls),(l,annot))) ->
- DEF_fundef (FD_aux (FD_function (recopt,topt,effopt,List.map deexist_funcl fcls),(l,annot)))
- | x -> x
-
-let deexist (Defs defs) = Defs (List.map deexist_def defs)
-
+ List.fold_left (fun sub v -> Bindings.remove v sub) env bound
(* Attempt simple pattern matches *)
let lit_match = function
@@ -495,22 +580,21 @@ type 'a matchresult =
| DoesNotMatch
| GiveUp
-let can_match (E_aux (e,(l,annot)) as exp0) cases =
- let (env,_) = env_typ_expected l annot in
+let can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases =
let rec findpat_generic check_pat description = function
| [] -> (Reporting_basic.print_err false true l "Monomorphisation"
("Failed to find a case for " ^ description); None)
- | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some (exp,[])
+ | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some (exp,[],[])
| (Pat_aux (Pat_exp (P_aux (P_typ (_,p),_),exp),ann))::tl ->
findpat_generic check_pat description ((Pat_aux (Pat_exp (p,exp),ann))::tl)
| (Pat_aux (Pat_exp (P_aux (P_id id',_),exp),_))::tl
when pat_id_is_variable env id' ->
- Some (exp, [(id', exp0)])
+ Some (exp, [(id', exp0)], [])
| (Pat_aux (Pat_when _,_))::_ -> None
| (Pat_aux (Pat_exp (p,exp),_))::tl ->
match check_pat p with
| DoesNotMatch -> findpat_generic check_pat description tl
- | DoesMatch subst -> Some (exp,subst)
+ | DoesMatch (subst,ksubst) -> Some (exp,subst,ksubst)
| GiveUp -> None
in
match e with
@@ -520,29 +604,56 @@ let can_match (E_aux (e,(l,annot)) as exp0) cases =
let checkpat = function
| P_aux (P_id id',_)
| P_aux (P_app (id',[]),_) ->
- if Id.compare id id' = 0 then DoesMatch [] else DoesNotMatch
+ if Id.compare id id' = 0 then DoesMatch ([],[]) else DoesNotMatch
| P_aux (_,(l',_)) ->
(Reporting_basic.print_err false true l' "Monomorphisation"
"Unexpected kind of pattern for enumeration"; GiveUp)
in findpat_generic checkpat (string_of_id id) cases
| _ -> None)
- | E_lit (L_aux (lit_e, _)) ->
+ | E_lit (L_aux (lit_e, lit_l)) ->
let checkpat = function
| P_aux (P_lit (L_aux (lit_p, _)),_) ->
- if lit_match (lit_e,lit_p) then DoesMatch [] else DoesNotMatch
+ if lit_match (lit_e,lit_p) then DoesMatch ([],[]) else DoesNotMatch
+ | P_aux (P_var (P_aux (P_id id,_), kid),_) ->
+ begin
+ match lit_e with
+ | L_num i ->
+ DoesMatch ([id, E_aux (e,(l,annot))],
+ [kid,Nexp_aux (Nexp_constant i,Unknown)])
+ | _ ->
+ (Reporting_basic.print_err false true lit_l "Monomorphisation"
+ "Unexpected kind of literal for var match"; GiveUp)
+ end
| P_aux (_,(l',_)) ->
(Reporting_basic.print_err false true l' "Monomorphisation"
- "Unexpected kind of pattern for bit"; GiveUp)
- in findpat_generic checkpat "bit" cases
+ "Unexpected kind of pattern for literal"; GiveUp)
+ in findpat_generic checkpat "literal" cases
| _ -> None
+let can_match (E_aux (_,(l,annot)) as exp0) cases =
+ let (env,_) = env_typ_expected l annot in
+ can_match_with_env env exp0 cases
+
+(* Remove top-level casts from an expression. Useful when we need to look at
+ subexpressions to reduce something, but could break type-checking if we used
+ it everywhere. *)
+let rec drop_casts = function
+ | E_aux (E_cast (_,e),_) -> drop_casts e
+ | exp -> exp
+
+(* TODO: ought to be a big int of some form, but would need L_num to be one *)
+let int_of_lit = function
+ | L_hex hex -> int_of_string ("0x" ^ hex)
+ | L_bin bin -> int_of_string ("0b" ^ bin)
+ | _ -> assert false
-(* Similarly, simple conditionals *)
let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) =
match l1,l2 with
| (L_zero|L_false), (L_zero|L_false)
| (L_one |L_true ), (L_one |L_true)
-> Some true
+ | (L_hex _| L_bin _), (L_hex _|L_bin _)
+ -> Some (int_of_lit l1 = int_of_lit l2)
| L_undef, _ | _, L_undef -> None
| _ -> Some (l1 = l2)
@@ -554,8 +665,8 @@ let neq_fns = [Id "neq_anything"]
let try_app (l,ann) (Id_aux (id,_),args) =
let is_eq = List.mem id eq_fns in
let is_neq = (not is_eq) && List.mem id neq_fns in
+ let new_l = Generated l in
if is_eq || is_neq then
- let new_l = Generated l in
match args with
| [E_aux (E_lit l1,_); E_aux (E_lit l2,_)] ->
let lit b = if b then L_true else L_false in
@@ -564,6 +675,34 @@ let try_app (l,ann) (Id_aux (id,_),args) =
| None -> None
| Some b -> Some (E_aux (E_lit (L_aux (lit b,new_l)),(l,ann))))
| _ -> None
+ else if id = Id "cast_bit_bool" then
+ match args with
+ | [E_aux (E_lit L_aux (L_zero,_),_)] -> Some (E_aux (E_lit (L_aux (L_false,new_l)),(l,ann)))
+ | [E_aux (E_lit L_aux (L_one ,_),_)] -> Some (E_aux (E_lit (L_aux (L_true ,new_l)),(l,ann)))
+ | _ -> None
+ else if id = Id "UInt" then
+ match args with
+ | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit,_), _)] ->
+ Some (E_aux (E_lit (L_aux (L_num (int_of_lit lit),new_l)),(l,ann)))
+ | _ -> None
+ else if id = Id "shl_int" then
+ match args with
+ | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
+ Some (E_aux (E_lit (L_aux (L_num (i lsl j),new_l)),(l,ann)))
+ | _ -> None
+ else if id = Id "ex_int" then
+ match args with
+ | [E_aux (E_lit (L_aux (L_num _,_)),_) as exp] -> Some exp
+ | _ -> None
+ else if id = Id "vector_access" || id = Id "bitvector_access" then
+ match args with
+ | [E_aux (E_lit L_aux ((L_hex _ | L_bin _) as lit,_),_);
+ E_aux (E_lit L_aux (L_num i,_),_)] ->
+ let v = int_of_lit lit in
+ let b = (v lsr i) land 1 in
+ let lit' = if b = 1 then L_one else L_zero in
+ Some (E_aux (E_lit (L_aux (lit',new_l)),(l,ann)))
+ | _ -> None
else None
@@ -579,14 +718,62 @@ let try_app_infix (l,ann) (E_aux (e1,ann1)) (Id_aux (id,_)) (E_aux (e2,ann2)) =
| None -> None)
| _ -> None
+let construct_lit_vector args =
+ let rec aux l = function
+ | [] -> Some (L_aux (L_bin (String.concat "" (List.rev l)),Unknown))
+ | E_aux (E_lit (L_aux ((L_zero | L_one) as lit,_)),_)::t ->
+ aux ((if lit = L_zero then "0" else "1")::l) t
+ | _ -> None
+ in aux [] args
(* We may need to split up a pattern match if (1) we've been told to case split
on a variable by the user, or (2) we monomorphised a constructor that's used
in the pattern. *)
type split =
| NoSplit
- | VarSplit of (tannot pat * (id * tannot Ast.exp)) list
- | ConstrSplit of (tannot pat * nexp KSubst.t) list
+ | VarSplit of (tannot pat * (id * tannot Ast.exp) list) list
+ | ConstrSplit of (tannot pat * nexp KBindings.t) list
+
+let threaded_map f state l =
+ let l',state' =
+ List.fold_left (fun (tl,state) element -> let (el',state') = f state element in (el'::tl,state'))
+ ([],state) l
+ in List.rev l',state'
+
+let isubst_minus subst subst' =
+ Bindings.merge (fun _ x y -> match x,y with (Some a), None -> Some a | _, _ -> None) subst subst'
+
+let isubst_minus_set subst set =
+ IdSet.fold Bindings.remove set subst
+
+let assigned_vars exp =
+ fst (Rewriter.fold_exp
+ { (Rewriter.compute_exp_alg IdSet.empty IdSet.union) with
+ Rewriter.lEXP_id = (fun id -> IdSet.singleton id, LEXP_id id);
+ Rewriter.lEXP_cast = (fun (ty,id) -> IdSet.singleton id, LEXP_cast (ty,id)) }
+ exp)
+
+let assigned_vars_in_fexps (FES_aux (FES_Fexps (fes,_), _)) =
+ List.fold_left
+ (fun vs (FE_aux (FE_Fexp (_,e),_)) -> IdSet.union vs (assigned_vars e))
+ IdSet.empty
+ fes
+
+let assigned_vars_in_pexp (Pat_aux (p,_)) =
+ match p with
+ | Pat_exp (_,e) -> assigned_vars e
+ | Pat_when (p,e1,e2) -> IdSet.union (assigned_vars e1) (assigned_vars e2)
+
+let rec assigned_vars_in_lexp (LEXP_aux (le,_)) =
+ match le with
+ | LEXP_id id
+ | LEXP_cast (_,id) -> IdSet.singleton id
+ | LEXP_tup lexps -> List.fold_left (fun vs le -> IdSet.union vs (assigned_vars_in_lexp le)) IdSet.empty lexps
+ | LEXP_memory (_,es) -> List.fold_left (fun vs e -> IdSet.union vs (assigned_vars e)) IdSet.empty es
+ | LEXP_vector (le,e) -> IdSet.union (assigned_vars_in_lexp le) (assigned_vars e)
+ | LEXP_vector_range (le,e1,e2) ->
+ IdSet.union (assigned_vars_in_lexp le) (IdSet.union (assigned_vars e1) (assigned_vars e2))
+ | LEXP_field (le,_) -> assigned_vars_in_lexp le
let split_defs splits defs =
let split_constructors (Defs defs) =
@@ -598,7 +785,7 @@ let split_defs splits defs =
| None -> ([],[Tu_aux (Tu_ty_id (ty,id),l)])
| Some variants ->
([(id,variants)],
- List.map (fun (insts, id') -> Tu_aux (Tu_ty_id (inst_src_type insts ty,id'),Generated l)) variants))
+ List.map (fun (insts, id', ty) -> Tu_aux (Tu_ty_id (ty,id'),Generated l)) variants))
in
let sc_type_def ((TD_aux (tda,annot)) as td) =
match tda with
@@ -618,60 +805,68 @@ let split_defs splits defs =
let (refinements, defs') = split_constructors defs in
- (* Extract nvar substitution by comparing two types *)
- let build_nexp_subst l t1 t2 = [] (*
- let rec from_types t1 t2 =
- let t1 = match t1.t with Tabbrev(_,t) -> t | _ -> t1 in
- let t2 = match t2.t with Tabbrev(_,t) -> t | _ -> t2 in
- if t1 = t2 then [] else
- match t1.t,t2.t with
- | Tapp (s1,args1), Tapp (s2,args2) ->
- if s1 = s2 then
- List.concat (List.map2 from_args args1 args2)
- else (Reporting_basic.print_err false true l "Monomorphisation"
- "Unexpected type mismatch"; [])
- | Ttup ts1, Ttup ts2 ->
- if List.length ts1 = List.length ts2 then
- List.concat (List.map2 from_types ts1 ts2)
- else (Reporting_basic.print_err false true l "Monomorphisation"
- "Unexpected type mismatch"; [])
- | _ -> []
- and from_args arg1 arg2 =
- match arg1,arg2 with
- | TA_typ t1, TA_typ t2 -> from_types t1 t2
- | TA_nexp n1, TA_nexp n2 -> from_nexps n1 n2
- | _ -> []
- and from_nexps n1 n2 =
- match n1.nexp, n2.nexp with
- | Nvar s, Nvar s' when s = s' -> []
- | Nvar s, _ -> [(s,n2)]
- | Nadd (n3,n4), Nadd (n5,n6)
- | Nsub (n3,n4), Nsub (n5,n6)
- | Nmult (n3,n4), Nmult (n5,n6)
- -> from_nexps n3 n5 @ from_nexps n4 n6
- | N2n (n3,p1), N2n (n4,p2) when p1 = p2 -> from_nexps n3 n4
- | Npow (n3,p1), Npow (n4,p2) when p1 = p2 -> from_nexps n3 n4
- | Nneg n3, Nneg n4 -> from_nexps n3 n4
- | _ -> []
- in match t1,t2 with
- | Base ((_,t1),_,_,_,_,_),Base ((_,t2),_,_,_,_,_) -> from_types t1 t2
- | _ -> []*)
- in
-
- let nexp_substs = ref [] in
-
- (* Constant propogation *)
- let rec const_prop_exp substs ((E_aux (e,(l,annot))) as exp) =
- let re e = E_aux (e,(l,annot)) in
+ (* Constant propogation.
+ Takes maps of immutable/mutable variables to subsitute.
+ Extremely conservative about evaluation order of assignments in
+ subexpressions, dropping assignments rather than committing to
+ any particular order *)
+ let rec const_prop_exp substs assigns ((E_aux (e,(l,annot))) as exp) =
+ (* Functions to treat lists and tuples of subexpressions as possibly
+ non-deterministic: that is, we stop making any assumptions about
+ variables that are assigned to in any of the subexpressions *)
+ let non_det_exp_list es =
+ let assigned_in =
+ List.fold_left (fun vs exp -> IdSet.union vs (assigned_vars exp))
+ IdSet.empty es in
+ let assigns = isubst_minus_set assigns assigned_in in
+ let es' = List.map (fun e -> fst (const_prop_exp substs assigns e)) es in
+ es',assigns
+ in
+ let non_det_exp_2 e1 e2 =
+ let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in
+ let assigns = isubst_minus_set assigns assigned_in_e12 in
+ let e1',_ = const_prop_exp substs assigns e1 in
+ let e2',_ = const_prop_exp substs assigns e2 in
+ e1',e2',assigns
+ in
+ let non_det_exp_3 e1 e2 e3 =
+ let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in
+ let assigned_in_e123 = IdSet.union assigned_in_e12 (assigned_vars e3) in
+ let assigns = isubst_minus_set assigns assigned_in_e123 in
+ let e1',_ = const_prop_exp substs assigns e1 in
+ let e2',_ = const_prop_exp substs assigns e2 in
+ let e3',_ = const_prop_exp substs assigns e3 in
+ e1',e2',e3',assigns
+ in
+ let non_det_exp_4 e1 e2 e3 e4 =
+ let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in
+ let assigned_in_e123 = IdSet.union assigned_in_e12 (assigned_vars e3) in
+ let assigned_in_e1234 = IdSet.union assigned_in_e123 (assigned_vars e4) in
+ let assigns = isubst_minus_set assigns assigned_in_e1234 in
+ let e1',_ = const_prop_exp substs assigns e1 in
+ let e2',_ = const_prop_exp substs assigns e2 in
+ let e3',_ = const_prop_exp substs assigns e3 in
+ let e4',_ = const_prop_exp substs assigns e4 in
+ e1',e2',e3',e4',assigns
+ in
+ let re e assigns = E_aux (e,(l,annot)),assigns in
match e with
(* TODO: are there more circumstances in which we should get rid of these? *)
- | E_block [e] -> const_prop_exp substs e
- | E_block es -> re (E_block (List.map (const_prop_exp substs) es))
- | E_nondet es -> re (E_nondet (List.map (const_prop_exp substs) es))
-
+ | E_block [e] -> const_prop_exp substs assigns e
+ | E_block es ->
+ let es',assigns = threaded_map (const_prop_exp substs) assigns es in
+ re (E_block es') assigns
+ | E_nondet es ->
+ let es',assigns = non_det_exp_list es in
+ re (E_nondet es') assigns
| E_id id ->
- (try ISubst.find id substs
- with Not_found -> exp)
+ let env,_ = env_typ_expected l annot in
+ (try
+ match Env.lookup_id id env with
+ | Local (Immutable,_) -> Bindings.find id substs
+ | Local (Mutable,_) -> Bindings.find id assigns
+ | _ -> exp
+ with Not_found -> exp),assigns
| E_lit _
| E_sizeof _
| E_internal_exp _
@@ -679,97 +874,201 @@ let split_defs splits defs =
| E_internal_exp_user _
| E_comment _
| E_constraint _
- -> exp
- | E_cast (t,e') -> re (E_cast (t, const_prop_exp substs e'))
+ -> exp,assigns
+ | E_cast (t,e') ->
+ let e'',assigns = const_prop_exp substs assigns e' in
+ re (E_cast (t, e'')) assigns
| E_app (id,es) ->
- let es' = List.map (const_prop_exp substs) es in
+ let es',assigns = non_det_exp_list es in
+ let env,_ = env_typ_expected l annot in
(match try_app (l,annot) (id,es') with
| None ->
- (match const_prop_try_fn (id,es') with
- | None -> re (E_app (id,es'))
- | Some r -> r)
- | Some r -> r)
+ (match const_prop_try_fn l env (id,es') with
+ | None ->
+ (let env,_ = env_typ_expected l annot in
+ match Env.is_union_constructor id env, refine_constructor refinements l env id es' with
+ | true, Some exp -> re exp assigns
+ | _,_ -> re (E_app (id,es')) assigns)
+ | Some r -> r,assigns)
+ | Some r -> r,assigns)
| E_app_infix (e1,id,e2) ->
- let e1',e2' = const_prop_exp substs e1,const_prop_exp substs e2 in
+ let e1',e2',assigns = non_det_exp_2 e1 e2 in
(match try_app_infix (l,annot) e1' id e2' with
- | Some exp -> exp
- | None -> re (E_app_infix (e1',id,e2')))
- | E_tuple es -> re (E_tuple (List.map (const_prop_exp substs) es))
+ | Some exp -> exp,assigns
+ | None -> re (E_app_infix (e1',id,e2')) assigns)
+ | E_tuple es ->
+ let es',assigns = non_det_exp_list es in
+ re (E_tuple es') assigns
| E_if (e1,e2,e3) ->
- let e1' = const_prop_exp substs e1 in
- let e2',e3' = const_prop_exp substs e2, const_prop_exp substs e3 in
- (match e1' with
+ let e1',assigns = const_prop_exp substs assigns e1 in
+ let e2',assigns2 = const_prop_exp substs assigns e2 in
+ let e3',assigns3 = const_prop_exp substs assigns e3 in
+ (match drop_casts e1' with
| E_aux (E_lit (L_aux ((L_true|L_false) as lit ,_)),_) ->
- let e' = match lit with L_true -> e2' | _ -> e3' in
- (match e' with E_aux (_,(_,annot')) ->
- nexp_substs := build_nexp_subst l annot annot' @ !nexp_substs;
- e')
- | _ -> re (E_if (e1',e2',e3')))
- | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3,ord,const_prop_exp (ISubst.remove id substs) e4))
- | E_vector es -> re (E_vector (List.map (const_prop_exp substs) es))
- | E_vector_access (e1,e2) -> re (E_vector_access (const_prop_exp substs e1,const_prop_exp substs e2))
- | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3))
- | E_vector_update (e1,e2,e3) -> re (E_vector_update (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3))
- | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3,const_prop_exp substs e4))
- | E_vector_append (e1,e2) -> re (E_vector_append (const_prop_exp substs e1,const_prop_exp substs e2))
- | E_list es -> re (E_list (List.map (const_prop_exp substs) es))
- | E_cons (e1,e2) -> re (E_cons (const_prop_exp substs e1,const_prop_exp substs e2))
- | E_record fes -> re (E_record (const_prop_fexps substs fes))
- | E_record_update (e,fes) -> re (E_record_update (const_prop_exp substs e, const_prop_fexps substs fes))
- | E_field (e,id) -> re (E_field (const_prop_exp substs e,id))
+ (match lit with L_true -> e2',assigns2 | _ -> e3',assigns3)
+ | _ ->
+ let assigns = isubst_minus_set assigns (assigned_vars e2) in
+ let assigns = isubst_minus_set assigns (assigned_vars e3) in
+ re (E_if (e1',e2',e3')) assigns)
+ | E_for (id,e1,e2,e3,ord,e4) ->
+ (* Treat e1, e2 and e3 (from, to and by) as a non-det tuple *)
+ let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in
+ let assigns = isubst_minus_set assigns (assigned_vars e4) in
+ let e4',_ = const_prop_exp (Bindings.remove id substs) assigns e4 in
+ re (E_for (id,e1',e2',e3',ord,e4')) assigns
+ | E_loop (loop,e1,e2) ->
+ let assigns = isubst_minus_set assigns (IdSet.union (assigned_vars e1) (assigned_vars e2)) in
+ let e1',_ = const_prop_exp substs assigns e1 in
+ let e2',_ = const_prop_exp substs assigns e2 in
+ re (E_loop (loop,e1',e2')) assigns
+ | E_vector es ->
+ let es',assigns = non_det_exp_list es in
+ begin
+ match construct_lit_vector es' with
+ | None -> re (E_vector es') assigns
+ | Some lit -> re (E_lit lit) assigns
+ end
+ | E_vector_access (e1,e2) ->
+ let e1',e2',assigns = non_det_exp_2 e1 e2 in
+ re (E_vector_access (e1',e2')) assigns
+ | E_vector_subrange (e1,e2,e3) ->
+ let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in
+ re (E_vector_subrange (e1',e2',e3')) assigns
+ | E_vector_update (e1,e2,e3) ->
+ let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in
+ re (E_vector_update (e1',e2',e3')) assigns
+ | E_vector_update_subrange (e1,e2,e3,e4) ->
+ let e1',e2',e3',e4',assigns = non_det_exp_4 e1 e2 e3 e4 in
+ re (E_vector_update_subrange (e1',e2',e3',e4')) assigns
+ | E_vector_append (e1,e2) ->
+ let e1',e2',assigns = non_det_exp_2 e1 e2 in
+ re (E_vector_append (e1',e2')) assigns
+ | E_list es ->
+ let es',assigns = non_det_exp_list es in
+ re (E_list es') assigns
+ | E_cons (e1,e2) ->
+ let e1',e2',assigns = non_det_exp_2 e1 e2 in
+ re (E_cons (e1',e2')) assigns
+ | E_record fes ->
+ let assigned_in_fes = assigned_vars_in_fexps fes in
+ let assigns = isubst_minus_set assigns assigned_in_fes in
+ re (E_record (const_prop_fexps substs assigns fes)) assigns
+ | E_record_update (e,fes) ->
+ let assigned_in = IdSet.union (assigned_vars_in_fexps fes) (assigned_vars e) in
+ let assigns = isubst_minus_set assigns assigned_in in
+ let e',_ = const_prop_exp substs assigns e in
+ re (E_record_update (e', const_prop_fexps substs assigns fes)) assigns
+ | E_field (e,id) ->
+ let e',assigns = const_prop_exp substs assigns e in
+ re (E_field (e',id)) assigns
| E_case (e,cases) ->
- let e' = const_prop_exp substs e in
+ let e',assigns = const_prop_exp substs assigns e in
(match can_match e' cases with
- | None -> re (E_case (e', List.map (const_prop_pexp substs) cases))
- | Some (E_aux (_,(_,annot')) as exp,newbindings) ->
- let newbindings_env = isubst_from_list newbindings in
- let substs' = isubst_union substs newbindings_env in
- nexp_substs := build_nexp_subst l annot annot' @ !nexp_substs;
- const_prop_exp substs' exp)
- | E_let (lb,e) ->
- let (lb',substs') = const_prop_letbind substs lb in
- re (E_let (lb', const_prop_exp substs' e))
- | E_assign (le,e) -> re (E_assign (const_prop_lexp substs le, const_prop_exp substs e))
- | E_exit e -> re (E_exit (const_prop_exp substs e))
- | E_return e -> re (E_return (const_prop_exp substs e))
- | E_assert (e1,e2) -> re (E_assert (const_prop_exp substs e1,const_prop_exp substs e2))
- | E_internal_cast (ann,e) -> re (E_internal_cast (ann,const_prop_exp substs e))
- | E_comment_struc e -> re (E_comment_struc e)
+ | None ->
+ let assigned_in =
+ List.fold_left (fun vs pe -> IdSet.union vs (assigned_vars_in_pexp pe))
+ IdSet.empty cases
+ in
+ let assigns' = isubst_minus_set assigns assigned_in in
+ re (E_case (e', List.map (const_prop_pexp substs assigns) cases)) assigns'
+ | Some (E_aux (_,(_,annot')) as exp,newbindings,kbindings) ->
+ let exp = nexp_subst_exp (kbindings_from_list kbindings) exp in
+ let newbindings_env = bindings_from_list newbindings in
+ let substs' = bindings_union substs newbindings_env in
+ const_prop_exp substs' assigns exp)
+ | E_let (lb,e2) ->
+ begin
+ match lb with
+ | LB_aux (LB_val (p,e), annot) ->
+ let e',assigns = const_prop_exp substs assigns e in
+ let substs' = remove_bound substs p in
+ let plain () =
+ let e2',assigns = const_prop_exp substs' assigns e2 in
+ re (E_let (LB_aux (LB_val (p,e'), annot),
+ e2')) assigns in
+ if is_value e' && not (is_value e) then
+ match can_match e' [Pat_aux (Pat_exp (p,e2),(Unknown,None))] with
+ | None -> plain ()
+ | Some (e'',bindings,kbindings) ->
+ let e'' = nexp_subst_exp (kbindings_from_list kbindings) e'' in
+ let bindings = bindings_from_list bindings in
+ let substs'' = bindings_union substs' bindings in
+ const_prop_exp substs'' assigns e''
+ else plain ()
+ end
+ (* TODO maybe - tuple assignments *)
+ | E_assign (le,e) ->
+ let env,_ = env_typ_expected l annot in
+ let assigned_in = IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) in
+ let assigns = isubst_minus_set assigns assigned_in in
+ let le',idopt = const_prop_lexp substs assigns le in
+ let e',_ = const_prop_exp substs assigns e in
+ let assigns =
+ match idopt with
+ | Some id ->
+ begin
+ match Env.lookup_id id env with
+ | Local (Mutable,_) | Unbound ->
+ if is_value e'
+ then Bindings.add id e' assigns
+ else Bindings.remove id assigns
+ | _ -> assigns
+ end
+ | None -> assigns
+ in
+ re (E_assign (le', e')) assigns
+ | E_exit e ->
+ let e',_ = const_prop_exp substs assigns e in
+ re (E_exit e') Bindings.empty
+ | E_throw e ->
+ let e',_ = const_prop_exp substs assigns e in
+ re (E_throw e') Bindings.empty
+ | E_try (e,cases) ->
+ (* TODO: try and preserve *any* assignment info *)
+ let e',_ = const_prop_exp substs assigns e in
+ re (E_case (e', List.map (const_prop_pexp substs Bindings.empty) cases)) Bindings.empty
+ | E_return e ->
+ let e',_ = const_prop_exp substs assigns e in
+ re (E_return e') Bindings.empty
+ | E_assert (e1,e2) ->
+ let e1',e2',assigns = non_det_exp_2 e1 e2 in
+ re (E_assert (e1',e2')) assigns
+ | E_internal_cast (ann,e) ->
+ let e',assigns = const_prop_exp substs assigns e in
+ re (E_internal_cast (ann,e')) assigns
+ (* TODO: should I substitute or anything here? Is it even used? *)
+ | E_comment_struc e -> re (E_comment_struc e) assigns
| E_internal_let _
| E_internal_plet _
| E_internal_return _
-> raise (Reporting_basic.err_unreachable l
"Unexpected internal expression encountered in monomorphisation")
- and const_prop_opt_default substs ((Def_val_aux (ed,annot)) as eda) =
- match ed with
- | Def_val_empty -> eda
- | Def_val_dec e -> Def_val_aux (Def_val_dec (const_prop_exp substs e),annot)
- and const_prop_fexps substs (FES_aux (FES_Fexps (fes,flag), annot)) =
- FES_aux (FES_Fexps (List.map (const_prop_fexp substs) fes, flag), annot)
- and const_prop_fexp substs (FE_aux (FE_Fexp (id,e), annot)) =
- FE_aux (FE_Fexp (id,const_prop_exp substs e),annot)
- and const_prop_pexp substs = function
+ and const_prop_fexps substs assigns (FES_aux (FES_Fexps (fes,flag), annot)) =
+ FES_aux (FES_Fexps (List.map (const_prop_fexp substs assigns) fes, flag), annot)
+ and const_prop_fexp substs assigns (FE_aux (FE_Fexp (id,e), annot)) =
+ FE_aux (FE_Fexp (id,fst (const_prop_exp substs assigns e)),annot)
+ and const_prop_pexp substs assigns = function
| (Pat_aux (Pat_exp (p,e),l)) ->
- Pat_aux (Pat_exp (p,const_prop_exp (remove_bound substs p) e),l)
+ Pat_aux (Pat_exp (p,fst (const_prop_exp (remove_bound substs p) assigns e)),l)
| (Pat_aux (Pat_when (p,e1,e2),l)) ->
let substs' = remove_bound substs p in
- Pat_aux (Pat_when (p, const_prop_exp substs' e1, const_prop_exp substs' e2),l)
- and const_prop_letbind substs (LB_aux (lb,annot)) =
- match lb with
- | LB_val (p,e) ->
- (LB_aux (LB_val (p,const_prop_exp substs e), annot),
- remove_bound substs p)
- and const_prop_lexp substs ((LEXP_aux (e,annot)) as le) =
- let re e = LEXP_aux (e,annot) in
+ let e1',assigns = const_prop_exp substs' assigns e1 in
+ Pat_aux (Pat_when (p, e1', fst (const_prop_exp substs' assigns e2)),l)
+ and const_prop_lexp substs assigns ((LEXP_aux (e,annot)) as le) =
+ let re e = LEXP_aux (e,annot), None in
match e with
- | LEXP_id _ (* shouldn't end up substituting here *)
- | LEXP_cast _
- -> le
- | LEXP_memory (id,es) -> re (LEXP_memory (id,List.map (const_prop_exp substs) es)) (* or here *)
- | LEXP_tup les -> re (LEXP_tup (List.map (const_prop_lexp substs) les))
- | LEXP_vector (le,e) -> re (LEXP_vector (const_prop_lexp substs le, const_prop_exp substs e))
- | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (const_prop_lexp substs le, const_prop_exp substs e1, const_prop_exp substs e2))
- | LEXP_field (le,id) -> re (LEXP_field (const_prop_lexp substs le, id))
+ | LEXP_id id (* shouldn't end up substituting here *)
+ | LEXP_cast (_,id)
+ -> le, Some id
+ | LEXP_memory (id,es) ->
+ re (LEXP_memory (id,List.map (fun e -> fst (const_prop_exp substs assigns e)) es)) (* or here *)
+ | LEXP_tup les -> re (LEXP_tup (List.map (fun le -> fst (const_prop_lexp substs assigns le)) les))
+ | LEXP_vector (le,e) -> re (LEXP_vector (fst (const_prop_lexp substs assigns le), fst (const_prop_exp substs assigns e)))
+ | LEXP_vector_range (le,e1,e2) ->
+ re (LEXP_vector_range (fst (const_prop_lexp substs assigns le),
+ fst (const_prop_exp substs assigns e1),
+ fst (const_prop_exp substs assigns e2)))
+ | LEXP_field (le,id) -> re (LEXP_field (fst (const_prop_lexp substs assigns le), id))
(* Reduce a function when
1. all arguments are values,
2. the function is pure,
@@ -777,7 +1076,7 @@ let split_defs splits defs =
(and 4. the function is not scattered, but that's not terribly important)
to try and keep execution time and the results managable.
*)
- and const_prop_try_fn (id,args) =
+ and const_prop_try_fn l env (id,args) =
if not (List.for_all is_value args) then
None
else
@@ -790,36 +1089,23 @@ let split_defs splits defs =
| Some (eff,_) when not (is_pure eff) -> None
| Some (_,fcls) ->
let arg = match args with
- | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(Unknown,None))
+ | [] -> E_aux (E_lit (L_aux (L_unit,Generated l)),(Generated l,None))
| [e] -> e
- | _ -> E_aux (E_tuple args,(Unknown,None)) in
+ | _ -> E_aux (E_tuple args,(Generated l,None)) in
let cases = List.map (function
| FCL_aux (FCL_Funcl (_,pat,exp), ann) -> Pat_aux (Pat_exp (pat,exp),ann))
fcls in
- match can_match arg cases with
- | Some (exp,bindings) ->
- let substs = isubst_from_list bindings in
- let result = const_prop_exp substs exp in
+ match can_match_with_env env arg cases with
+ | Some (exp,bindings,kbindings) ->
+ let substs = bindings_from_list bindings in
+ let result,_ = const_prop_exp substs Bindings.empty exp in
if is_value result then Some result else None
| None -> None
in
- let subst_exp subst exp =
- if disable_const_propagation then
- let (subi,(E_aux (_,subannot) as sube)) = subst in
- let E_aux (e,(l,annot)) = exp in
- let lg = Generated l in
- let id = match subi with Id_aux (i,l) -> Id_aux (i,lg) in
- let p = P_aux (P_id id, subannot) in
- E_aux (E_let (LB_aux (LB_val (p,sube),(lg,annot)), exp),(lg,annot))
- else
- let substs = isubst_from_list [subst] in
- let () = nexp_substs := [] in
- let exp' = const_prop_exp substs exp in
- (* Substitute what we've learned about nvars into the term *)
- let nsubsts = isubst_from_list !nexp_substs in
- let () = nexp_substs := [] in
- nexp_subst_exp nsubsts refinements exp'
+ let subst_exp substs exp =
+ let substs = bindings_from_list substs in
+ fst (const_prop_exp substs Bindings.empty exp)
in
(* Split a variable pattern into every possible value *)
@@ -841,13 +1127,13 @@ let split_defs splits defs =
(* enumerations *)
let ns = Env.get_enum id env in
List.map (fun n -> (P_aux (P_id (renew_id n),(l,annot)),
- (var,E_aux (E_id (renew_id n),(new_l,annot))))) ns
+ [var,E_aux (E_id (renew_id n),(new_l,annot))])) ns
with Type_error _ ->
match id with
| Id_aux (Id "bit",_) ->
List.map (fun b ->
P_aux (P_lit (L_aux (b,new_l)),(l,annot)),
- (var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))))
+ [var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))])
[L_zero; L_one]
| _ -> cannot ())
| Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp len,_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
@@ -857,7 +1143,7 @@ let split_defs splits defs =
let lits = make_vectors sz in
List.map (fun lit ->
P_aux (P_lit lit,(l,annot)),
- (var,E_aux (E_lit lit,(new_l,annot)))) lits
+ [var,E_aux (E_lit lit,(new_l,annot))]) lits
else
raise (Reporting_basic.err_general l
("Refusing to split vector type of length " ^ string_of_int sz ^
@@ -879,29 +1165,39 @@ let split_defs splits defs =
| Int _ -> []
| Generated l -> [] (* Could do match_l l, but only want to split user-written patterns *)
| Range (p,q) ->
- List.filter (fun ((filename,line),_) ->
- Filename.basename p.Lexing.pos_fname = filename &&
- p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls
+ let matches =
+ List.filter (fun ((filename,line),_) ->
+ Filename.basename p.Lexing.pos_fname = filename &&
+ p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls
+ in List.map snd matches
in
- let split_pat var p =
+ let split_pat vars p =
let id_matches = function
- | Id_aux (Id x,_) -> x = var
- | Id_aux (DeIid x,_) -> x = var
+ | Id_aux (Id x,_) -> List.mem x vars
+ | Id_aux (DeIid x,_) -> List.mem x vars
in
let rec list f = function
| [] -> None
| h::t ->
- match f h with
- | None -> (match list f t with None -> None | Some (l,ps,r) -> Some (h::l,ps,r))
- | Some ps -> Some ([],ps,t)
+ let t' =
+ match list f t with
+ | None -> [t,[]]
+ | Some t' -> t'
+ in
+ let h' =
+ match f h with
+ | None -> [h,[]]
+ | Some ps -> ps
+ in
+ Some (List.concat (List.map (fun (h,hsubs) -> List.map (fun (t,tsubs) -> (h::t,hsubs@tsubs)) t') h'))
in
let rec spl (P_aux (p,(l,annot))) =
let relist f ctx ps =
optmap (list f ps)
- (fun (left,ps,right) ->
- List.map (fun (p,sub) -> P_aux (ctx (left@p::right),(l,annot)),sub) ps)
+ (fun ps ->
+ List.map (fun (ps,sub) -> P_aux (ctx ps,(l,annot)),sub) ps)
in
let re f p =
optmap (spl p)
@@ -916,10 +1212,11 @@ let split_defs splits defs =
match p with
| P_lit _
| P_wild
+ | P_var _
-> None
| P_as (p',id) when id_matches id ->
raise (Reporting_basic.err_general l
- ("Cannot split " ^ var ^ " on 'as' pattern"))
+ ("Cannot split " ^ string_of_id id ^ " on 'as' pattern"))
| P_as (p',id) ->
re (fun p -> P_as (p,id)) p'
| P_typ (t,p') -> re (fun p -> P_typ (t,p)) p'
@@ -940,18 +1237,21 @@ let split_defs splits defs =
| P_list ps ->
relist spl (fun ps -> P_list ps) ps
| P_cons (p1,p2) ->
- match re (fun p' -> P_cons (p',p2)) p1 with
- | Some r -> Some r
- | None -> re (fun p' -> P_cons (p1,p')) p2
+ match spl p1, spl p2 with
+ | None, None -> None
+ | p1', p2' ->
+ let p1' = match p1' with None -> [p1,[]] | Some p1' -> p1' in
+ let p2' = match p2' with None -> [p2,[]] | Some p2' -> p2' in
+ let ps = List.map (fun (p1',subs1) -> List.map (fun (p2',subs2) ->
+ P_aux (P_cons (p1',p2'),(l,annot)),subs1@subs2) p2') p1' in
+ Some (List.concat ps)
in spl p
in
let map_pat_by_loc (P_aux (p,(l,_)) as pat) =
match match_l l with
| [] -> None
- | [(_,var)] -> split_pat var pat
- | lvs -> raise (Reporting_basic.err_general l
- ("Multiple variables to split on: " ^ String.concat ", " (List.map snd lvs)))
+ | vars -> split_pat vars pat
in
let map_pat (P_aux (p,(l,tannot)) as pat) =
match map_pat_by_loc pat with
@@ -959,29 +1259,34 @@ let split_defs splits defs =
| None ->
match p with
| P_app (id,args) ->
- (try
- let (_,variants) = List.find (fun (id',_) -> Id.compare id id' = 0) refinements in
- let env,_ = env_typ_expected l tannot in
- let constr_out_typ =
- match Env.get_val_spec id env with
- | (qs,Typ_aux (Typ_fn(_,outt,_),_)) -> outt
- | _ -> raise (Reporting_basic.err_general l
- ("Constructor " ^ string_of_id id ^ " is not a construtor!"))
- in
- let varmap = build_nexp_subst l constr_out_typ tannot in
- let map_inst (insts,id') =
- let insts = List.map (fun (v,i) ->
- ((match List.assoc (string_of_kid v) varmap with
- | Nexp_aux (Nexp_var s, _) -> s
- | _ -> raise (Reporting_basic.err_general l
- ("Constructor parameter not a variable: " ^ string_of_kid v))),
- Nexp_aux (Nexp_constant i,Generated l)))
- insts in
- P_aux (P_app (id',args),(Generated l,tannot)),
- ksubst_from_list insts
- in
- ConstrSplit (List.map map_inst variants)
- with Not_found -> NoSplit)
+ begin
+ let kid,kid_annot =
+ match args with
+ | [P_aux (P_var (_,kid),ann)] -> kid,ann
+ | _ ->
+ raise (Reporting_basic.err_general l
+ "Pattern match not currently supported by monomorphisation")
+ in match List.find (fun (id',_) -> Id.compare id id' = 0) refinements with
+ | (_,variants) ->
+ let map_inst (insts,id',_) =
+ let insts =
+ match insts with [(v,Some i)] -> [(kid,Nexp_aux (Nexp_constant i, Generated l))]
+ | _ -> assert false
+ in
+(*
+ let insts,_ = split_insts insts in
+ let insts = List.map (fun (v,i) ->
+ (??,
+ Nexp_aux (Nexp_constant i,Generated l)))
+ insts in
+ P_aux (P_app (id',args),(Generated l,tannot)),
+*)
+ P_aux (P_app (id',[P_aux (P_id (id_of_kid kid),kid_annot)]),(Generated l,tannot)),
+ kbindings_from_list insts
+ in
+ ConstrSplit (List.map map_inst variants)
+ | exception Not_found -> NoSplit
+ end
| _ -> NoSplit
in
@@ -991,7 +1296,7 @@ let split_defs splits defs =
| lvs ->
let pvs = bindings_from_pat p in
let pvs = List.map string_of_id pvs in
- let overlap = List.exists (fun (_,v) -> List.mem v pvs) lvs in
+ let overlap = List.exists (fun v -> List.mem v pvs) lvs in
let () =
if overlap then
Reporting_basic.print_err false true l "Monomorphisation"
@@ -1019,6 +1324,7 @@ let split_defs splits defs =
| E_tuple es -> re (E_tuple (List.map map_exp es))
| E_if (e1,e2,e3) -> re (E_if (map_exp e1, map_exp e2, map_exp e3))
| E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,map_exp e1,map_exp e2,map_exp e3,ord,map_exp e4))
+ | E_loop (loop,e1,e2) -> re (E_loop (loop,map_exp e1,map_exp e2))
| E_vector es -> re (E_vector (List.map map_exp es))
| E_vector_access (e1,e2) -> re (E_vector_access (map_exp e1,map_exp e2))
| E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (map_exp e1,map_exp e2,map_exp e3))
@@ -1034,6 +1340,8 @@ let split_defs splits defs =
| E_let (lb,e) -> re (E_let (map_letbind lb, map_exp e))
| E_assign (le,e) -> re (E_assign (map_lexp le, map_exp e))
| E_exit e -> re (E_exit (map_exp e))
+ | E_throw e -> re (E_throw e)
+ | E_try (e,cases) -> re (E_try (map_exp e, List.concat (List.map map_pexp cases)))
| E_return e -> re (E_return (map_exp e))
| E_assert (e1,e2) -> re (E_assert (map_exp e1,map_exp e2))
| E_internal_cast (ann,e) -> re (E_internal_cast (ann,map_exp e))
@@ -1054,32 +1362,30 @@ let split_defs splits defs =
(match map_pat p with
| NoSplit -> [Pat_aux (Pat_exp (p,map_exp e),l)]
| VarSplit patsubsts ->
- List.map (fun (pat',subst) ->
- let exp' = subst_exp subst e in
+ List.map (fun (pat',substs) ->
+ let exp' = subst_exp substs e in
Pat_aux (Pat_exp (pat', map_exp exp'),l))
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp' = nexp_subst_exp nsubst [] e in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp' = nexp_subst_exp nsubst e in
Pat_aux (Pat_exp (pat', map_exp exp'),l)
) patnsubsts)
| Pat_aux (Pat_when (p,e1,e2),l) ->
(match map_pat p with
| NoSplit -> [Pat_aux (Pat_when (p,map_exp e1,map_exp e2),l)]
| VarSplit patsubsts ->
- List.map (fun (pat',subst) ->
- let exp1' = subst_exp subst e1 in
- let exp2' = subst_exp subst e2 in
+ List.map (fun (pat',substs) ->
+ let exp1' = subst_exp substs e1 in
+ let exp2' = subst_exp substs e2 in
Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l))
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp1' = nexp_subst_exp nsubst [] e1 in
- let exp2' = nexp_subst_exp nsubst [] e2 in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp1' = nexp_subst_exp nsubst e1 in
+ let exp2' = nexp_subst_exp nsubst e2 in
Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)
) patnsubsts)
and map_letbind (LB_aux (lb,annot)) =
@@ -1102,15 +1408,14 @@ let split_defs splits defs =
match map_pat pat with
| NoSplit -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)]
| VarSplit patsubsts ->
- List.map (fun (pat',subst) ->
- let exp' = subst_exp subst exp in
+ List.map (fun (pat',substs) ->
+ let exp' = subst_exp substs exp in
FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot))
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp' = nexp_subst_exp nsubst [] exp in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp' = nexp_subst_exp nsubst exp in
FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot)
) patnsubsts
in
@@ -1141,5 +1446,5 @@ let split_defs splits defs =
in
Defs (List.concat (List.map map_def defs))
in
- deexist (map_locs splits defs')
+ map_locs splits defs'
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 5ca0c1bc..a890e039 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -163,7 +163,7 @@ let doc_typ_lem, doc_atomic_typ_lem =
(*let exc_typ = string "string" in*)
let ret_typ =
if effectful efct
- then separate space [string "M";(*parens exc_typ;*) fn_typ sequential mwords true ret]
+ then separate space [string "_M";(*parens exc_typ;*) fn_typ sequential mwords true ret]
else separate space [fn_typ sequential mwords false ret] in
let tpp = separate space [tup_typ sequential mwords true arg; arrow;ret_typ] in
(* once we have proper excetions we need to know what the exceptions type is *)
@@ -320,10 +320,37 @@ let doc_typquant_items_lem (TypQ_aux(tq,_)) = match tq with
let doc_typquant_lem (TypQ_aux(tq,_)) typ = match tq with
| TypQ_tq ((_ :: _) as qs) ->
string "forall " ^^ separate_map space doc_quant_item qs ^^ string ". " ^^ typ
-| _ -> empty
+| _ -> typ
+
+(* Produce Size type constraints for bitvector sizes when using
+ machine words. Often these will be unnecessary, but this simple
+ approach will do for now. *)
+
+let rec typeclass_vars (Typ_aux(t,_)) = match t with
+| Typ_wild
+| Typ_id _
+| Typ_var _
+ -> []
+| Typ_fn (t1,t2,_) -> typeclass_vars t1 @ typeclass_vars t2
+| Typ_tup ts -> List.concat (List.map typeclass_vars ts)
+| Typ_app (Id_aux (Id "vector",_),
+ [_;Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var v,_)),_);
+ _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
+ [v]
+| Typ_app _ -> []
+| Typ_exist (kids,_,t) -> [] (* todo *)
+
+let doc_typclasses_lem mwords t =
+ if mwords then
+ let vars = typeclass_vars t in
+ let vars = List.sort_uniq Kid.compare vars in
+ match vars with
+ | [] -> empty
+ | _ -> separate_map comma_sp (fun var -> string "Size " ^^ doc_var var) vars ^^ string " => "
+ else empty
let doc_typschm_lem sequential mwords quants (TypSchm_aux(TypSchm_ts(tq,t),_)) =
- if quants then (doc_typquant_lem tq (doc_typ_lem sequential mwords t))
+ if quants then (doc_typquant_lem tq (doc_typclasses_lem mwords t ^^ doc_typ_lem sequential mwords t))
else doc_typ_lem sequential mwords t
let is_ctor env id = match Env.lookup_id id env with
@@ -1538,12 +1565,16 @@ let doc_dec_lem sequential (DEC_aux (reg, ((l, _) as annot))) =
| DEC_alias(id,alspec) -> empty
| DEC_typ_alias(typ,id,alspec) -> empty
-let doc_spec_lem mwords (VS_aux (valspec,annot)) = empty
+let doc_spec_lem sequential mwords (VS_aux (valspec,annot)) =
+ match valspec with
+ | VS_val_spec (typschm,id,None,_) ->
+ separate space [string "val"; doc_id_lem id; string ":";doc_typschm_lem sequential mwords true typschm] ^/^ hardline
+ | VS_val_spec (_,_,Some _,_) -> empty
let rec doc_def_lem sequential mwords def =
(* let _ = Pretty_print_sail.pp_defs stderr (Defs [def]) in *)
match def with
- | DEF_spec v_spec -> (doc_spec_lem mwords v_spec,empty)
+ | DEF_spec v_spec -> (empty,doc_spec_lem sequential mwords v_spec)
| DEF_overload _ -> (empty,empty)
| DEF_type t_def -> (group (doc_typdef_lem sequential mwords t_def) ^/^ hardline,empty)
| DEF_reg_dec dec -> (group (doc_dec_lem sequential dec),empty)
diff --git a/src/process_file.ml b/src/process_file.ml
index 8d23fcd2..d35ccf5e 100644
--- a/src/process_file.ml
+++ b/src/process_file.ml
@@ -118,8 +118,11 @@ let check_ast (defs : unit Ast.defs) : Type_check.tannot Ast.defs * Type_check.E
let () = if !opt_just_check then exit 0 else () in
(ast, env)
+let opt_ddump_raw_mono_ast = ref false
+
let monomorphise_ast locs ast =
let ast = Monomorphise.split_defs locs ast in
+ let () = if !opt_ddump_raw_mono_ast then Pretty_print.pp_defs stdout ast else () in
let ienv = Type_check.Env.no_casts Type_check.initial_env in
Type_check.check ienv ast
diff --git a/src/process_file.mli b/src/process_file.mli
index 53a6f3f2..c477d185 100644
--- a/src/process_file.mli
+++ b/src/process_file.mli
@@ -58,6 +58,7 @@ val opt_just_check : bool ref
val opt_ddump_tc_ast : bool ref
val opt_ddump_rewrite_ast : ((string * int) option) ref
val opt_dno_cast : bool ref
+val opt_ddump_raw_mono_ast : bool ref
type out_type =
| Lem_ast_out
diff --git a/src/sail.ml b/src/sail.ml
index e5a5071f..ca121a79 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -96,6 +96,9 @@ let options = Arg.align ([
| [fname;line;var] -> opt_mono_split := ((fname,int_of_string line),var)::!opt_mono_split
| _ -> raise (Arg.Bad (s ^ " not of form <filename>:<line>:<variable>"))),
"<filename>:<line>:<variable> to case split for monomorphisation");
+ ( "-ddump_raw_mono_ast",
+ Arg.Set opt_ddump_raw_mono_ast,
+ " (debug) dump the monomorphised ast before type-checking");
( "-new_parser",
Arg.Set Process_file.opt_new_parser,
" (experimental) use new parser");
diff --git a/src/type_check.mli b/src/type_check.mli
index 9f5771b9..b6b5e75e 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -235,6 +235,8 @@ type uvar =
val string_of_uvar : uvar -> string
+val unify : l -> Env.t -> typ -> typ -> uvar KBindings.t * kid list * n_constraint option
+
(* Throws Invalid_argument if the argument is not a E_app expression *)
val instantiation_of : tannot exp -> uvar KBindings.t
diff --git a/test/mono/.gitignore b/test/mono/.gitignore
new file mode 100644
index 00000000..2dd3daa3
--- /dev/null
+++ b/test/mono/.gitignore
@@ -0,0 +1 @@
+test-out \ No newline at end of file
diff --git a/test/mono/addsubexist.sail b/test/mono/addsubexist.sail
new file mode 100644
index 00000000..f59f596e
--- /dev/null
+++ b/test/mono/addsubexist.sail
@@ -0,0 +1,75 @@
+(* Adapted from hand-written ARM model *)
+
+default Order dec
+typedef boolean = bit
+typedef reg_size = bit[5]
+typedef reg_index = [|31|]
+
+val reg_size -> reg_index effect pure UInt_reg
+val unit -> unit effect pure (* probably not pure *) ReservedValue
+function forall Nat 'N. bit['N] NOT((bit['N]) x) = ~(x)
+val forall Nat 'M, Nat 'N. bit['M] -> bit['N] effect pure ZeroExtend
+val forall Nat 'N. (bit['N], bit['N], bit) -> (bit['N],bit[4]) effect pure AddWithCarry
+val forall Nat 'N, 'N IN {8,16,32,64}. (*broken? implicit<'N>*)unit -> bit['N] effect {rreg} rSP
+val forall Nat 'N, 'N IN {8,16,32,64}. ((*ditto implicit<'N>,*)reg_index) -> bit['N] effect {rreg}rX
+val (unit, bit[4]) -> unit effect {wreg} wPSTATE_NZCV
+val forall Nat 'N, 'N IN {32,64}. (unit, bit['N]) -> unit effect {rreg,wreg} wSP
+val forall Nat 'N, 'N IN {32,64}. (reg_index,bit['N]) -> unit effect {wreg} wX
+
+typedef ast = const union
+{
+ (exist 'R, 'R in {32,64}. (reg_index,reg_index,[:'R:],boolean,boolean,bit['R])) AddSubImmediate;
+}
+
+val ast -> unit effect {rreg,wreg(*,rmem,barr,eamem,wmv,escape*)} execute
+scattered function unit execute
+
+val bit[32] -> option<(ast)> effect pure decodeAddSubtractImmediate
+
+(* ADD/ADDS (immediate) *)
+(* SUB/SUBS (immediate) *)
+function option<(ast)> effect pure decodeAddSubtractImmediate ([sf]:[op]:[S]:0b10001:(bit[2]) shift:(bit[12]) imm12:(reg_size) Rn:(reg_size) Rd) =
+{
+ (reg_index) d := UInt_reg(Rd);
+ (reg_index) n := UInt_reg(Rn);
+ let (exist 'R, 'R in {32,64}. [:'R:]) 'datasize = if sf then 64 else 32 in {
+ (boolean) sub_op := op;
+ (boolean) setflags := S;
+ (bit['datasize]) imm := 0; (* ARM: uninitialized *)
+
+ switch shift {
+ case 0b00 -> imm := ZeroExtend(imm12)
+ case 0b01 -> imm := ZeroExtend(imm12 : (0b0 ^^ 12))
+ case [bitone,_] -> ReservedValue()
+ };
+
+ Some(AddSubImmediate( (d,n,datasize,sub_op,setflags,imm) ));
+}}
+
+function clause execute (AddSubImmediate('datasize)) = {
+switch datasize {
+case (d,n,datasize,sub_op,setflags,imm) ->
+{
+ (bit['datasize]) operand1 := if n == 31 then rSP() else rX(n);
+ (bit['datasize]) operand2 := imm;
+ (bit) carry_in := bitzero; (* ARM: uninitialized *)
+
+ if sub_op then {
+ operand2 := NOT(operand2);
+ carry_in := bitone;
+ }
+ else
+ carry_in := bitzero;
+
+ let (result,nzcv) = (AddWithCarry(operand1, operand2, carry_in)) in {
+
+ if setflags then
+ wPSTATE_NZCV() := nzcv;
+
+ if (d == 31 & ~(setflags)) then
+ wSP() := result
+ else
+ wX(d) := result;
+ }
+}}}
+end execute
diff --git a/test/mono/fnreduce.sail b/test/mono/fnreduce.sail
new file mode 100644
index 00000000..f39fb87d
--- /dev/null
+++ b/test/mono/fnreduce.sail
@@ -0,0 +1,69 @@
+(* Test constant propagation part of monomorphisation involving
+ functions. We should reduce a function application when the
+ arguments are suitable values, the function is pure and the result
+ is a value.
+ *)
+
+typedef AnEnum = enumerate { One; Two; Three }
+
+typedef EnumlikeUnion = const union { First; Second }
+
+typedef ProperUnion = const union {
+ (AnEnum) Stuff;
+ (EnumlikeUnion) Nonsense;
+}
+
+function AnEnum canReduce ((AnEnum) x) = {
+ switch (x) {
+ case One -> Two
+ case x -> x
+ }
+}
+
+function nat cannotReduce ((AnEnum) x) = {
+ let (nat) y = switch (x) { case One -> 1 case _ -> 5 } in
+ 2 + y
+}
+
+function AnEnum effect {rreg} fakeUnpure ((AnEnum) x) = {
+ switch (x) {
+ case One -> Two
+ case x -> x
+ }
+}
+
+function EnumlikeUnion canReduceUnion ((AnEnum) x) = {
+ switch (x) {
+ case One -> First
+ case _ -> Second
+ }
+}
+
+function ProperUnion canReduceUnion2 ((AnEnum) x) = {
+ switch (x) {
+ case One -> Nonsense(First)
+ case y -> Stuff(y)
+ }
+}
+
+(* FIXME LATER: once effect handling is in place we should get an error
+ because this isn't pure *)
+
+val AnEnum -> (AnEnum,nat,AnEnum,EnumlikeUnion,ProperUnion) effect pure test
+
+function test (x) = {
+ let a = canReduce(x) in
+ let b = cannotReduce(x) in
+ let c = fakeUnpure(x) in
+ let d = canReduceUnion(x) in
+ let e = canReduceUnion2(x) in
+ (a,b,c,d,e)
+}
+
+val unit -> bool effect pure run
+
+function run () = {
+ test(One) == (Two,3,Two,First,Nonsense(First)) &
+ test(Two) == (Two,7,Two,Second,Stuff(Two)) &
+ test(Three) == (Three,7,Three,Second,Stuff(Three))
+}
diff --git a/test/mono/test.ml b/test/mono/test.ml
new file mode 100644
index 00000000..f99abfb8
--- /dev/null
+++ b/test/mono/test.ml
@@ -0,0 +1 @@
+if Testout_embed_sequential.run() then print_endline "OK" else print_endline "Failed";;
diff --git a/test/mono/test.sh b/test/mono/test.sh
new file mode 100755
index 00000000..2a5aa80b
--- /dev/null
+++ b/test/mono/test.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+
+set -e
+
+DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+SAILDIR="$DIR/../.."
+LEMDIR="$DIR/../../../lem"
+OUTDIR="$DIR/test-out"
+ZARITH="$LEMDIR/ocaml-lib/dependencies/zarith"
+
+if [ ! -d "$OUTDIR" ]; then
+ mkdir -- "$OUTDIR"
+fi
+cd "$OUTDIR"
+
+TESTONLY="$1"
+if [ -n "$TESTONLY" ]; then shift; fi
+
+LOG="$DIR/log"
+date > "$LOG"
+
+exec 3< "$DIR/tests"
+set +e
+
+while read -u 3 TEST ARGS; do
+ if [ -z "$TESTONLY" -o "$TEST" = "$TESTONLY" ]; then
+# echo "$TEST ocaml"
+# rm -f -- "$OUTDIR"/*
+# "$SAILDIR/sail" -ocaml "$SAILDIR/lib/prelude.sail" "$SAILDIR/lib/prelude_wrappers.sail" "$DIR/$TEST" -o "$OUTDIR/testout" $ARGS
+# cp -- "$SAILDIR"/src/gen_lib/sail_values.ml .
+# cp -- "$DIR"/test.ml .
+# ocamlc -I "$ZARITH" "$ZARITH/zarith.cma" -dllpath "$ZARITH" -I "$LEMDIR/ocaml-lib" "$LEMDIR/ocaml-lib/extract.cma" -I "$SAILDIR/src/_build/lem_interp" "$SAILDIR/src/_build/lem_interp/extract.cma" sail_values.ml testout.ml test.ml -o test
+# ./test
+
+ echo "$TEST lem - ocaml" | tee -a -- "$LOG"
+ rm -f -- "$OUTDIR"/*
+ "$SAILDIR/sail" -lem -lem_sequential -lem_mwords "$SAILDIR/lib/prelude.sail" "$SAILDIR/lib/prelude_wrappers.sail" "$DIR/$TEST".sail -o "$OUTDIR/testout" $ARGS $@ &>> "$LOG" && \
+ "$LEMDIR/bin/lem" -ocaml -lib "$SAILDIR/src/lem_interp" "$SAILDIR/src/gen_lib/sail_values.lem" "$SAILDIR/src/gen_lib/sail_operators_mwords.lem" "$SAILDIR/src/gen_lib/state.lem" testout_embed_types_sequential.lem testout_embed_sequential.lem -outdir "$OUTDIR" &>> "$LOG" && \
+ cp -- "$DIR"/test.ml "$OUTDIR" && \
+ ocamlc -I "$ZARITH" "$ZARITH/zarith.cma" -dllpath "$ZARITH" -I "$LEMDIR/ocaml-lib" "$LEMDIR/ocaml-lib/extract.cma" -I "$SAILDIR/src/_build/lem_interp" "$SAILDIR/src/_build/lem_interp/extract.cma" sail_values.ml sail_operators_mwords.ml state.ml testout_embed_types_sequential.ml testout_embed_sequential.ml test.ml -o test &>> "$LOG" && \
+ ./test |& tee -a -- "$LOG" || \
+ (echo "Failed:"; echo; tail -- "$LOG"; echo; echo)
+ fi
+done
diff --git a/test/mono/tests b/test/mono/tests
new file mode 100644
index 00000000..0825c686
--- /dev/null
+++ b/test/mono/tests
@@ -0,0 +1,4 @@
+fnreduce -mono-split fnreduce.sail:43:x
+varmatch -mono-split varmatch.sail:7:x
+vector -mono-split vector.sail:7:sel
+union-exist -mono-split union-exist.sail:9:v
diff --git a/test/mono/union-exist.sail b/test/mono/union-exist.sail
new file mode 100644
index 00000000..74ab429a
--- /dev/null
+++ b/test/mono/union-exist.sail
@@ -0,0 +1,33 @@
+default Order dec
+
+typedef myunion = const union {
+ (exist 'n, 'n in {8,16}. ([:'n:],bit['n])) MyConstr;
+}
+
+val bit[2] -> myunion effect pure make
+
+function make(v) =
+ (* Can't mention these below without running into exp/nexp parsing conflict! *)
+ let eight = 8 in let sixteen = 16 in
+ switch v {
+ case 0b00 -> MyConstr( ( eight, 0x12) )
+ case 0b01 -> MyConstr( (sixteen,0x1234) )
+ case 0b10 -> MyConstr( ( eight, 0x56) )
+ case 0b11 -> MyConstr( (sixteen,0x5678) )
+ }
+
+val myunion -> bit[32] effect pure use
+
+function use(MyConstr('n)) = {
+ switch n {
+ case (n,v) -> extz(v)
+ }
+}
+val unit -> bool effect pure run
+
+function run () = {
+ use(make(0b00)) == 0x00000012 &
+ use(make(0b01)) == 0x00001234 &
+ use(make(0b10)) == 0x00000056 &
+ use(make(0b11)) == 0x00005678
+}
diff --git a/test/mono/varmatch.sail b/test/mono/varmatch.sail
new file mode 100644
index 00000000..7d2b9b73
--- /dev/null
+++ b/test/mono/varmatch.sail
@@ -0,0 +1,19 @@
+(* Check that when we case split on a variable that the constant propagation
+ handles the default case correctly. *)
+
+typedef AnEnum = enumerate { One; Two; Three }
+
+function AnEnum foo((AnEnum) x) = {
+ switch (x) {
+ case One -> Two
+ case y -> y
+ }
+}
+
+val unit -> bool effect pure run
+
+function run () = {
+ foo(One) == Two &
+ foo(Two) == Two &
+ foo(Three) == Three
+}
diff --git a/test/mono/vector.sail b/test/mono/vector.sail
new file mode 100644
index 00000000..03f36da5
--- /dev/null
+++ b/test/mono/vector.sail
@@ -0,0 +1,21 @@
+(* Check case splitting on a vector *)
+
+default Order dec
+
+val bit[32] -> nat effect pure test
+
+function nat test((bit[2]) sel : (bit[30]) _) = {
+ switch (sel) {
+ case 0b00 -> 5
+ case 0b10 -> 1
+ case _ -> 0
+ }
+}
+
+val unit -> bool effect pure run
+
+function run () = {
+ test(0x0f353533) == 5 &
+ test(0x84534656) == 1 &
+ test(0xf3463903) == 0
+}