summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--LICENCE (renamed from src/LICENCE)17
-rw-r--r--aarch64/main.sail2
-rw-r--r--aarch64/mono/aarch64_extras.lem6
-rw-r--r--aarch64/mono/aarch64_integer_crc.sail13
-rw-r--r--aarch64/mono/aarch64_memory_exclusive_pair.sail97
-rw-r--r--aarch64/mono/aarch64_memory_exclusive_single.sail97
-rw-r--r--aarch64/mono/extra_constraints.sail56
-rw-r--r--aarch64/no_vector.isail6
-rw-r--r--aarch64/no_vector/spec.sail2
-rw-r--r--aarch64/prelude.sail190
-rw-r--r--language/bytecode.ott2
-rw-r--r--lib/elf.sail10
-rw-r--r--lib/flow.sail2
-rw-r--r--lib/vector_dec.sail12
-rw-r--r--src/ast_util.ml18
-rw-r--r--src/ast_util.mli6
-rw-r--r--src/bitfield.ml40
-rw-r--r--src/c_backend.ml420
-rw-r--r--src/monomorphise.ml365
-rw-r--r--src/pattern_completeness.ml7
-rw-r--r--src/pretty_print_lem.ml42
-rw-r--r--src/pretty_print_sail.ml2
-rw-r--r--src/rewrites.ml1
-rw-r--r--src/sail.ml2
-rw-r--r--src/specialize.ml7
-rw-r--r--test/c/gvector.expect3
-rw-r--r--test/c/gvector.sail20
-rw-r--r--test/c/sail.h344
-rw-r--r--test/ocaml/bitfield/bitfield.sail9
-rw-r--r--test/ocaml/bitfield/expect2
30 files changed, 1520 insertions, 280 deletions
diff --git a/src/LICENCE b/LICENCE
index c777e037..451ce6c3 100644
--- a/src/LICENCE
+++ b/LICENCE
@@ -1,6 +1,13 @@
Sail
-Copyright (c) 2013-2017
+Sail and the Sail architecture models here, comprising all files and
+directories except the PPrint library, are subject to the BSD
+two-clause licence below.
+
+The PPrint library, in src/pprint, is subject to the CeCILL-C free
+software licence agreement therein.
+
+Copyright (c) 2013-2018
Kathyrn Gray
Shaked Flur
Stephen Kell
@@ -19,10 +26,10 @@ Copyright (c) 2013-2017
All rights reserved.
-This software was developed by the University of Cambridge Computer
-Laboratory and the University of Edinburgh as part of the Rigorous
-Engineering of Mainstream Systems (REMS) project, funded by EPSRC
-grant EP/K008528/1.
+This software was developed by the above within the Rigorous
+Engineering of Mainstream Systems (REMS) project, partly funded by
+EPSRC grant EP/K008528/1, at the Universities of Cambridge and
+Edinburgh.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
diff --git a/aarch64/main.sail b/aarch64/main.sail
index fd590b2c..e9e2f84f 100644
--- a/aarch64/main.sail
+++ b/aarch64/main.sail
@@ -10,7 +10,7 @@ function fetch_and_execute () =
decode(instr);
} catch {
Error_See("HINT") => (),
- exn => throw(exn)
+ _ => exit(())
};
if __BranchTaken then __BranchTaken = false else _PC = _PC + 4
}
diff --git a/aarch64/mono/aarch64_extras.lem b/aarch64/mono/aarch64_extras.lem
index 4a32ad44..8af9ee5d 100644
--- a/aarch64/mono/aarch64_extras.lem
+++ b/aarch64/mono/aarch64_extras.lem
@@ -5,6 +5,12 @@ open import Sail_operators_mwords
open import Prompt_monad
open import State
+type ty512
+instance (Size ty512) let size = 512 end
+declare isabelle target_rep type ty512 = `512`
+type ty1024
+instance (Size ty1024) let size = 1024 end
+declare isabelle target_rep type ty1024 = `1024`
type ty2048
instance (Size ty2048) let size = 2048 end
declare isabelle target_rep type ty2048 = `2048`
diff --git a/aarch64/mono/aarch64_integer_crc.sail b/aarch64/mono/aarch64_integer_crc.sail
new file mode 100644
index 00000000..729a05a1
--- /dev/null
+++ b/aarch64/mono/aarch64_integer_crc.sail
@@ -0,0 +1,13 @@
+val aarch64_integer_crc : forall ('size : Int).
+ (bool, int, int, int, atom('size)) -> unit effect {escape, undef, rreg, wreg}
+
+function aarch64_integer_crc (crc32c, d, m, n, size) = {
+ assert(constraint('size in {8,16,32,64}));
+ if ~(HaveCRCExt()) then UnallocatedEncoding() else ();
+ acc : bits(32) = aget_X(n);
+ val_name : bits('size) = aget_X(m);
+ poly : bits(32) = __GetSlice_int(32, if crc32c then 517762881 else 79764919, 0);
+ tempacc : bits('size + 32) = BitReverse(acc) @ Zeros(size);
+ tempval : bits('size + 32) = BitReverse(val_name) @ Zeros(32);
+ aset_X(d, BitReverse(Poly32Mod2(tempacc ^ tempval, poly)))
+}
diff --git a/aarch64/mono/aarch64_memory_exclusive_pair.sail b/aarch64/mono/aarch64_memory_exclusive_pair.sail
new file mode 100644
index 00000000..27897eb3
--- /dev/null
+++ b/aarch64/mono/aarch64_memory_exclusive_pair.sail
@@ -0,0 +1,97 @@
+/* Moved and tightened an assertion */
+
+val aarch64_memory_exclusive_pair : forall ('datasize : Int) ('regsize : Int) ('elsize : Int).
+ (AccType, atom('datasize), atom('elsize), MemOp, int, bool, atom('regsize), int, int, int) -> unit effect {escape, undef, rreg, wreg, rmem, wmem}
+
+function aarch64_memory_exclusive_pair (acctype, datasize, elsize, memop, n, pair, regsize, s, t, t2) = {
+ assert(constraint('regsize >= 0), "regsize constraint");
+ let 'dbytes = ex_int(datasize / 8);
+ assert(constraint('datasize in {8, 16, 32, 64, 128}), "datasize constraint");
+ assert(constraint(8 * 'dbytes = 'datasize), "dbytes constraint");
+ address : bits(64) = undefined;
+ data : bits('datasize) = undefined;
+ rt_unknown : bool = false;
+ rn_unknown : bool = false;
+ if (memop == MemOp_LOAD & pair) & t == t2 then {
+ c : Constraint = ConstrainUnpredictable(Unpredictable_LDPOVERLAP);
+ assert(c == Constraint_UNKNOWN | c == Constraint_UNDEF | c == Constraint_NOP, "((c == Constraint_UNKNOWN) || ((c == Constraint_UNDEF) || (c == Constraint_NOP)))");
+ match c {
+ Constraint_UNKNOWN => rt_unknown = true,
+ Constraint_UNDEF => UnallocatedEncoding(),
+ Constraint_NOP => EndOfInstruction()
+ }
+ } else ();
+ if memop == MemOp_STORE then {
+ if s == t | pair & s == t2 then {
+ c : Constraint = ConstrainUnpredictable(Unpredictable_DATAOVERLAP);
+ assert(c == Constraint_UNKNOWN | c == Constraint_NONE | c == Constraint_UNDEF | c == Constraint_NOP, "((c == Constraint_UNKNOWN) || ((c == Constraint_NONE) || ((c == Constraint_UNDEF) || (c == Constraint_NOP))))");
+ match c {
+ Constraint_UNKNOWN => rt_unknown = true,
+ Constraint_NONE => rt_unknown = false,
+ Constraint_UNDEF => UnallocatedEncoding(),
+ Constraint_NOP => EndOfInstruction()
+ }
+ } else ();
+ if s == n & n != 31 then {
+ c : Constraint = ConstrainUnpredictable(Unpredictable_BASEOVERLAP);
+ assert(c == Constraint_UNKNOWN | c == Constraint_NONE | c == Constraint_UNDEF | c == Constraint_NOP, "((c == Constraint_UNKNOWN) || ((c == Constraint_NONE) || ((c == Constraint_UNDEF) || (c == Constraint_NOP))))");
+ match c {
+ Constraint_UNKNOWN => rn_unknown = true,
+ Constraint_NONE => rn_unknown = false,
+ Constraint_UNDEF => UnallocatedEncoding(),
+ Constraint_NOP => EndOfInstruction()
+ }
+ } else ()
+ } else ();
+ if n == 31 then {
+ CheckSPAlignment();
+ address = aget_SP()
+ } else if rn_unknown then address = undefined
+ else address = aget_X(n);
+ secondstage : bool = undefined;
+ iswrite : bool = undefined;
+ match memop {
+ MemOp_STORE => {
+ if rt_unknown then data = undefined
+ else if pair then let 'v = ex_int(datasize / 2) in {
+ assert(constraint(2 * 'v = 'datasize));
+ el1 : bits('v) = aget_X(t);
+ el2 : bits('v) = aget_X(t2);
+ data = if BigEndian() then el1 @ el2 else el2 @ el1
+ } else data = aget_X(t);
+ status : bits(1) = 0b1;
+ if AArch64_ExclusiveMonitorsPass(address, dbytes) then {
+ aset_Mem(address, dbytes, acctype, data);
+ status = ExclusiveMonitorsStatus()
+ } else ();
+ aset_X(s, ZeroExtend(status, 32))
+ },
+ MemOp_LOAD => {
+ AArch64_SetExclusiveMonitors(address, dbytes);
+ if pair then
+ if rt_unknown then aset_X(t, undefined : bits(32)) else if elsize == 32 then {
+ assert(constraint(- 'elsize + 'datasize > 0 & 'elsize >= 0), "datasize constraint");
+ data = aget_Mem(address, dbytes, acctype);
+ if BigEndian() then {
+ aset_X(t, slice(data, elsize, negate(elsize) + datasize));
+ aset_X(t2, slice(data, 0, elsize))
+ } else {
+ aset_X(t, slice(data, 0, elsize));
+ aset_X(t2, slice(data, elsize, negate(elsize) + datasize))
+ }
+ } else {
+ if address != Align(address, dbytes) then {
+ iswrite = false;
+ secondstage = false;
+ AArch64_Abort(address, AArch64_AlignmentFault(acctype, iswrite, secondstage))
+ } else ();
+ aset_X(t, aget_Mem(address + 0, 8, acctype));
+ aset_X(t2, aget_Mem(address + 8, 8, acctype))
+ }
+ else {
+ data = aget_Mem(address, dbytes, acctype);
+ aset_X(t, ZeroExtend(data, regsize))
+ }
+ }
+ }
+}
diff --git a/aarch64/mono/aarch64_memory_exclusive_single.sail b/aarch64/mono/aarch64_memory_exclusive_single.sail
new file mode 100644
index 00000000..a370794e
--- /dev/null
+++ b/aarch64/mono/aarch64_memory_exclusive_single.sail
@@ -0,0 +1,97 @@
+/* Changed an assertion to a strict inequality and moved it to where that's true */
+
+val aarch64_memory_exclusive_single : forall ('datasize : Int) 'elsize ('regsize : Int).
+ (AccType, atom('datasize), atom('elsize), MemOp, int, bool, atom('regsize), int, int, int) -> unit effect {escape, undef, rreg, wreg, rmem, wmem}
+
+function aarch64_memory_exclusive_single (acctype, datasize, elsize, memop, n, pair, regsize, s, t, t2) = {
+ assert(constraint('regsize >= 0), "destsize constraint");
+ let 'dbytes = ex_int(datasize / 8);
+ assert(constraint('datasize in {8, 16, 32, 64, 128}), "datasize constraint");
+ assert(constraint(8 * 'dbytes = 'datasize), "dbytes constraint");
+ address : bits(64) = undefined;
+ data : bits('datasize) = undefined;
+ rt_unknown : bool = false;
+ rn_unknown : bool = false;
+ if (memop == MemOp_LOAD & pair) & t == t2 then {
+ c : Constraint = ConstrainUnpredictable(Unpredictable_LDPOVERLAP);
+ assert(c == Constraint_UNKNOWN | c == Constraint_UNDEF | c == Constraint_NOP, "((c == Constraint_UNKNOWN) || ((c == Constraint_UNDEF) || (c == Constraint_NOP)))");
+ match c {
+ Constraint_UNKNOWN => rt_unknown = true,
+ Constraint_UNDEF => UnallocatedEncoding(),
+ Constraint_NOP => EndOfInstruction()
+ }
+ } else ();
+ if memop == MemOp_STORE then {
+ if s == t | pair & s == t2 then {
+ c : Constraint = ConstrainUnpredictable(Unpredictable_DATAOVERLAP);
+ assert(c == Constraint_UNKNOWN | c == Constraint_NONE | c == Constraint_UNDEF | c == Constraint_NOP, "((c == Constraint_UNKNOWN) || ((c == Constraint_NONE) || ((c == Constraint_UNDEF) || (c == Constraint_NOP))))");
+ match c {
+ Constraint_UNKNOWN => rt_unknown = true,
+ Constraint_NONE => rt_unknown = false,
+ Constraint_UNDEF => UnallocatedEncoding(),
+ Constraint_NOP => EndOfInstruction()
+ }
+ } else ();
+ if s == n & n != 31 then {
+ c : Constraint = ConstrainUnpredictable(Unpredictable_BASEOVERLAP);
+ assert(c == Constraint_UNKNOWN | c == Constraint_NONE | c == Constraint_UNDEF | c == Constraint_NOP, "((c == Constraint_UNKNOWN) || ((c == Constraint_NONE) || ((c == Constraint_UNDEF) || (c == Constraint_NOP))))");
+ match c {
+ Constraint_UNKNOWN => rn_unknown = true,
+ Constraint_NONE => rn_unknown = false,
+ Constraint_UNDEF => UnallocatedEncoding(),
+ Constraint_NOP => EndOfInstruction()
+ }
+ } else ()
+ } else ();
+ if n == 31 then {
+ CheckSPAlignment();
+ address = aget_SP()
+ } else if rn_unknown then address = undefined
+ else address = aget_X(n);
+ secondstage : bool = undefined;
+ iswrite : bool = undefined;
+ match memop {
+ MemOp_STORE => {
+ if rt_unknown then data = undefined
+ else if pair then let 'v = ex_int(datasize / 2) in {
+ assert(constraint(2 * 'v = 'datasize));
+ el1 : bits('v) = aget_X(t);
+ el2 : bits('v) = aget_X(t2);
+ data = if BigEndian() then el1 @ el2 else el2 @ el1
+ } else data = aget_X(t);
+ status : bits(1) = 0b1;
+ if AArch64_ExclusiveMonitorsPass(address, dbytes) then {
+ aset_Mem(address, dbytes, acctype, data);
+ status = ExclusiveMonitorsStatus()
+ } else ();
+ aset_X(s, ZeroExtend(status, 32))
+ },
+ MemOp_LOAD => {
+ AArch64_SetExclusiveMonitors(address, dbytes);
+ if pair then {
+ assert(constraint(- 'elsize + 'datasize > 0 & 'elsize >= 0));
+ if rt_unknown then aset_X(t, undefined : bits(32)) else if elsize == 32 then {
+ data = aget_Mem(address, dbytes, acctype);
+ if BigEndian() then {
+ aset_X(t, slice(data, elsize, negate(elsize) + datasize));
+ aset_X(t2, slice(data, 0, elsize))
+ } else {
+ aset_X(t, slice(data, 0, elsize));
+ aset_X(t2, slice(data, elsize, negate(elsize) + datasize))
+ }
+ } else {
+ if address != Align(address, dbytes) then {
+ iswrite = false;
+ secondstage = false;
+ AArch64_Abort(address, AArch64_AlignmentFault(acctype, iswrite, secondstage))
+ } else ();
+ aset_X(t, aget_Mem(address + 0, 8, acctype));
+ aset_X(t2, aget_Mem(address + 8, 8, acctype))
+ }
+ } else {
+ data = aget_Mem(address, dbytes, acctype);
+ aset_X(t, ZeroExtend(data, regsize))
+ }
+ }
+ }
+}
diff --git a/aarch64/mono/extra_constraints.sail b/aarch64/mono/extra_constraints.sail
new file mode 100644
index 00000000..2f89f401
--- /dev/null
+++ b/aarch64/mono/extra_constraints.sail
@@ -0,0 +1,56 @@
+/* Ideally we'd rewrite these to take the bit size */
+
+val aarch64_memory_literal_simdfp : forall ('size : Int).
+ (bits(64), atom('size), int) -> unit effect {escape, undef, wreg, rreg, rmem, wmem}
+
+function aarch64_memory_literal_simdfp (offset, size, t) = {
+ assert(constraint('size >= 0));
+ assert(constraint('size in {4,8,16}));
+ address : bits(64) = aget_PC() + offset;
+ data : bits(8 * 'size) = undefined;
+ CheckFPAdvSIMDEnabled64();
+ data = aget_Mem(address, size, AccType_VEC);
+ aset_V(t, data)
+}
+
+/* like this, which would be difficult otherwise... */
+
+val aarch64_memory_literal_general : forall ('size : Int).
+ (MemOp, bits(64), bool, atom('size), int) -> unit effect {escape, undef, wreg, rreg, rmem, wmem}
+
+function aarch64_memory_literal_general (memop, offset, signed, size, t) = {
+ address : bits(64) = aget_PC() + offset;
+ data : bits('size) = undefined;
+ match memop {
+ MemOp_LOAD => {
+ assert(constraint('size >= 0));
+ let 'bytes = size / 8;
+ assert(constraint(8 * 'bytes = 'size));
+ data = aget_Mem(address, bytes, AccType_NORMAL);
+ if signed then aset_X(t, SignExtend(data, 64)) else aset_X(t, data)
+ },
+ MemOp_PREFETCH => Prefetch(address, __GetSlice_int(5, t, 0))
+ }
+}
+
+val memory_literal_general_decode : (bits(2), bits(1), bits(19), bits(5)) -> unit effect {escape, rmem, rreg, undef, wmem, wreg}
+
+function memory_literal_general_decode (opc, V, imm19, Rt) = {
+ __unconditional = true;
+ t : int = UInt(Rt);
+ memop : MemOp = MemOp_LOAD;
+ signed : bool = false;
+ size : int = undefined;
+ offset : bits(64) = undefined;
+ match opc {
+ 0b00 => size = 4,
+ 0b01 => size = 8,
+ 0b10 => {
+ size = 4;
+ signed = true
+ },
+ 0b11 => memop = MemOp_PREFETCH
+ };
+ offset = SignExtend(imm19 @ 0b00, 64);
+ aarch64_memory_literal_general(memop, offset, signed, 8 * size, t)
+}
diff --git a/aarch64/no_vector.isail b/aarch64/no_vector.isail
index 15e29e26..aa0855d5 100644
--- a/aarch64/no_vector.isail
+++ b/aarch64/no_vector.isail
@@ -1,4 +1,8 @@
:unload
:load prelude.sail no_vector/spec.sail decode_start.sail no_vector/decode.sail decode_end.sail main.sail
initialize_registers()
-:run \ No newline at end of file
+:run
+:elf ../bench.elf
+main()
+:run
+:q \ No newline at end of file
diff --git a/aarch64/no_vector/spec.sail b/aarch64/no_vector/spec.sail
index 6edec31c..21448bc1 100644
--- a/aarch64/no_vector/spec.sail
+++ b/aarch64/no_vector/spec.sail
@@ -1370,7 +1370,7 @@ function AArch64_SysRegWrite ('op0, 'op1, 'crn, 'crm, 'op2, val_name) = assert(f
val AArch64_SysRegRead : (int, int, int, int, int) -> bits(64) effect {escape, undef}
-function AArch64_SysRegRead _ = {
+function AArch64_SysRegRead (_, _, _, _, _) = {
assert(false, "Tried to read system register");
undefined
}
diff --git a/aarch64/prelude.sail b/aarch64/prelude.sail
index b4c59fef..4c6b7974 100644
--- a/aarch64/prelude.sail
+++ b/aarch64/prelude.sail
@@ -5,65 +5,89 @@ $include <flow.sail>
type bits ('n : Int) = vector('n, dec, bit)
-val eq_bit = {ocaml: "(fun (x, y) -> x = y)", lem: "eq", interpreter: "eq_anything"} : (bit, bit) -> bool
+val eq_bit = {ocaml: "(fun (x, y) -> x = y)", lem: "eq", interpreter: "eq_anything", c: "eq_bit"} : (bit, bit) -> bool
-val eq_vec = {ocaml: "eq_list", lem: "eq_vec"} : forall 'n. (bits('n), bits('n)) -> bool
+val eq_vec = {ocaml: "eq_list", lem: "eq_vec", c: "eq_bits"} : forall 'n. (bits('n), bits('n)) -> bool
-val eq_string = {ocaml: "eq_string", lem: "eq"} : (string, string) -> bool
+val eq_string = {ocaml: "eq_string", lem: "eq", c: "eq_string"} : (string, string) -> bool
-val eq_real = {ocaml: "eq_real", lem: "eq"} : (real, real) -> bool
+val eq_real = {ocaml: "eq_real", lem: "eq", c: "eq_real"} : (real, real) -> bool
val eq_anything = {
ocaml: "(fun (x, y) -> x = y)",
interpreter: "eq_anything",
- lem: "eq"
+ lem: "eq",
+ c: "eq_anything"
} : forall ('a : Type). ('a, 'a) -> bool
-val bitvector_length = {ocaml: "length", lem: "length"} : forall 'n. bits('n) -> atom('n)
-val vector_length = {ocaml: "length", lem: "length_list"} : forall 'n ('a : Type). vector('n, dec, 'a) -> atom('n)
-val list_length = {ocaml: "length", lem: "length_list"} : forall ('a : Type). list('a) -> int
+val bitvector_length = "length" : forall 'n. bits('n) -> atom('n)
+val vector_length = {ocaml: "length", lem: "length_list", c: "length"} : forall 'n ('a : Type). vector('n, dec, 'a) -> atom('n)
+val list_length = {ocaml: "length", lem: "length_list", c: "length"} : forall ('a : Type). list('a) -> int
overload length = {bitvector_length, vector_length, list_length}
overload operator == = {eq_bit, eq_vec, eq_string, eq_real, eq_anything}
-val vector_subrange_A = {ocaml: "subrange", lem: "subrange_vec_dec"} : forall ('n : Int) ('m : Int) ('o : Int), 'o <= 'm <= 'n.
+val vector_subrange_A = {
+ ocaml: "subrange",
+ lem: "subrange_vec_dec",
+ c: "vector_subrange"
+} : forall ('n : Int) ('m : Int) ('o : Int), 'o <= 'm <= 'n.
(bits('n), atom('m), atom('o)) -> bits('m - ('o - 1))
-val vector_subrange_B = {ocaml: "subrange", lem: "subrange_vec_dec"} : forall ('n : Int) ('m : Int) ('o : Int).
+val vector_subrange_B = {
+ ocaml: "subrange",
+ lem: "subrange_vec_dec",
+ c: "vector_subrange"
+} : forall ('n : Int) ('m : Int) ('o : Int).
(bits('n), atom('m), atom('o)) -> bits('m - ('o - 1))
overload vector_subrange = {vector_subrange_A, vector_subrange_B}
-val bitvector_access_A = {ocaml: "access", lem: "access_vec_dec"} : forall ('n : Int) ('m : Int), 0 <= 'm < 'n.
- (bits('n), atom('m)) -> bit
-
-val bitvector_access_B = {ocaml: "access", lem: "access_vec_dec"} : forall ('n : Int).
- (bits('n), int) -> bit
-
-val vector_access_A = {ocaml: "access", lem: "access_list_dec"} : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n.
- (vector('n, dec, 'a), atom('m)) -> 'a
-
-val vector_access_B = {ocaml: "access", lem: "access_list_dec"} : forall ('n : Int) ('a : Type).
- (vector('n, dec, 'a), int) -> 'a
+val bitvector_access_A = {
+ ocaml: "access",
+ lem: "access_vec_dec",
+ c: "vector_access"
+} : forall ('n : Int) ('m : Int), 0 <= 'm < 'n. (bits('n), atom('m)) -> bit
+
+val bitvector_access_B = {
+ ocaml: "access",
+ lem: "access_vec_dec",
+ c: "vector_access"
+} : forall ('n : Int). (bits('n), int) -> bit
+
+val vector_access_A = {
+ ocaml: "access",
+ lem: "access_list_dec",
+ c: "vector_access"
+} : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n. (vector('n, dec, 'a), atom('m)) -> 'a
+
+val vector_access_B = {
+ ocaml: "access",
+ lem: "access_list_dec",
+ c: "vector_access"
+} : forall ('n : Int) ('a : Type). (vector('n, dec, 'a), int) -> 'a
overload vector_access = {bitvector_access_A, bitvector_access_B, vector_access_A, vector_access_B}
-val bitvector_update_B = {ocaml: "update", lem: "update_vec_dec"} : forall 'n.
+val bitvector_update_B = {ocaml: "update", lem: "update_vec_dec", c: "vector_update"} : forall 'n.
(bits('n), int, bit) -> bits('n)
-val vector_update_B = {ocaml: "update", lem: "update_list_dec"} : forall 'n ('a : Type).
+val vector_update_B = {ocaml: "update", lem: "update_list_dec", c: "vector_update"} : forall 'n ('a : Type).
(vector('n, dec, 'a), int, 'a) -> vector('n, dec, 'a)
overload vector_update = {bitvector_update_B, vector_update_B}
-val vector_update_subrange = {ocaml: "update_subrange", lem: "update_subrange_vec_dec"} : forall 'n 'm 'o.
- (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n)
+val vector_update_subrange = {
+ ocaml: "update_subrange",
+ lem: "update_subrange_vec_dec",
+ c: "vector_update_subrange"
+} : forall 'n 'm 'o. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n)
val vcons : forall ('n : Int) ('a : Type).
('a, vector('n, dec, 'a)) -> vector('n + 1, dec, 'a)
-val bitvector_concat = {ocaml: "append", lem: "concat_vec"} : forall ('n : Int) ('m : Int).
+val bitvector_concat = {ocaml: "append", lem: "concat_vec", c: "append"} : forall ('n : Int) ('m : Int).
(bits('n), bits('m)) -> bits('n + 'm)
val vector_concat = {ocaml: "append", lem: "append_list"} : forall ('n : Int) ('m : Int) ('a : Type).
@@ -71,7 +95,11 @@ val vector_concat = {ocaml: "append", lem: "append_list"} : forall ('n : Int) ('
overload append = {bitvector_concat, vector_concat}
-val not_vec = "not_vec" : forall 'n. bits('n) -> bits('n)
+val not_vec = {
+ ocaml: "not_vec",
+ lem: "not_vec",
+ c: "not_bits"
+} : forall 'n. bits('n) -> bits('n)
overload ~ = {not_bool, not_vec}
@@ -89,23 +117,28 @@ function neq_anything (x, y) = not_bool(x == y)
overload operator != = {neq_atom, neq_int, neq_vec, neq_anything}
-val builtin_and_vec = {ocaml: "and_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
+val builtin_and_vec = {ocaml: "and_vec", c: "and_bits"} : forall 'n. (bits('n), bits('n)) -> bits('n)
-val and_vec = {lem: "and_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
+val and_vec = {lem: "and_vec", c: "and_bits"} : forall 'n. (bits('n), bits('n)) -> bits('n)
function and_vec (xs, ys) = builtin_and_vec(xs, ys)
overload operator & = {and_bool, and_vec}
-val builtin_or_vec = {ocaml: "or_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
+val builtin_or_vec = {ocaml: "or_vec", c: "or_bits"} : forall 'n. (bits('n), bits('n)) -> bits('n)
-val or_vec = {lem: "or_vec"}: forall 'n. (bits('n), bits('n)) -> bits('n)
+val or_vec = {lem: "or_vec", c: "or_bits"}: forall 'n. (bits('n), bits('n)) -> bits('n)
function or_vec (xs, ys) = builtin_or_vec(xs, ys)
overload operator | = {or_bool, or_vec}
-val UInt = "uint" : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1)
+val UInt = {
+ ocaml: "uint",
+ lem: "uint",
+ interpreter: "uint",
+ c: "sail_uint"
+} : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1)
val SInt = "sint" : forall 'n. bits('n) -> range(- (2 ^ ('n - 1)), 2 ^ ('n - 1) - 1)
@@ -146,7 +179,12 @@ function cast_unit_vec b =
val print = "prerr_endline" : string -> unit
-val putchar = "putchar" : forall ('a : Type). 'a -> unit
+val putchar = {
+ ocaml: "putchar",
+ lem: "putchar",
+ interpreter: "putchar",
+ c: "sail_putchar"
+} : int -> unit
val concat_str = {ocaml: "concat_str", lem: "stringAppend"} : (string, string) -> string
@@ -156,78 +194,94 @@ val HexStr : int -> string
val BitStr = "string_of_bits" : forall 'n. bits('n) -> string
-val xor_vec = "xor_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
+val xor_vec = {
+ ocaml: "xor_vec",
+ lem: "xor_vec",
+ c: "xor_bits"
+} : forall 'n. (bits('n), bits('n)) -> bits('n)
val int_power = {lem: "pow"} : (int, int) -> int
-val real_power = {ocaml: "real_power", lem: "realPowInteger"} : (real, int) -> real
+val real_power = {ocaml: "real_power", lem: "realPowInteger", c: "real_power"} : (real, int) -> real
overload operator ^ = {xor_vec, int_power, real_power}
-val add_range = {ocaml: "add_int", lem: "integerAdd"} : forall 'n 'm 'o 'p.
+val add_range = {ocaml: "add_int", lem: "integerAdd", c: "add_int"} : forall 'n 'm 'o 'p.
(range('n, 'm), range('o, 'p)) -> range('n + 'o, 'm + 'p)
-val add_int = {ocaml: "add_int", lem: "integerAdd"} : (int, int) -> int
+val add_int = {ocaml: "add_int", lem: "integerAdd", c: "add_int"} : (int, int) -> int
-val add_vec = "add_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
+val add_vec = {
+ ocaml: "add_vec",
+ lem: "add_vec",
+ c: "add_bits"
+} : forall 'n. (bits('n), bits('n)) -> bits('n)
-val add_vec_int = "add_vec_int" : forall 'n. (bits('n), int) -> bits('n)
+val add_vec_int = {
+ ocaml: "add_vec_int",
+ lem: "add_vec_int",
+ c: "add_bits_int"
+} : forall 'n. (bits('n), int) -> bits('n)
-val add_real = {ocaml: "add_real", lem: "realAdd"} : (real, real) -> real
+val add_real = {ocaml: "add_real", lem: "realAdd", c: "add_real"} : (real, real) -> real
overload operator + = {add_range, add_int, add_vec, add_vec_int, add_real}
-val sub_range = {ocaml: "sub_int", lem: "integerMinus"} : forall 'n 'm 'o 'p.
+val sub_range = {ocaml: "sub_int", lem: "integerMinus", c: "sub_int"} : forall 'n 'm 'o 'p.
(range('n, 'm), range('o, 'p)) -> range('n - 'p, 'm - 'o)
-val sub_int = {ocaml: "sub_int", lem: "integerMinus"} : (int, int) -> int
+val sub_int = {ocaml: "sub_int", lem: "integerMinus", c: "sub_int"} : (int, int) -> int
val "sub_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
-val "sub_vec_int" : forall 'n. (bits('n), int) -> bits('n)
+val sub_vec_int = {
+ ocaml: "sub_vec_int",
+ lem: "sub_vec_int",
+ c: "sub_bits_int"
+} : forall 'n. (bits('n), int) -> bits('n)
-val sub_real = {ocaml: "sub_real", lem: "realMinus"} : (real, real) -> real
+val sub_real = {ocaml: "sub_real", lem: "realMinus", c: "sub_real"} : (real, real) -> real
-val negate_range = {ocaml: "negate", lem: "integerNegate"} : forall 'n 'm. range('n, 'm) -> range(- 'm, - 'n)
+val negate_range = {ocaml: "negate", lem: "integerNegate", c: "neg_int"} : forall 'n 'm. range('n, 'm) -> range(- 'm, - 'n)
-val negate_int = {ocaml: "negate", lem: "integerNegate"} : int -> int
+val negate_int = {ocaml: "negate", lem: "integerNegate", c: "neg_int"} : int -> int
-val negate_real = {ocaml: "negate_real", lem: "realNegate"} : real -> real
+val negate_real = {ocaml: "negate_real", lem: "realNegate", c: "neg_real"} : real -> real
overload operator - = {sub_range, sub_int, sub_vec, sub_vec_int, sub_real}
overload negate = {negate_range, negate_int, negate_real}
-val mult_range = {ocaml: "mult", lem: "integerMult"} : forall 'n 'm 'o 'p.
+val mult_range = {ocaml: "mult", lem: "integerMult", c: "mult_int"} : forall 'n 'm 'o 'p.
(range('n, 'm), range('o, 'p)) -> range('n * 'o, 'm * 'p)
-val mult_int = {ocaml: "mult", lem: "integerMult"} : (int, int) -> int
+val mult_int = {ocaml: "mult", lem: "integerMult", c: "mult_int"} : (int, int) -> int
-val mult_real = {ocaml: "mult_real", lem: "realMult"} : (real, real) -> real
+val mult_real = {ocaml: "mult_real", lem: "realMult", c: "mult_real"} : (real, real) -> real
overload operator * = {mult_range, mult_int, mult_real}
-val Sqrt = {ocaml: "sqrt_real", lem: "realSqrt"} : real -> real
+val Sqrt = {ocaml: "sqrt_real", lem: "realSqrt", c: "sqrt_real"} : real -> real
-val gteq_real = {ocaml: "gteq_real", lem: "gteq"} : (real, real) -> bool
+val gteq_real = {ocaml: "gteq_real", lem: "gteq", c: "gteq_real"} : (real, real) -> bool
overload operator >= = {gteq_atom, gteq_int, gteq_real}
-val lteq_real = {ocaml: "lteq_real", lem: "lteq"} : (real, real) -> bool
+val lteq_real = {ocaml: "lteq_real", lem: "lteq", c: "lteq_real"} : (real, real) -> bool
overload operator <= = {lteq_atom, lteq_int, lteq_real}
-val gt_real = {ocaml: "gt_real", lem: "gt"} : (real, real) -> bool
+val gt_real = {ocaml: "gt_real", lem: "gt", c: "gt_real"} : (real, real) -> bool
overload operator > = {gt_atom, gt_int, gt_real}
-val lt_real = {ocaml: "lt_real", lem: "lt"} : (real, real) -> bool
+val lt_real = {ocaml: "lt_real", lem: "lt", c: "lt_real"} : (real, real) -> bool
overload operator < = {lt_atom, lt_int, lt_real}
-val RoundDown = {ocaml: "round_down", lem: "realFloor"} : real -> int
+val RoundDown = {ocaml: "round_down", lem: "realFloor", c: "round_down"} : real -> int
-val RoundUp = {ocaml: "round_up", lem: "realCeiling"} : real -> int
+val RoundUp = {ocaml: "round_up", lem: "realCeiling", c: "round_up"} : real -> int
val abs_int = "abs_int" : int -> int
@@ -235,31 +289,31 @@ val abs_real = "abs_real" : real -> real
overload abs = {abs_atom, abs_int, abs_real}
-val quotient_nat = {ocaml: "quotient", lem: "integerDiv"} : (nat, nat) -> nat
+val quotient_nat = {ocaml: "quotient", lem: "integerDiv", c: "div_int"} : (nat, nat) -> nat
-val quotient_real = {ocaml: "quotient_real", lem: "realDiv"} : (real, real) -> real
+val quotient_real = {ocaml: "quotient_real", lem: "realDiv", c: "div_real"} : (real, real) -> real
-val quotient = {ocaml: "quotient", lem: "integerDiv"} : (int, int) -> int
+val quotient = {ocaml: "quotient", lem: "integerDiv", c: "div_int"} : (int, int) -> int
overload operator / = {quotient_nat, quotient, quotient_real}
-val modulus = {ocaml: "modulus", lem: "hardware_mod"} : (int, int) -> int
+val modulus = {ocaml: "modulus", lem: "hardware_mod", c: "mod_int"} : (int, int) -> int
overload operator % = {modulus}
-val Real = {ocaml: "to_real", lem: "realFromInteger"} : int -> real
+val Real = {ocaml: "to_real", lem: "realFromInteger", c: "to_real"} : int -> real
val shl_int = "shl_int" : (int, int) -> int
val shr_int = "shr_int" : (int, int) -> int
-val min_nat = {ocaml: "min_int", lem: "min"} : (nat, nat) -> nat
+val min_nat = {ocaml: "min_int", lem: "min", c: "max_int"} : (nat, nat) -> nat
-val min_int = {ocaml: "min_int", lem: "min"} : (int, int) -> int
+val min_int = {ocaml: "min_int", lem: "min", c: "max_int"} : (int, int) -> int
-val max_nat = {ocaml: "max_int", lem: "max"} : (nat, nat) -> nat
+val max_nat = {ocaml: "max_int", lem: "max", c: "max_int"} : (nat, nat) -> nat
-val max_int = {ocaml: "max_int", lem: "max"} : (int, int) -> int
+val max_int = {ocaml: "max_int", lem: "max", c: "max_int"} : (int, int) -> int
overload min = {min_nat, min_int}
@@ -273,7 +327,7 @@ val __TraceMemoryWrite : forall 'n 'm.
val __InitRAM : forall 'm. (atom('m), int, bits('m), bits(8)) -> unit
-function __InitRAM _ = ()
+function __InitRAM (_, _, _, _) = ()
val __ReadRAM = "read_ram" : forall 'n 'm.
(atom('m), atom('n), bits('m), bits('m)) -> bits(8 * 'n) effect {rmem}
@@ -296,7 +350,7 @@ val ex_range : forall 'n 'm.
val coerce_int_nat : int -> nat effect {escape}
function coerce_int_nat 'x = {
- assert(constraint('x >= 0));
+ assert(constraint('x >= 0), "Cannot coerce int to nat");
x
}
diff --git a/language/bytecode.ott b/language/bytecode.ott
index e909fc09..e0d7db24 100644
--- a/language/bytecode.ott
+++ b/language/bytecode.ott
@@ -140,6 +140,8 @@ cdef :: 'CDEF_' ::=
} :: :: let
% The first list of instructions creates up the global letbinding, the
% second kills it.
+ | val id ( ctyp0 , ... , ctypn ) -> ctyp
+ :: :: spec
| function id mid ( id0 , ... , idn ) {
instr0 ; ... ; instrm
} :: :: fundef
diff --git a/lib/elf.sail b/lib/elf.sail
index f158fbad..e953839d 100644
--- a/lib/elf.sail
+++ b/lib/elf.sail
@@ -1,8 +1,14 @@
$ifndef _ELF
$define _ELF
-val elf_entry = "Elf_loader.elf_entry" : unit -> int
+val elf_entry = {
+ ocaml: "Elf_loader.elf_entry",
+ c: "elf_entry"
+} : unit -> int
-val elf_tohost = "Elf_loader.elf_tohost" : unit -> int
+val elf_tohost = {
+ ocaml: "Elf_loader.elf_tohost",
+ c: "elf_tohost"
+} : unit -> int
$endif
diff --git a/lib/flow.sail b/lib/flow.sail
index 8c902803..1a0e0f2f 100644
--- a/lib/flow.sail
+++ b/lib/flow.sail
@@ -34,6 +34,8 @@ val lt_int = "lt" : (int, int) -> bool
val gt_int = "lt" : (int, int) -> bool
overload operator == = {eq_atom, eq_range, eq_int}
+overload operator | = {or_bool}
+overload operator & = {and_bool}
$ifdef TEST
diff --git a/lib/vector_dec.sail b/lib/vector_dec.sail
index e24f5111..8a55ed61 100644
--- a/lib/vector_dec.sail
+++ b/lib/vector_dec.sail
@@ -13,6 +13,18 @@ val "zero_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m)
/* Used for creating long bitvector literals in the C backend. */
val "append_64" : forall 'n. (bits('n), bits(64)) -> bits('n + 64)
+val vector_access = {
+ ocaml: "access",
+ lem: "access_list_dec",
+ c: "vector_access"
+} : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n. (vector('n, dec, 'a), atom('m)) -> 'a
+
+val vector_update = {
+ ocaml: "update",
+ lem: "update_list_dec",
+ c: "vector_update"
+} : forall 'n ('a : Type). (vector('n, dec, 'a), int, 'a) -> vector('n, dec, 'a)
+
val add_bits = {
ocaml: "add_vec",
c: "add_bits"
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 27ae93e8..5bbf9a40 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -182,6 +182,7 @@ module IdSet = Set.Make(Id)
module KBindings = Map.Make(Kid)
module KidSet = Set.Make(Kid)
module NexpSet = Set.Make(Nexp)
+module NexpMap = Map.Make(Nexp)
let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0)
@@ -235,6 +236,14 @@ and nexp_simp_aux = function
when Big_int.equal c1 (Big_int.negate c2) -> n
| _, _ -> Nexp_minus (n1, n2)
end
+ | Nexp_app (Id_aux (Id "div",_) as id,[n1;n2]) ->
+ begin
+ let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in
+ let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in
+ match n1_simp, n2_simp with
+ | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.div c1 c2)
+ | _, _ -> Nexp_app (id,[n1;n2])
+ end
| nexp -> nexp
let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown)
@@ -483,7 +492,7 @@ let append_id id str =
match id with
| Id_aux (Id v, l) -> Id_aux (Id (v ^ str), l)
| Id_aux (DeIid v, l) -> Id_aux (DeIid (v ^ str), l)
-
+
let prepend_kid str = function
| Kid_aux (Var v, l) -> Kid_aux (Var ("'" ^ str ^ String.sub v 1 (String.length v - 1)), l)
@@ -1098,3 +1107,10 @@ and subst_lexp id value (LEXP_aux (lexp_aux, annot) as lexp) =
| LEXP_field (lexp, id') -> LEXP_field (subst_lexp id value lexp, id')
in
wrap lexp_aux
+
+let hex_to_bin hex =
+ Util.string_to_list hex
+ |> List.map Sail_lib.hex_char
+ |> List.concat
+ |> List.map Sail_lib.char_of_bit
+ |> (fun bits -> String.init (List.length bits) (List.nth bits))
diff --git a/src/ast_util.mli b/src/ast_util.mli
index bbbde27f..9f815899 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -248,6 +248,10 @@ module NexpSet : sig
include Set.S with type elt = nexp
end
+module NexpMap : sig
+ include Map.S with type key = nexp
+end
+
module BESet : sig
include Set.S with type elt = base_effect
end
@@ -316,3 +320,5 @@ val ids_of_defs : 'a defs -> IdSet.t
val pat_ids : 'a pat -> IdSet.t
val subst : id -> 'a exp -> 'a exp -> 'a exp
+
+val hex_to_bin : string -> string
diff --git a/src/bitfield.ml b/src/bitfield.ml
index 67a26b89..391a653d 100644
--- a/src/bitfield.ml
+++ b/src/bitfield.ml
@@ -92,18 +92,18 @@ let full_accessor name size order =
combine [full_getter name size order; full_setter name size order; full_overload name order]
(* For every index range, create a getter and setter *)
-let index_range_getter' name field order start stop =
+let index_range_getter name field order start stop =
let size = if start > stop then start - (stop - 1) else stop - (start - 1) in
- let irg_val = Printf.sprintf "val _get_%s : %s -> %s" field name (bitvec size order) in
- let irg_function = Printf.sprintf "function _get_%s Mk_%s(v) = v[%i .. %i]" field name start stop in
+ let irg_val = Printf.sprintf "val _get_%s_%s : %s -> %s" name field name (bitvec size order) in
+ let irg_function = Printf.sprintf "function _get_%s_%s Mk_%s(v) = v[%i .. %i]" name field name start stop in
combine [ast_of_def_string order irg_val; ast_of_def_string order irg_function]
-let index_range_setter' name field order start stop =
+let index_range_setter name field order start stop =
let size = if start > stop then start - (stop - 1) else stop - (start - 1) in
- let irs_val = Printf.sprintf "val _set_%s : (register(%s), %s) -> unit effect {wreg}" field name (bitvec size order) in
+ let irs_val = Printf.sprintf "val _set_%s_%s : (register(%s), %s) -> unit effect {wreg}" name field name (bitvec size order) in
(* Read-modify-write using an internal _reg_deref function without rreg effect *)
let irs_function = String.concat "\n"
- [ Printf.sprintf "function _set_%s (r_ref, v) = {" field;
+ [ Printf.sprintf "function _set_%s_%s (r_ref, v) = {" name field;
Printf.sprintf " r = _get_%s(_reg_deref(r_ref));" name;
Printf.sprintf " r[%i .. %i] = v;" start stop;
Printf.sprintf " (*r_ref) = Mk_%s(r)" name;
@@ -112,16 +112,30 @@ let index_range_setter' name field order start stop =
in
combine [ast_of_def_string order irs_val; ast_of_def_string order irs_function]
-let index_range_overload field order =
- ast_of_def_string order (Printf.sprintf "overload _mod_%s = {_get_%s, _set_%s}" field field field)
+let index_range_update name field order start stop =
+ let size = if start > stop then start - (stop - 1) else stop - (start - 1) in
+ let iru_val = Printf.sprintf "val _update_%s_%s : (%s, %s) -> %s" name field name (bitvec size order) name in
+ (* Read-modify-write using an internal _reg_deref function without rreg effect *)
+ let iru_function = String.concat "\n"
+ [ Printf.sprintf "function _update_%s_%s (Mk_%s(v), x) = {" name field name;
+ Printf.sprintf " Mk_%s([v with %i .. %i = x]);" name start stop;
+ "}"
+ ]
+ in
+ let iru_overload = Printf.sprintf "overload update_%s = {_update_%s_%s}" field name field in
+ combine [ast_of_def_string order iru_val; ast_of_def_string order iru_function; ast_of_def_string order iru_overload]
+
+let index_range_overload name field order =
+ ast_of_def_string order (Printf.sprintf "overload _mod_%s = {_get_%s_%s, _set_%s_%s}" field name field name field)
let index_range_accessor name field order (BF_aux (bf_aux, l)) =
- let getter n m = index_range_getter' name field order (Big_int.to_int n) (Big_int.to_int m) in
- let setter n m = index_range_setter' name field order (Big_int.to_int n) (Big_int.to_int m) in
- let overload = index_range_overload field order in
+ let getter n m = index_range_getter name field order (Big_int.to_int n) (Big_int.to_int m) in
+ let setter n m = index_range_setter name field order (Big_int.to_int n) (Big_int.to_int m) in
+ let update n m = index_range_update name field order (Big_int.to_int n) (Big_int.to_int m) in
+ let overload = index_range_overload name field order in
match bf_aux with
- | BF_single n -> combine [getter n n; setter n n; overload]
- | BF_range (n, m) -> combine [getter n m; setter n m; overload]
+ | BF_single n -> combine [getter n n; setter n n; update n n; overload]
+ | BF_range (n, m) -> combine [getter n m; setter n m; update n m; overload]
| BF_concat _ -> failwith "Unimplemented"
let field_accessor name order (id, ir) = index_range_accessor name (string_of_id id) order ir
diff --git a/src/c_backend.ml b/src/c_backend.ml
index 77f1b39f..fa1f2b5e 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -158,6 +158,107 @@ and aval =
| AV_record of aval Bindings.t * typ
| AV_C_fragment of fragment * typ
+(* Renaming variables in ANF expressions *)
+
+let rec frag_rename from_id to_id = function
+ | F_id id when Id.compare id from_id = 0 -> F_id to_id
+ | F_id id -> F_id id
+ | F_lit str -> F_lit str
+ | F_have_exception -> F_have_exception
+ | F_current_exception -> F_current_exception
+ | F_op (f1, op, f2) -> F_op (frag_rename from_id to_id f1, op, frag_rename from_id to_id f2)
+ | F_unary (op, f) -> F_unary (op, frag_rename from_id to_id f)
+ | F_field (f, field) -> F_field (frag_rename from_id to_id f, field)
+
+let rec apat_bindings = function
+ | AP_tup apats -> List.fold_left IdSet.union IdSet.empty (List.map apat_bindings apats)
+ | AP_id id -> IdSet.singleton id
+ | AP_global (id, typ) -> IdSet.empty
+ | AP_app (id, apat) -> apat_bindings apat
+ | AP_cons (apat1, apat2) -> IdSet.union (apat_bindings apat1) (apat_bindings apat2)
+ | AP_nil -> IdSet.empty
+ | AP_wild -> IdSet.empty
+
+let rec aval_rename from_id to_id = function
+ | AV_lit (lit, typ) -> AV_lit (lit, typ)
+ | AV_id (id, lvar) when Id.compare id from_id = 0 -> AV_id (to_id, lvar)
+ | AV_id (id, lvar) -> AV_id (id, lvar)
+ | AV_ref (id, lvar) when Id.compare id from_id = 0 -> AV_ref (to_id, lvar)
+ | AV_ref (id, lvar) -> AV_ref (id, lvar)
+ | AV_tuple avals -> AV_tuple (List.map (aval_rename from_id to_id) avals)
+ | AV_list (avals, typ) -> AV_list (List.map (aval_rename from_id to_id) avals, typ)
+ | AV_vector (avals, typ) -> AV_vector (List.map (aval_rename from_id to_id) avals, typ)
+ | AV_record (avals, typ) -> AV_record (Bindings.map (aval_rename from_id to_id) avals, typ)
+ | AV_C_fragment (fragment, typ) -> AV_C_fragment (frag_rename from_id to_id fragment, typ)
+
+let rec aexp_rename from_id to_id aexp =
+ let recur = aexp_rename from_id to_id in
+ match aexp with
+ | AE_val aval -> AE_val (aval_rename from_id to_id aval)
+ | AE_app (id, avals, typ) -> AE_app (id, List.map (aval_rename from_id to_id) avals, typ)
+ | AE_cast (aexp, typ) -> AE_cast (recur aexp, typ)
+ | AE_assign (id, typ, aexp) when Id.compare from_id id = 0 -> AE_assign (to_id, typ, aexp)
+ | AE_assign (id, typ, aexp) -> AE_assign (id, typ, aexp)
+ | AE_let (id, typ1, aexp1, aexp2, typ2) when Id.compare from_id id = 0 -> AE_let (id, typ1, aexp1, aexp2, typ2)
+ | AE_let (id, typ1, aexp1, aexp2, typ2) -> AE_let (id, typ1, recur aexp1, recur aexp2, typ2)
+ | AE_block (aexps, aexp, typ) -> AE_block (List.map recur aexps, recur aexp, typ)
+ | AE_return (aval, typ) -> AE_return (aval_rename from_id to_id aval, typ)
+ | AE_throw (aval, typ) -> AE_throw (aval_rename from_id to_id aval, typ)
+ | AE_if (aval, then_aexp, else_aexp, typ) -> AE_if (aval_rename from_id to_id aval, recur then_aexp, recur else_aexp, typ)
+ | AE_field (aval, id, typ) -> AE_field (aval_rename from_id to_id aval, id, typ)
+ | AE_case (aval, apexps, typ) -> AE_case (aval_rename from_id to_id aval, List.map (apexp_rename from_id to_id) apexps, typ)
+ | AE_try (aexp, apexps, typ) -> AE_try (aexp_rename from_id to_id aexp, List.map (apexp_rename from_id to_id) apexps, typ)
+ | AE_record_update (aval, avals, typ) -> AE_record_update (aval_rename from_id to_id aval, Bindings.map (aval_rename from_id to_id) avals, typ)
+ | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) when Id.compare from_id to_id = 0 -> AE_for (id, aexp1, aexp2, aexp3, order, aexp4)
+ | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> AE_for (id, recur aexp1, recur aexp2, recur aexp3, order, recur aexp4)
+ | AE_loop (loop, aexp1, aexp2) -> AE_loop (loop, recur aexp1, recur aexp2)
+
+and apexp_rename from_id to_id (apat, aexp1, aexp2) =
+ if IdSet.mem from_id (apat_bindings apat) then
+ (apat, aexp1, aexp2)
+ else
+ (apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2)
+
+let shadow_counter = ref 0
+
+let new_shadow id =
+ let shadow_id = append_id id ("shadow#" ^ string_of_int !shadow_counter) in
+ incr shadow_counter;
+ shadow_id
+
+let rec no_shadow ids aexp =
+ match aexp with
+ | AE_val aval -> AE_val aval
+ | AE_app (id, avals, typ) -> AE_app (id, avals, typ)
+ | AE_cast (aexp, typ) -> AE_cast (no_shadow ids aexp, typ)
+ | AE_assign (id, typ, aexp) -> AE_assign (id, typ, no_shadow ids aexp)
+ | AE_let (id, typ1, aexp1, aexp2, typ2) when IdSet.mem id ids ->
+ let shadow_id = new_shadow id in
+ let aexp1 = no_shadow ids aexp1 in
+ let ids = IdSet.add shadow_id ids in
+ AE_let (shadow_id, typ1, aexp1, no_shadow ids (aexp_rename id shadow_id aexp2), typ2)
+ | AE_let (id, typ1, aexp1, aexp2, typ2) ->
+ AE_let (id, typ1, no_shadow ids aexp1, no_shadow (IdSet.add id ids) aexp2, typ2)
+ | AE_block (aexps, aexp, typ) -> AE_block (List.map (no_shadow ids) aexps, no_shadow ids aexp, typ)
+ | AE_return (aval, typ) -> AE_return (aval, typ)
+ | AE_throw (aval, typ) -> AE_throw (aval, typ)
+ | AE_if (aval, then_aexp, else_aexp, typ) -> AE_if (aval, no_shadow ids then_aexp, no_shadow ids else_aexp, typ)
+ | AE_field (aval, id, typ) -> AE_field (aval, id, typ)
+ | AE_case (aval, apexps, typ) -> AE_case (aval, List.map (no_shadow_apexp ids) apexps, typ)
+ | AE_try (aexp, apexps, typ) -> AE_try (no_shadow ids aexp, List.map (no_shadow_apexp ids) apexps, typ)
+ | AE_record_update (aval, avals, typ) -> AE_record_update (aval, avals, typ)
+ | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) ->
+ let ids = IdSet.add id ids in
+ AE_for (id, no_shadow ids aexp1, no_shadow ids aexp2, no_shadow ids aexp3, order, no_shadow ids aexp4)
+ | AE_loop (loop, aexp1, aexp2) -> AE_loop (loop, no_shadow ids aexp1, no_shadow ids aexp2)
+
+and no_shadow_apexp ids (apat, aexp1, aexp2) =
+ let shadows = IdSet.inter (apat_bindings apat) ids in
+ let shadows = List.map (fun id -> id, new_shadow id) (IdSet.elements shadows) in
+ let rename aexp = List.fold_left (fun aexp (from_id, to_id) -> aexp_rename from_id to_id aexp) aexp shadows in
+ let ids = IdSet.union ids (IdSet.of_list (List.map snd shadows)) in
+ (apat, no_shadow ids (rename aexp1), no_shadow ids (rename aexp2))
+
(* Map over all the avals in an aexp. *)
let rec map_aval f = function
| AE_val v -> AE_val (f v)
@@ -557,8 +658,6 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) =
| E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_plet _ | E_internal_return _ | E_internal_exp_user _ ->
failwith "encountered unexpected internal node when converting to ANF"
- | E_record _ -> AE_val (AV_lit (mk_lit (L_string "testing"), string_typ)) (* c_error ("Cannot convert to ANF: " ^ string_of_exp exp) *)
-
(**************************************************************************)
(* 2. Converting sail types to C types *)
(**************************************************************************)
@@ -927,8 +1026,14 @@ let rec instr_ctyps (I_aux (instr, aux)) =
| I_throw cval | I_jump (cval, _) | I_return cval -> [cval_ctyp cval]
| I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure -> []
+let rec c_ast_registers = function
+ | CDEF_reg_dec (id, ctyp) :: ast -> (id, ctyp) :: c_ast_registers ast
+ | _ :: ast -> c_ast_registers ast
+ | [] -> []
+
let cdef_ctyps ctx = function
| CDEF_reg_dec (_, ctyp) -> [ctyp]
+ | CDEF_spec (_, ctyps, ctyp) -> ctyp :: ctyps
| CDEF_fundef (id, _, _, instrs) ->
(* TODO: Move this code to DEF_fundef -> CDEF_fundef translation, and modify bytecode.ott *)
let _, Typ_aux (fn_typ, _) =
@@ -1036,6 +1141,8 @@ let pp_ctype_def = function
^^ surround 2 0 lbrace (separate_map (semi ^^ hardline) (fun (id, ctyp) -> pp_id id ^^ string " : " ^^ pp_ctyp ctyp) ctors) rbrace
let pp_cdef = function
+ | CDEF_spec (id, ctyps, ctyp) ->
+ pp_keyword "val" ^^ pp_id id ^^ space ^^ parens (separate_map (comma ^^ space) pp_ctyp ctyps) ^^ string " -> " ^^ pp_ctyp ctyp
| CDEF_fundef (id, ret, args, instrs) ->
let ret = match ret with
| None -> empty
@@ -1071,6 +1178,10 @@ let is_ct_list = function
| CT_list _ -> true
| _ -> false
+let is_ct_vector = function
+ | CT_vector _ -> true
+ | _ -> false
+
let rec is_bitvector = function
| [] -> true
| AV_lit (L_aux (L_zero, _), _) :: avals -> is_bitvector avals
@@ -1100,7 +1211,7 @@ let rec compile_aval ctx = function
| AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal min_int64 n && Big_int.less_equal n max_int64 ->
let gs = gensym () in
[idecl CT_mpz gs;
- iinit CT_mpz gs (F_lit (Big_int.to_string n ^ "L"), CT_int64)],
+ iinit CT_mpz gs (F_lit (Big_int.to_string n ^ "l"), CT_int64)],
(F_id gs, CT_mpz),
[iclear CT_mpz gs]
@@ -1117,6 +1228,13 @@ let rec compile_aval ctx = function
| AV_lit (L_aux (L_true, _), _) -> [], (F_lit "true", CT_bool), []
| AV_lit (L_aux (L_false, _), _) -> [], (F_lit "false", CT_bool), []
+ | AV_lit (L_aux (L_real str, _), _) ->
+ let gs = gensym () in
+ [idecl CT_real gs;
+ iinit CT_real gs (F_lit ("\"" ^ str ^ "\""), CT_string)],
+ (F_id gs, CT_real),
+ [iclear CT_real gs]
+
| AV_lit (L_aux (_, l) as lit, _) ->
c_error ~loc:l ("Encountered unexpected literal " ^ string_of_lit lit)
@@ -1301,7 +1419,7 @@ let rec compile_match ctx apat cval case_label =
[]
| AP_global (pid, _), _ -> [icopy (CL_id pid) cval], []
| AP_id pid, (frag, ctyp) when is_ct_enum ctyp ->
- [ijump (F_op (F_id pid, "!=", frag), CT_bool) case_label], []
+ [idecl ctyp pid; ijump (F_op (F_id pid, "!=", frag), CT_bool) case_label], []
| AP_id pid, _ ->
let ctyp = cval_ctyp cval in
let init, cleanup = if is_stack_ctyp ctyp then [], [] else [ialloc ctyp pid], [iclear ctyp pid] in
@@ -1355,15 +1473,14 @@ let label str =
let rec compile_aexp ctx = function
| AE_let (id, _, binding, body, typ) ->
let setup, ctyp, call, cleanup = compile_aexp ctx binding in
- let letb1, letb1c =
+ let letb_setup, letb_cleanup =
if is_stack_ctyp ctyp then
- [idecl ctyp id; call (CL_id id)], []
+ [idecl ctyp id; iblock (setup @ [call (CL_id id)] @ cleanup)], []
else
- [idecl ctyp id; ialloc ctyp id; call (CL_id id)], [iclear ctyp id]
+ [idecl ctyp id; ialloc ctyp id; iblock (setup @ [call (CL_id id)] @ cleanup)], [iclear ctyp id]
in
- let letb2 = setup @ letb1 @ cleanup in
let setup, ctyp, call, cleanup = compile_aexp ctx body in
- letb2 @ setup, ctyp, call, cleanup @ letb1c
+ letb_setup @ setup, ctyp, call, cleanup @ letb_cleanup
| AE_app (id, vs, typ) ->
compile_funcall ctx id vs typ
@@ -1539,6 +1656,29 @@ let rec compile_aexp ctx = function
(fun clexp -> icopy clexp unit_fragment),
[]
+ | AE_loop (Until, cond, body) ->
+ let loop_start_label = label "repeat_" in
+ let loop_end_label = label "until_" in
+ let cond_setup, _, cond_call, cond_cleanup = compile_aexp ctx cond in
+ let body_setup, _, body_call, body_cleanup = compile_aexp ctx body in
+ let gs = gensym () in
+ let unit_gs = gensym () in
+ let loop_test = (F_unary ("!", F_id gs), CT_bool) in
+ [idecl CT_bool gs; idecl CT_unit unit_gs]
+ @ [ilabel loop_start_label]
+ @ [iblock (body_setup
+ @ [body_call (CL_id unit_gs)]
+ @ body_cleanup
+ @ cond_setup
+ @ [cond_call (CL_id gs)]
+ @ cond_cleanup
+ @ [ijump loop_test loop_end_label]
+ @ [igoto loop_start_label])]
+ @ [ilabel loop_end_label],
+ CT_unit,
+ (fun clexp -> icopy clexp unit_fragment),
+ []
+
| AE_cast (aexp, typ) -> compile_aexp ctx aexp
| AE_return (aval, typ) ->
@@ -1583,15 +1723,20 @@ and compile_block ctx = function
let gs = gensym () in
setup @ [idecl CT_unit gs; call (CL_id gs)] @ cleanup @ rest
-let rec pat_ids (P_aux (p_aux, (l, _)) as pat) =
- match p_aux with
- | P_id id -> [id]
- | P_tup pats -> List.concat (List.map pat_ids pats)
- | P_lit (L_aux (L_unit, _)) -> let gs = gensym () in [gs]
- | P_wild -> let gs = gensym () in [gs]
- | P_var (pat, _) -> pat_ids pat
- | P_typ (_, pat) -> pat_ids pat
- | _ -> c_error ~loc:l ("Cannot compile pattern " ^ string_of_pat pat ^ " to C")
+(* FIXME: this function is a bit of a hack *)
+let rec pat_ids (Typ_aux (arg_typ_aux, _) as arg_typ) (P_aux (p_aux, (l, _)) as pat) =
+ prerr_endline (string_of_typ arg_typ);
+ match p_aux, arg_typ_aux with
+ | P_id id, _ -> [id]
+ | P_tup pats, Typ_tup arg_typs when List.length pats = List.length arg_typs ->
+ List.concat (List.map2 pat_ids arg_typs pats)
+ | P_tup pats, _ -> c_error ~loc:l ("Cannot compile tuple pattern " ^ string_of_pat pat ^ " to C, as it doesn't have tuple type.")
+ | P_lit (L_aux (L_unit, _)), _ -> let gs = gensym () in [gs]
+ | P_wild, Typ_tup arg_typs -> List.map (fun _ -> let gs = gensym () in gs) arg_typs
+ | P_wild, _ -> let gs = gensym () in [gs]
+ | P_var (pat, _), _ -> pat_ids arg_typ pat
+ | P_typ (_, pat), _ -> pat_ids arg_typ pat
+ | _, _ -> c_error ~loc:l ("Cannot compile pattern " ^ string_of_pat pat ^ " to C")
(** Compile a sail type definition into a IR one. Most of the
actual work of translating the typedefs into C is done by the code
@@ -1773,27 +1918,40 @@ let compile_def ctx = function
[CDEF_reg_dec (id, ctyp_of_typ ctx typ)], ctx
| DEF_reg_dec _ -> failwith "Unsupported register declaration" (* FIXME *)
- | DEF_spec _ -> [], ctx
+ | DEF_spec (VS_aux (VS_val_spec (_, id, _, _), _)) ->
+ let _, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in
+ let arg_typs, ret_typ = match fn_typ with
+ | Typ_fn (Typ_aux (Typ_tup arg_typs, _), ret_typ, _) -> arg_typs, ret_typ
+ | Typ_fn (arg_typ, ret_typ, _) -> [arg_typ], ret_typ
+ | _ -> assert false
+ in
+ let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in
+ [CDEF_spec (id, arg_ctyps, ret_ctyp)], ctx
| DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)]), _)) ->
- let aexp = map_functions (analyze_primop ctx) (c_literals ctx (anf exp)) in
- prerr_endline (Pretty_print_sail.to_string (pp_aexp aexp));
+ let aexp = map_functions (analyze_primop ctx) (c_literals ctx (no_shadow IdSet.empty (anf exp))) in
+ if string_of_id id = "system_barriers_decode" then prerr_endline (Pretty_print_sail.to_string (pp_aexp aexp)) else ();
let setup, ctyp, call, cleanup = compile_aexp ctx aexp in
let gs = gensym () in
let pat = match pat with
| P_aux (P_tup [], annot) -> P_aux (P_lit (mk_lit L_unit), annot)
| _ -> pat
in
+ let _, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in
+ let arg_typ, ret_typ = match fn_typ with
+ | Typ_fn (arg_typ, ret_typ, _) -> arg_typ, ret_typ
+ | _ -> assert false
+ in
prerr_endline (string_of_id id ^ " : " ^ string_of_ctyp ctyp);
if is_stack_ctyp ctyp then
let instrs = [idecl ctyp gs] @ setup @ [call (CL_id gs)] @ cleanup @ [ireturn (F_id gs, ctyp)] in
let instrs = fix_exception ctx instrs in
- [CDEF_fundef (id, None, pat_ids pat, instrs)], ctx
+ [CDEF_fundef (id, None, pat_ids arg_typ pat, instrs)], ctx
else
let instrs = setup @ [call (CL_addr gs)] @ cleanup in
let instrs = fix_early_return (CL_addr gs) ctx instrs in
let instrs = fix_exception ctx instrs in
- [CDEF_fundef (id, Some gs, pat_ids pat, instrs)], ctx
+ [CDEF_fundef (id, Some gs, pat_ids arg_typ pat, instrs)], ctx
| DEF_fundef (FD_aux (FD_function (_, _, _, []), (l, _))) ->
c_error ~loc:l "Encountered function with no clauses"
@@ -1809,7 +1967,7 @@ let compile_def ctx = function
[CDEF_type tdef], ctx
| DEF_val (LB_aux (LB_val (pat, exp), _)) ->
- let aexp = map_functions (analyze_primop ctx) (c_literals ctx (anf exp)) in
+ let aexp = map_functions (analyze_primop ctx) (c_literals ctx (no_shadow IdSet.empty (anf exp))) in
let setup, ctyp, call, cleanup = compile_aexp ctx aexp in
let apat = anf_pat ~global:true pat in
let gs = gensym () in
@@ -2073,8 +2231,9 @@ let sgen_ctyp = function
| CT_enum (id, _) -> "enum " ^ sgen_id id
| CT_variant (id, _) -> "struct " ^ sgen_id id
| CT_list _ as l -> Util.zencode_string (string_of_ctyp l)
- | CT_vector _ -> "int" (* FIXME *)
+ | CT_vector _ as v -> Util.zencode_string (string_of_ctyp v)
| CT_string -> "sail_string"
+ | CT_real -> "real"
let sgen_ctyp_name = function
| CT_unit -> "unit"
@@ -2089,8 +2248,9 @@ let sgen_ctyp_name = function
| CT_enum (id, _) -> sgen_id id
| CT_variant (id, _) -> sgen_id id
| CT_list _ as l -> Util.zencode_string (string_of_ctyp l)
- | CT_vector _ -> "int" (* FIXME *)
+ | CT_vector _ as v -> Util.zencode_string (string_of_ctyp v)
| CT_string -> "sail_string"
+ | CT_real -> "real"
let sgen_cval_param (frag, ctyp) =
match ctyp with
@@ -2149,12 +2309,41 @@ let rec codegen_instr ctx (I_aux (instr, _)) =
^^ jump 2 2 (separate_map hardline (codegen_instr ctx) instrs) ^^ hardline
^^ string " }"
| I_funcall (x, f, args, ctyp) ->
- let args = Util.string_of_list ", " sgen_cval args in
+ let c_args = Util.string_of_list ", " sgen_cval args in
let fname = if Env.is_extern f ctx.tc_env "c" then Env.get_extern f ctx.tc_env "c" else sgen_id f in
+ let fname =
+ match fname, ctyp with
+ | "internal_pick", _ -> Printf.sprintf "pick_%s" (sgen_ctyp_name ctyp)
+ | "eq_anything", _ ->
+ begin match args with
+ | cval :: _ -> Printf.sprintf "eq_%s" (sgen_ctyp_name (cval_ctyp cval))
+ | _ -> c_error "eq_anything function with bad arity."
+ end
+ | "length", _ ->
+ begin match args with
+ | cval :: _ -> Printf.sprintf "length_%s" (sgen_ctyp_name (cval_ctyp cval))
+ | _ -> c_error "length function with bad arity."
+ end
+ | "vector_access", CT_bit -> "bitvector_access"
+ | "vector_access", _ ->
+ begin match args with
+ | cval :: _ -> Printf.sprintf "vector_access_%s" (sgen_ctyp_name (cval_ctyp cval))
+ | _ -> c_error "vector access function with bad arity."
+ end
+ | "vector_update_subrange", _ -> Printf.sprintf "vector_update_subrange_%s" (sgen_ctyp_name ctyp)
+ | "vector_subrange", _ -> Printf.sprintf "vector_subrange_%s" (sgen_ctyp_name ctyp)
+ | "vector_update", CT_uint64 _ -> "update_uint64_t"
+ | "vector_update", CT_bv _ -> "update_bv"
+ | "vector_update", _ -> Printf.sprintf "vector_update_%s" (sgen_ctyp_name ctyp)
+ | "undefined_vector", CT_uint64 _ -> "undefined_uint64_t"
+ | "undefined_vector", CT_bv _ -> "undefined_bv_t"
+ | "undefined_vector", _ -> Printf.sprintf "undefined_vector_%s" (sgen_ctyp_name ctyp)
+ | fname, _ -> fname
+ in
if is_stack_ctyp ctyp then
- string (Printf.sprintf " %s = %s(%s);" (sgen_clexp_pure x) fname args)
+ string (Printf.sprintf " %s = %s(%s);" (sgen_clexp_pure x) fname c_args)
else
- string (Printf.sprintf " %s(%s, %s);" fname (sgen_clexp x) args)
+ string (Printf.sprintf " %s(%s, %s);" fname (sgen_clexp x) c_args)
| I_clear (ctyp, id) ->
string (Printf.sprintf " clear_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id))
| I_init (ctyp, id, cval) ->
@@ -2165,6 +2354,23 @@ let rec codegen_instr ctx (I_aux (instr, _)) =
(sgen_cval_param cval))
| I_alloc (ctyp, id) ->
string (Printf.sprintf " init_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id))
+ (* FIXME: This just covers the cases we see in our specs, need a
+ special conversion code-generator for full generality *)
+ | I_convert (x, CT_tup ctyps1, y, CT_tup ctyps2) when List.length ctyps1 = List.length ctyps2 ->
+ let convert i (ctyp1, ctyp2) =
+ if ctyp_equal ctyp1 ctyp2 then string " /* no change */"
+ else if is_stack_ctyp ctyp1 then
+ string (Printf.sprintf " %s.ztup%i = convert_%s_of_%s(%s.ztup%i);"
+ (sgen_clexp_pure x)
+ i
+ (sgen_ctyp_name ctyp1)
+ (sgen_ctyp_name ctyp2)
+ (sgen_id y)
+ i)
+ else
+ c_error "Cannot compile type conversion"
+ in
+ separate hardline (List.mapi convert (List.map2 (fun x y -> (x, y)) ctyps1 ctyps2))
| I_convert (x, ctyp1, y, ctyp2) ->
if is_stack_ctyp ctyp1 then
string (Printf.sprintf " %s = convert_%s_of_%s(%s);"
@@ -2195,8 +2401,14 @@ let rec codegen_instr ctx (I_aux (instr, _)) =
let codegen_type_def ctx = function
| CTD_enum (id, ids) ->
+ let codegen_eq =
+ let name = sgen_id id in
+ string (Printf.sprintf "bool eq_%s(enum %s op1, enum %s op2) { return op1 == op2; }" name name name)
+ in
string (Printf.sprintf "// enum %s" (string_of_id id)) ^^ hardline
^^ separate space [string "enum"; codegen_id id; lbrace; separate_map (comma ^^ space) upper_codegen_id ids; rbrace ^^ semi]
+ ^^ twice hardline
+ ^^ codegen_eq
| CTD_struct (id, ctors) ->
(* Generate a set_T function for every struct T *)
@@ -2224,6 +2436,9 @@ let codegen_type_def ctx = function
(separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat))
rbrace
in
+ let codegen_eq =
+ string (Printf.sprintf "bool eq_%s(struct %s op1, struct %s op2) { return true; }" (sgen_id id) (sgen_id id) (sgen_id id))
+ in
(* Generate the struct and add the generated functions *)
let codegen_ctor (id, ctyp) =
string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id
@@ -2239,6 +2454,8 @@ let codegen_type_def ctx = function
^^ codegen_init "init" id (ctor_bindings ctors)
^^ twice hardline
^^ codegen_init "clear" id (ctor_bindings ctors)
+ ^^ twice hardline
+ ^^ codegen_eq
| CTD_variant (id, tus) ->
let codegen_tu (ctor_id, ctyp) =
@@ -2403,7 +2620,8 @@ let codegen_list_init id =
let codegen_list_clear id ctyp =
string (Printf.sprintf "void clear_%s(%s *rop) {\n" (sgen_id id) (sgen_id id))
^^ string (Printf.sprintf " if (*rop == NULL) return;")
- ^^ string (Printf.sprintf " clear_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))
+ ^^ (if is_stack_ctyp ctyp then empty
+ else string (Printf.sprintf " clear_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)))
^^ string (Printf.sprintf " clear_%s(&(*rop)->tl);\n" (sgen_id id))
^^ string " free(*rop);"
^^ string "}"
@@ -2412,8 +2630,11 @@ let codegen_list_set id ctyp =
string (Printf.sprintf "void internal_set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
^^ string " if (op == NULL) { *rop = NULL; return; };\n"
^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id))
- ^^ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))
- ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, op->hd);\n" (sgen_ctyp_name ctyp))
+ ^^ (if is_stack_ctyp ctyp then
+ string " (*rop)->hd = op->hd;\n"
+ else
+ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))
+ ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, op->hd);\n" (sgen_ctyp_name ctyp)))
^^ string (Printf.sprintf " internal_set_%s(&(*rop)->tl, op->tl);\n" (sgen_id id))
^^ string "}"
^^ twice hardline
@@ -2426,11 +2647,20 @@ let codegen_cons id ctyp =
let cons_id = mk_id ("cons#" ^ string_of_ctyp ctyp) in
string (Printf.sprintf "void %s(%s *rop, const %s x, const %s xs) {\n" (sgen_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id))
^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id))
- ^^ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))
- ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, x);\n" (sgen_ctyp_name ctyp))
+ ^^ (if is_stack_ctyp ctyp then
+ string " (*rop)->hd = x;\n"
+ else
+ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))
+ ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, x);\n" (sgen_ctyp_name ctyp)))
^^ string " (*rop)->tl = xs;\n"
^^ string "}"
+let codegen_pick id ctyp =
+ if is_stack_ctyp ctyp then
+ string (Printf.sprintf "%s pick_%s(const %s xs) { return xs->hd; }" (sgen_ctyp ctyp) (sgen_ctyp_name ctyp) (sgen_id id))
+ else
+ string (Printf.sprintf "void pick_%s(%s *x, const %s xs) { set_%s(x, xs->hd); }" (sgen_ctyp_name ctyp) (sgen_ctyp ctyp) (sgen_id id) (sgen_ctyp_name ctyp))
+
let codegen_list ctx ctyp =
let id = mk_id (string_of_ctyp (CT_list ctyp)) in
if IdSet.mem id !generated then
@@ -2443,6 +2673,94 @@ let codegen_list ctx ctyp =
^^ codegen_list_clear id ctyp ^^ twice hardline
^^ codegen_list_set id ctyp ^^ twice hardline
^^ codegen_cons id ctyp ^^ twice hardline
+ ^^ codegen_pick id ctyp ^^ twice hardline
+ end
+
+let codegen_vector ctx (direction, ctyp) =
+ let id = mk_id (string_of_ctyp (CT_vector (direction, ctyp))) in
+ if IdSet.mem id !generated then
+ empty
+ else
+ let vector_typedef =
+ string (Printf.sprintf "struct %s {\n size_t len;\n %s *data;\n};\n" (sgen_id id) (sgen_ctyp ctyp))
+ ^^ string (Printf.sprintf "typedef struct %s %s;" (sgen_id id) (sgen_id id))
+ in
+ let vector_init =
+ string (Printf.sprintf "void init_%s(%s *rop) {\n rop->len = 0;\n rop->data = NULL;\n}" (sgen_id id) (sgen_id id))
+ in
+ let vector_set =
+ string (Printf.sprintf "void set_%s(%s *rop, %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
+ ^^ string (Printf.sprintf " clear_%s(rop);\n" (sgen_id id))
+ ^^ string " rop->len = op.len;\n"
+ ^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp))
+ ^^ string " for (int i = 0; i < op.len; i++) {\n"
+ ^^ string (if is_stack_ctyp ctyp then
+ " (rop->data)[i] = op.data[i];\n"
+ else
+ Printf.sprintf " init_%s((rop->data) + i);\n set_%s((rop->data) + i, op.data[i]);\n" (sgen_ctyp_name ctyp) (sgen_ctyp_name ctyp))
+ ^^ string " }\n"
+ ^^ string "}"
+ in
+ let vector_clear =
+ string (Printf.sprintf "void clear_%s(%s *rop) {\n" (sgen_id id) (sgen_id id))
+ ^^ (if is_stack_ctyp ctyp then empty
+ else
+ string " for (int i = 0; i < (rop->len); i++) {\n"
+ ^^ string (Printf.sprintf " clear_%s((rop->data) + i);\n" (sgen_ctyp_name ctyp))
+ ^^ string " }\n")
+ ^^ string " if (rop->data != NULL) free(rop->data);\n"
+ ^^ string "}"
+ in
+ let vector_update =
+ string (Printf.sprintf "void vector_update_%s(%s *rop, %s op, mpz_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
+ ^^ string " int m = mpz_get_ui(n);\n"
+ ^^ string " if (rop->data == op.data) {\n"
+ ^^ string (if is_stack_ctyp ctyp then
+ " rop->data[m] = elem;\n"
+ else
+ Printf.sprintf " set_%s((rop->data) + m, elem);\n" (sgen_ctyp_name ctyp))
+ ^^ string " } else {\n"
+ ^^ string (Printf.sprintf " set_%s(rop, op);\n" (sgen_id id))
+ ^^ string (if is_stack_ctyp ctyp then
+ " rop->data[m] = elem;\n"
+ else
+ Printf.sprintf " set_%s((rop->data) + m, elem);\n" (sgen_ctyp_name ctyp))
+ ^^ string " }\n"
+ ^^ string "}"
+ in
+ let vector_access =
+ if is_stack_ctyp ctyp then
+ string (Printf.sprintf "%s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id))
+ ^^ string " int m = mpz_get_ui(n);\n"
+ ^^ string " return op.data[m];\n"
+ ^^ string "}"
+ else
+ string (Printf.sprintf "void vector_access_%s(%s *rop, %s op, mpz_t n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id))
+ ^^ string " int m = mpz_get_ui(n);\n"
+ ^^ string (Printf.sprintf " set_%s(rop, op.data[m]);\n" (sgen_ctyp_name ctyp))
+ ^^ string "}"
+ in
+ let vector_undefined =
+ string (Printf.sprintf "void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
+ ^^ string (Printf.sprintf " rop->len = mpz_get_ui(len);\n")
+ ^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp))
+ ^^ string " for (int i = 0; i < (rop->len); i++) {\n"
+ ^^ string (if is_stack_ctyp ctyp then
+ " (rop->data)[i] = elem;\n"
+ else
+ Printf.sprintf " init_%s((rop->data) + i);\n set_%s((rop->data) + i, elem);\n" (sgen_ctyp_name ctyp) (sgen_ctyp_name ctyp))
+ ^^ string " }\n"
+ ^^ string "}"
+ in
+ begin
+ generated := IdSet.add id !generated;
+ vector_typedef ^^ twice hardline
+ ^^ vector_init ^^ twice hardline
+ ^^ vector_clear ^^ twice hardline
+ ^^ vector_undefined ^^ twice hardline
+ ^^ vector_access ^^ twice hardline
+ ^^ vector_set ^^ twice hardline
+ ^^ vector_update ^^ twice hardline
end
let codegen_def' ctx = function
@@ -2450,6 +2768,14 @@ let codegen_def' ctx = function
string (Printf.sprintf "// register %s" (string_of_id id)) ^^ hardline
^^ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))
+ | CDEF_spec (id, arg_ctyps, ret_ctyp) ->
+ if Env.is_extern id ctx.tc_env "c" then
+ empty
+ else if is_stack_ctyp ret_ctyp then
+ string (Printf.sprintf "%s %s(%s);" (sgen_ctyp ret_ctyp) (sgen_id id) (Util.string_of_list ", " sgen_ctyp arg_ctyps))
+ else
+ string (Printf.sprintf "void %s(%s *rop, %s);" (sgen_id id) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps))
+
| CDEF_fundef (id, ret_arg, args, instrs) as def ->
if !opt_ddump_flow_graphs then make_dot id (instrs_graph instrs) else ();
let instrs = add_local_labels instrs in
@@ -2504,13 +2830,20 @@ let codegen_def ctx def =
| CT_list ctyp -> ctyp
| _ -> assert false
in
+ let unvector = function
+ | CT_vector (direction, ctyp) -> (direction, ctyp)
+ | _ -> assert false
+ in
let tups = List.filter is_ct_tup (cdef_ctyps ctx def) in
let tups = List.map (fun ctyp -> codegen_tup ctx (untup ctyp)) tups in
let lists = List.filter is_ct_list (cdef_ctyps ctx def) in
let lists = List.map (fun ctyp -> codegen_list ctx (unlist ctyp)) lists in
- prerr_endline (Pretty_print_sail.to_string (pp_cdef def));
+ let vectors = List.filter is_ct_vector (cdef_ctyps ctx def) in
+ let vectors = List.map (fun ctyp -> codegen_vector ctx (unvector ctyp)) vectors in
+ (* prerr_endline (Pretty_print_sail.to_string (pp_cdef def)); *)
concat tups
^^ concat lists
+ ^^ concat vectors
^^ codegen_def' ctx def
let compile_ast ctx (Defs defs) =
@@ -2542,14 +2875,29 @@ let compile_ast ctx (Defs defs) =
List.map (fun n -> Printf.sprintf " kill_letbind_%d();" n) ctx.letbinds
in
+ let regs = c_ast_registers cdefs in
+
+ let register_init_clear (id, ctyp) =
+ if is_stack_ctyp ctyp then
+ [], []
+ else
+ [ Printf.sprintf " init_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id) ],
+ [ Printf.sprintf " clear_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id) ]
+ in
+
let postamble = separate hardline (List.map string
( [ "int main(void)";
- "{" ]
+ "{";
+ " setup_real();" ]
@ fst exn_boilerplate
+ @ List.concat (List.map (fun r -> fst (register_init_clear r)) regs)
+ @ (if regs = [] then [] else [ " zinitializze_registers(UNIT);" ])
@ letbind_initializers
@ [ " zmain(UNIT);" ]
@ letbind_finalizers
+ @ List.concat (List.map (fun r -> snd (register_init_clear r)) regs)
@ snd exn_boilerplate
+ @ [ " return 0;" ]
@ [ "}" ] ))
in
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 71efcb22..d14097af 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -54,7 +54,7 @@ open Ast_util
module Big_int = Nat_big_num
open Type_check
-let size_set_limit = 32
+let size_set_limit = 64
let optmap v f =
match v with
@@ -69,6 +69,11 @@ let bindings_union s1 s2 =
| _, (Some x) -> Some x
| (Some x), _ -> Some x
| _, _ -> None) s1 s2
+let kbindings_union s1 s2 =
+ KBindings.merge (fun _ x y -> match x,y with
+ | _, (Some x) -> Some x
+ | (Some x), _ -> Some x
+ | _, _ -> None) s1 s2
let subst_nexp substs nexp =
let rec s_snexp substs (Nexp_aux (ne,l) as nexp) =
@@ -615,9 +620,9 @@ let bindings_from_pat p =
and aux_fpat (FP_aux (FP_Fpat (_,p), _)) = aux_pat p
in aux_pat p
-let remove_bound env pat =
+let remove_bound (substs,ksubsts) pat =
let bound = bindings_from_pat pat in
- List.fold_left (fun sub v -> Bindings.remove v sub) env bound
+ List.fold_left (fun sub v -> Bindings.remove v sub) substs bound, ksubsts
(* Attempt simple pattern matches *)
let lit_match = function
@@ -721,6 +726,30 @@ let int_of_str_lit = function
| L_bin bin -> Big_int.of_string ("0b" ^ bin)
| _ -> assert false
+let bits_of_lit = function
+ | L_bin bin -> bin
+ | L_hex hex -> hex_to_bin hex
+ | _ -> assert false
+
+let slice_lit (L_aux (lit,ll)) i len (Ord_aux (ord,_)) =
+ let i = Big_int.to_int i in
+ let len = Big_int.to_int len in
+ match match ord with
+ | Ord_inc -> Some i
+ | Ord_dec -> Some (len - i)
+ | Ord_var _ -> None
+ with
+ | None -> None
+ | Some i ->
+ match lit with
+ | L_bin bin -> Some (L_aux (L_bin (String.sub bin i len),Generated ll))
+ | _ -> assert false
+
+let concat_vec lit1 lit2 =
+ let bits1 = bits_of_lit lit1 in
+ let bits2 = bits_of_lit lit2 in
+ L_bin (bits1 ^ bits2)
+
let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) =
match l1,l2 with
| (L_zero|L_false), (L_zero|L_false)
@@ -758,16 +787,47 @@ let try_app (l,ann) (id,args) =
| [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit,_), _)] ->
Some (E_aux (E_lit (L_aux (L_num (int_of_str_lit lit),new_l)),(l,ann)))
| _ -> None
+ else if is_id "slice" then
+ match args with
+ | [E_aux (E_lit (L_aux ((L_hex _| L_bin _),_) as lit),
+ (_,Some (_,Typ_aux (Typ_app (_,[_;Typ_arg_aux (Typ_arg_order ord,_);_]),_),_)));
+ E_aux (E_lit L_aux (L_num i,_), _);
+ E_aux (E_lit L_aux (L_num len,_), _)] ->
+ (match slice_lit lit i len ord with
+ | Some lit' -> Some (E_aux (E_lit lit',(l,ann)))
+ | None -> None)
+ | _ -> None
+ else if is_id "bitvector_concat" then
+ match args with
+ | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit1,_), _);
+ E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit2,_), _)] ->
+ Some (E_aux (E_lit (L_aux (concat_vec lit1 lit2,new_l)),(l,ann)))
+ | _ -> None
else if is_id "shl_int" then
match args with
| [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
Some (E_aux (E_lit (L_aux (L_num (Big_int.shift_left i (Big_int.to_int j)),new_l)),(l,ann)))
| _ -> None
- else if is_id "mult_int" then
+ else if is_id "mult_int" || is_id "mult_range" then
match args with
| [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
Some (E_aux (E_lit (L_aux (L_num (Big_int.mul i j),new_l)),(l,ann)))
| _ -> None
+ else if is_id "quotient_nat" then
+ match args with
+ | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
+ Some (E_aux (E_lit (L_aux (L_num (Big_int.div i j),new_l)),(l,ann)))
+ | _ -> None
+ else if is_id "add_range" then
+ match args with
+ | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
+ Some (E_aux (E_lit (L_aux (L_num (Big_int.add i j),new_l)),(l,ann)))
+ | _ -> None
+ else if is_id "negate_range" then
+ match args with
+ | [E_aux (E_lit L_aux (L_num i,_),_)] ->
+ Some (E_aux (E_lit (L_aux (L_num (Big_int.negate i),new_l)),(l,ann)))
+ | _ -> None
else if is_id "ex_int" then
match args with
| [E_aux (E_lit lit,(l,_))] -> Some (E_aux (E_lit lit,(l,ann)))
@@ -1034,6 +1094,13 @@ let apply_pat_choices choices =
e_assert = rewrite_assert;
e_case = rewrite_case }
+(* Check whether the current environment with the given kid assignments is
+ inconsistent (and hence whether the code is dead) *)
+let is_env_inconsistent env ksubsts =
+ let env = KBindings.fold (fun k nexp env ->
+ Env.add_constraint (nc_eq (nvar k) nexp) env) ksubsts env in
+ prove env nc_false
+
let split_defs all_errors splits defs =
let no_errors_happened = ref true in
let split_constructors (Defs defs) =
@@ -1065,8 +1132,13 @@ let split_defs all_errors splits defs =
let (refinements, defs') = split_constructors defs in
+ (* COULD DO: dead code is only eliminated at if expressions, but we could
+ also cut out impossible case branches and code after assertions. *)
+
(* Constant propogation.
Takes maps of immutable/mutable variables to subsitute.
+ The substs argument also contains the current type-level kid refinements
+ so that we can check for dead code.
Extremely conservative about evaluation order of assignments in
subexpressions, dropping assignments rather than committing to
any particular order *)
@@ -1123,7 +1195,7 @@ let split_defs all_errors splits defs =
let env = Type_check.env_of_annot (l, annot) in
(try
match Env.lookup_id id env with
- | Local (Immutable,_) -> Bindings.find id substs
+ | Local (Immutable,_) -> Bindings.find id (fst substs)
| Local (Mutable,_) -> Bindings.find id assigns
| _ -> exp
with Not_found -> exp),assigns
@@ -1154,20 +1226,48 @@ let split_defs all_errors splits defs =
re (E_tuple es') assigns
| E_if (e1,e2,e3) ->
let e1',assigns = const_prop_exp substs assigns e1 in
- let e2',assigns2 = const_prop_exp substs assigns e2 in
- let e3',assigns3 = const_prop_exp substs assigns e3 in
- (match drop_casts e1' with
+ let e1_no_casts = drop_casts e1' in
+ (match e1_no_casts with
| E_aux (E_lit (L_aux ((L_true|L_false) as lit ,_)),_) ->
- (match lit with L_true -> e2',assigns2 | _ -> e3',assigns3)
+ (match lit with
+ | L_true -> const_prop_exp substs assigns e2
+ | _ -> const_prop_exp substs assigns e3)
| _ ->
- let assigns = isubst_minus_set assigns (assigned_vars e2) in
- let assigns = isubst_minus_set assigns (assigned_vars e3) in
- re (E_if (e1',e2',e3')) assigns)
+ (* If the guard is an equality check, propagate the value. *)
+ let env1 = env_of e1_no_casts in
+ let is_equal id =
+ List.exists (fun id' -> Id.compare id id' == 0)
+ (Env.get_overloads (Id_aux (DeIid "==", Parse_ast.Unknown))
+ env1)
+ in
+ let substs_true =
+ match e1_no_casts with
+ | E_aux (E_app (id, [E_aux (E_id var,_); vl]),_)
+ | E_aux (E_app (id, [vl; E_aux (E_id var,_)]),_)
+ when is_equal id ->
+ if is_value vl then
+ (match Env.lookup_id var env1 with
+ | Local (Immutable,_) -> Bindings.add var vl (fst substs),snd substs
+ | _ -> substs)
+ else substs
+ | _ -> substs
+ in
+ (* Discard impossible branches *)
+ if is_env_inconsistent (env_of e2) (snd substs) then
+ const_prop_exp substs assigns e3
+ else if is_env_inconsistent (env_of e3) (snd substs) then
+ const_prop_exp substs_true assigns e2
+ else
+ let e2',assigns2 = const_prop_exp substs_true assigns e2 in
+ let e3',assigns3 = const_prop_exp substs assigns e3 in
+ let assigns = isubst_minus_set assigns (assigned_vars e2) in
+ let assigns = isubst_minus_set assigns (assigned_vars e3) in
+ re (E_if (e1',e2',e3')) assigns)
| E_for (id,e1,e2,e3,ord,e4) ->
(* Treat e1, e2 and e3 (from, to and by) as a non-det tuple *)
let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in
let assigns = isubst_minus_set assigns (assigned_vars e4) in
- let e4',_ = const_prop_exp (Bindings.remove id substs) assigns e4 in
+ let e4',_ = const_prop_exp (Bindings.remove id (fst substs),snd substs) assigns e4 in
re (E_for (id,e1',e2',e3',ord,e4')) assigns
| E_loop (loop,e1,e2) ->
let assigns = isubst_minus_set assigns (IdSet.union (assigned_vars e1) (assigned_vars e2)) in
@@ -1227,7 +1327,7 @@ let split_defs all_errors splits defs =
| Some (E_aux (_,(_,annot')) as exp,newbindings,kbindings) ->
let exp = nexp_subst_exp (kbindings_from_list kbindings) exp in
let newbindings_env = bindings_from_list newbindings in
- let substs' = bindings_union substs newbindings_env in
+ let substs' = bindings_union (fst substs) newbindings_env, snd substs in
const_prop_exp substs' assigns exp)
| E_let (lb,e2) ->
begin
@@ -1245,7 +1345,7 @@ let split_defs all_errors splits defs =
| Some (e'',bindings,kbindings) ->
let e'' = nexp_subst_exp (kbindings_from_list kbindings) e'' in
let bindings = bindings_from_list bindings in
- let substs'' = bindings_union substs' bindings in
+ let substs'' = bindings_union (fst substs') bindings, snd substs' in
const_prop_exp substs'' assigns e''
else plain ()
end
@@ -1350,9 +1450,9 @@ let split_defs all_errors splits defs =
let cases = List.map (function
| FCL_aux (FCL_Funcl (_,pexp), ann) -> pexp)
fcls in
- match can_match_with_env env arg cases Bindings.empty Bindings.empty with
+ match can_match_with_env env arg cases (Bindings.empty,KBindings.empty) Bindings.empty with
| Some (exp,bindings,kbindings) ->
- let substs = bindings_from_list bindings in
+ let substs = bindings_from_list bindings, kbindings_from_list kbindings in
let result,_ = const_prop_exp substs Bindings.empty exp in
let result = match result with
| E_aux (E_return e,_) -> e
@@ -1361,7 +1461,7 @@ let split_defs all_errors splits defs =
if is_value result then Some result else None
| None -> None
- and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases substs assigns =
+ and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases (substs,ksubsts) assigns =
let rec findpat_generic check_pat description assigns = function
| [] -> (Reporting_basic.print_err false true l "Monomorphisation"
("Failed to find a case for " ^ description); None)
@@ -1373,7 +1473,7 @@ let split_defs all_errors splits defs =
Some (exp, [(id', exp0)], [])
| (Pat_aux (Pat_when (P_aux (P_id id',_),guard,exp),_))::tl
when pat_id_is_variable env id' -> begin
- let substs = Bindings.add id' exp0 substs in
+ let substs = Bindings.add id' exp0 substs, ksubsts in
let (E_aux (guard,_)),assigns = const_prop_exp substs assigns guard in
match guard with
| E_lit (L_aux (L_true,_)) -> Some (exp,[(id',exp0)],[])
@@ -1385,7 +1485,8 @@ let split_defs all_errors splits defs =
| DoesNotMatch -> findpat_generic check_pat description assigns tl
| DoesMatch (vsubst,ksubst) -> begin
let guard = nexp_subst_exp (kbindings_from_list ksubst) guard in
- let substs = bindings_union substs (bindings_from_list vsubst) in
+ let substs = bindings_union substs (bindings_from_list vsubst),
+ kbindings_union ksubsts (kbindings_from_list ksubst) in
let (E_aux (guard,_)),assigns = const_prop_exp substs assigns guard in
match guard with
| E_lit (L_aux (L_true,_)) -> Some (exp,vsubst,ksubst)
@@ -1463,8 +1564,8 @@ let split_defs all_errors splits defs =
can_match_with_env env exp
in
- let subst_exp substs exp =
- let substs = bindings_from_list substs in
+ let subst_exp substs ksubsts exp =
+ let substs = bindings_from_list substs, ksubsts in
fst (const_prop_exp substs Bindings.empty exp)
in
@@ -1813,8 +1914,9 @@ let split_defs all_errors splits defs =
| VarSplit patsubsts ->
if check_split_size patsubsts (pat_loc p) then
List.map (fun (pat',substs,pchoices,ksubsts) ->
- let exp' = nexp_subst_exp (kbindings_from_list ksubsts) e in
- let exp' = subst_exp substs exp' in
+ let ksubsts = kbindings_from_list ksubsts in
+ let exp' = nexp_subst_exp ksubsts e in
+ let exp' = subst_exp substs ksubsts exp' in
let exp' = apply_pat_choices pchoices exp' in
let exp' = stop_at_false_assertions exp' in
Pat_aux (Pat_exp (pat', map_exp exp'),l))
@@ -1833,11 +1935,12 @@ let split_defs all_errors splits defs =
| VarSplit patsubsts ->
if check_split_size patsubsts (pat_loc p) then
List.map (fun (pat',substs,pchoices,ksubsts) ->
- let exp1' = nexp_subst_exp (kbindings_from_list ksubsts) e1 in
- let exp1' = subst_exp substs exp1' in
+ let ksubsts = kbindings_from_list ksubsts in
+ let exp1' = nexp_subst_exp ksubsts e1 in
+ let exp1' = subst_exp substs ksubsts exp1' in
let exp1' = apply_pat_choices pchoices exp1' in
- let exp2' = nexp_subst_exp (kbindings_from_list ksubsts) e2 in
- let exp2' = subst_exp substs exp2' in
+ let exp2' = nexp_subst_exp ksubsts e2 in
+ let exp2' = subst_exp substs ksubsts exp2' in
let exp2' = apply_pat_choices pchoices exp2' in
let exp2' = stop_at_false_assertions exp2' in
Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l))
@@ -1917,27 +2020,27 @@ let findi f =
let mapat f is xs =
let rec aux n = function
- | _, [] -> []
- | (i,_)::is, h::t when i = n ->
+ | [] -> []
+ | h::t when Util.IntSet.mem n is ->
let h' = f h in
- let t' = aux (n+1) (is, t) in
+ let t' = aux (n+1) t in
h'::t'
- | is, h::t ->
- let t' = aux (n+1) (is, t) in
+ | h::t ->
+ let t' = aux (n+1) t in
h::t'
- in aux 0 (is, xs)
+ in aux 0 xs
let mapat_extra f is xs =
let rec aux n = function
- | _, [] -> [], []
- | (i,v)::is, h::t when i = n ->
- let h',x = f v h in
- let t',xs = aux (n+1) (is, t) in
+ | [] -> [], []
+ | h::t when Util.IntSet.mem n is ->
+ let h',x = f h in
+ let t',xs = aux (n+1) t in
h'::t',x::xs
- | is, h::t ->
- let t',xs = aux (n+1) (is, t) in
+ | h::t ->
+ let t',xs = aux (n+1) t in
h::t',xs
- in aux 0 (is, xs)
+ in aux 0 xs
let tyvars_bound_in_pat pat =
let open Rewriter in
@@ -1975,34 +2078,45 @@ let sizes_of_annot = function
| _,None -> KidSet.empty
| _,Some (env,typ,_) -> sizes_of_typ (Env.base_typ_of env typ)
-let change_parameter_pat kid = function
- | P_aux (P_id var, (l,_))
- | P_aux (P_typ (_,P_aux (P_id var, (l,_))),_)
- -> P_aux (P_id var, (l,None)), (var,kid)
+let change_parameter_pat = function
+ | P_aux (P_id var, (l,Some (env,typ,_)))
+ | P_aux (P_typ (_,P_aux (P_id var, (l,Some (env,typ,_)))),_) ->
+ P_aux (P_id var, (l,None)), var
| P_aux (_,(l,_)) -> raise (Reporting_basic.err_unreachable l
"Expected variable pattern")
(* We add code to change the itself('n) parameter into the corresponding
integer. *)
-let add_var_rebind exp (var,kid) =
+let add_var_rebind exp var =
let l = Generated Unknown in
let annot = (l,None) in
E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,annot),
E_aux (E_app (mk_id "size_itself_int",[E_aux (E_id var,annot)]),annot)),annot),exp),annot)
(* atom('n) arguments to function calls need to be rewritten *)
-let replace_with_the_value (E_aux (_,(l,_)) as exp) =
+let replace_with_the_value bound_nexps (E_aux (_,(l,_)) as exp) =
let env = env_of exp in
let typ, wrap = match typ_of exp with
| Typ_aux (Typ_exist (kids,nc,typ),l) -> typ, fun t -> Typ_aux (Typ_exist (kids,nc,t),l)
| typ -> typ, fun x -> x
in
let typ = Env.expand_synonyms env typ in
+ let replace_size size =
+ (* TODO: pick simpler nexp when there's a choice (also in pretty printer) *)
+ let is_equal nexp =
+ prove env (NC_aux (NC_equal (size,nexp), Parse_ast.Unknown))
+ in
+ if is_nexp_constant size then size else
+ match List.find is_equal bound_nexps with
+ | nexp -> nexp
+ | exception Not_found -> size
+ in
let mk_exp nexp l l' =
- E_aux (E_cast (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown),
- [Typ_arg_aux (Typ_arg_nexp nexp,l')]),Generated Unknown)),
- E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,None))),
- (Generated l,None))
+ let nexp = replace_size nexp in
+ E_aux (E_cast (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown),
+ [Typ_arg_aux (Typ_arg_nexp nexp,l')]),Generated Unknown)),
+ E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,None))),
+ (Generated l,None))
in
match typ with
| Typ_aux (Typ_app (Id_aux (Id "range",_),
@@ -2032,91 +2146,77 @@ let replace_type env typ =
let rewrite_size_parameters env (Defs defs) =
let open Rewriter in
- let size_vars pexp =
- fst (fold_pexp
- { (compute_exp_alg KidSet.empty KidSet.union) with
- e_aux = (fun ((s,e),annot) -> KidSet.union s (sizes_of_annot annot), E_aux (e,annot));
- e_let = (fun ((sl,lb),(s2,e2)) -> KidSet.union sl (KidSet.diff s2 (tyvars_bound_in_lb lb)), E_let (lb,e2));
- e_for = (fun (id,(s1,e1),(s2,e2),(s3,e3),ord,(s4,e4)) ->
- let kid = mk_kid ("loop_" ^ string_of_id id) in
- KidSet.union s1 (KidSet.union s2 (KidSet.union s3 (KidSet.remove kid s4))),
- E_for (id,e1,e2,e3,ord,e4));
- pat_exp = (fun ((sp,pat),(s,e)) -> KidSet.diff s (tyvars_bound_in_pat pat), Pat_exp (pat,e))}
- pexp)
- in
- let exposed_sizes_funcl fnsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) =
- let sizes = size_vars pexp in
- let pat,guard,exp,pannot = destruct_pexp pexp in
- let visible_tyvars =
- KidSet.union
- (Pretty_print_lem.lem_tyvars_of_typ (pat_typ_of pat))
- (Pretty_print_lem.lem_tyvars_of_typ (typ_of exp))
- in
- let expose_tyvars = KidSet.diff sizes visible_tyvars in
- KidSet.union fnsizes expose_tyvars
- in
- let sizes_funcl expose_tyvars fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) =
+ let open Util in
+
+ let sizes_funcl fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) =
let pat,guard,exp,pannot = destruct_pexp pexp in
let parameters = match pat with
| P_aux (P_tup ps,_) -> ps
| _ -> [pat]
in
- let to_change = Util.map_filter
- (fun kid ->
- let check (P_aux (_,(_,Some (env,typ,_)))) =
- match Env.expand_synonyms env typ with
- Typ_aux (Typ_app(Id_aux (Id "range",_),
- [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_);
- Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid'',_)),_)]),_) ->
- if Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 then Some kid else None
- | Typ_aux (Typ_app(Id_aux (Id "atom", _),
- [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_)]), _) ->
- if Kid.compare kid kid' = 0 then Some kid else None
- | _ -> None
- in match findi check parameters with
- | None -> (Reporting_basic.print_error (Reporting_basic.Err_general (l,
- ("Unable to find an argument for " ^ string_of_kid kid)));
- None)
- | Some i -> Some i)
- (KidSet.elements expose_tyvars)
+ let add_parameter (i,nmap) (P_aux (_,(_,Some (env,typ,_)))) =
+ let nmap =
+ match Env.base_typ_of env typ with
+ Typ_aux (Typ_app(Id_aux (Id "range",_),
+ [Typ_arg_aux (Typ_arg_nexp nexp,_);
+ Typ_arg_aux (Typ_arg_nexp nexp',_)]),_)
+ when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) ->
+ NexpMap.add nexp i nmap
+ | Typ_aux (Typ_app(Id_aux (Id "atom", _),
+ [Typ_arg_aux (Typ_arg_nexp nexp,_)]), _)
+ when not (NexpMap.mem nexp nmap) ->
+ NexpMap.add nexp i nmap
+ | _ -> nmap
+ in (i+1,nmap)
+ in
+ let (_,nexp_map) = List.fold_left add_parameter (0,NexpMap.empty) parameters in
+ let nexp_list = NexpMap.bindings nexp_map in
+ let parameters_for = function
+ | Some (env,typ,_) ->
+ begin match Env.base_typ_of env typ with
+ | Typ_aux (Typ_app (Id_aux (Id "vector",_), [Typ_arg_aux (Typ_arg_nexp size,_);_;_]),_)
+ when not (is_nexp_constant size) ->
+ begin
+ match NexpMap.find size nexp_map with
+ | i -> IntSet.singleton i
+ | exception Not_found ->
+ (* Look for equivalent nexps, but only in consistent type env *)
+ if prove env (NC_aux (NC_false,Unknown)) then IntSet.empty else
+ match List.find (fun (nexp,i) ->
+ prove env (NC_aux (NC_equal (nexp,size),Unknown))) nexp_list with
+ | _, i -> IntSet.singleton i
+ | exception Not_found -> IntSet.empty
+ end
+ | _ -> IntSet.empty
+ end
+ | None -> IntSet.empty
in
- let ik_compare (i,k) (i',k') =
- match compare (i : int) i' with
- | 0 -> Kid.compare k k'
- | x -> x
+ let parameters_to_rewrite =
+ fst (fold_pexp
+ { (compute_exp_alg IntSet.empty IntSet.union) with
+ e_aux = (fun ((s,e),(l,annot)) -> IntSet.union s (parameters_for annot),E_aux (e,(l,annot)))
+ } pexp)
in
- let to_change = List.sort ik_compare to_change in
+ let new_nexps = NexpSet.of_list (List.map fst
+ (List.filter (fun (nexp,i) -> IntSet.mem i parameters_to_rewrite) nexp_list)) in
match Bindings.find id fsizes with
- | old -> if List.for_all2 (fun x y -> ik_compare x y = 0) old to_change then fsizes else
- let str l = String.concat "," (List.map (fun (i,k) -> string_of_int i ^ "." ^ string_of_kid k) l) in
- raise (Reporting_basic.err_general l
- ("Different size type variables in different clauses of " ^ string_of_id id ^
- " old: " ^ str old ^ " new: " ^ str to_change))
- | exception Not_found -> Bindings.add id to_change fsizes
+ | old,old_nexps -> Bindings.add id (IntSet.union old parameters_to_rewrite,
+ NexpSet.union old_nexps new_nexps) fsizes
+ | exception Not_found -> Bindings.add id (parameters_to_rewrite, new_nexps) fsizes
in
let sizes_def fsizes = function
| DEF_fundef (FD_aux (FD_function (_,_,_,funcls),_)) ->
- let expose_tyvars = List.fold_left exposed_sizes_funcl KidSet.empty funcls in
- List.fold_left (sizes_funcl expose_tyvars) fsizes funcls
+ List.fold_left sizes_funcl fsizes funcls
| _ -> fsizes
in
let fn_sizes = List.fold_left sizes_def Bindings.empty defs in
- let rewrite_e_app (id,args) =
- match Bindings.find id fn_sizes with
- | [] -> E_app (id,args)
- | to_change ->
- let args' = mapat replace_with_the_value to_change args in
- E_app (id,args')
- | exception Not_found -> E_app (id,args)
- in
let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),(l,annot))) =
let pat,guard,body,(pl,_) = destruct_pexp pexp in
- let pat,guard,body =
+ let pat,guard,body, nexps =
(* Update pattern and add itself -> nat wrapper to body *)
match Bindings.find id fn_sizes with
- | [] -> pat,guard,body
- | to_change ->
+ | to_change,nexps ->
let pat, vars =
match pat with
P_aux (P_tup pats,(l,_)) ->
@@ -2124,13 +2224,10 @@ let rewrite_size_parameters env (Defs defs) =
P_aux (P_tup pats,(l,None)), vars
| P_aux (_,(l,_)) ->
begin
- match to_change with
- | [0,kid] ->
- let pat, var = change_parameter_pat kid pat in
+ if IntSet.is_empty to_change then pat, []
+ else
+ let pat, var = change_parameter_pat pat in
pat, [var]
- | _ ->
- raise (Reporting_basic.err_unreachable l
- "Expected multiple parameters at single parameter")
end
in
(* TODO: only add bindings that are necessary (esp for guards) *)
@@ -2139,10 +2236,24 @@ let rewrite_size_parameters env (Defs defs) =
| None -> None
| Some exp -> Some (List.fold_left add_var_rebind exp vars)
in
- pat,guard,body
- | exception Not_found -> pat,guard,body
+ pat,guard,body,nexps
+ | exception Not_found -> pat,guard,body,NexpSet.empty
in
(* Update function applications *)
+ let funcl_typ = typ_of_annot (l,annot) in
+ let already_visible_nexps =
+ NexpSet.union
+ (Pretty_print_lem.lem_nexps_of_typ funcl_typ)
+ (Pretty_print_lem.typeclass_nexps funcl_typ)
+ in
+ let bound_nexps = NexpSet.elements (NexpSet.union nexps already_visible_nexps) in
+ let rewrite_e_app (id,args) =
+ match Bindings.find id fn_sizes with
+ | to_change,_ ->
+ let args' = mapat (replace_with_the_value bound_nexps) to_change args in
+ E_app (id,args')
+ | exception Not_found -> E_app (id,args)
+ in
let body = fold_exp { id_exp_alg with e_app = rewrite_e_app } body in
let guard = match guard with
| None -> None
@@ -2156,8 +2267,7 @@ let rewrite_size_parameters env (Defs defs) =
| DEF_spec (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,annot))) as spec ->
begin
match Bindings.find id fn_sizes with
- | [] -> spec
- | to_change ->
+ | to_change,_ when not (IntSet.is_empty to_change) ->
let typschm = match typschm with
| TypSchm_aux (TypSchm_ts (tq,typ),l) ->
let typ = match typ with
@@ -2169,6 +2279,7 @@ let rewrite_size_parameters env (Defs defs) =
in TypSchm_aux (TypSchm_ts (tq,typ),l)
in
DEF_spec (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,None)))
+ | _ -> spec
| exception Not_found -> spec
end
| def -> def
diff --git a/src/pattern_completeness.ml b/src/pattern_completeness.ml
index 123592a3..ebb402e5 100644
--- a/src/pattern_completeness.ml
+++ b/src/pattern_completeness.ml
@@ -58,13 +58,6 @@ type ctx =
variants : IdSet.t Bindings.t
}
-let hex_to_bin hex =
- Util.string_to_list hex
- |> List.map Sail_lib.hex_char
- |> List.concat
- |> List.map Sail_lib.char_of_bit
- |> (fun bits -> String.init (List.length bits) (List.nth bits))
-
type gpat =
| GP_lit of lit
| GP_wild
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 38862382..ac8ad48d 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -179,10 +179,11 @@ let doc_nexp_lem nexp =
| Nexp_minus (n1, n2) -> mangle_nexp n1 ^ "_minus_" ^ mangle_nexp n2
| Nexp_exp n -> "exp_" ^ mangle_nexp n
| Nexp_neg n -> "neg_" ^ mangle_nexp n
+ | _ ->
+ raise (Reporting_basic.err_unreachable l
+ ("cannot pretty-print nexp \"" ^ string_of_nexp full_nexp ^ "\""))
end in
string ("'" ^ mangle_nexp full_nexp)
- (* raise (Reporting_basic.err_unreachable l
- ("cannot pretty-print non-atomic nexp \"" ^ string_of_nexp full_nexp ^ "\"")) *)
(* Rewrite mangled names of type variables to the original names *)
let rec orig_nexp (Nexp_aux (nexp, l)) =
@@ -321,12 +322,30 @@ let contains_t_pp_var ctxt (Typ_aux (t,a) as typ) =
NexpSet.diff (lem_nexps_of_typ typ) ctxt.bound_nexps
|> NexpSet.exists (fun nexp -> not (is_nexp_constant nexp))
-let doc_tannot_lem ctxt eff typ =
- if contains_t_pp_var ctxt typ then empty
- else
+let replace_typ_size ctxt env (Typ_aux (t,a)) =
+ match t with
+ | Typ_app (Id_aux (Id "vector",_) as id, [Typ_arg_aux (Typ_arg_nexp size,_);ord;typ']) ->
+ begin
+ let is_equal nexp =
+ prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
+ in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
+ | nexp -> Some (Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp nexp,Parse_ast.Unknown);ord;typ']),a))
+ | exception Not_found -> None
+ end
+ | _ -> None
+
+let doc_tannot_lem ctxt env eff typ =
+ let of_typ typ =
let ta = doc_typ_lem typ in
if eff then string " : M " ^^ parens ta
else string " : " ^^ ta
+ in
+ if contains_t_pp_var ctxt typ
+ then
+ match replace_typ_size ctxt env typ with
+ | None -> empty
+ | Some typ -> of_typ typ
+ else of_typ typ
let doc_lit_lem (L_aux(lit,l)) =
match lit with
@@ -676,10 +695,11 @@ let doc_exp_lem, doc_let_lem =
let argspp = align (separate_map (break 1) (expV true) args) in
let epp = align (call ^//^ argspp) in
let (taepp,aexp_needed) =
- let t = Env.expand_synonyms (env_of full_exp) (typ_of full_exp) in
+ let env = env_of full_exp in
+ let t = Env.expand_synonyms env (typ_of full_exp) in
let eff = effect_of full_exp in
if typ_needs_printed t
- then (align epp ^^ (doc_tannot_lem ctxt (effectful eff) t), true)
+ then (align epp ^^ (doc_tannot_lem ctxt env (effectful eff) t), true)
else (epp, aexp_needed) in
liftR (if aexp_needed then parens (align taepp) else taepp)
end
@@ -714,7 +734,7 @@ let doc_exp_lem, doc_let_lem =
if has_effect eff BE_rreg then
let epp = separate space [string "read_reg";doc_id_lem id] in
if is_bitvector_typ base_typ
- then liftR (parens (epp ^^ doc_tannot_lem ctxt true base_typ))
+ then liftR (parens (epp ^^ doc_tannot_lem ctxt env true base_typ))
else liftR epp
else if is_ctor env id then doc_id_lem_ctor id
else doc_id_lem id
@@ -768,7 +788,7 @@ let doc_exp_lem, doc_let_lem =
let (epp,aexp_needed) =
if is_bit_typ etyp && !opt_mwords then
let bepp = string "of_bits" ^^ space ^^ parens (align epp) in
- (bepp ^^ doc_tannot_lem ctxt false t, true)
+ (bepp ^^ doc_tannot_lem ctxt (env_of full_exp) false t, true)
else (epp,aexp_needed) in
if aexp_needed then parens (align epp) else epp
| E_vector_update(v,e1,e2) ->
@@ -912,7 +932,7 @@ let doc_typdef_lem (TD_aux(td, (l, annot))) = match td with
mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown),
[mk_typ_arg (Typ_arg_typ rectyp);
mk_typ_arg (Typ_arg_typ ftyp)])) in
- let rfannot = doc_tannot_lem empty_ctxt false reftyp in
+ let rfannot = doc_tannot_lem empty_ctxt env false reftyp in
let get, set =
string "rec_val" ^^ dot ^^ fname fid,
anglebars (space ^^ string "rec_val with " ^^
@@ -1342,7 +1362,7 @@ let doc_regtype_fields (tname, (n1, n2, fields)) =
mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown),
[mk_typ_arg (Typ_arg_typ (mk_id_typ (mk_id tname)));
mk_typ_arg (Typ_arg_typ ftyp)])) in
- let rfannot = doc_tannot_lem empty_ctxt false reftyp in
+ let rfannot = doc_tannot_lem empty_ctxt Env.empty false reftyp in
doc_op equals
(concat [string "let "; parens (concat [string tname; underscore; doc_id_lem fid; rfannot])])
(concat [
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 1dac7a1c..7620ca50 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -265,7 +265,7 @@ let fixities =
(mk_id "|", (InfixR, 2));
]
in
- ref Bindings.empty (*(fixities' : (prec * int) Bindings.t)*)
+ ref (fixities' : (prec * int) Bindings.t)
let rec doc_exp (E_aux (e_aux, _) as exp) =
match e_aux with
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 9cba6b39..cc3df801 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -3005,6 +3005,7 @@ let rewrite_defs_c = [
("simple_assignments", rewrite_simple_assignments);
("remove_vector_concat", rewrite_defs_remove_vector_concat);
("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
+ ("guarded_pats", rewrite_defs_guarded_pats);
("exp_lift_assign", rewrite_defs_exp_lift_assign);
("constraint", rewrite_constraint);
("trivial_sizeof", rewrite_trivial_sizeof);
diff --git a/src/sail.ml b/src/sail.ml
index 35a7279b..95e060b2 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -88,7 +88,7 @@ let options = Arg.align ([
Arg.Tuple [Arg.Set opt_print_ocaml; Arg.Set Initial_check.opt_undefined_gen; Arg.Set Ocaml_backend.opt_trace_ocaml],
" output an OCaml translated version of the input with tracing instrumentation, implies -ocaml");
( "-c",
- Arg.Tuple [Arg.Set opt_print_c; (* Arg.Set Initial_check.opt_undefined_gen *)],
+ Arg.Tuple [Arg.Set opt_print_c; Arg.Set Initial_check.opt_undefined_gen],
" output a C translated version of the input");
( "-lem_ast",
Arg.Set opt_print_lem_ast,
diff --git a/src/specialize.ml b/src/specialize.ml
index efa8783e..2ebc7307 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -80,6 +80,10 @@ let id_of_instantiation id instantiation =
let str = Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) ^ "#" in
prepend_id str id
+let string_of_instantiation instantiation =
+ let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in
+ Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation))
+
(* Returns a list of all the instantiations of a function id in an
ast. *)
let rec instantiations_of id ast =
@@ -161,6 +165,7 @@ let specialize_id_valspec instantiations id ast =
let typschm = mk_typschm typq typ in
let spec_id = id_of_instantiation id instantiation in
+
if IdSet.mem spec_id !spec_ids then [] else
begin
spec_ids := IdSet.add spec_id !spec_ids;
@@ -209,7 +214,7 @@ let specialize_id_overloads instantiations id (Defs defs) =
valspecs are then re-specialized. This process is iterated until
the whole spec is specialized. *)
let remove_unused_valspecs ast =
- let calls = ref (IdSet.singleton (mk_id "main")) in
+ let calls = ref (IdSet.of_list [mk_id "main"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"]) in
let vs_ids = Initial_check.val_spec_ids ast in
let inspect_exp = function
diff --git a/test/c/gvector.expect b/test/c/gvector.expect
new file mode 100644
index 00000000..ae7bf842
--- /dev/null
+++ b/test/c/gvector.expect
@@ -0,0 +1,3 @@
+T[1] = 5
+y[1] = 5
+R[0] = 32'0xDEADBEEF
diff --git a/test/c/gvector.sail b/test/c/gvector.sail
new file mode 100644
index 00000000..e7553644
--- /dev/null
+++ b/test/c/gvector.sail
@@ -0,0 +1,20 @@
+default Order dec
+
+$include <vector_dec.sail>
+
+val "print_int" : (string, int) -> unit
+
+register R : vector(32, dec, vector(32, dec, bit))
+
+register T : vector(32, dec, int)
+
+val main : unit -> unit effect {rreg, wreg}
+
+function main () = {
+ R[0] = 0xDEAD_BEEF;
+ T[1] = 5;
+ print_int("T[1] = ", T[1]);
+ let y = T;
+ print_int("y[1] = ", y[1]);
+ print_bits("R[0] = ", R[0]);
+} \ No newline at end of file
diff --git a/test/c/sail.h b/test/c/sail.h
index 033d791e..880f5a57 100644
--- a/test/c/sail.h
+++ b/test/c/sail.h
@@ -26,10 +26,50 @@ void sail_match_failure(void) {
exit(1);
}
+unit sail_assert(bool b, sail_string msg) {
+ if (b) return UNIT;
+ fprintf(stderr, "Assertion failed: %s\n", msg);
+ exit(1);
+}
+
+unit sail_exit(const unit u) {
+ fprintf(stderr, "Unexpected exit\n");
+ exit(1);
+}
+
+void elf_entry(mpz_t *rop, const unit u) {
+ mpz_set_ui(*rop, 0x8000ul);
+}
+
+// Sail bits are mapped to ints where bitzero = 0 and bitone = 1
+bool eq_bit(const int a, const int b) {
+ return a == b;
+}
+
+int undefined_bit(unit u) { return 0; }
+
+// ***** Sail booleans *****
+
bool not(const bool b) {
return !b;
}
+bool and_bool(const bool a, const bool b) {
+ return a && b;
+}
+
+bool or_bool(const bool a, const bool b) {
+ return a || b;
+}
+
+bool eq_bool(const bool a, const bool b) {
+ return a == b;
+}
+
+bool undefined_bool(const unit u) {
+ return false;
+}
+
// ***** Sail strings *****
void init_sail_string(sail_string *str) {
char *istr = (char *) malloc(1 * sizeof(char));
@@ -47,11 +87,20 @@ void clear_sail_string(sail_string *str) {
free(*str);
}
+bool eq_string(const sail_string str1, const sail_string str2) {
+ return strcmp(str1, str2) == 0;
+}
+
unit print_endline(sail_string str) {
printf("%s\n", str);
return UNIT;
}
+unit prerr_endline(sail_string str) {
+ fprintf(stderr, "%s\n", str);
+ return UNIT;
+}
+
unit print_int(const sail_string str, const mpz_t op) {
fputs(str, stdout);
mpz_out_str(stdout, 10, op);
@@ -64,7 +113,12 @@ unit print_int64(const sail_string str, const int64_t op) {
return UNIT;
}
-// ***** Multiple precision integers *****
+unit sail_putchar(const mpz_t op) {
+ char c = (char) mpz_get_ui(op);
+ putchar(c);
+}
+
+// ***** Arbitrary precision integers *****
// We wrap around the GMP functions so they follow a consistent naming
// scheme that is shared with the other builtin sail types.
@@ -89,6 +143,10 @@ void init_mpz_t_of_sail_string(mpz_t *rop, sail_string str) {
mpz_init_set_str(*rop, str, 10);
}
+int64_t convert_int64_t_of_mpz_t(const mpz_t op) {
+ return mpz_get_si(op);
+}
+
// ***** Sail builtins for integers *****
bool eq_int(const mpz_t op1, const mpz_t op2) {
@@ -103,6 +161,26 @@ bool gt(const mpz_t op1, const mpz_t op2) {
return mpz_cmp(op1, op2) > 0;
}
+bool lteq(const mpz_t op1, const mpz_t op2) {
+ return mpz_cmp(op1, op2) <= 0;
+}
+
+bool gteq(const mpz_t op1, const mpz_t op2) {
+ return mpz_cmp(op1, op2) >= 0;
+}
+
+void shl_int(mpz_t *rop, const mpz_t op1, const mpz_t op2) {
+ mpz_mul_2exp(*rop, op1, mpz_get_ui(op2));
+}
+
+void undefined_int(mpz_t *rop, const unit u) {
+ mpz_set_ui(*rop, 0ul);
+}
+
+void undefined_range(mpz_t *rop, const mpz_t l, const mpz_t u) {
+ mpz_set(*rop, l);
+}
+
void add_int(mpz_t *rop, const mpz_t op1, const mpz_t op2)
{
mpz_add(*rop, op1, op2);
@@ -113,6 +191,37 @@ void sub_int(mpz_t *rop, const mpz_t op1, const mpz_t op2)
mpz_sub(*rop, op1, op2);
}
+void mult_int(mpz_t *rop, const mpz_t op1, const mpz_t op2)
+{
+ mpz_mul(*rop, op1, op2);
+}
+
+void div_int(mpz_t *rop, const mpz_t op1, const mpz_t op2)
+{
+ mpz_tdiv_q(*rop, op1, op2);
+}
+
+void mod_int(mpz_t *rop, const mpz_t op1, const mpz_t op2)
+{
+ mpz_tdiv_r(*rop, op1, op2);
+}
+
+void max_int(mpz_t *rop, const mpz_t op1, const mpz_t op2) {
+ if (lt(op1, op2)) {
+ mpz_set(*rop, op2);
+ } else {
+ mpz_set(*rop, op1);
+ }
+}
+
+void min_int(mpz_t *rop, const mpz_t op1, const mpz_t op2) {
+ if (gt(op1, op2)) {
+ mpz_set(*rop, op2);
+ } else {
+ mpz_set(*rop, op1);
+ }
+}
+
void neg_int(mpz_t *rop, const mpz_t op) {
mpz_neg(*rop, op);
}
@@ -121,8 +230,20 @@ void abs_int(mpz_t *rop, const mpz_t op) {
mpz_abs(*rop, op);
}
+void pow2(mpz_t *rop, mpz_t exp) {
+ uint64_t exp_ui = mpz_get_ui(exp);
+ mpz_t base;
+ mpz_init_set_ui(base, 2ul);
+ mpz_pow_ui(*rop, base, exp_ui);
+ mpz_clear(base);
+}
+
// ***** Sail bitvectors *****
+void length_bv_t(mpz_t *rop, const bv_t op) {
+ mpz_set_ui(*rop, op.len);
+}
+
void init_bv_t(bv_t *rop) {
rop->bits = malloc(sizeof(mpz_t));
rop->len = 0;
@@ -140,12 +261,32 @@ void set_bv_t(bv_t *rop, const bv_t op) {
mpz_set(*rop->bits, *op.bits);
}
-void append_64(bv_t *rop, bv_t op, const uint64_t chunk) {
+void append_64(bv_t *rop, const bv_t op, const uint64_t chunk) {
rop->len = rop->len + 64ul;
mpz_mul_2exp(*rop->bits, *op.bits, 64ul);
mpz_add_ui(*rop->bits, *rop->bits, chunk);
}
+void append(bv_t *rop, const bv_t op1, const bv_t op2) {
+ rop->len = op1.len + op2.len;
+ mpz_mul_2exp(*rop->bits, *op1.bits, op2.len);
+ mpz_add(*rop->bits, *rop->bits, *op2.bits);
+}
+
+void replicate_bits(bv_t *rop, const bv_t op1, const mpz_t op2) {
+ uint64_t op2_ui = mpz_get_ui(op2);
+ rop->len = op1.len * op2_ui;
+ mpz_set(*rop->bits, *op1.bits);
+ for (int i = 1; i < op2_ui; i++) {
+ mpz_mul_2exp(*rop->bits, *rop->bits, op2_ui);
+ mpz_add(*rop->bits, *rop->bits, *op1.bits);
+ }
+}
+
+void slice(bv_t *rop, const bv_t op, const mpz_t i, const mpz_t len) {
+ // TODO
+}
+
uint64_t convert_uint64_t_of_bv_t(const bv_t op) {
return mpz_get_ui(*op.bits);
}
@@ -165,6 +306,10 @@ void clear_bv_t(bv_t *rop) {
free(rop->bits);
}
+void undefined_bv_t(bv_t *rop, mpz_t len, int bit) {
+ zeros(rop, len);
+}
+
void mask(bv_t *rop) {
if (mpz_sizeinbase(*rop->bits, 2) > rop->len) {
mpz_t m;
@@ -176,6 +321,49 @@ void mask(bv_t *rop) {
}
}
+void and_bits(bv_t *rop, const bv_t op1, const bv_t op2) {
+ rop->len = op1.len;
+ mpz_and(*rop->bits, *op1.bits, *op2.bits);
+}
+
+void or_bits(bv_t *rop, const bv_t op1, const bv_t op2) {
+ rop->len = op1.len;
+ mpz_ior(*rop->bits, *op1.bits, *op2.bits);
+}
+
+void not_bits(bv_t *rop, const bv_t op) {
+ rop->len = op.len;
+ mpz_com(*rop->bits, *op.bits);
+}
+
+void xor_bits(bv_t *rop, const bv_t op1, const bv_t op2) {
+ rop->len = op1.len;
+ mpz_xor(*rop->bits, *op1.bits, *op2.bits);
+}
+
+bool eq_bits(const bv_t op1, const bv_t op2) {
+ return mpz_cmp(*op1.bits, *op2.bits) == 0;
+}
+
+void sail_uint(mpz_t *rop, const bv_t op) {
+ mpz_set(*rop, *op.bits);
+}
+
+void sint(mpz_t *rop, const bv_t op) {
+ if (mpz_tstbit(*op.bits, op.len - 1)) {
+ mpz_set(*rop, *op.bits);
+ mpz_clrbit(*rop, op.len - 1);
+ mpz_t x;
+ mpz_init(x);
+ mpz_setbit(x, op.len - 1);
+ mpz_neg(x, x);
+ mpz_add(*rop, *rop, *op.bits);
+ mpz_clear(x);
+ } else {
+ mpz_set(*rop, *op.bits);
+ }
+}
+
void add_bits(bv_t *rop, const bv_t op1, const bv_t op2) {
rop->len = op1.len;
mpz_add(*rop->bits, *op1.bits, *op2.bits);
@@ -188,9 +376,161 @@ void add_bits_int(bv_t *rop, const bv_t op1, const mpz_t op2) {
mask(rop);
}
+void sub_bits_int(bv_t *rop, const bv_t op1, const mpz_t op2) {
+ rop->len = op1.len;
+ mpz_sub(*rop->bits, *op1.bits, op2);
+ mask(rop);
+}
+
+void get_slice_int(bv_t *rop, const mpz_t n, const mpz_t m, const mpz_t o) {
+ // TODO
+}
+
+void set_slice_int(mpz_t *rop, const mpz_t n, const mpz_t m, const mpz_t o, const bv_t op) {
+ // TODO
+}
+
+void vector_update_subrange_bv_t(bv_t *rop, const bv_t op, const mpz_t n, const mpz_t m, const bv_t slice) {
+ // TODO
+}
+
+void vector_subrange_bv_t(bv_t *rop, const bv_t op, const mpz_t n, const mpz_t m) {
+ // TODO
+}
+
+int bitvector_access(const bv_t op, const mpz_t n) {
+ return 0; // TODO
+}
+
+void hex_slice (bv_t *rop, const sail_string hex, const mpz_t n, const mpz_t m) {
+ // TODO
+}
+
+void set_slice (bv_t *rop, const mpz_t len, const mpz_t slen, const bv_t op, const mpz_t i, const bv_t slice) {
+ // TODO
+}
+
unit print_bits(const sail_string str, const bv_t op) {
fputs(str, stdout);
gmp_printf("%d'0x%ZX\n", op.len, op.bits);
}
+// ***** Real number implementation *****
+
+#define REAL_FLOAT
+
+#ifdef REAL_FLOAT
+
+typedef mpf_t real;
+
+#define FLOAT_PRECISION 255
+
+void setup_real(void) {
+ mpf_set_default_prec(FLOAT_PRECISION);
+}
+
+void init_real(real *rop) {
+ mpf_init(*rop);
+}
+
+void clear_real(real *rop) {
+ mpf_clear(*rop);
+}
+
+void set_real(real *rop, const real op) {
+ mpf_set(*rop, op);
+}
+
+void undefined_real(real *rop, unit u) {
+ mpf_set_ui(*rop, 0ul);
+}
+
+void neg_real(real *rop, const real op) {
+ mpf_neg(*rop, op);
+}
+
+void mult_real(real *rop, const real op1, const real op2) {
+ mpf_mul(*rop, op1, op2);
+}
+
+void sub_real(real *rop, const real op1, const real op2) {
+ mpf_sub(*rop, op1, op2);
+}
+
+void add_real(real *rop, const real op1, const real op2) {
+ mpf_add(*rop, op1, op2);
+}
+
+void div_real(real *rop, const real op1, const real op2) {
+ mpf_div(*rop, op1, op2);
+}
+
+void sqrt_real(real *rop, const real op) {
+ mpf_sqrt(*rop, op);
+}
+
+void abs_real(real *rop, const real op) {
+ mpf_abs(*rop, op);
+}
+
+void round_up(mpz_t *rop, const real op) {
+ mpf_t x;
+ mpf_ceil(x, op);
+ mpz_set_ui(*rop, mpf_get_ui(x));
+ mpf_clear(x);
+}
+
+void round_down(mpz_t *rop, const real op) {
+ mpf_t x;
+ mpf_floor(x, op);
+ mpz_set_ui(*rop, mpf_get_ui(x));
+ mpf_clear(x);
+}
+
+void to_real(real *rop, const mpz_t op) {
+ mpf_set_z(*rop, op);
+}
+
+bool eq_real(const real op1, const real op2) {
+ return mpf_cmp(op1, op2) == 0;
+}
+
+bool lt_real(const real op1, const real op2) {
+ return mpf_cmp(op1, op2) < 0;
+}
+
+bool gt_real(const real op1, const real op2) {
+ return mpf_cmp(op1, op2) > 0;
+}
+
+bool lteq_real(const real op1, const real op2) {
+ return mpf_cmp(op1, op2) <= 0;
+}
+
+bool gteq_real(const real op1, const real op2) {
+ return mpf_cmp(op1, op2) >= 0;
+}
+
+void real_power(real *rop, const real base, const mpz_t exp) {
+ uint64_t exp_ui = mpz_get_ui(exp);
+ mpf_pow_ui(*rop, base, exp_ui);
+}
+
+void init_real_of_sail_string(real *rop, const sail_string op) {
+ // FIXME
+ mpf_init(*rop);
+}
+
#endif
+
+#endif
+
+// ***** Memory *****
+
+unit write_ram(const mpz_t m, const mpz_t n, const bv_t x, const bv_t y, const bv_t data) {
+ return UNIT;
+}
+
+void read_ram(bv_t *data, const mpz_t m, const mpz_t n, const bv_t x, const bv_t y) {
+ // TODO
+}
diff --git a/test/ocaml/bitfield/bitfield.sail b/test/ocaml/bitfield/bitfield.sail
index 5a70d52e..2a53ab3c 100644
--- a/test/ocaml/bitfield/bitfield.sail
+++ b/test/ocaml/bitfield/bitfield.sail
@@ -12,7 +12,8 @@ bitfield cr : bits(8) = {
register CR : cr
bitfield dr : vector(4, inc, bit) = {
- DR0 : 2 .. 3
+ DR0 : 2 .. 3,
+ LT : 2
}
register DR : dr
@@ -28,5 +29,9 @@ function main () = {
print_bits("CR.CR0: ", CR.CR0());
print_bits("CR: ", CR.bits());
CR->CR3() = 0b0;
- print_bits("CR: ", CR.bits())
+ print_bits("CR: ", CR.bits());
+ CR = update_CR1(CR, 0b11);
+ print_bits("CR.CR1: ", CR.CR1());
+ CR = update_CR1(CR, 0b01);
+ print_bits("CR.CR1: ", CR.CR1());
}
diff --git a/test/ocaml/bitfield/expect b/test/ocaml/bitfield/expect
index 63247dfd..e6e5a618 100644
--- a/test/ocaml/bitfield/expect
+++ b/test/ocaml/bitfield/expect
@@ -3,3 +3,5 @@ CR: 0x0F
CR.CR0: 0x8
CR: 0x8F
CR: 0x8E
+CR.CR1: 0b11
+CR.CR1: 0b01