diff options
58 files changed, 1407 insertions, 592 deletions
diff --git a/aarch64_small/armV8.h.sail b/aarch64_small/armV8.h.sail index f5c5aa1e..30309e30 100644 --- a/aarch64_small/armV8.h.sail +++ b/aarch64_small/armV8.h.sail @@ -98,14 +98,47 @@ register R2 : bits(64) register R1 : bits(64) register R0 : bits(64) -let _R : vector(32,dec,(register(bits(64)))) = - [undefined,R30,R29,R28,R27,R26,R25,R24,R23,R22,R21, - R20,R19,R18,R17,R16,R15,R14,R13,R12,R11, - R10,R9 ,R8 ,R7 ,R6 ,R5 ,R4 ,R3 ,R2 ,R1 , - R0] +/* let _R : vector(32,dec,register(bits(64))) = */ +/* [ref undefined */ +let _R : vector(31,dec,register(bits(64))) = + [ref R30 + ,ref R29 + ,ref R28 + ,ref R27 + ,ref R26 + ,ref R25 + ,ref R24 + ,ref R23 + ,ref R22 + ,ref R21 + ,ref R20 + ,ref R19 + ,ref R18 + ,ref R17 + ,ref R16 + ,ref R15 + ,ref R14 + ,ref R13 + ,ref R12 + ,ref R11 + ,ref R10 + ,ref R9 + ,ref R8 + ,ref R7 + ,ref R6 + ,ref R5 + ,ref R4 + ,ref R3 + ,ref R2 + ,ref R1 + ,ref R0 + ] -val reg_index : reg_size -> UInt_reg effect pure -function reg_index x = (x : (reg_index)) +/* val reg_index : reg_size -> UInt_reg effect pure */ +/* function reg_index x = (x : (reg_index)) */ + +val reg_index : reg_size -> reg_index +function reg_index x = unsigned(x) /* SIMD and floating-point registers */ @@ -142,11 +175,46 @@ register V2 : bits(128) register V1 : bits(128) register V0 : bits(128) +/* let _V : vector(33,dec,(register(bits(128)))) = */ +/* [undefined,V31,V30,V29,V28,V27,V26,V25,V24,V23,V22, */ +/* V21,V20,V19,V18,V17,V16,V15,V14,V13,V12, */ +/* V11,V10,V9 ,V8 ,V7 ,V6 ,V5 ,V4 ,V3 ,V2 , */ +/* V1 ,V0] */ + let _V : vector(32,dec,(register(bits(128)))) = - [undefined,V31,V30,V29,V28,V27,V26,V25,V24,V23,V22, - V21,V20,V19,V18,V17,V16,V15,V14,V13,V12, - V11,V10,V9 ,V8 ,V7 ,V6 ,V5 ,V4 ,V3 ,V2 , - V1 ,V0] + [ref V31 + ,ref V30 + ,ref V29 + ,ref V28 + ,ref V27 + ,ref V26 + ,ref V25 + ,ref V24 + ,ref V23 + ,ref V22 + ,ref V21 + ,ref V20 + ,ref V19 + ,ref V18 + ,ref V17 + ,ref V16 + ,ref V15 + ,ref V14 + ,ref V13 + ,ref V12 + ,ref V11 + ,ref V10 + ,ref V9 + ,ref V8 + ,ref V7 + ,ref V6 + ,ref V5 + ,ref V4 + ,ref V3 + ,ref V2 + ,ref V1 + ,ref V0 + ] /* lsl: used instead of the ARM ARM << over integers */ @@ -154,22 +222,22 @@ val lsl : forall 'm 'n, 'm >= 0 & 'n >= 0. (atom('n), atom('m)) -> atom('n * (2 function lsl (n, m) = n * (2 ^ m) /* not_implemented is used to indicate something WE did not implement */ -val not_implemented : string -> unit effect { escape } +val not_implemented : forall ('a : Type). string -> 'a effect { escape } function not_implemented message = exit () /* TODO message */ /* not_implemented_extern is used to indicate something ARM did not define and we did not define yet either. Those functions used to be declared as external but undefined there. */ -val not_implemented_extern : forall 'a. string -> 'a effect { escape } +/* val not_implemented_extern : forall 'a. string -> 'a effect { escape } */ +val not_implemented_extern : forall ('a : Type). string -> 'a effect { escape } function not_implemented_extern (message) = exit () /* message; TODO */ /* info is used to convey information to the user */ val info : string -> unit effect pure -let info(message) = () +function info(message) = () -struct IMPLEMENTATION_DEFINED_type = -{ +struct IMPLEMENTATION_DEFINED_type = { HaveCRCExt : boolean, HaveAArch32EL : boolean, HaveAnyAArch32 : boolean, @@ -178,15 +246,15 @@ struct IMPLEMENTATION_DEFINED_type = HighestELUsingAArch32 : boolean, IsSecureBelowEL3 : boolean, } -let IMPLEMENTATION_DEFINED = -{ - HaveCRCExt = true; - HaveAArch32EL = false; - HaveAnyAArch32 = false; - HaveEL2 = false; - HaveEL3 = false; - HighestELUsingAArch32 = false; - IsSecureBelowEL3 = false; + +let IMPLEMENTATION_DEFINED : IMPLEMENTATION_DEFINED_type = struct { + HaveCRCExt = true, + HaveAArch32EL = false, + HaveAnyAArch32 = false, + HaveEL2 = false, + HaveEL3 = false, + HighestELUsingAArch32 = false, + IsSecureBelowEL3 = false } /* FIXME: ask Kathy what should we do with this */ diff --git a/aarch64_small/armV8_common_lib.sail b/aarch64_small/armV8_common_lib.sail index a7e14c62..b28dc462 100644 --- a/aarch64_small/armV8_common_lib.sail +++ b/aarch64_small/armV8_common_lib.sail @@ -34,12 +34,11 @@ /** FUNCTION:shared/debug/DoubleLockStatus/DoubleLockStatus */ -function boolean DoubleLockStatus() = -{ +function DoubleLockStatus() -> boolean= { if ELUsingAArch32(EL1) then - (DBGOSDLR.DLK == 1 & DBGPRCR.CORENPDRQ == 0 & ~(Halted())) + (DBGOSDLR.DLK() == 0b1 & DBGPRCR.CORENPDRQ() == 0b0 & ~(Halted())) else - (OSDLR_EL1.DLK == 1 & DBGPRCR_EL1.CORENPDRQ == 0 & ~(Halted())); + (OSDLR_EL1.DLK() == 0b1 & DBGPRCR_EL1.CORENPDRQ() == 0b0 & ~(Halted())); } /** FUNCTION:shared/debug/authentication/Debug_authentication */ @@ -48,15 +47,14 @@ function boolean DoubleLockStatus() = enum signalValue = {LOw, HIGH} -function signalValue signalDBGEN () = not_implemented_extern("signalDBGEN") -function signalValue signelNIDEN () = not_implemented_extern("signalNIDEN") -function signalValue signalSPIDEN () = not_implemented_extern("signalSPIDEN") -function signalValue signalDPNIDEN () = not_implemented_extern("signalSPNIDEN") +function signalDBGEN () -> signalValue = not_implemented_extern("signalDBGEN") +function signelNIDEN () -> signalValue = not_implemented_extern("signalNIDEN") +function signalSPIDEN () -> signalValue = not_implemented_extern("signalSPIDEN") +function signalDPNIDEN () -> signalValue = not_implemented_extern("signalSPNIDEN") /** FUNCTION:shared/debug/authentication/ExternalInvasiveDebugEnabled */ -function boolean ExternalInvasiveDebugEnabled() = -{ +function ExternalInvasiveDebugEnabled() -> boolean = { /* In the recommended interface, ExternalInvasiveDebugEnabled returns the state of the DBGEN */ /* signal. */ signalDBGEN() == HIGH; @@ -64,8 +62,7 @@ function boolean ExternalInvasiveDebugEnabled() = /** FUNCTION:shared/debug/authentication/ExternalSecureInvasiveDebugEnabled */ -function boolean ExternalSecureInvasiveDebugEnabled() = -{ +function ExternalSecureInvasiveDebugEnabled() -> boolean = { /* In the recommended interface, ExternalSecureInvasiveDebugEnabled returns the state of the */ /* (DBGEN AND SPIDEN) signal. */ /* CoreSight allows asserting SPIDEN without also asserting DBGEN, but this is not recommended. */ @@ -76,15 +73,13 @@ function boolean ExternalSecureInvasiveDebugEnabled() = /** FUNCTION:shared/debug/halting/DCPSInstruction */ -function unit DCPSInstruction(target_el : bits(2)) = -{ +function DCPSInstruction(target_el : bits(2)) -> unit = { not_implemented("DCPSInstruction"); } /** FUNCTION:shared/debug/halting/DRPSInstruction */ -function unit DRPSInstruction() = -{ +function DRPSInstruction() -> unit = { not_implemented("DRPSInstruction"); } @@ -104,22 +99,19 @@ let DebugHalt_Step_NoSyndrome = 0b111011 /** FUNCTION:shared/debug/halting/Halt */ -function unit Halt(reason : bits(6)) = -{ +function Halt(reason : bits(6)) -> unit= { not_implemented("Halt"); } /** FUNCTION:shared/debug/halting/Halted */ -function boolean Halted() = -{ - ~(EDSCR.STATUS == 0b000001 | EDSCR.STATUS == 0b000010); /* Halted */ +function Halted() -> boolean = { + ~(EDSCR.STATUS() == 0b000001 | EDSCR.STATUS() == 0b000010); /* Halted */ } /** FUNCTION:shared/debug/halting/HaltingAllowed */ -function boolean HaltingAllowed() = -{ +function HaltingAllowed() -> boolean = { if Halted() | DoubleLockStatus() then false else if IsSecure() then @@ -130,8 +122,7 @@ function boolean HaltingAllowed() = /** FUNCTION:shared/exceptions/traps/ReservedValue */ -function unit ReservedValue() = -{ +function ReservedValue() -> unit = { /* ARM: uncomment when adding aarch32 if UsingAArch32() && !AArch32.GeneralExceptionsToAArch64() then AArch32.TakeUndefInstrException() @@ -141,8 +132,7 @@ function unit ReservedValue() = /** FUNCTION:shared/exceptions/traps/UnallocatedEncoding */ -function unit UnallocatedEncoding() = -{ +function UnallocatedEncoding() -> unit = { /* If the unallocated encoding is an AArch32 CP10 or CP11 instruction, FPEXC.DEX must be written */ /* to zero. This is omitted from this code. */ /* ARM: uncomment whenimplementing aarch32 @@ -154,30 +144,40 @@ function unit UnallocatedEncoding() = /** FUNCTION:shared/functions/aborts/IsFault */ -function boolean IsFault(addrdesc : AddressDescriptor) = -{ +function IsFault(addrdesc : AddressDescriptor) -> boolean= { (addrdesc.fault).faulttype != Fault_None; } /** FUNCTION:shared/functions/common/ASR */ -val ASR : forall 'N, 'N >= 0. (bits('N), uinteger) -> bits('N) -function ASR (x, shift) = -{ +/* original: */ + +/* val ASR : forall 'N, 'N >= 0. (bits('N), uinteger) -> bits('N) */ +/* function ASR (x, shift) = */ +/* { */ +/* /\*assert shift >= 0;*\/ */ +/* result : bits('N) = 0; */ +/* if shift == 0 then */ +/* result = x */ +/* else */ +/* let (result', _) = ASR_C (x, shift) in { result = result' }; */ +/* result; */ +/* } */ + +/* CP: replacing this with the slightly simpler one below */ + +val ASR : forall 'N, 'N >= 1. (bits('N), uinteger) -> bits('N) +function ASR (x, shift) = { /*assert shift >= 0;*/ - result : bits('N) = 0; - if shift == 0 then - result = x - else - let (result', _) = ASR_C (x, shift) in { result = result' }; - result; + if shift == 0 then x + else let (result', _) = ASR_C (x, shift) in result' } + /** FUNCTION:shared/functions/common/ASR_C */ -val ASR_C : forall 'N 'S, 'N >= 0 & 'S >= 1. (bits('N), atom('S)) -> (bits('N), bit) -function ASR_C (x, shift) = -{ +val ASR_C : forall 'N 'S, 'N >= 1 & 'S >= 1. (bits('N), atom('S)) -> (bits('N), bit) +function ASR_C (x, shift) = { /*assert shift > 0;*/ extended_x : bits('S+'N) = SignExtend(x); result : bits('N) = extended_x[(shift + length(x) - 1)..shift]; @@ -185,59 +185,75 @@ function ASR_C (x, shift) = (result, carry_out); } +/* SignExtend : */ + +/* 'S+'N > 'N & 'N >= 1. (atom('S+'N),bits('N)) -> bits('S+'N) */ + + /** FUNCTION:integer Align(integer x, integer y) */ -function uinteger Align'(x : uinteger, y : uinteger) = +function Align'(x : uinteger, y : uinteger) -> uinteger = y * (quot (x,y)) /** FUNCTION:bits(N) Align(bits(N) x, integer y) */ val Align : forall 'N, 'N >= 0. (bits('N), uinteger) -> bits('N) function Align (x, y) = - Align'(UInt(x), y) : (bits('N)) + to_bits (Align'(UInt(x), y)) /** FUNCTION:integer CountLeadingSignBits(bits(N) x) */ -val CountLeadingSignBits : forall 'N, 'N >= 0. bits('N) -> range(0,'N) +val CountLeadingSignBits : forall 'N, 'N >= 2. bits('N) -> range(0,'N - 1) function CountLeadingSignBits(x) = - CountLeadingZeroBits(x[(length(x) - 1)..1] ^ x[(length(x) - 2)..0]) + CountLeadingZeroBits( (x[(length(x) - 1)..1]) ^ + (x[(length(x) - 2)..0]) ) /** FUNCTION:integer CountLeadingZeroBits(bits(N) x) */ val CountLeadingZeroBits : forall 'N, 'N >= 0. bits('N) -> range(0,'N) function CountLeadingZeroBits(x) = match HighestSetBit(x) { - None => length(x), + None() => length(x), Some(n) => length(x) - 1 - n } /** FUNCTION:bits(N) Extend(bits(M) x, integer N, boolean unsigned) */ -val Extend : forall 'N 'M, 0 <= 'M & 'M <= 'N. (implicit('N),bits('M),bit) -> bits('N) effect pure -function Extend (x, unsigned) = - if unsigned then ZeroExtend(x) else SignExtend(x) +val Extend : forall 'N 'M, 1 <= 'M & 'M < 'N. (atom('N),bits('M),bit) -> bits('N) effect pure +function Extend (n, x, _unsigned) = + if _unsigned then ZeroExtend(n,x) else SignExtend(n,x) /** FUNCTION:integer HighestSetBit(bits(N) x) */ -val HighestSetBit : forall 'N, 'N >= 0. bits('N) -> option(range(0, 'N + -1)) +/* val HighestSetBit : forall 'N, 'N >= 0. bits('N) -> option(range(0, 'N - 1)) */ +/* function HighestSetBit(x) = { */ +/* let N = (length(x)) in { */ +/* result : range(0,'N - 1) = 0; */ +/* break : bool = false; */ +/* foreach (i from (N - 1) downto 0) */ +/* if ~(break) & (x[i] == bitone) then { */ +/* result = i; */ +/* break = true; */ +/* }; */ + +/* if break then Some(result) else None; */ +/* }} */ + +val HighestSetBit : forall 'N, 'N >= 0. bits('N) -> option(range(0, 'N - 1)) function HighestSetBit(x) = { - let N = (length(x)) in { - result : range(0, 'N + -1) = 0; - break : bool = false; + let N = length(x) in { foreach (i from (N - 1) downto 0) - if ~(break) & x[i] == 1 then { - result = i; - break = true; + if x[i] == bitone then { + return (Some(i)); }; - - if break then Some(result) else None; + None(); }} + /** FUNCTION:integer Int(bits(N) x, boolean unsigned) */ /* used to be called Int */ val _Int : forall 'N, 'N >= 0. (bits('N), boolean) -> integer -function _Int (x, unsigned) = { - result = if unsigned then UInt(x) else SInt(x); - result; +function _Int (x, _unsigned) = { + if _unsigned then UInt(x) else SInt(x) } /** FUNCTION:boolean IsZero(bits(N) x) */ @@ -369,7 +385,7 @@ function Replicate (x) = { /** FUNCTION:integer SInt(bits(N) x) */ /*function forall Nat 'N, Nat 'M, Nat 'K, 'M = 'N + -1, 'K = 2**'M. [|'K * -1:'K + -1|] SInt((bits('N)) x) =*/ -val SInt : forall 'N 'M, 'N >= 0 & 'M >= 0. bits('N) -> atom('M) +val SInt : forall 'N, 'N >= 0. bits('N) -> range(-(2 ^ ('N - 1) - 1), 2 ^ ('N - 1) - 1) function SInt(x) = { signed(x) /*let N = (length((bits('N)) 0)) in { @@ -381,20 +397,18 @@ function SInt(x) = { /** FUNCTION:bits(N) SignExtend(bits(M) x, integer N) */ -val SignExtend : forall 'N 'M, 'N >= 0 & 'M >= 1. bits('M) -> bits('N) -function SignExtend ([h]:remainder as x) = +function SignExtend forall 'N 'M, 'N > M & 'M >= 1. (_ : atom('N), ([h]@remainder as x) : bits('M)) -> bits('N) = (Replicate([h]) : bits(('N - 'M))) @ x /** FUNCTION:integer UInt(bits(N) x) */ -/* function forall Nat 'N, Nat 'M, 'M = 2**'N. [|'M + -1|] UInt((bits('N)) x) = ([|'M + -1|]) x */ -val Uint : forall 'M 'N, 'M >= 0 & 'N >= 0. bits('N) -> atom('M) -function UInt(x) = unsigned(x) +function UInt forall 'N, 'N >=0 . (x : bits('N)) -> range(0, 2 ^ 'N - 1) = + unsigned(x) /** FUNCTION:bits(N) ZeroExtend(bits(M) x, integer N) */ -val ZeroExtend : forall 'M 'N, 'M >= 0 & 'N >= 0. bits('M) -> bits('N) -function ZeroExtend (x) = (Zeros() : bits(('N + 'M * -1))) @ x +val ZeroExtend : forall 'M 'N, N >= M & 'M >= 0. (atom('N),bits('M)) -> bits('N) +function ZeroExtend (x) = (Zeros() : bits(('N - 'M))) @ x /** FUNCTION:shared/functions/common/Zeros */ diff --git a/aarch64_small/armV8_lib.h.sail b/aarch64_small/armV8_lib.h.sail index 3b15a1cd..50909c88 100644 --- a/aarch64_small/armV8_lib.h.sail +++ b/aarch64_small/armV8_lib.h.sail @@ -153,21 +153,22 @@ struct AddressDescriptor = { enum PrefetchHint = {Prefetch_READ, Prefetch_WRITE, Prefetch_EXEC} -val ASR_C : forall 'N 'S, 'N >= 0 & 'S >= 1. (bits('N),atom('S)) -> (bits('N), bit) effect pure +val ASR_C : forall 'N 'S, 'N >= 1 & 'S >= 1. (bits('N),atom('S)) -> (bits('N), bit) effect pure val LSL_C : forall 'N 'S, 'N >= 0 & 'S >= 1. (bits('N),atom('S)) -> (bits('N), bit) effect pure val LSR_C : forall 'N 'S, 'N >= 0 & 'S >= 1. (bits('N),atom('S)) -> (bits('N), bit) effect pure val ROR_C : forall 'N 'S, 'N >= 0 & ('S >= 1 | 'S <= -1). (bits('N),int('S)) -> (bits('N), bit) effect pure val IsZero : forall 'N, 'N >=0. bits('N) -> boolean effect pure val Replicate : forall 'N 'M, 'N >=0 & 'M >=0. (implicit('N),bits('M)) -> bits('N) effect pure -val SignExtend : forall 'N 'M, 'N >= 'M & 'M >= 0. (implicit('N),bits('M)) -> bits('N) effect pure +val SignExtend : forall 'N 'M, 'N > 'M & 'M >= 1. (implicit('N),bits('M)) -> bits('N) effect pure val ZeroExtend : forall 'N 'M, 'N >= 'M & 'M >= 0. (implicit('N),bits('M)) -> bits('N) effect pure val Zeros : forall 'N, 'N >=0. implicit('N) -> bits('N) effect pure val Ones : forall 'N, 'N >=0. implicit('N) -> bits('N) effect pure /* val UInt : forall Nat 'N, Nat 'M, 'M = 2**'N. bits('N) -> [|'M + -1|] effect pure */ -val UInt : forall 'N 'M, 'N >=0 & 'M >= 0. bits('N) -> atom('M) effect pure +val UInt : forall 'N, 'N >=0 . bits('N) -> range(0, 2 ^ 'N + -1) +/* val UInt : forall 'N 'M, 'N >=0 & 'M >= 0. bits('N) -> Int('M) effect pure */ /* val SInt : forall Nat 'N, Nat 'M, Nat 'K, 'M = 'N + -1, 'K = 2**'M. bits('N) -> [|'K * -1:'K + -1|] effect pure */ -val SInt : forall 'N 'M, 'N >= 0 & 'M >=0. bits('N) -> atom('M) effect pure -val HighestSetBit : forall 'N, 'N >= 0. bits('N+1) -> option(range(0,'N + -1)) effect pure +val SInt : forall 'N, 'N >= 0. bits('N) -> range(-(2 ^ ('N - 1) - 1), 2 ^ ('N - 1) - 1) effect pure +val HighestSetBit : forall 'N, 'N >= 0. bits('N) -> option(range(0,'N - 1)) effect pure val CountLeadingZeroBits : forall 'N, 'N >= 0. bits('N) -> range(0,'N) effect pure val IsSecure : unit -> boolean effect {rreg} val IsSecureBelowEL3 : unit -> boolean effect {rreg} diff --git a/aarch64_small/armV8_pstate.sail b/aarch64_small/armV8_pstate.sail index dcf35488..9bc88891 100644 --- a/aarch64_small/armV8_pstate.sail +++ b/aarch64_small/armV8_pstate.sail @@ -33,61 +33,79 @@ /*========================================================================*/ /* register alias PSTATE_N = NZCV.N /\* Negative condition flag *\/ */ -function set_PSTATE_N(v) = {NZCV.N = v} -function get_PSTATE_N() = NZCV.N +val set_PSTATE_N : bits(1) -> unit effect{wreg} +val get_PSTATE_N : unit -> bits(1) effect{rreg} +function set_PSTATE_N(v) = {NZCV->N() = v} +function get_PSTATE_N() = NZCV.N() overload PSTATE_N = {set_PSTATE_N, get_PSTATE_N} /* register alias PSTATE_Z = NZCV.Z /\* Zero condition flag *\/ */ -function set_PSTATE_Z(v) = {NZCV.Z = v} -function get_PSTATE_Z() = NZCV.Z +val set_PSTATE_Z : bits(1) -> unit effect{wreg} +val get_PSTATE_Z : unit -> bits(1) effect{rreg} +function set_PSTATE_Z(v) = {NZCV->Z() = v} +function get_PSTATE_Z() = NZCV.Z() overload PSTATE_Z = {set_PSTATE_Z, get_PSTATE_Z} /* register alias PSTATE_C = NZCV.C /\* Carry condition flag *\/ */ -function set_PSTATE_C(v) = {NZCV.C = v} -function get_PSTATE_C() = NZCV.C +val set_PSTATE_C : bits(1) -> unit effect{wreg} +val get_PSTATE_C : unit -> bits(1) effect{rreg} +function set_PSTATE_C(v) = {NZCV->C() = v} +function get_PSTATE_C() = NZCV.C() overload PSTATE_C = {set_PSTATE_C, get_PSTATE_C} /* register alias PSTATE_V = NZCV.V /\* oVerflow condition flag *\/ */ -function set_PSTATE_V(v) = {NZCV.V = v} -function get_PSTATE_V() = NZCV.V +val set_PSTATE_V : bits(1) -> unit effect{wreg} +val get_PSTATE_V : unit -> bits(1) effect{rreg} +function set_PSTATE_V(v) = {NZCV->V() = v} +function get_PSTATE_V() = NZCV.V() overload PSTATE_V = {set_PSTATE_V, get_PSTATE_V} /* register alias PSTATE_D = DAIF.D /\* Debug mask bits(AArch64 only) *\/ */ -function set_PSTATE_D(v) = {NZCV.D = v} -function get_PSTATE_D() = NZCV.D +val set_PSTATE_D : bits(1) -> unit effect{wreg} +val get_PSTATE_D : unit -> bits(1) effect{rreg} +function set_PSTATE_D(v) = {DAIF->D() = v} +function get_PSTATE_D() = DAIF.D() overload PSTATE_D = {set_PSTATE_D, get_PSTATE_D} /* register alias PSTATE_A = DAIF.A /\* Asynchronous abort mask bit *\/ */ -function set_PSTATE_A(v) = {NZCV.A = v} -function get_PSTATE_A() = NZCV.A +val set_PSTATE_A : bits(1) -> unit effect{wreg} +val get_PSTATE_A : unit -> bits(1) effect{rreg} +function set_PSTATE_A(v) = {DAIF->A() = v} +function get_PSTATE_A() = DAIF.A() overload PSTATE_A = {set_PSTATE_A, get_PSTATE_A} /* register alias PSTATE_I = DAIF.I /\* IRQ mask bit *\/ */ -function set_PSTATE_I(v) = {NZCV.I = v} -function get_PSTATE_I() = NZCV.I +val set_PSTATE_I : bits(1) -> unit effect{wreg} +val get_PSTATE_I : unit -> bits(1) effect{rreg} +function set_PSTATE_I(v) = {DAIF->I() = v} +function get_PSTATE_I() = DAIF.I() overload PSTATE_I = {set_PSTATE_I, get_PSTATE_I} /* register alias PSTATE_F = DAIF.F /\* FIQ mask bit *\/ */ -function set_PSTATE_F(v) = {NZCV.F = v} -function get_PSTATE_F() = NZCV.F +val set_PSTATE_F : bits(1) -> unit effect{wreg} +val get_PSTATE_F : unit -> bits(1) effect{rreg} +function set_PSTATE_F(v) = {DAIF->F() = v} +function get_PSTATE_F() = DAIF.F() overload PSTATE_F = {set_PSTATE_F, get_PSTATE_F} /* register alias PSTATE_SS = /* Software step bit */ */ /* register alias PSTATE_IL = /* Illegal execution state bit */ */ /* register alias PSTATE_EL = CurrentEL.EL /\* Exception Level *\/ */ -function set_PSTATE_EL(v) = {NZCV.EL = v} -function get_PSTATE_EL() = NZCV.EL +val set_PSTATE_EL : bits(2) -> unit effect{wreg} +val get_PSTATE_EL : unit -> bits(2) effect{rreg} +function set_PSTATE_EL(v) = {CurrentEL->EL() = v} +function get_PSTATE_EL() = CurrentEL.EL() overload PSTATE_EL = {set_PSTATE_EL, get_PSTATE_EL} -/* register PSTATE_nRW : bits(1) /\* not Register Width: 0=64, 1=32 *\/ */ -function set_PSTATE_nRW(v) = {NZCV.nRW = v} -function get_PSTATE_nRW() = NZCV.nRW -overload PSTATE_nRW = {set_PSTATE_nRW, get_PSTATE_nRW} +register PSTATE_nRW : bits(1) /* not Register Width: 0=64, 1=32 */ + /* register alias PSTATE_SP = SPSel.SP /\* Stack pointer select: 0=SP0, 1=SPx [AArch64 only] *\/ /\* TODO: confirm this *\/ */ -function set_PSTATE_SP(v) = {NZCV.SP = v} -function get_PSTATE_SP() = NZCV.SP +val set_PSTATE_SP : bits(1) -> unit effect{wreg} +val get_PSTATE_SP : unit -> bits(1) effect{rreg} +function set_PSTATE_SP(v) = {SPSel->SP() = v} +function get_PSTATE_SP() = SPSel.SP() overload PSTATE_SP = {set_PSTATE_SP, get_PSTATE_SP} /* register alias PSTATE_Q = /* Cumulative saturation flag [AArch32 only] */ */ @@ -95,33 +113,26 @@ overload PSTATE_SP = {set_PSTATE_SP, get_PSTATE_SP} /* register alias PSTATE_IT = /* If-then bits, RES0 in CPSR [AArch32 only] */ */ /* register alias PSTATE_J = /* J bit, RES0 in CPSR [AArch32 only, RES0 in ARMv8] */ */ /* register alias PSTATE_T = /* T32 bit, RES0 in CPSR [AArch32 only] */ */ -/* register PSTATE_E : bits(1) /\* Endianness bits(AArch32 only) *\/ */ -function set_PSTATE_E(v) = {NZCV.E = v} -function get_PSTATE_E() = NZCV.E -overload PSTATE_E = {set_PSTATE_E, get_PSTATE_E} +register PSTATE_E : bits(1) /* Endianness bits(AArch32 only) */ -/* register PSTATE_M : bits(5) /\* Mode field [AArch32 only] *\/ */ -function set_PSTATE_M(v) = {NZCV.M = v} -function get_PSTATE_M() = NZCV.M -overload PSTATE_M = {set_PSTATE_M, get_PSTATE_M} +register PSTATE_M : bits(5) /* Mode field [AArch32 only] */ /* this is a convenient way to do "PSTATE.<N,Z,C,V> = nzcv;" */ val wPSTATE_NZCV : (unit, bits(4)) -> unit effect {wreg} -function wPSTATE_NZCV((), [n,z,c,v]) = -{ - PSTATE_N = n; - PSTATE_Z = z; - PSTATE_C = c; - PSTATE_V = v; +function wPSTATE_NZCV((), [n,z,c,v]) = { + PSTATE_N() = [n]; + PSTATE_Z() = [z]; + PSTATE_C() = [c]; + PSTATE_V() = [v]; } /* this is a convenient way to do "PSTATE.<D,A,I,F> = daif;" */ val wPSTATE_DAIF : (unit, bits(4)) -> unit effect {wreg} function wPSTATE_DAIF((), [d,a,i,f]) = { - PSTATE_D = d; - PSTATE_A = a; - PSTATE_I = i; - PSTATE_F = f; + PSTATE_D() = [d]; + PSTATE_A() = [a]; + PSTATE_I() = [i]; + PSTATE_F() = [f]; } diff --git a/aarch64_small/prelude.sail b/aarch64_small/prelude.sail index 75fdc129..e16e0a98 100644 --- a/aarch64_small/prelude.sail +++ b/aarch64_small/prelude.sail @@ -1,18 +1,30 @@ -default Order dec +/* default Order dec */ -union option ('a : Type) = {None : unit, Some : 'a} +val "reg_deref" : forall ('a : Type). register('a) -> 'a effect {rreg} +/* sneaky deref with no effect necessary for bitfield writes */ +val _reg_deref = "reg_deref" : forall ('a : Type). register('a) -> 'a + +/* this is here because if we don't have it before including vector_dec + we get infinite loops caused by interaction with bool/bit casts */ +/* val eq_bit2 = "eq_bit" : (bit, bit) -> bool */ +/* overload operator == = {eq_bit2} */ + + + +$include <smt.sail> +$include <flow.sail> +$include <arith.sail> +$include <option.sail> +$include <vector_dec.sail> -type bits ('n : Int) = vector('n, dec, bit) infix 7 >> infix 7 << +infix 7 ^^ val operator >> = "shift_bits_right" : forall 'n 'm. (bits('n), bits('m)) -> bits('n) val operator << = "shift_bits_left" : forall 'n 'm. (bits('n), bits('m)) -> bits('n) - -infix 7 ^^ - val replicate_bits = "replicate_bits" : forall 'n 'm. (bits('n), atom('m)) -> bits('n * 'm) val operator ^^ = "replicate_bits" : forall 'n 'm. (bits('n), atom('m)) -> bits('n * 'm) @@ -33,4 +45,73 @@ function operator <_s (x, y) = signed(x) < signed(y) function operator >=_s (x, y) = signed(x) >= signed(y) function operator <_u (x, y) = unsigned(x) < unsigned(y) function operator >=_u (x, y) = unsigned(x) >= unsigned(y) -function operator <=_u (x, y) = unsigned(x) <= unsigned(y)
\ No newline at end of file +function operator <=_u (x, y) = unsigned(x) <= unsigned(y) + +val pow2_atom = "pow2" : forall 'n. atom('n) -> atom(2 ^ 'n) +val pow2_int = "pow2" : int -> int + +overload pow2 = {pow2_atom, pow2_int} + + +val cast cast_bool_bit : bool -> bit +function cast_bool_bit(b) = + match b { + true => bitzero, + false => bitone + } + +val cast cast_bit_bool : bit -> bool +function cast_bit_bool (b) = + match b { + bitzero => false, + bitone => true + } + + + + +val and_bits = {c:"and_bits", _: "and_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n) + +overload operator & = {and_bool, and_bits} + + +val not_vec = {c:"not_bits", _:"not_vec"} : forall 'n. bits('n) -> bits('n) + +overload ~ = {not_bool, not_vec} + +val eq_anything = {ocaml: "(fun (x, y) -> x = y)", lem: "eq", coq: "generic_eq", _:"eq_anything"} : forall ('a : Type). ('a, 'a) -> bool +overload operator == = {eq_anything} + + +val neq_vec = {lem: "neq"} : forall 'n. (bits('n), bits('n)) -> bool +function neq_vec (x, y) = not_bool(eq_bits(x, y)) + +val neq_anything = {lem: "neq", coq: "generic_neq"} : forall ('a : Type). ('a, 'a) -> bool +function neq_anything (x, y) = not_bool(x == y) + +overload operator != = {neq_atom, neq_int, neq_vec, neq_anything} + + +/* val reg_index : reg_size -> reg_index */ +/* function reg_index x = unsigned(x) */ + + +val quotient_nat = {ocaml: "quotient", lem: "integerDiv"} : (nat, nat) -> nat +val quotient = {ocaml: "quotient", lem: "integerDiv"} : (int, int) -> int +overload quot = {quotient_nat, quotient} + + +val __raw_GetSlice_int = "get_slice_int" : forall 'w, 'w >= 0. (atom('w), int, int) -> bits('w) + +val __GetSlice_int : forall 'n, 'n >= 0. (atom('n), int, int) -> bits('n) +function __GetSlice_int (n, m, o) = __raw_GetSlice_int(n, m, o) + +val to_bits : forall 'l, 'l >= 0.(implicit('l), int) -> bits('l) +function to_bits (l, n) = __raw_GetSlice_int(l, n, 0) + + +val xor_vec = {c: "xor_bits", _: "xor_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n) + +val int_power = {ocaml: "int_power", lem: "pow", coq: "pow", c: "pow_int"} : (int, int) -> int + +overload operator ^ = {xor_vec, int_power, concat_str} diff --git a/doc/Makefile b/doc/Makefile index 981463ca..7afebdf2 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -85,5 +85,5 @@ clean: -rm -f code_riscv.tex -rm -rf code_myreplicatebits/ -rm -f code_myreplicatebits.tex - -rm *.aux - -rm *.log + -rm -f *.aux + -rm -f *.log diff --git a/doc/examples/my_replicate_bits.sail b/doc/examples/my_replicate_bits.sail index 8c3c9458..9334163b 100644 --- a/doc/examples/my_replicate_bits.sail +++ b/doc/examples/my_replicate_bits.sail @@ -44,9 +44,9 @@ function my_replicate_bits_2(n, xs) = { ys } -val cast extz : forall 'n 'm, 'm >= 'n. bits('n) -> bits('m) +val cast extz : forall 'n 'm, 'm >= 'n. (implicit('m), bits('n)) -> bits('m) -function extz(xs) = zero_extend(xs, 'm) +function extz(m, xs) = zero_extend(xs, m) val my_replicate_bits_3 : forall 'n 'm, 'm >= 1 & 'n >= 1. (int('n), bits('m)) -> bits('n * 'm) diff --git a/doc/riscv.tex b/doc/riscv.tex index 586efdf4..ee0c07e1 100644 --- a/doc/riscv.tex +++ b/doc/riscv.tex @@ -61,7 +61,7 @@ register Xs : vector(32, dec, xlen_t) \sailval{wX} \sailfn{wX} -\sailoverloadHHX +\sailoverloadIIX We also give a function \ll{MEMr} for reading memory, this function just points at a builtin we have defined elsewhere. Note that @@ -138,4 +138,4 @@ end execute The actual code for this example, as well as our more complete \riscv\ specification can be found on our github at -\anonymise{\url{https://github.com/rems-project/sail/blob/sail2/riscv/riscv_duopod.sail}}. +\anonymise{\url{https://github.com/rems-project/sail-riscv/blob/master/model/riscv_duopod.sail}}. diff --git a/doc/tutorial.tex b/doc/tutorial.tex index dcac1c13..9bf47f5b 100644 --- a/doc/tutorial.tex +++ b/doc/tutorial.tex @@ -372,7 +372,7 @@ Sail allows numerous ways to match on bitvectors, for example: \begin{lstlisting} match v { 0xFF => print("hex match"), - 0x0000_0001 => print("binary match"), + 0b0000_0001 => print("binary match"), 0xF @ v : bits(4) => print("vector concatenation pattern"), 0xF @ [bitone, _, b1, b0] => print("vector pattern"), _ : bits(4) @ v : bits(4) => print("annotated wildcard pattern") @@ -499,7 +499,7 @@ written \end{lstlisting} but it would not have type-checked. The reason for this is if a mutable variable is declared without a type, Sail will try to infer -the most specific type from the left hand side of the +the most specific type from the right hand side of the expression. However, in this case Sail will infer the type as \ll{int(3)} and will therefore complain when we try to reassign it to \ll{2}, as the type \ll{int(2)} is not a subtype of \ll{int(3)}. We @@ -629,7 +629,7 @@ implicitly dereference registers if that semantics is desired for a specific specification that makes heavy use of register references, like so: \begin{lstlisting} -val cast auto_reg_deref = "reg_deref" : forall ('a : Type). register('a) -> a effect {rreg} +val cast auto_reg_deref = "reg_deref" : forall ('a : Type). register('a) -> 'a effect {rreg} \end{lstlisting} @@ -687,7 +687,7 @@ the names of its fields. \subsubsection{Unions} \label{sec:union} -As an example, the \ll{maybe} type \'{a} la Haskell could be defined +As an example, the \ll{maybe} type \`{a} la Haskell could be defined in Sail as follows: \begin{lstlisting} union maybe ('a : Type) = { diff --git a/language/bytecode.ott b/language/bytecode.ott index d2580e8c..cc329e02 100644 --- a/language/bytecode.ott +++ b/language/bytecode.ott @@ -66,7 +66,7 @@ fragment :: 'F_' ::= ctyp :: 'CT_' ::= {{ com C type }} - | mpz_t :: :: int + | mpz_t :: :: lint % Arbitrary precision GMP integer, mpz_t in C. | bv_t ( bool ) :: :: lbits % Variable length bitvector - flag represents direction, true - dec or false - inc @@ -75,7 +75,7 @@ ctyp :: 'CT_' ::= | 'uint64_t' ( nat , bool ) :: :: fbits % Fixed length bitvector that fits within a 64-bit word. - int % represents length, and flag is the same as CT_bv. - | 'int64_t' :: :: int64 + | 'int64_t' nat :: :: fint % Used for (signed) integers that fit within 64-bits. | unit_t :: :: unit % unit is a value in sail, so we represent it as a one element type diff --git a/language/sail.ott b/language/sail.ott index cc97973c..b3df66bb 100644 --- a/language/sail.ott +++ b/language/sail.ott @@ -773,7 +773,7 @@ rec_opt :: 'Rec_' ::= effect_opt :: 'Effect_opt_' ::= {{ com optional effect annotation for functions }} {{ aux _ l }} - | :: :: pure {{ com sugar for empty effect set }} + | :: :: none {{ com no effect annotation }} | effectkw effect :: :: effect % Generate a pexp, but from slightly different syntax (= rather than ->) @@ -1460,7 +1460,8 @@ void get_time_ns(sail_int *rop, const unit u) // ARM specific optimisations -void arm_align(lbits *rop, const lbits x_bv, const sail_int y_mpz) { +void arm_align(lbits *rop, const lbits x_bv, const sail_int y_mpz) +{ uint64_t x = mpz_get_ui(*x_bv.bits); uint64_t y = mpz_get_ui(y_mpz); uint64_t z = y * (x / y); @@ -1468,3 +1469,14 @@ void arm_align(lbits *rop, const lbits x_bv, const sail_int y_mpz) { mpz_set_ui(*rop->bits, safe_rshift(UINT64_MAX, 64l - (n - 1)) & z); rop->len = n; } + +// Monomorphisation +void make_the_value(sail_int *rop, const sail_int op) +{ + mpz_set(*rop, op); +} + +void size_itself_int(sail_int *rop, const sail_int op) +{ + mpz_set(*rop, op); +} @@ -153,6 +153,9 @@ SAIL_INT_FUNCTION(pow_int, sail_int, const sail_int, const sail_int); SAIL_INT_FUNCTION(pow2, sail_int, const sail_int); +void make_the_value(sail_int *, const sail_int); +void size_itself_int(sail_int *, const sail_int); + /* ***** Sail bitvectors ***** */ typedef uint64_t fbits; Binary files differ@@ -1,6 +1,6 @@ opam-version: "1.2" name: "sail" -version: "0.7.1" +version: "0.8" maintainer: "Sail Devs <cl-sail-dev@lists.cam.ac.uk>" authors: [ "Alasdair Armstrong" @@ -36,5 +36,6 @@ depends: [ "conf-gmp" "conf-zlib" "base64" + "yojson" ] available: [ocaml-version >= "4.06.0"] diff --git a/src/Makefile b/src/Makefile index e29a1ef0..abf49423 100644 --- a/src/Makefile +++ b/src/Makefile @@ -96,7 +96,7 @@ else echo let dir=\"$(SHARE_DIR)\" >> manifest.ml echo let commit=\"opam\" >> manifest.ml echo let branch=\"sail2\" >> manifest.ml - echo let version=\"0.8\" >> manifest.ml + echo let version=\"$(shell grep '^version:' ../opam | grep -o -E '"[^"]+"')\" >> manifest.ml endif sail: ast.ml bytecode.ml manifest.ml @@ -436,6 +436,8 @@ and pp_aval = function let ae_lit lit typ = AE_val (AV_lit (lit, typ)) +let is_dead_aexp (AE_aux (_, env, _)) = prove __POS__ env nc_false + (** GLOBAL: gensym_counter is used to generate fresh identifiers where needed. It should be safe to reset between top level definitions. **) diff --git a/src/anf.mli b/src/anf.mli index 5e162b7c..6b9c9b51 100644 --- a/src/anf.mli +++ b/src/anf.mli @@ -112,6 +112,8 @@ val apat_globals : 'a apat -> (id * 'a) list val apat_types : 'a apat -> 'a Bindings.t +val is_dead_aexp : 'a aexp -> bool + (* Compiling to ANF expressions *) val anf_pat : ?global:bool -> tannot pat -> typ apat diff --git a/src/ast_util.ml b/src/ast_util.ml index 04b76a61..548cd15e 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -116,7 +116,7 @@ let mk_qi_kopt kopt = QI_aux (QI_id kopt, Parse_ast.Unknown) let mk_fundef funcls = let tannot_opt = Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown) in - let effect_opt = Effect_opt_aux (Effect_opt_pure, Parse_ast.Unknown) in + let effect_opt = Effect_opt_aux (Effect_opt_none, Parse_ast.Unknown) in let rec_opt = Rec_aux (Rec_nonrec, Parse_ast.Unknown) in DEF_fundef (FD_aux (FD_function (rec_opt, tannot_opt, effect_opt, funcls), no_annot)) @@ -129,7 +129,7 @@ let mk_val_spec vs_aux = let kopt_kid (KOpt_aux (KOpt_kind (_, kid), _)) = kid let kopt_kind (KOpt_aux (KOpt_kind (k, _), _)) = k -let is_nat_kopt = function +let is_int_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true | _ -> false @@ -417,11 +417,16 @@ let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown) let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown) let nc_lt n1 n2 = nc_lteq (nsum n1 (nint 1)) n2 let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nint 1)) -let nc_or nc1 nc2 = mk_nc (NC_or (nc1, nc2)) let nc_var kid = mk_nc (NC_var kid) let nc_true = mk_nc NC_true let nc_false = mk_nc NC_false +let nc_or nc1 nc2 = + match nc1, nc2 with + | _, NC_aux (NC_false, _) -> nc1 + | NC_aux (NC_false, _), _ -> nc2 + | _, _ -> mk_nc (NC_or (nc1, nc2)) + let nc_and nc1 nc2 = match nc1, nc2 with | _, NC_aux (NC_true, _) -> nc1 @@ -439,7 +444,7 @@ let arg_kopt (KOpt_aux (KOpt_kind (K_aux (k, _), v), l)) = | K_order -> arg_order (Ord_aux (Ord_var v, l)) | K_bool -> arg_bool (nc_var v) | K_type -> arg_typ (mk_typ (Typ_var v)) - + let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc])) let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown) @@ -1355,6 +1360,7 @@ and undefined_of_typ_args mwords l annot (A_aux (typ_arg_aux, _) as typ_arg) = match typ_arg_aux with | A_nexp n -> [E_aux (E_sizeof n, (l, annot (atom_typ n)))] | A_typ typ -> [undefined_of_typ mwords l annot typ] + | A_bool nc -> [E_aux (E_constraint nc, (l, annot (atom_bool_typ nc)))] | A_order _ -> [] let destruct_pexp (Pat_aux (pexp,ann)) = @@ -1870,6 +1876,12 @@ let rec find_annot_exp sl (E_aux (aux, (l, annot)) as exp) = option_chain (find_annot_lexp sl lexp) (find_annot_exp sl exp) | E_var (lexp, exp1, exp2) -> option_chain (find_annot_lexp sl lexp) (option_mapm (find_annot_exp sl) [exp1; exp2]) + | E_if (cond_exp, then_exp, else_exp) -> + option_mapm (find_annot_exp sl) [cond_exp; then_exp; else_exp] + | E_case (exp, cases) | E_try (exp, cases) -> + option_chain (find_annot_exp sl exp) (option_mapm (find_annot_pexp sl) cases) + | E_return exp | E_cast (_, exp) -> + find_annot_exp sl exp | _ -> None in match result with @@ -1896,6 +1908,8 @@ and find_annot_lexp sl (LEXP_aux (aux, (l, annot))) = and find_annot_pat sl (P_aux (aux, (l, annot))) = if not (subloc sl l) then None else let result = match aux with + | P_vector_concat pats -> + option_mapm (find_annot_pat sl) pats | _ -> None in match result with @@ -1906,32 +1920,43 @@ and find_annot_pexp sl (Pat_aux (aux, (l, annot))) = if not (subloc sl l) then None else match aux with | Pat_exp (pat, exp) -> - find_annot_exp sl exp + option_chain (find_annot_pat sl pat) (find_annot_exp sl exp) | Pat_when (pat, guard, exp) -> - None + option_chain (find_annot_pat sl pat) (option_mapm (find_annot_exp sl) [guard; exp]) let find_annot_funcl sl (FCL_aux (FCL_Funcl (id, pexp), (l, annot))) = - if not (subloc sl l) then - None - else + if not (subloc sl l) then None else match find_annot_pexp sl pexp with | None -> Some (l, annot) | result -> result let find_annot_fundef sl (FD_aux (FD_function (_, _, _, funcls), (l, annot))) = - if not (subloc sl l) then - None - else + if not (subloc sl l) then None else match option_mapm (find_annot_funcl sl) funcls with | None -> Some (l, annot) | result -> result +let find_annot_scattered sl (SD_aux (aux, (l, annot))) = + if not (subloc sl l) then None else + let result = match aux with + | SD_funcl fcl -> find_annot_funcl sl fcl + | _ -> None + in + match result with + | None -> Some (l, annot) + | _ -> result + let rec find_annot_defs sl = function | DEF_fundef fdef :: defs -> begin match find_annot_fundef sl fdef with | None -> find_annot_defs sl defs | result -> result end + | DEF_scattered sdef :: defs -> + begin match find_annot_scattered sl sdef with + | None -> find_annot_defs sl defs + | result -> result + end | _ :: defs -> find_annot_defs sl defs | [] -> None diff --git a/src/ast_util.mli b/src/ast_util.mli index a2466326..ae63aca7 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -109,7 +109,7 @@ val dec_ord : order (* Utilites for working with kinded_ids *) val kopt_kid : kinded_id -> kid val kopt_kind : kinded_id -> kind -val is_nat_kopt : kinded_id -> bool +val is_int_kopt : kinded_id -> bool val is_order_kopt : kinded_id -> bool val is_typ_kopt : kinded_id -> bool val is_bool_kopt : kinded_id -> bool diff --git a/src/bytecode_util.ml b/src/bytecode_util.ml index 3ced48b6..489bcc64 100644 --- a/src/bytecode_util.ml +++ b/src/bytecode_util.ml @@ -252,14 +252,14 @@ and string_of_fragment' ?zencode:(zencode=true) f = (* String representation of ctyps here is only for debugging and intermediate language pretty-printer. *) and string_of_ctyp = function - | CT_int -> "int" + | CT_lint -> "int" | CT_lbits true -> "lbits(dec)" | CT_lbits false -> "lbits(inc)" | CT_fbits (n, true) -> "fbits(" ^ string_of_int n ^ ", dec)" | CT_fbits (n, false) -> "fbits(" ^ string_of_int n ^ ", int)" | CT_sbits true -> "sbits(dec)" | CT_sbits false -> "sbits(inc)" - | CT_int64 -> "int64" + | CT_fint n -> "int(" ^ string_of_int n ^ ")" | CT_bit -> "bit" | CT_unit -> "unit" | CT_bool -> "bool" @@ -276,14 +276,14 @@ and string_of_ctyp = function (** This function is like string_of_ctyp, but recursively prints all constructors in variants and structs. Used for debug output. *) and full_string_of_ctyp = function - | CT_int -> "int" + | CT_lint -> "int" | CT_lbits true -> "lbits(dec)" | CT_lbits false -> "lbits(inc)" | CT_fbits (n, true) -> "fbits(" ^ string_of_int n ^ ", dec)" | CT_fbits (n, false) -> "fbits(" ^ string_of_int n ^ ", int)" | CT_sbits true -> "sbits(dec)" | CT_sbits false -> "sbits(inc)" - | CT_int64 -> "int64" + | CT_fint n -> "int(" ^ string_of_int n ^ ")" | CT_bit -> "bit" | CT_unit -> "unit" | CT_bool -> "bool" @@ -303,7 +303,7 @@ and full_string_of_ctyp = function | CT_poly -> "*" let rec map_ctyp f = function - | (CT_int | CT_int64 | CT_lbits _ | CT_fbits _ | CT_sbits _ + | (CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string | CT_poly | CT_enum _) as ctyp -> f ctyp | CT_tup ctyps -> f (CT_tup (List.map (map_ctyp f) ctyps)) | CT_ref ctyp -> f (CT_ref (map_ctyp f ctyp)) @@ -314,12 +314,12 @@ let rec map_ctyp f = function let rec ctyp_equal ctyp1 ctyp2 = match ctyp1, ctyp2 with - | CT_int, CT_int -> true + | CT_lint, CT_lint -> true | CT_lbits d1, CT_lbits d2 -> d1 = d2 | CT_sbits d1, CT_sbits d2 -> d1 = d2 | CT_fbits (m1, d1), CT_fbits (m2, d2) -> m1 = m2 && d1 = d2 | CT_bit, CT_bit -> true - | CT_int64, CT_int64 -> true + | CT_fint n, CT_fint m -> n = m | CT_unit, CT_unit -> true | CT_bool, CT_bool -> true | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0 @@ -353,11 +353,11 @@ let rec ctyp_unify ctyp1 ctyp2 = | _, _ -> raise (Invalid_argument "ctyp_unify") let rec ctyp_suprema = function - | CT_int -> CT_int + | CT_lint -> CT_lint | CT_lbits d -> CT_lbits d | CT_fbits (_, d) -> CT_lbits d | CT_sbits d -> CT_lbits d - | CT_int64 -> CT_int + | CT_fint _ -> CT_lint | CT_unit -> CT_unit | CT_bool -> CT_bool | CT_real -> CT_real @@ -382,7 +382,7 @@ let rec ctyp_ids = function IdSet.add id (List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors) | CT_tup ctyps -> List.fold_left (fun ids ctyp -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctyps | CT_vector (_, ctyp) | CT_list ctyp | CT_ref ctyp -> ctyp_ids ctyp - | CT_int | CT_int64 | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit + | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit | CT_string | CT_poly -> IdSet.empty let rec unpoly = function @@ -394,7 +394,7 @@ let rec unpoly = function | f -> f let rec is_polymorphic = function - | CT_int | CT_int64 | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string -> false + | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string -> false | CT_tup ctyps -> List.exists is_polymorphic ctyps | CT_enum _ -> false | CT_struct (_, ctors) | CT_variant (_, ctors) -> List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors diff --git a/src/c_backend.ml b/src/c_backend.ml index a1050972..aff2d49e 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -110,7 +110,7 @@ type ctx = letbinds : int list; recursive_functions : IdSet.t; no_raw : bool; - optimize_z3 : bool; + optimize_smt : bool; } let initial_ctx env = @@ -123,26 +123,28 @@ let initial_ctx env = letbinds = []; recursive_functions = IdSet.empty; no_raw = false; - optimize_z3 = true; + optimize_smt = true; } (** Convert a sail type into a C-type. This function can be quite - slow, because it uses ctx.local_env and Z3 to analyse the Sail + slow, because it uses ctx.local_env and SMT to analyse the Sail types and attempts to fit them into the smallest possible C - types, provided ctx.optimize_z3 is true (default) **) + types, provided ctx.optimize_smt is true (default) **) let rec ctyp_of_typ ctx typ = let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in match typ_aux with | Typ_id id when string_of_id id = "bit" -> CT_bit | Typ_id id when string_of_id id = "bool" -> CT_bool - | Typ_id id when string_of_id id = "int" -> CT_int - | Typ_id id when string_of_id id = "nat" -> CT_int + | Typ_id id when string_of_id id = "int" -> CT_lint + | Typ_id id when string_of_id id = "nat" -> CT_lint | Typ_id id when string_of_id id = "unit" -> CT_unit | Typ_id id when string_of_id id = "string" -> CT_string | Typ_id id when string_of_id id = "real" -> CT_real | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool + | Typ_app (id, args) when string_of_id id = "itself" -> + ctyp_of_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" -> begin match destruct_range Env.empty typ with | None -> assert false (* Checked if range type in guard *) @@ -150,13 +152,13 @@ let rec ctyp_of_typ ctx typ = match nexp_simp n, nexp_simp m with | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) when Big_int.less_equal min_int64 n && Big_int.less_equal m max_int64 -> - CT_int64 - | n, m when ctx.optimize_z3 -> + CT_fint 64 + | n, m when ctx.optimize_smt -> if prove __POS__ ctx.local_env (nc_lteq (nconstant min_int64) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant max_int64)) then - CT_int64 + CT_fint 64 else - CT_int - | _ -> CT_int + CT_lint + | _ -> CT_lint end | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> @@ -173,7 +175,7 @@ let rec ctyp_of_typ ctx typ = let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in begin match nexp_simp n with | Nexp_aux (Nexp_constant n, _) when Big_int.less_equal n (Big_int.of_int 64) -> CT_fbits (Big_int.to_int n, direction) - | n when ctx.optimize_z3 && prove __POS__ ctx.local_env (nc_lteq n (nint 64)) -> CT_sbits direction + | n when ctx.optimize_smt && prove __POS__ ctx.local_env (nc_lteq n (nint 64)) -> CT_sbits direction | _ -> CT_lbits direction end @@ -193,8 +195,8 @@ let rec ctyp_of_typ ctx typ = | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) - | Typ_exist _ when ctx.optimize_z3 -> - (* Use Type_check.destruct_exist when optimising with z3, to + | Typ_exist _ when ctx.optimize_smt -> + (* Use Type_check.destruct_exist when optimising with SMT, to ensure that we don't cause any type variable clashes in local_env, and that we can optimize the existential based upon it's constraints. *) @@ -212,8 +214,9 @@ let rec ctyp_of_typ ctx typ = | _ -> c_error ~loc:l ("No C type for type " ^ string_of_typ typ) let rec is_stack_ctyp ctyp = match ctyp with - | CT_fbits _ | CT_sbits _ | CT_int64 | CT_bit | CT_unit | CT_bool | CT_enum _ -> true - | CT_lbits _ | CT_int | CT_real | CT_string | CT_list _ | CT_vector _ -> false + | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_enum _ -> true + | CT_fint n -> n <= 64 + | CT_lbits _ | CT_lint | CT_real | CT_string | CT_list _ | CT_vector _ -> false | CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields | CT_variant (_, ctors) -> false (* List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors *) (* FIXME *) | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps @@ -262,7 +265,7 @@ let hex_char = let literal_to_fragment (L_aux (l_aux, _) as lit) = match l_aux with | L_num n when Big_int.less_equal min_int64 n && Big_int.less_equal n max_int64 -> - Some (F_lit (V_int n), CT_int64) + Some (F_lit (V_int n), CT_fint 64) | L_hex str when String.length str <= 16 -> let padding = 16 - String.length str in let padding = Util.list_init padding (fun _ -> Sail2_values.B0) in @@ -397,7 +400,7 @@ let rec analyze_functions ctx f (AE_aux (aexp, env, l)) = let aexp3 = analyze_functions ctx f aexp3 in let aexp4 = analyze_functions ctx f aexp4 in (* Currently we assume that loop indexes are always safe to put into an int64 *) - let ctx = { ctx with locals = Bindings.add id (Immutable, CT_int64) ctx.locals } in + let ctx = { ctx with locals = Bindings.add id (Immutable, CT_fint 64) ctx.locals } in AE_for (id, aexp1, aexp2, aexp3, order, aexp4) | AE_case (aval, cases, typ) -> @@ -542,14 +545,14 @@ let analyze_primop' ctx id args typ = match nexp_simp n, nexp_simp m with | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) when Big_int.less_equal min_int64 n && Big_int.less_equal m max_int64 -> - AE_val (AV_C_fragment (F_op (op1, "+", op2), typ, CT_int64)) + AE_val (AV_C_fragment (F_op (op1, "+", op2), typ, CT_fint 64)) | n, m when prove __POS__ ctx.local_env (nc_lteq (nconstant min_int64) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant max_int64)) -> - AE_val (AV_C_fragment (F_op (op1, "+", op2), typ, CT_int64)) + AE_val (AV_C_fragment (F_op (op1, "+", op2), typ, CT_fint 64)) | _ -> no_change end | "neg_int", [AV_C_fragment (frag, _, _)] -> - AE_val (AV_C_fragment (F_op (v_int 0, "-", frag), typ, CT_int64)) + AE_val (AV_C_fragment (F_op (v_int 0, "-", frag), typ, CT_fint 64)) | "replicate_bits", [AV_C_fragment (vec, vtyp, _); AV_C_fragment (times, _, _)] -> begin match destruct_vector ctx.tc_env typ, destruct_vector ctx.tc_env vtyp with @@ -736,15 +739,15 @@ let rec compile_aval l ctx = function | AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal min_int64 n && Big_int.less_equal n max_int64 -> let gs = gensym () in - [iinit CT_int gs (F_lit (V_int n), CT_int64)], - (F_id gs, CT_int), - [iclear CT_int gs] + [iinit CT_lint gs (F_lit (V_int n), CT_fint 64)], + (F_id gs, CT_lint), + [iclear CT_lint gs] | AV_lit (L_aux (L_num n, _), typ) -> let gs = gensym () in - [iinit CT_int gs (F_lit (V_string (Big_int.to_string n)), CT_string)], - (F_id gs, CT_int), - [iclear CT_int gs] + [iinit CT_lint gs (F_lit (V_string (Big_int.to_string n)), CT_string)], + (F_id gs, CT_lint), + [iclear CT_lint gs] | AV_lit (L_aux (L_zero, _), _) -> [], (F_lit (V_bit Sail2_values.B0), CT_bit), [] | AV_lit (L_aux (L_one, _), _) -> [], (F_lit (V_bit Sail2_values.B1), CT_bit), [] @@ -868,11 +871,11 @@ let rec compile_aval l ctx = function setup @ [iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_update") - [(F_id gs, vector_ctyp); (F_lit (V_int (Big_int.of_int i)), CT_int64); cval]] + [(F_id gs, vector_ctyp); (F_lit (V_int (Big_int.of_int i)), CT_fint 64); cval]] @ cleanup in [idecl vector_ctyp gs; - iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [(F_lit (V_int (Big_int.of_int len)), CT_int64)]] + iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [(F_lit (V_int (Big_int.of_int len)), CT_fint 64)]] @ List.concat (List.mapi aval_set (if direction then List.rev avals else avals)), (F_id gs, vector_ctyp), [iclear vector_ctyp gs] @@ -1083,7 +1086,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let compile_case (apat, guard, body) = let trivial_guard = match guard with | AE_aux (AE_val (AV_lit (L_aux (L_true, _), _)), _, _) - | AE_aux (AE_val (AV_C_fragment (F_lit (V_bool true), _, _)), _, _) -> true + | AE_aux (AE_val (AV_C_fragment (F_lit (V_bool true), _, _)), _, _) -> true | _ -> false in let case_label = label "case_" in @@ -1103,7 +1106,10 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = @ body_setup @ [body_call (CL_id (case_return_id, ctyp))] @ body_cleanup @ destructure_cleanup @ [igoto finish_match_label] in - [iblock case_instrs; ilabel case_label] + if is_dead_aexp body then + [ilabel case_label] + else + [iblock case_instrs; ilabel case_label] in [icomment "begin match"] @ aval_setup @ [idecl ctyp case_return_id] @@ -1159,18 +1165,23 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = [] | AE_if (aval, then_aexp, else_aexp, if_typ) -> - let if_ctyp = ctyp_of_typ ctx if_typ in - let compile_branch aexp = - let setup, call, cleanup = compile_aexp ctx aexp in - fun clexp -> setup @ [call clexp] @ cleanup - in - let setup, cval, cleanup = compile_aval l ctx aval in - setup, - (fun clexp -> iif cval - (compile_branch then_aexp clexp) - (compile_branch else_aexp clexp) - if_ctyp), - cleanup + if is_dead_aexp then_aexp then + compile_aexp ctx else_aexp + else if is_dead_aexp else_aexp then + compile_aexp ctx then_aexp + else + let if_ctyp = ctyp_of_typ ctx if_typ in + let compile_branch aexp = + let setup, call, cleanup = compile_aexp ctx aexp in + fun clexp -> setup @ [call clexp] @ cleanup + in + let setup, cval, cleanup = compile_aval l ctx aval in + setup, + (fun clexp -> iif cval + (compile_branch then_aexp clexp) + (compile_branch else_aexp clexp) + if_ctyp), + cleanup (* FIXME: AE_record_update could be AV_record_update - would reduce some copying. *) | AE_record_update (aval, fields, typ) -> @@ -1332,8 +1343,8 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = cleanup | AE_for (loop_var, loop_from, loop_to, loop_step, Ord_aux (ord, _), body) -> - (* We assume that all loop indices are safe to put in a CT_int64. *) - let ctx = { ctx with locals = Bindings.add loop_var (Immutable, CT_int64) ctx.locals } in + (* We assume that all loop indices are safe to put in a CT_fint. *) + let ctx = { ctx with locals = Bindings.add loop_var (Immutable, CT_fint 64) ctx.locals } in let is_inc = match ord with | Ord_inc -> true @@ -1349,8 +1360,8 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let step_setup, step_call, step_cleanup = compile_aexp ctx loop_step in let step_gs = gensym () in let variable_init gs setup call cleanup = - [idecl CT_int64 gs; - iblock (setup @ [call (CL_id (gs, CT_int64))] @ cleanup)] + [idecl (CT_fint 64) gs; + iblock (setup @ [call (CL_id (gs, CT_fint 64))] @ cleanup)] in let loop_start_label = label "for_start_" in @@ -1361,16 +1372,16 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = variable_init from_gs from_setup from_call from_cleanup @ variable_init to_gs to_setup to_call to_cleanup @ variable_init step_gs step_setup step_call step_cleanup - @ [iblock ([idecl CT_int64 loop_var; - icopy l (CL_id (loop_var, CT_int64)) (F_id from_gs, CT_int64); + @ [iblock ([idecl (CT_fint 64) loop_var; + icopy l (CL_id (loop_var, (CT_fint 64))) (F_id from_gs, (CT_fint 64)); idecl CT_unit body_gs; iblock ([ilabel loop_start_label] @ [ijump (F_op (F_id loop_var, (if is_inc then ">" else "<"), F_id to_gs), CT_bool) loop_end_label] @ body_setup @ [body_call (CL_id (body_gs, CT_unit))] @ body_cleanup - @ [icopy l (CL_id (loop_var, CT_int64)) - (F_op (F_id loop_var, (if is_inc then "+" else "-"), F_id step_gs), CT_int64)] + @ [icopy l (CL_id (loop_var, (CT_fint 64))) + (F_op (F_id loop_var, (if is_inc then "+" else "-"), F_id step_gs), (CT_fint 64))] @ [igoto loop_start_label]); ilabel loop_end_label])], (fun clexp -> icopy l clexp unit_fragment), @@ -1878,7 +1889,7 @@ let rec instrs_rename from_id to_id = | [] -> [] let hoist_ctyp = function - | CT_int | CT_lbits _ | CT_struct _ -> true + | CT_lint | CT_lbits _ | CT_struct _ -> true | _ -> false let hoist_counter = ref 0 @@ -2410,8 +2421,8 @@ let rec sgen_ctyp = function | CT_bool -> "bool" | CT_fbits _ -> "fbits" | CT_sbits _ -> "sbits" - | CT_int64 -> "mach_int" - | CT_int -> "sail_int" + | CT_fint _ -> "mach_int" + | CT_lint -> "sail_int" | CT_lbits _ -> "lbits" | CT_tup _ as tup -> "struct " ^ Util.zencode_string ("tuple_" ^ string_of_ctyp tup) | CT_struct (id, _) -> "struct " ^ sgen_id id @@ -2430,8 +2441,8 @@ let rec sgen_ctyp_name = function | CT_bool -> "bool" | CT_fbits _ -> "fbits" | CT_sbits _ -> "sbits" - | CT_int64 -> "mach_int" - | CT_int -> "sail_int" + | CT_fint _ -> "mach_int" + | CT_lint -> "sail_int" | CT_lbits _ -> "lbits" | CT_tup _ as tup -> Util.zencode_string ("tuple_" ^ string_of_ctyp tup) | CT_struct (id, _) -> sgen_id id @@ -2647,7 +2658,7 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = match ctyp with | CT_unit -> "UNIT", [] | CT_bit -> "UINT64_C(0)", [] - | CT_int64 -> "INT64_C(0xdeadc0de)", [] + | CT_fint _ -> "INT64_C(0xdeadc0de)", [] | CT_fbits _ -> "UINT64_C(0xdeadc0de)", [] | CT_sbits _ -> "undefined_sbits()", [] | CT_bool -> "false", [] @@ -3265,7 +3276,7 @@ let rec ctyp_dependencies = function | CT_ref ctyp -> ctyp_dependencies ctyp | CT_struct (_, ctors) -> List.concat (List.map (fun (_, ctyp) -> ctyp_dependencies ctyp) ctors) | CT_variant (_, ctors) -> List.concat (List.map (fun (_, ctyp) -> ctyp_dependencies ctyp) ctors) - | CT_int | CT_int64 | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit | CT_string | CT_enum _ | CT_poly -> [] + | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit | CT_string | CT_enum _ | CT_poly -> [] let codegen_ctg ctx = function | CTG_vector (direction, ctyp) -> codegen_vector ctx (direction, ctyp) diff --git a/src/constant_fold.ml b/src/constant_fold.ml index 031493a4..2c46f38b 100644 --- a/src/constant_fold.ml +++ b/src/constant_fold.ml @@ -188,6 +188,8 @@ let rec rewrite_constant_function_calls' env ast = | E_app (id, args) when List.for_all is_constant args -> evaluate e_aux annot + | E_cast (typ, (E_aux (E_lit _, _) as lit)) -> ok (); lit + | E_field (exp, id) when is_constant exp -> evaluate e_aux annot diff --git a/src/constraint.ml b/src/constraint.ml index b7fa50c3..5402f6f7 100644 --- a/src/constraint.ml +++ b/src/constraint.ml @@ -55,6 +55,76 @@ open Util let opt_smt_verbose = ref false +type solver = { + command : string; + header : string; + footer : string; + negative_literals : bool; + uninterpret_power : bool + } + +let cvc4_solver = { + command = "cvc4 -L smtlib2 --tlimit=2000"; + header = "(set-logic QF_UFNIA)\n"; + footer = ""; + negative_literals = false; + uninterpret_power = true + } + +let mathsat_solver = { + command = "mathsat"; + header = "(set-logic QF_UFLIA)\n"; + footer = ""; + negative_literals = false; + uninterpret_power = true + } + +let z3_solver = { + command = "z3 -t:1000 -T:10"; + (* Using push and pop is much faster, I believe because + incremental mode uses a different solver. *) + header = "(push)\n"; + footer = "(pop)\n"; + negative_literals = true; + uninterpret_power = false; + } + +let yices_solver = { + command = "yices-smt2 --timeout=2"; + header = "(set-logic QF_UFLIA)\n"; + footer = ""; + negative_literals = false; + uninterpret_power = true + } + +let vampire_solver = { + (* vampire sometimes likes to ignore its time limit *) + command = "timeout -s SIGKILL 3s vampire --time_limit 2s --input_syntax smtlib2 --mode smtcomp"; + header = ""; + footer = ""; + negative_literals = false; + uninterpret_power = true + } + +let alt_ergo_solver ={ + command = "alt-ergo"; + header = ""; + footer = ""; + negative_literals = false; + uninterpret_power = true + } + +let opt_solver = ref z3_solver + +let set_solver = function + | "z3" -> opt_solver := z3_solver + | "alt-ergo" -> opt_solver := alt_ergo_solver + | "cvc4" -> opt_solver := cvc4_solver + | "mathsat" -> opt_solver := mathsat_solver + | "vampire" -> opt_solver := vampire_solver + | "yices" -> opt_solver := yices_solver + | unknown -> prerr_endline ("Unrecognised SMT solver " ^ unknown) + (* SMTLIB v2.0 format is based on S-expressions so we have a lightweight representation of those here. *) type sexpr = List of (sexpr list) | Atom of string @@ -101,6 +171,9 @@ let to_smt l vars constr = match aux with | Nexp_id id -> Atom (Util.zencode_string (string_of_id id)) | Nexp_var v -> smt_var v + | Nexp_constant c + when Big_int.less_equal c (Big_int.of_int (-1)) && not !opt_solver.negative_literals -> + sfun "-" [Atom "0"; Atom (Big_int.to_string (Big_int.abs c))] | Nexp_constant c -> Atom (Big_int.to_string c) | Nexp_app (id, nexps) -> sfun (string_of_id id) (List.map smt_nexp nexps) | Nexp_times (nexp1, nexp2) -> sfun "*" [smt_nexp nexp1; smt_nexp nexp2] @@ -108,6 +181,7 @@ let to_smt l vars constr = | Nexp_minus (nexp1, nexp2) -> sfun "-" [smt_nexp nexp1; smt_nexp nexp2] | Nexp_exp (Nexp_aux (Nexp_constant c, _)) when Big_int.greater c Big_int.zero -> Atom (Big_int.to_string (Big_int.pow_int_positive 2 (Big_int.to_int c))) + | Nexp_exp nexp when !opt_solver.uninterpret_power -> sfun "sailexp" [smt_nexp nexp] | Nexp_exp nexp -> sfun "^" [Atom "2"; smt_nexp nexp] | Nexp_neg nexp -> sfun "-" [smt_nexp nexp] in @@ -137,12 +211,13 @@ let to_smt l vars constr = let smtlib_of_constraints ?get_model:(get_model=false) l vars constr : string * (kid -> sexpr) = let variables, problem, var_map = to_smt l vars constr in - "(push)\n" + !opt_solver.header ^ variables ^ "\n" + ^ (if !opt_solver.uninterpret_power then "(declare-fun sailexp (Int) Int)\n" else "") ^ pp_sexpr (sfun "define-fun" [Atom "constraint"; List []; Atom "Bool"; problem]) ^ "\n(assert constraint)\n(check-sat)" - ^ (if get_model then "\n(get-model)" else "") - ^ "\n(pop)", + ^ (if get_model then "\n(get-model)\n" else "\n") + ^ !opt_solver.footer, var_map type smt_result = Unknown | Sat | Unsat @@ -184,12 +259,12 @@ let save_digests () = DigestMap.iter output !known_problems; close_out out_chan -let call_z3' l vars constraints : smt_result = +let call_smt' l vars constraints : smt_result = let problems = [constraints] in - let z3_file, _ = smtlib_of_constraints l vars constraints in + let smt_file, _ = smtlib_of_constraints l vars constraints in if !opt_smt_verbose then - prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file) + prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" smt_file) else (); let rec input_lines chan = function @@ -202,7 +277,7 @@ let call_z3' l vars constraints : smt_result = end in - let digest = Digest.string z3_file in + let digest = Digest.string smt_file in try let result = DigestMap.find digest !known_problems in result @@ -210,45 +285,49 @@ let call_z3' l vars constraints : smt_result = | Not_found -> begin let (input_file, tmp_chan) = - try Filename.open_temp_file "constraint_" ".sat" with - | Sys_error msg -> raise (Reporting.err_general l ("Could not open temp file when calling Z3: " ^ msg)) + try Filename.open_temp_file "constraint_" ".smt2" with + | Sys_error msg -> raise (Reporting.err_general l ("Could not open temp file when calling SMT: " ^ msg)) in - output_string tmp_chan z3_file; + output_string tmp_chan smt_file; close_out tmp_chan; - let z3_output = + let smt_output = try - let z3_chan = Unix.open_process_in ("z3 -t:1000 -T:10 " ^ input_file) in - let z3_output = List.combine problems (input_lines z3_chan (List.length problems)) in - let _ = Unix.close_process_in z3_chan in - z3_output + let smt_out, smt_in, smt_err = Unix.open_process_full (!opt_solver.command ^ " " ^ input_file) (Unix.environment ()) in + let smt_output = + try List.combine problems (input_lines smt_out (List.length problems)) with + | End_of_file -> List.combine problems ["unknown"] + in + let _ = Unix.close_process_full (smt_out, smt_in, smt_err) in + smt_output with - | exn -> raise (Reporting.err_general l ("Error when calling z3: " ^ Printexc.to_string exn)) + | exn -> raise (Reporting.err_general l ("Error when calling smt: " ^ Printexc.to_string exn)) in Sys.remove input_file; try - let (problem, _) = List.find (fun (_, result) -> result = "unsat") z3_output in + let (problem, _) = List.find (fun (_, result) -> result = "unsat") smt_output in known_problems := DigestMap.add digest Unsat !known_problems; Unsat with | Not_found -> - let unsolved = List.filter (fun (_, result) -> result = "unknown") z3_output in + let unsolved = List.filter (fun (_, result) -> result = "unknown") smt_output in if unsolved == [] then (known_problems := DigestMap.add digest Sat !known_problems; Sat) else (known_problems := DigestMap.add digest Unknown !known_problems; Unknown) end -let call_z3 l vars constraints = - let t = Profile.start_z3 () in - let result = call_z3' l vars constraints in - Profile.finish_z3 t; +let call_smt l vars constraints = + let t = Profile.start_smt () in + let result = call_smt' l vars constraints in + Profile.finish_smt t; result -let rec solve_z3 l vars constraints var = - let z3_file, smt_var = smtlib_of_constraints ~get_model:true l vars constraints in - let z3_var = pp_sexpr (smt_var var) in +let solve_smt l vars constraints var = + let smt_file, smt_var = smtlib_of_constraints ~get_model:true l vars constraints in + let smt_var = pp_sexpr (smt_var var) in - (* prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file); - prerr_endline ("Solving for " ^ z3_var); *) + if !opt_smt_verbose then + prerr_endline (Printf.sprintf "SMTLIB2 constraints are (solve for %s): \n%s%!" smt_var smt_file) + else (); let rec input_all chan = try @@ -259,27 +338,45 @@ let rec solve_z3 l vars constraints var = End_of_file -> [] in - let (input_file, tmp_chan) = Filename.open_temp_file "constraint_" ".sat" in - output_string tmp_chan z3_file; + let (input_file, tmp_chan) = Filename.open_temp_file "constraint_" ".smt2" in + output_string tmp_chan smt_file; close_out tmp_chan; - let z3_output = + let smt_output = try - let z3_chan = Unix.open_process_in ("z3 -t:1000 -T:10 " ^ input_file) in - let z3_output = String.concat " " (input_all z3_chan) in - let _ = Unix.close_process_in z3_chan in - z3_output + let smt_chan = Unix.open_process_in ("z3 -t:1000 -T:10 " ^ input_file) in + let smt_output = String.concat " " (input_all smt_chan) in + let _ = Unix.close_process_in smt_chan in + smt_output with | exn -> - raise (Reporting.err_general l ("Got error when calling z3: " ^ Printexc.to_string exn)) + raise (Reporting.err_general l ("Got error when calling smt: " ^ Printexc.to_string exn)) in Sys.remove input_file; - let regexp = {|(define-fun |} ^ z3_var ^ {| () Int[ ]+\([0-9]+\))|} in + let regexp = {|(define-fun |} ^ smt_var ^ {| () Int[ ]+\([0-9]+\))|} in try - let _ = Str.search_forward (Str.regexp regexp) z3_output 0 in - let result = Big_int.of_string (Str.matched_group 1 z3_output) in - begin match call_z3 l vars (nc_and constraints (nc_neq (nconstant result) (nvar var))) with - | Unsat -> Some result - | _ -> None - end + let _ = Str.search_forward (Str.regexp regexp) smt_output 0 in + let result = Big_int.of_string (Str.matched_group 1 smt_output) in + Some result with - Not_found -> None + | Not_found -> None + +let solve_all_smt l vars constraints var = + let rec aux results = + let constraints = List.fold_left (fun ncs r -> (nc_and ncs (nc_neq (nconstant r) (nvar var)))) constraints results in + match solve_smt l vars constraints var with + | Some result -> aux (result :: results) + | None -> + match call_smt l vars constraints with + | Unsat -> Some results + | _ -> None + in + aux [] + +let solve_unique_smt l vars constraints var = + match solve_smt l vars constraints var with + | Some result -> + begin match call_smt l vars (nc_and constraints (nc_neq (nconstant result) (nvar var))) with + | Unsat -> Some result + | _ -> None + end + | None -> None diff --git a/src/constraint.mli b/src/constraint.mli index fa318c35..b5d6ff6b 100644 --- a/src/constraint.mli +++ b/src/constraint.mli @@ -54,11 +54,17 @@ open Ast_util val opt_smt_verbose : bool ref +val set_solver : string -> unit + type smt_result = Unknown | Sat | Unsat val load_digests : unit -> unit val save_digests : unit -> unit -val call_z3 : l -> kind_aux KBindings.t -> n_constraint -> smt_result +val call_smt : l -> kind_aux KBindings.t -> n_constraint -> smt_result + +val solve_smt : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num option + +val solve_all_smt : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num list option -val solve_z3 : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num option +val solve_unique_smt : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num option diff --git a/src/elf_loader.ml b/src/elf_loader.ml index c6fb0589..abe935b2 100644 --- a/src/elf_loader.ml +++ b/src/elf_loader.ml @@ -47,6 +47,10 @@ let opt_elf_threads = ref 1 let opt_elf_entry = ref Big_int.zero let opt_elf_tohost = ref Big_int.zero +(* the type of elf last loaded *) +type elf_class = ELF_Class_64 | ELF_Class_32 +let opt_elf_class = ref ELF_Class_64 (* default *) + type word8 = int let escape_char c = @@ -66,14 +70,16 @@ let break n xs = | (_ :: _ as xs) -> helper ([Lem_list.take n xs] @ acc) (Lem_list.drop n xs) in helper [] xs -let print_segment seg = - let bs = seg.Elf_interpreted_segment.elf64_segment_body in +let print_segment bs = prerr_endline "0011 2233 4455 6677 8899 aabb ccdd eeff 0123456789abcdef"; List.iter (fun bs -> prerr_endline (hex_line bs)) (break 16 (Byte_sequence.char_list_of_byte_sequence bs)) +type elf_segs = + | ELF64 of Elf_interpreted_segment.elf64_interpreted_segment list + | ELF32 of Elf_interpreted_segment.elf32_interpreted_segment list + let read name = let info = Sail_interface.populate_and_obtain_global_symbol_init_info name in - prerr_endline "Elf read:"; let (elf_file, elf_epi, symbol_map) = begin match info with @@ -87,20 +93,18 @@ let read name = (elf_file, elf_epi, symbol_map) end in - prerr_endline "\nElf segments:"; + + (* remove all the auto generated segments (they contain only 0s) *) + let prune_segments segs = + Lem_list.mapMaybe (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) segs in let (segments, e_entry, e_machine) = begin match elf_epi, elf_file with - | (Sail_interface.ELF_Class_32 _, _) -> failwith "cannot handle ELF_Class_32" - | (_, Elf_file.ELF_File_32 _) -> failwith "cannot handle ELF_File_32" - | (Sail_interface.ELF_Class_64 (segments, e_entry, e_machine), Elf_file.ELF_File_64 f1) -> - (* remove all the auto generated segments (they contain only 0s) *) - let segments = - Lem_list.mapMaybe - (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) - segments - in - (segments, e_entry, e_machine) + | (Sail_interface.ELF_Class_32 (segments, e_entry, e_machine), Elf_file.ELF_File_32 _) -> + (ELF32 (prune_segments segments), e_entry, e_machine) + | (Sail_interface.ELF_Class_64 (segments, e_entry, e_machine), Elf_file.ELF_File_64 _) -> + (ELF64 (prune_segments segments), e_entry, e_machine) + | (_, _) -> failwith "cannot handle ELF file" end in (segments, e_entry, symbol_map) @@ -120,24 +124,20 @@ let write_file chan paddr i byte = output_string chan (Big_int.to_string (Big_int.add paddr (Big_int.of_int i)) ^ "\n"); output_string chan (string_of_int byte ^ "\n") -let load_segment ?writer:(writer=write_sail_lib) seg = - let open Elf_interpreted_segment in - let bs = seg.elf64_segment_body in - let paddr = seg.elf64_segment_paddr in - let base = seg.elf64_segment_base in - let offset = seg.elf64_segment_offset in - let size = seg.elf64_segment_size in - let memsz = seg.elf64_segment_memsz in +let print_seg_info offset base paddr size memsz = prerr_endline "\nLoading Segment"; prerr_endline ("Segment offset: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 offset))); prerr_endline ("Segment base address: " ^ (Big_int.to_string base)); (* NB don't attempt to convert paddr to int64 because on MIPS it is quite likely to exceed signed - 64-bit range e.g. addresses beginning 0x9.... Really need to_uint64 or to_string_hex but lem + 64-bit range e.g. addresses beginning 0x9.... Really need to_uint64 or to_string_hex but lem doesn't have them. *) prerr_endline ("Segment physical address: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 paddr))); prerr_endline ("Segment size: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 size))); - prerr_endline ("Segment memsz: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 memsz))); - print_segment seg; + prerr_endline ("Segment memsz: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 memsz))) + +let load_segment ?writer:(writer=write_sail_lib) bs paddr base offset size memsz = + print_seg_info offset base paddr size memsz; + print_segment bs; List.iteri (writer paddr) (List.rev_map int_of_char (List.rev (Byte_sequence.char_list_of_byte_sequence bs))); write_mem_zeros (Big_int.add paddr size) (Big_int.sub memsz size) @@ -147,7 +147,32 @@ let load_elf ?writer:(writer=write_sail_lib) name = (if List.mem_assoc "tohost" symbol_map then let (_, _, tohost_addr, _, _) = List.assoc "tohost" symbol_map in opt_elf_tohost := tohost_addr); - List.iter (load_segment ~writer:writer) segments + (match segments with + | ELF64 segs -> + List.iter (fun seg -> + let open Elf_interpreted_segment in + let bs = seg.elf64_segment_body in + let paddr = seg.elf64_segment_paddr in + let base = seg.elf64_segment_base in + let offset = seg.elf64_segment_offset in + let size = seg.elf64_segment_size in + let memsz = seg.elf64_segment_memsz in + load_segment ~writer:writer bs paddr base offset size memsz) + segs; + opt_elf_class := ELF_Class_64 + | ELF32 segs -> + List.iter (fun seg -> + let open Elf_interpreted_segment in + let bs = seg.elf32_segment_body in + let paddr = seg.elf32_segment_paddr in + let base = seg.elf32_segment_base in + let offset = seg.elf32_segment_offset in + let size = seg.elf32_segment_size in + let memsz = seg.elf32_segment_memsz in + load_segment ~writer:writer bs paddr base offset size memsz) + segs; + opt_elf_class := ELF_Class_32 + ) let load_binary ?writer:(writer=write_sail_lib) addr name = let f = open_in_bin name in @@ -172,3 +197,5 @@ let load_binary ?writer:(writer=write_sail_lib) addr name = let elf_entry () = !opt_elf_entry (* Used by RISCV sail model test harness for exiting test *) let elf_tohost () = !opt_elf_tohost +(* Used to check last loaded elf class. *) +let elf_class () = !opt_elf_class diff --git a/src/initial_check.ml b/src/initial_check.ml index f728d92d..33844a72 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -595,7 +595,7 @@ let to_ast_typschm_opt ctx (P.TypSchm_opt_aux(aux,l)) : tannot_opt ctx_out = let to_ast_effects_opt (P.Effect_opt_aux(e,l)) : effect_opt = match e with - | P.Effect_opt_pure -> Effect_opt_aux(Effect_opt_pure,l) + | P.Effect_opt_none -> Effect_opt_aux(Effect_opt_none,l) | P.Effect_opt_effect(typ) -> Effect_opt_aux(Effect_opt_effect(to_ast_effects typ),l) let to_ast_funcl ctx (P.FCL_aux(fcl,l) : P.funcl) : (unit funcl) = @@ -824,15 +824,15 @@ let val_spec_ids (Defs defs) = IdSet.of_list (vs_ids defs) let quant_item_param = function - | QI_aux (QI_id kopt, _) when is_nat_kopt kopt -> [prepend_id "atom_" (id_of_kid (kopt_kid kopt))] + | QI_aux (QI_id kopt, _) when is_int_kopt kopt -> [prepend_id "atom_" (id_of_kid (kopt_kid kopt))] | QI_aux (QI_id kopt, _) when is_typ_kopt kopt -> [prepend_id "typ_" (id_of_kid (kopt_kid kopt))] | _ -> [] let quant_item_typ = function - | QI_aux (QI_id kopt, _) when is_nat_kopt kopt -> [atom_typ (nvar (kopt_kid kopt))] + | QI_aux (QI_id kopt, _) when is_int_kopt kopt -> [atom_typ (nvar (kopt_kid kopt))] | QI_aux (QI_id kopt, _) when is_typ_kopt kopt -> [mk_typ (Typ_var (kopt_kid kopt))] | _ -> [] let quant_item_arg = function - | QI_aux (QI_id kopt, _) when is_nat_kopt kopt -> [mk_typ_arg (A_nexp (nvar (kopt_kid kopt)))] + | QI_aux (QI_id kopt, _) when is_int_kopt kopt -> [mk_typ_arg (A_nexp (nvar (kopt_kid kopt)))] | QI_aux (QI_id kopt, _) when is_typ_kopt kopt -> [mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt))))] | _ -> [] let undefined_typschm id typq = diff --git a/src/isail.ml b/src/isail.ml index e513e0ee..9a300673 100644 --- a/src/isail.ml +++ b/src/isail.ml @@ -294,7 +294,7 @@ let load_session upto file = | Some upto_file when Filename.basename upto_file = file -> None | Some upto_file -> let (_, ast, env) = - load_files ~generate:false !Interactive.env [Filename.concat (Filename.dirname upto_file) file] + load_files ~check:true !Interactive.env [Filename.concat (Filename.dirname upto_file) file] in Interactive.ast := append_ast !Interactive.ast ast; Interactive.env := env; @@ -376,7 +376,7 @@ let handle_input' input = List.iter print_endline commands | ":poly" -> let is_kopt = match arg with - | "Int" -> is_nat_kopt + | "Int" -> is_int_kopt | "Type" -> is_typ_kopt | "Order" -> is_order_kopt | _ -> failwith "Invalid kind" @@ -392,7 +392,7 @@ let handle_input' input = | Arg.Bad message | Arg.Help message -> print_endline message end; | ":spec" -> - let ast, env = Specialize.specialize !Interactive.ast !Interactive.env in + let ast, env = Specialize.(specialize' 1 int_specialization !Interactive.ast !Interactive.env) in Interactive.ast := ast; Interactive.env := env; interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops @@ -402,7 +402,7 @@ let handle_input' input = let open PPrint in let open C_backend in let ast = Process_file.rewrite_ast_c !Interactive.env !Interactive.ast in - let ast, env = Specialize.specialize ast !Interactive.env in + let ast, env = Specialize.(specialize typ_ord_specialization ast !Interactive.env) in let ctx = initial_ctx env in interactive_bytecode := bytecode_ast ctx (List.map flatten_cdef) ast | ":ir" -> @@ -492,7 +492,7 @@ let handle_input' input = begin try load_into_session arg; - let (_, ast, env) = load_files !Interactive.env [arg] in + let (_, ast, env) = load_files ~check:true !Interactive.env [arg] in Interactive.ast := append_ast !Interactive.ast ast; interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops; Interactive.env := env; diff --git a/src/latex.ml b/src/latex.ml index a0660daa..1806da47 100644 --- a/src/latex.ml +++ b/src/latex.ml @@ -71,20 +71,23 @@ type latex_state = { mutable noindent : bool; mutable this : id option; mutable norefs : StringSet.t; - mutable generated_names : string Bindings.t + mutable generated_names : string Bindings.t; + mutable commands : StringSet.t } let reset_state state = state.noindent <- false; state.this <- None; state.norefs <- StringSet.empty; - state.generated_names <- Bindings.empty + state.generated_names <- Bindings.empty; + state.commands <- StringSet.empty let state = { noindent = false; this = None; norefs = StringSet.empty; - generated_names = Bindings.empty + generated_names = Bindings.empty; + commands = StringSet.empty } let rec unique_postfix n = @@ -285,24 +288,39 @@ let add_links str = in Str.global_substitute r subst str +let rec skip_lines in_chan = function + | n when n <= 0 -> () + | n -> ignore (input_line in_chan); skip_lines in_chan (n - 1) + +let rec read_lines in_chan = function + | n when n <= 0 -> [] + | n -> + let l = input_line in_chan in + let ls = read_lines in_chan (n - 1) in + l :: ls + let latex_loc no_loc l = - match l with - | Parse_ast.Range (_, _) | Parse_ast.Documented (_, Parse_ast.Range (_, _)) -> + match simp_loc l with + | Some (p1, p2) -> begin - let using_color = !Util.opt_colors in - Util.opt_colors := false; - let code = Util.split_on_char '\n' (Reporting.loc_to_string l) in - let doc = match code with - | _ :: _ :: code -> string (add_links (String.concat "\n" code)) - | _ -> empty - in - Util.opt_colors := using_color; - doc ^^ hardline + let open Lexing in + try + let in_chan = open_in p1.pos_fname in + try + skip_lines in_chan (p1.pos_lnum - 3); + let code = read_lines in_chan ((p2.pos_lnum - p1.pos_lnum) + 3) in + close_in in_chan; + let doc = match code with + | _ :: _ :: code -> string (add_links (String.concat "\n" code)) + | _ -> empty + in + doc ^^ hardline + with + | _ -> close_in_noerr in_chan; docstring l ^^ no_loc + with + | _ -> docstring l ^^ no_loc end - - | _ -> docstring l ^^ no_loc - -let commands = ref StringSet.empty + | None -> docstring l ^^ no_loc let doc_spec_simple (VS_aux (VS_val_spec (ts, id, ext, is_cast), _)) = Pretty_print_sail.doc_id id ^^ space @@ -322,10 +340,17 @@ let rec latex_command cat id no_loc ((l, _) as annot) = let doc = if cat = Val then no_loc else latex_loc no_loc l in output_string chan (Pretty_print_sail.to_string doc); close_out chan; + let command = sprintf "\\%s%s%s" !opt_prefix (category_name cat) (latex_id id) in + if StringSet.mem command state.commands then + (Util.warn ("Multiple instances of " ^ string_of_id id ^ " only generating latex for the first"); empty) + else + begin + state.commands <- StringSet.add command state.commands; - ksprintf string "\\newcommand{\\%s%s%s}{\\phantomsection%s\\saildoc%s{" !opt_prefix (category_name cat) (latex_id id) labelling (category_name_simple cat) - ^^ docstring l ^^ string "}{" - ^^ ksprintf string "\\lstinputlisting[language=sail]{%s}}}" (Filename.concat !opt_directory code_file) + ksprintf string "\\newcommand{%s}{\\phantomsection%s\\saildoc%s{" command labelling (category_name_simple cat) + ^^ docstring l ^^ string "}{" + ^^ ksprintf string "\\lstinputlisting[language=sail]{%s}}}" (Filename.concat !opt_directory code_file) + end let latex_label str id = string (Printf.sprintf "\\label{%s:%s}" str (Util.zencode_string (string_of_id id))) diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 3167ad6b..c67b4fcb 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -201,7 +201,7 @@ let rec is_value (E_aux (e,(l,annot))) = let is_pure (Effect_opt_aux (e,_)) = match e with - | Effect_opt_pure -> true + | Effect_opt_none -> true | Effect_opt_effect (Effect_aux (Effect_set [],_)) -> true | _ -> false @@ -2263,7 +2263,7 @@ let replace_with_the_value bound_nexps (E_aux (_,(l,_)) as exp) = prove __POS__ env (NC_aux (NC_equal (size,nexp), Parse_ast.Unknown)) in if is_nexp_constant size then size else - match solve env size with + match solve_unique env size with | Some n -> nconstant n | None -> match List.find is_equal bound_nexps with @@ -2930,7 +2930,7 @@ let refine_dependency env (E_aux (e,(l,annot)) as exp) pexps = | _ -> None let simplify_size_nexp env typ_env (Nexp_aux (ne,l) as nexp) = - match solve typ_env nexp with + match solve_unique typ_env nexp with | Some n -> nconstant n | None -> let is_equal kid = @@ -3691,7 +3691,7 @@ let rec rewrite_app env typ (id,args) = let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in match size with | Nexp_aux (Nexp_constant _,_) -> E_cast (typ,exp) - | _ -> match solve env size with + | _ -> match solve_unique env size with | Some c -> E_cast (vector_typ (nconstant c) order bittyp, exp) | None -> e in @@ -3711,7 +3711,7 @@ let rec rewrite_app env typ (id,args) = let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in let (size1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in let midsize = nminus size size1 in begin - match solve env midsize with + match solve_unique env midsize with | Some c -> let midtyp = vector_typ (nconstant c) order bittyp in E_app (append, @@ -3739,7 +3739,7 @@ let rec rewrite_app env typ (id,args) = let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in let (size1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in let midsize = nminus size size1 in begin - match solve env midsize with + match solve_unique env midsize with | Some c -> let midtyp = vector_typ (nconstant c) order bittyp in E_app (append, @@ -3797,7 +3797,7 @@ let rec rewrite_app env typ (id,args) = let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in let (size1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in let midsize = nminus size size1 in begin - match solve env midsize with + match solve_unique env midsize with | Some c -> let midtyp = vector_typ (nconstant c) order bittyp in try_cast_to_typ @@ -4000,7 +4000,7 @@ struct let simplify_size_nexp env quant_kids nexp = let rec aux (Nexp_aux (ne,l) as nexp) = - match solve env nexp with + match solve_unique env nexp with | Some n -> Some (nconstant n) | None -> let is_equal kid = @@ -4191,7 +4191,7 @@ let fill_in_type env typ = | K_order | K_bool -> subst | K_int -> - (match solve env (nvar kid) with + (match solve_unique env (nvar kid) with | None -> subst | Some n -> KBindings.add kid (nconstant n) subst)) tyvars KBindings.empty in subst_src_typ subst typ @@ -4300,7 +4300,7 @@ let add_bitvector_casts (Defs defs) = let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann)) = let fcl_env = env_of_annot fcl_ann in let (tq,typ) = Env.get_val_spec_orig id fcl_env in - let quant_kids = List.map kopt_kid (List.filter is_nat_kopt (quant_kopts tq)) in + let quant_kids = List.map kopt_kid (List.filter is_int_kopt (quant_kopts tq)) in let ret_typ = match typ with | Typ_aux (Typ_fn (_,ret,_),_) -> ret diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml index ba21dd0a..f42a279b 100644 --- a/src/ocaml_backend.ml +++ b/src/ocaml_backend.ml @@ -716,13 +716,12 @@ let ocaml_pp_generators ctx defs orig_types required = and add_req_from_typarg required (A_aux (arg,_)) = match arg with | A_typ typ -> add_req_from_typ required typ - | A_nexp _ - | A_order _ - -> required + | A_nexp _ | A_order _ | A_bool _ -> required and add_req_from_td required (TD_aux (td,(l,_))) = match td with | TD_abbrev (_, _, A_aux (A_typ typ, _)) -> add_req_from_typ required typ + | TD_abbrev _ -> required | TD_record (_, _, fields, _) -> List.fold_left (fun req (typ,_) -> add_req_from_typ req typ) required fields | TD_variant (_, _, variants, _) -> @@ -751,13 +750,13 @@ let ocaml_pp_generators ctx defs orig_types required = let gen_tyvars = List.map (fun k -> kopt_kid k |> zencode_kid) (List.filter is_typ_kopt tquants) in let print_quant kindedid = - if is_nat_kopt kindedid then string "int" else + if is_int_kopt kindedid then string "int" else if is_order_kopt kindedid then string "bool" else parens (separate space [string "generators"; string "->"; zencode_kid (kopt_kid kindedid)]) in let name = "gen_" ^ type_name id in let make_tyarg kindedid = - if is_nat_kopt kindedid + if is_int_kopt kindedid then mk_typ_arg (A_nexp (nvar (kopt_kid kindedid))) else if is_order_kopt kindedid then mk_typ_arg (A_order (mk_ord (Ord_var (kopt_kid kindedid)))) @@ -793,7 +792,7 @@ let ocaml_pp_generators ctx defs orig_types required = | _ -> space ^^ separate space args_pp in string ("g.gen_" ^ typ_str) ^^ args_pp - and typearg (A_aux (arg,_)) = + and typearg (A_aux (arg,l)) = match arg with | A_nexp (Nexp_aux (nexp,l) as full_nexp) -> (match nexp with @@ -807,6 +806,7 @@ let ocaml_pp_generators ctx defs orig_types required = | Ord_inc -> string "true" | Ord_dec -> string "false") | A_typ typ -> parens (string "fun g -> " ^^ gen_type typ) + | A_bool nc -> raise (Reporting.err_todo l ("Unsupported constraint for generators: " ^ string_of_n_constraint nc)) in let make_subgen (Typ_aux (typ,l) as full_typ) = let typ_str, args_pp = diff --git a/src/parse_ast.ml b/src/parse_ast.ml index 5f0d7487..b86d4dd5 100644 --- a/src/parse_ast.ml +++ b/src/parse_ast.ml @@ -324,7 +324,7 @@ typschm_opt = type effect_opt_aux = (* Optional effect annotation for functions *) - Effect_opt_pure (* sugar for empty effect set *) + Effect_opt_none (* sugar for empty effect set *) | Effect_opt_effect of atyp diff --git a/src/parser.mly b/src/parser.mly index 2cd0dbe1..9f7e2e0c 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -133,7 +133,7 @@ let mk_recn = (Rec_aux((Rec_nonrec), Unknown)) let mk_typqn = (TypQ_aux(TypQ_no_forall,Unknown)) let mk_tannotn = Typ_annot_opt_aux(Typ_annot_opt_none,Unknown) let mk_tannot typq typ n m = Typ_annot_opt_aux(Typ_annot_opt_some (typq, typ), loc n m) -let mk_eannotn = Effect_opt_aux(Effect_opt_pure,Unknown) +let mk_eannotn = Effect_opt_aux(Effect_opt_none,Unknown) let mk_typq kopts nc n m = TypQ_aux (TypQ_tq (List.map qi_id_of_kopt kopts @ nc), loc n m) diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index 46d07cc3..adb00a77 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -370,6 +370,34 @@ let doc_nc_fn id = | "not" -> string "negb" | s -> string s +type ex_atom_bool = ExBool_simple | ExBool_val of bool | ExBool_complex + +let non_trivial_ex_atom_bool l kopts nc atom_nc = + let vars = KOptSet.union (kopts_of_constraint nc) (kopts_of_constraint atom_nc) in + let exists = KOptSet.of_list kopts in + if KOptSet.subset vars exists then + let kenv = List.fold_left (fun kenv kopt -> KBindings.add (kopt_kid kopt) (unaux_kind (kopt_kind kopt)) kenv) KBindings.empty kopts in + match Constraint.call_smt l kenv (nc_and nc atom_nc), + Constraint.call_smt l kenv (nc_and nc (nc_not atom_nc)) with + | Sat, Sat -> ExBool_simple + | Sat, Unsat -> ExBool_val true + | Unsat, Sat -> ExBool_val false + | _ -> ExBool_complex + else ExBool_complex + +type ex_kind = ExNone | ExBool | ExGeneral + +let classify_ex_type (Typ_aux (t,l) as t0) = + match t with + | Typ_exist (kopts,nc,(Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),_) as t1)) -> begin + match non_trivial_ex_atom_bool l kopts nc atom_nc with + | ExBool_simple -> ExNone, t1 + | ExBool_val _ -> ExBool, t1 + | ExBool_complex -> ExGeneral, t1 + end + | Typ_exist (_,_,t1) -> ExGeneral,t1 + | _ -> ExNone,t0 + (* When making changes here, check whether they affect coq_nvars_of_typ *) let rec doc_typ_fns ctx = (* following the structure of parser for precedence *) @@ -476,13 +504,18 @@ let rec doc_typ_fns ctx = [doc_var ctx var; colon; tpp; ampersand; doc_arithfact ctx ~exists:(List.map kopt_kid kopts) ?extra:length_constraint_pp nc]) - | Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),_) -> - let var = mk_kid "_bool" in (* TODO collision avoid *) - let nc = nice_and (nice_iff (nc_var var) atom_nc) nc in - braces (separate space - [doc_var ctx var; colon; string "bool"; - ampersand; - doc_arithfact ctx ~exists:(List.map kopt_kid kopts) nc]) + | Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),_) -> begin + match non_trivial_ex_atom_bool l kopts nc atom_nc with + | ExBool_simple -> string "bool" + | ExBool_val t -> string "Bool(" ^^ if t then string "True)" else string "False)" + | ExBool_complex -> + let var = mk_kid "_bool" in (* TODO collision avoid *) + let nc = nice_and (nice_iff (nc_var var) atom_nc) nc in + braces (separate space + [doc_var ctx var; colon; string "bool"; + ampersand; + doc_arithfact ctx ~exists:(List.map kopt_kid kopts) nc]) + end | _ -> raise (Reporting.err_todo l ("Non-atom existential type not yet supported in Coq: " ^ @@ -1034,9 +1067,14 @@ let doc_exp, doc_let = let typ = expand_range_type typ in match destruct_exist_plain typ with | None -> epp + | Some (kopts,nc,Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),l)) -> begin + match non_trivial_ex_atom_bool l kopts nc atom_nc with + | ExBool_simple -> epp + | ExBool_val t -> wrap_parens (string "build_Bool" ^/^ epp) + | ExBool_complex -> wrap_parens (string "build_ex" ^/^ epp) + end | Some _ -> - let epp = string "build_ex" ^/^ epp in - if aexp_needed then parens epp else epp + wrap_parens (string "build_ex" ^/^ epp) in let rec construct_dep_pairs env = let rec aux want_parens (E_aux (e,_) as exp) (Typ_aux (t,_) as typ) = @@ -1049,8 +1087,14 @@ let doc_exp, doc_let = let typ' = expand_range_type (Env.expand_synonyms (env_of exp) typ) in let build_ex, out_typ = match destruct_exist_plain typ' with - | Some (_,_,t) -> true, t - | None -> false, typ' + | Some (kopts,nc,(Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),l) as t)) -> begin + match non_trivial_ex_atom_bool l kopts nc atom_nc with + | ExBool_simple -> None, t + | ExBool_val _ -> Some "build_Bool", t + | ExBool_complex -> Some "build_ex", t + end + | Some (_,_,t) -> Some "build_ex", t + | None -> None, typ' in let in_typ = expand_range_type (Env.expand_synonyms (env_of exp) (typ_of exp)) in let in_typ = match destruct_exist_plain in_typ with Some (_,_,t) -> t | None -> in_typ in @@ -1063,16 +1107,17 @@ let doc_exp, doc_let = not (similar_nexps ctxt (env_of exp) n1 n2) | _ -> false in - let exp_pp = expV (want_parens || autocast || build_ex) exp in + let exp_pp = expV (want_parens || autocast || Util.is_some build_ex) exp in let exp_pp = if autocast then let exp_pp = string "autocast" ^^ space ^^ exp_pp in - if want_parens || build_ex then parens exp_pp else exp_pp + if want_parens || Util.is_some build_ex then parens exp_pp else exp_pp else exp_pp - in if build_ex then - let exp_pp = string "build_ex" ^^ space ^^ exp_pp in + in match build_ex with + | Some s -> + let exp_pp = string s ^^ space ^^ exp_pp in if want_parens then parens exp_pp else exp_pp - else exp_pp + | None -> exp_pp in aux in let liftR doc = @@ -1162,7 +1207,7 @@ let doc_exp, doc_let = wrap_parens (hang 2 (flow (break 1) (call :: List.map expY args))) (* temporary hack to make the loop body a function of the temporary variables *) | Id_aux (Id "None", _) as none -> doc_id_ctor none - | Id_aux (Id "foreach", _) -> + | Id_aux (Id "foreach#", _) -> begin match args with | [from_exp; to_exp; step_exp; ord_exp; vartuple; body] -> @@ -1209,7 +1254,8 @@ let doc_exp, doc_let = | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") end - | Id_aux (Id (("while" | "until") as combinator), _) -> + | Id_aux (Id (("while#" | "until#") as combinator), _) -> + let combinator = String.sub combinator 0 (String.length combinator - 1) in begin match args with | [cond; varstuple; body] -> @@ -1499,21 +1545,9 @@ let doc_exp, doc_let = debug ctxt (lazy (" on expr of type " ^ string_of_typ inner_typ)); debug ctxt (lazy (" where type expected is " ^ string_of_typ outer_typ)) in - let outer_ex,outer_typ' = - match outer_typ with - | Typ_aux (Typ_exist (_,_,t1),_) -> true,t1 - | t1 -> false,t1 - in - let cast_ex,cast_typ' = - match cast_typ with - | Typ_aux (Typ_exist (_,_,t1),_) -> true,t1 - | t1 -> false,t1 - in - let inner_ex,inner_typ' = - match inner_typ with - | Typ_aux (Typ_exist (_,_,t1),_) -> true,t1 - | t1 -> false,t1 - in + let outer_ex,outer_typ' = classify_ex_type outer_typ in + let cast_ex,cast_typ' = classify_ex_type cast_typ in + let inner_ex,inner_typ' = classify_ex_type inner_typ in let autocast = (* Avoid using helper functions which simplify the nexps *) is_bitvector_typ outer_typ' && is_bitvector_typ cast_typ' && @@ -1526,30 +1560,34 @@ let doc_exp, doc_let = let effects = effectful (effect_of e) in let epp = if effects then - if inner_ex then - if cast_ex - (* If the types are the same use the cast as a hint to Coq, - otherwise derive the new type from the old one. *) - then if alpha_equivalent env inner_typ cast_typ - then epp - else string "derive_m" ^^ space ^^ epp - else string "projT1_m" ^^ space ^^ epp - else if cast_ex - then string "build_ex_m" ^^ space ^^ epp - else epp - else if cast_ex - then string "build_ex" ^^ space ^^ epp - else epp + match inner_ex, cast_ex with + | ExGeneral, ExGeneral -> + (* If the types are the same use the cast as a hint to Coq, + otherwise derive the new type from the old one. *) + if alpha_equivalent env inner_typ cast_typ + then epp + else string "derive_m" ^^ space ^^ epp + | ExGeneral, ExNone -> + string "projT1_m" ^^ space ^^ epp + | ExNone, ExGeneral -> + string "build_ex_m" ^^ space ^^ epp + | ExNone, ExNone -> epp + else match cast_ex with + | ExGeneral -> string "build_ex" ^^ space ^^ epp + | ExBool -> string "build_Bool" ^^ space ^^ epp + | ExNone -> epp in let epp = epp ^/^ doc_tannot ctxt (env_of e) effects typ in let epp = if effects then - if cast_ex && not outer_ex - then string "projT1_m" ^^ space ^^ parens epp - else epp - else if cast_ex - then string "projT1" ^^ space ^^ parens epp - else epp + match cast_ex, outer_ex with + | ExGeneral, ExNone -> string "projT1_m" ^^ space ^^ parens epp + | ExBool, ExNone -> string "projBool_m" ^^ space ^^ parens epp + | _ -> epp + else match cast_ex with + | ExGeneral -> string "projT1" ^^ space ^^ parens epp + | ExBool -> string "projBool" ^^ space ^^ parens epp + | ExNone -> epp in let epp = if autocast then @@ -1718,10 +1756,10 @@ let doc_exp, doc_let = | P_aux (P_var (P_aux (P_typ (typ, P_aux (P_id id,_)),_),_),_) when not (is_enum (env_of e1) id) -> let full_typ = (expand_range_type typ) in - let binder = match destruct_exist_plain (Env.expand_synonyms (env_of e1) full_typ) with - | Some _ -> + let binder = match classify_ex_type (Env.expand_synonyms (env_of e1) full_typ) with + | ExGeneral, _ -> squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt typ]) - | _ -> + | (ExBool | ExNone), _ -> parens (separate space [doc_id id; colon; doc_typ ctxt typ]) in separate space [string ">>= fun"; binder; bigarrow] | _ -> @@ -1775,6 +1813,10 @@ let doc_exp, doc_let = | E_aux (E_if (c', t', e'), _) | E_aux (E_cast (_, E_aux (E_if (c', t', e'), _)), _) -> if_exp ctxt true c' t' e' + (* Special case to prevent current arm decoder becoming a staircase *) + (* TODO: replace with smarter pretty printing *) + | E_aux (E_internal_plet (pat,exp1,E_aux (E_cast (typ, (E_aux (E_if (_, _, _), _) as exp2)),_)),ann) when Typ.compare typ unit_typ == 0 -> + string "else" ^/^ top_exp ctxt false (E_aux (E_internal_plet (pat,exp1,exp2),ann)) | _ -> prefix 2 1 (string "else") (top_exp ctxt false e) in (prefix 2 1 @@ -2200,9 +2242,10 @@ let doc_funcl mutrec rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) = | _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type") in let build_ex, ret_typ = replace_atom_return_type ret_typ in - let build_ex = match destruct_exist_plain (Env.expand_synonyms env (expand_range_type ret_typ)) with - | Some _ -> Some "build_ex" - | _ -> build_ex + let build_ex = match classify_ex_type (Env.expand_synonyms env (expand_range_type ret_typ)) with + | ExGeneral, _ -> Some "build_ex" + | ExBool, _ -> Some "build_Bool" + | ExNone, _ -> build_ex in let ids_to_avoid = all_ids pexp in let bound_kids = tyvars_of_typquant tq in @@ -2469,7 +2512,7 @@ let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) = let used = if is_number ret_ty then used else KidSet.union used (tyvars_of_typ ret_ty) in let tqs = match tqs with | TypQ_aux (TypQ_tq qs,l) -> TypQ_aux (TypQ_tq (List.filter (function - | QI_aux (QI_id kopt,_) when is_nat_kopt kopt -> + | QI_aux (QI_id kopt,_) when is_int_kopt kopt -> let kid = kopt_kid kopt in KidSet.mem kid used && not (KidSet.mem kid args) | _ -> true) qs),l) diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index aa03528f..6adcec46 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -316,7 +316,7 @@ let doc_typ_lem, doc_atomic_typ_lem = * if we add a new Typ constructor *) let tpp = typ ty in if atyp_needed then parens tpp else tpp - | Typ_exist (kopts,_,ty) when List.for_all is_nat_kopt kopts -> begin + | Typ_exist (kopts,_,ty) when List.for_all is_int_kopt kopts -> begin let kids = List.map kopt_kid kopts in let tpp = typ ty in let visible_vars = lem_tyvars_of_typ ty in @@ -359,7 +359,7 @@ let replace_typ_size ctxt env (Typ_aux (t,a)) = let mk_typ nexp = Some (Typ_aux (Typ_app (id, [A_aux (A_nexp nexp,Parse_ast.Unknown);ord;typ']),a)) in - match Type_check.solve env size with + match Type_check.solve_unique env size with | Some n -> mk_typ (nconstant n) | None -> let is_equal nexp = @@ -668,7 +668,7 @@ let doc_exp_lem, doc_let_lem = let call = doc_id_lem (append_id f "M") in wrap_parens (hang 2 (flow (break 1) (call :: List.map expY args))) (* temporary hack to make the loop body a function of the temporary variables *) - | Id_aux (Id "foreach", _) -> + | Id_aux (Id "foreach#", _) -> begin match args with | [exp1; exp2; exp3; ord_exp; vartuple; body] -> @@ -713,7 +713,8 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") end - | Id_aux (Id (("while" | "until") as combinator), _) -> + | Id_aux (Id (("while#" | "until#") as combinator), _) -> + let combinator = String.sub combinator 0 (String.length combinator - 1) in begin match args with | [cond; varstuple; body] -> diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 67f291bd..56026c81 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -66,7 +66,7 @@ let doc_id (Id_aux (id_aux, _)) = let doc_kid kid = string (Ast_util.string_of_kid kid) let doc_kopt = function - | kopt when is_nat_kopt kopt -> doc_kid (kopt_kid kopt) + | kopt when is_int_kopt kopt -> doc_kid (kopt_kid kopt) | kopt when is_typ_kopt kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Type"]) | kopt when is_order_kopt kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Order"]) | kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Bool"]) @@ -213,7 +213,7 @@ and doc_arg_typs = function let doc_quants quants = let doc_qi_kopt (QI_aux (qi_aux, _)) = match qi_aux with - | QI_id kopt when is_nat_kopt kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Int"])] + | QI_id kopt when is_int_kopt kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Int"])] | QI_id kopt when is_typ_kopt kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Type"])] | QI_id kopt when is_bool_kopt kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Bool"])] | QI_id kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Order"])] @@ -234,7 +234,7 @@ let doc_quants quants = let doc_param_quants quants = let doc_qi_kopt (QI_aux (qi_aux, _)) = match qi_aux with - | QI_id kopt when is_nat_kopt kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Int"] + | QI_id kopt when is_int_kopt kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Int"] | QI_id kopt when is_typ_kopt kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Type"] | QI_id kopt when is_bool_kopt kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Bool"] | QI_id kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Order"] @@ -619,8 +619,7 @@ let doc_typdef (TD_aux(td,_)) = match td with | TD_variant (id, TypQ_aux (TypQ_tq qs, _), unions, _) -> separate space [string "union"; doc_id id; doc_param_quants qs; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_union unions) rbrace] - | _ -> string "TYPEDEF" - + | TD_bitfield _ -> string "BITFIELD" (* should be rewritten *) let doc_spec ?comment:(comment=false) (VS_aux (v, annot)) = let doc_extern ext = @@ -654,6 +653,12 @@ let rec doc_scattered (SD_aux (sd_aux, _)) = string "scattered" ^^ space ^^ string "union" ^^ space ^^ doc_id id | SD_variant (id, TypQ_aux (TypQ_tq quants, _)) -> string "scattered" ^^ space ^^ string "union" ^^ space ^^ doc_id id ^^ doc_param_quants quants + | SD_mapcl (id, mapcl) -> + separate space [string "mapping clause"; doc_id id; equals; doc_mapcl mapcl] + | SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_none, _)) -> + separate space [string "scattered mapping"; doc_id id] + | SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_some (_, typ), _)) -> + separate space [string "scattered mapping"; doc_id id; string ":"; doc_typ typ] | SD_unioncl (id, tu) -> separate space [string "union clause"; doc_id id; equals; doc_union tu] diff --git a/src/process_file.ml b/src/process_file.ml index d2a43b4a..3c2d4a22 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -159,7 +159,7 @@ let rec preprocess opts = function symbols := StringSet.add symbol !symbols; preprocess opts defs - | Parse_ast.DEF_pragma ("option", command, l) :: defs -> + | (Parse_ast.DEF_pragma ("option", command, l) as opt_pragma) :: defs -> begin try let args = Str.split (Str.regexp " +") command in @@ -167,7 +167,7 @@ let rec preprocess opts = function with | Arg.Bad message | Arg.Help message -> raise (Reporting.err_general l message) end; - preprocess opts defs + opt_pragma :: preprocess opts defs | Parse_ast.DEF_pragma ("ifndef", symbol, l) :: defs -> let then_defs, else_defs, defs = cond_pragma l defs in diff --git a/src/profile.ml b/src/profile.ml index 1a8bd30b..f64bdfe0 100644 --- a/src/profile.ml +++ b/src/profile.ml @@ -51,13 +51,13 @@ let opt_profile = ref false type profile = { - z3_calls : int; - z3_time : float + smt_calls : int; + smt_time : float } let new_profile = { - z3_calls = 0; - z3_time = 0.0 + smt_calls = 0; + smt_time = 0.0 } let profile_stack = ref [] @@ -68,12 +68,12 @@ let update_profile f = | (p :: ps) -> profile_stack := f p :: ps -let start_z3 () = - update_profile (fun p -> { p with z3_calls = p.z3_calls + 1 }); +let start_smt () = + update_profile (fun p -> { p with smt_calls = p.smt_calls + 1 }); Sys.time () -let finish_z3 t = - update_profile (fun p -> { p with z3_time = p.z3_time +. (Sys.time () -. t) }) +let finish_smt t = + update_profile (fun p -> { p with smt_time = p.smt_time +. (Sys.time () -. t) }) let start () = profile_stack := new_profile :: !profile_stack; @@ -84,7 +84,7 @@ let finish msg t = begin match !profile_stack with | p :: ps -> prerr_endline (Printf.sprintf "%s %s: %fs" Util.("Profiled" |> magenta |> clear) msg (Sys.time () -. t)); - prerr_endline (Printf.sprintf " Z3 calls: %d, Z3 time: %fs" p.z3_calls p.z3_time); + prerr_endline (Printf.sprintf " SMT calls: %d, SMT time: %fs" p.smt_calls p.smt_time); profile_stack := ps | [] -> () end diff --git a/src/rewrites.ml b/src/rewrites.ml index 44d99537..d4601fa6 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1549,10 +1549,11 @@ let rewrite_exp_remove_bitvector_pat rewriters (E_aux (exp,(l,annot)) as full_ex | None -> Pat_aux (Pat_exp (pat', body'), annot')) | Pat_aux (Pat_when (pat,guard,body),annot') -> let (pat',(guard',decls,_)) = remove_bitvector_pat pat in + let guard'' = rewrite_rec guard in let body' = decls (rewrite_rec body) in (match guard' with - | Some guard' -> Pat_aux (Pat_when (pat', bitwise_and_exp (decls guard) guard', body'), annot') - | None -> Pat_aux (Pat_when (pat', (decls guard), body'), annot')) in + | Some guard' -> Pat_aux (Pat_when (pat', bitwise_and_exp (decls guard'') guard', body'), annot') + | None -> Pat_aux (Pat_when (pat', (decls guard''), body'), annot')) in rewrap (E_case (e, List.map rewrite_pexp ps)) | E_let (LB_aux (LB_val (pat,v),annot'),body) -> let (pat,(_,decls,_)) = remove_bitvector_pat pat in @@ -3105,7 +3106,7 @@ let construct_toplevel_string_append_func env f_id pat = let new_val_spec, env = Type_check.check_val_spec env new_val_spec in let non_rec = (Rec_aux (Rec_nonrec, Parse_ast.Unknown)) in let no_tannot = (Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown)) in - let effect_pure = (Effect_opt_aux (Effect_opt_pure, Parse_ast.Unknown)) in + let effect_none = (Effect_opt_aux (Effect_opt_none, Parse_ast.Unknown)) in let s_id = fresh_stringappend_id () in let arg_pat = mk_pat (P_id s_id) in (* We can ignore guards here because we've already removed them *) @@ -3210,7 +3211,7 @@ let construct_toplevel_string_append_func env f_id pat = in let wildcard = mk_pexp (Pat_exp (mk_pat P_wild, mk_exp (E_app (mk_id "None", [mk_lit_exp L_unit])))) in let new_match = mk_exp (E_case (mk_exp (E_id s_id), [strip_pexp new_pexp; wildcard])) in - let new_fun_def = FD_aux (FD_function (non_rec, no_tannot, effect_pure, [mk_funcl f_id arg_pat new_match]), (unk,())) in + let new_fun_def = FD_aux (FD_function (non_rec, no_tannot, effect_none, [mk_funcl f_id arg_pat new_match]), (unk,())) in let new_fun_def, env = Type_check.check_fundef env new_fun_def in List.flatten [new_val_spec; new_fun_def] @@ -3525,7 +3526,11 @@ let rewrite_defs_mapping_patterns env = let mapping_in_typ = typ_of_annot p_annot in let x = Env.get_val_spec mapping_id env in - let (_, Typ_aux(Typ_bidir(typ1, typ2), _)) = x in + + let typ1, typ2 = match x with + | (_, Typ_aux(Typ_bidir(typ1, typ2), _)) -> typ1, typ2 + | (_, typ) -> raise (Reporting.err_unreachable (fst p_annot) __POS__ ("Must be bi-directional mapping: " ^ string_of_typ typ)) + in let mapping_direction = if mapping_in_typ = typ1 then @@ -3822,7 +3827,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let v = fix_eff_exp (annot_exp (E_let (lb_lower, fix_eff_exp (annot_exp (E_let (lb_upper, - fix_eff_exp (annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; guarded_body])) + fix_eff_exp (annot_exp (E_app (mk_id "foreach#", [exp1; exp2; exp3; ord_exp; tuple_exp vars; guarded_body])) el env (typ_of exp4)))) el env (typ_of exp4)))) el env (typ_of exp4)) in @@ -3838,8 +3843,8 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let body = rewrite_var_updates (add_vars overwrite body vars) in let (E_aux (_,(_,bannot))) = body in let fname = match loop with - | While -> "while" - | Until -> "until" in + | While -> "while#" + | Until -> "until#" in let funcl = Id_aux (Id fname,gen_loc el) in let v = E_aux (E_app (funcl,[cond;tuple_exp vars;body]), (gen_loc el, bannot)) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) @@ -4140,9 +4145,9 @@ let rewrite_defs_remove_superfluous_returns env = let rewrite_defs_remove_e_assign env (Defs defs) = let (Defs loop_specs) = fst (Type_error.check Env.empty (Defs (List.map gen_vs - [("foreach", "forall ('vars : Type). (int, int, int, bool, 'vars, 'vars) -> 'vars"); - ("while", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); - ("until", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars")]))) in + [("foreach#", "forall ('vars : Type). (int, int, int, bool, 'vars, 'vars) -> 'vars"); + ("while#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); + ("until#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars")]))) in let rewrite_exp _ e = replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in rewrite_defs_base @@ -4334,7 +4339,7 @@ let rewrite_defs_realise_mappings _ (Defs defs) = let backwards_matches_id = mk_id (string_of_id id ^ "_backwards_matches") in let non_rec = (Rec_aux (Rec_nonrec, Parse_ast.Unknown)) in - let effect_pure = (Effect_opt_aux (Effect_opt_pure, Parse_ast.Unknown)) in + let effect_none = (Effect_opt_aux (Effect_opt_none, Parse_ast.Unknown)) in (* We need to make sure we get the environment for the last mapping clause *) let env = match List.rev mapcls with | MCL_aux (_, mapcl_annot) :: _ -> env_of_annot mapcl_annot @@ -4368,10 +4373,10 @@ let rewrite_defs_realise_mappings _ (Defs defs) = let forwards_matches_match = mk_exp (E_case (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realise_bool_mapcl true forwards_matches_id) mapcls) |> List.flatten) @ [wildcard])) in let backwards_matches_match = mk_exp (E_case (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realise_bool_mapcl false backwards_matches_id) mapcls) |> List.flatten) @ [wildcard])) in - let forwards_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_pure, [mk_funcl forwards_id arg_pat forwards_match]), (l, ()))) in - let backwards_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_pure, [mk_funcl backwards_id arg_pat backwards_match]), (l, ()))) in - let forwards_matches_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_pure, [mk_funcl forwards_matches_id arg_pat forwards_matches_match]), (l, ()))) in - let backwards_matches_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_pure, [mk_funcl backwards_matches_id arg_pat backwards_matches_match]), (l, ()))) in + let forwards_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_none, [mk_funcl forwards_id arg_pat forwards_match]), (l, ()))) in + let backwards_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_none, [mk_funcl backwards_id arg_pat backwards_match]), (l, ()))) in + let forwards_matches_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_none, [mk_funcl forwards_matches_id arg_pat forwards_matches_match]), (l, ()))) in + let backwards_matches_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_none, [mk_funcl backwards_matches_id arg_pat backwards_matches_match]), (l, ()))) in typ_debug (lazy (Printf.sprintf "forwards for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef forwards_fun |> Pretty_print_sail.to_string))); typ_debug (lazy (Printf.sprintf "backwards for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef backwards_fun |> Pretty_print_sail.to_string))); @@ -4390,7 +4395,7 @@ let rewrite_defs_realise_mappings _ (Defs defs) = let forwards_prefix_spec = VS_aux (VS_val_spec (mk_typschm typq forwards_prefix_typ, prefix_id, [], false), (Parse_ast.Unknown,())) in let forwards_prefix_spec, env = Type_check.check_val_spec env forwards_prefix_spec in let forwards_prefix_match = mk_exp (E_case (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realise_prefix_mapcl true prefix_id) mapcls) |> List.flatten) @ [prefix_wildcard])) in - let forwards_prefix_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_pure, [mk_funcl prefix_id arg_pat forwards_prefix_match]), (l, ()))) in + let forwards_prefix_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_none, [mk_funcl prefix_id arg_pat forwards_prefix_match]), (l, ()))) in typ_debug (lazy (Printf.sprintf "forwards prefix matches for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef forwards_prefix_fun |> Pretty_print_sail.to_string))); let forwards_prefix_fun, _ = Type_check.check_fundef env forwards_prefix_fun in forwards_prefix_spec @ forwards_prefix_fun @@ -4400,7 +4405,7 @@ let rewrite_defs_realise_mappings _ (Defs defs) = let backwards_prefix_spec = VS_aux (VS_val_spec (mk_typschm typq backwards_prefix_typ, prefix_id, [], false), (Parse_ast.Unknown,())) in let backwards_prefix_spec, env = Type_check.check_val_spec env backwards_prefix_spec in let backwards_prefix_match = mk_exp (E_case (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realise_prefix_mapcl false prefix_id) mapcls) |> List.flatten) @ [prefix_wildcard])) in - let backwards_prefix_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_pure, [mk_funcl prefix_id arg_pat backwards_prefix_match]), (l, ()))) in + let backwards_prefix_fun = (FD_aux (FD_function (non_rec, no_tannot, effect_none, [mk_funcl prefix_id arg_pat backwards_prefix_match]), (l, ()))) in typ_debug (lazy (Printf.sprintf "backwards prefix matches for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef backwards_prefix_fun |> Pretty_print_sail.to_string))); let backwards_prefix_fun, _ = Type_check.check_fundef env backwards_prefix_fun in backwards_prefix_spec @ backwards_prefix_fun @@ -4642,8 +4647,11 @@ let rec remove_clause_from_pattern ctx (P_aux (rm_pat,ann)) res_pat = rp' @ List.map (function [rp1;rp2] -> RP_cons (rp1,rp2) | _ -> assert false) res_pats end | P_record _ -> - raise (Reporting.err_unreachable (fst ann) __POS__ - "Record pattern not supported") + raise (Reporting.err_unreachable (fst ann) __POS__ "Record pattern not supported") + | P_or _ -> + raise (Reporting.err_unreachable (fst ann) __POS__ "Or pattern not supported") + | P_not _ -> + raise (Reporting.err_unreachable (fst ann) __POS__ "Negated pattern not supported") | P_vector _ | P_vector_concat _ | P_string_append _ -> @@ -4929,7 +4937,7 @@ let rewrite_explicit_measure env (Defs defs) = match Bindings.find id measures with | (measure_pat, measure_exp) -> let e = match e with - | Effect_opt_aux (Effect_opt_pure, _) -> + | Effect_opt_aux (Effect_opt_none, _) -> Effect_opt_aux (Effect_opt_effect (mk_effect [BE_escape]), loc) | Effect_opt_aux (Effect_opt_effect eff,_) -> Effect_opt_aux (Effect_opt_effect (add_escape eff), loc) @@ -5164,10 +5172,21 @@ let rewrite_defs_ocaml = [ let rewrite_defs_c = [ ("no_effect_check", (fun _ defs -> opt_no_effects := true; defs)); + + (* Remove bidirectional mappings *) ("realise_mappings", rewrite_defs_realise_mappings); ("toplevel_string_append", rewrite_defs_toplevel_string_append); ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_patterns); + + (* Monomorphisation *) + ("mono_rewrites", if_mono mono_rewrites); + ("recheck_defs", if_mono recheck_defs); + ("rewrite_toplevel_nexps", if_mono rewrite_toplevel_nexps); + ("monomorphise", if_mono monomorphise); + ("rewrite_atoms_to_singletons", if_mono (fun _ -> Monomorphise.rewrite_atoms_to_singletons)); + ("recheck_defs", if_mono recheck_defs); + ("rewrite_undefined", rewrite_undefined_if_gen false); ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list); ("remove_not_pats", rewrite_defs_not_pats); diff --git a/src/sail.ml b/src/sail.ml index f481eb7b..edf1bda4 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -65,7 +65,7 @@ let opt_print_c = ref false let opt_print_latex = ref false let opt_print_coq = ref false let opt_print_cgen = ref false -let opt_memo_z3 = ref true +let opt_memo_z3 = ref false let opt_sanity = ref false let opt_includes_c = ref ([]:string list) let opt_libs_lem = ref ([]:string list) @@ -117,9 +117,15 @@ let options = Arg.align ([ ( "-ocaml_generators", Arg.String (fun s -> opt_ocaml_generators := s::!opt_ocaml_generators), "<types> produce random generators for the given types"); + ( "-smt_solver", + Arg.String (fun s -> Constraint.set_solver (String.trim s)), + "<solver> choose SMT solver. Supported solvers are z3 (default), alt-ergo, cvc4, mathsat, vampire and yices."); + ( "-smt_linearize", + Arg.Set Type_check.opt_smt_linearize, + "(experimental) force linearization for constraints involving exponentials"); ( "-latex", - Arg.Tuple [Arg.Set opt_print_latex; Arg.Clear Type_check.opt_expand_valspec ], - " pretty print the input to latex"); + Arg.Tuple [Arg.Set opt_print_latex; Arg.Clear Type_check.opt_expand_valspec], + " pretty print the input to LaTeX"); ( "-latex_prefix", Arg.String (fun prefix -> Latex.opt_prefix := prefix), " set a custom prefix for generated LaTeX macro command (default sail)"); @@ -216,7 +222,7 @@ let options = Arg.align ([ " memoize calls to z3, improving performance when typechecking repeatedly (default)"); ( "-no_memo_z3", Arg.Clear opt_memo_z3, - " do not memoize calls to z3"); + " do not memoize calls to z3 (default)"); ( "-memo", Arg.Tuple [Arg.Set opt_memo_z3; Arg.Set C_backend.opt_memo_cache], " memoize calls to z3, and intermediate compilation results"); @@ -276,7 +282,7 @@ let options = Arg.align ([ "<verbosity> (debug) verbose typechecker output: 0 is silent"); ( "-dsmt_verbose", Arg.Set Constraint.opt_smt_verbose, - " (debug) print SMTLIB constraints sent to Z3"); + " (debug) print SMTLIB constraints sent to SMT solver"); ( "-dno_cast", Arg.Set opt_dno_cast, " (debug) typecheck without any implicit casting"); @@ -319,7 +325,7 @@ let _ = opt_file_arguments := (!opt_file_arguments) @ [s]) usage_msg -let load_files ?generate:(generate=true) type_envs files = +let load_files ?check:(check=false) type_envs files = if !opt_memo_z3 then Constraint.load_digests () else (); let t = Profile.start () in @@ -328,24 +334,27 @@ let load_files ?generate:(generate=true) type_envs files = List.fold_right (fun (_, Parse_ast.Defs ast_nodes) (Parse_ast.Defs later_nodes) -> Parse_ast.Defs (ast_nodes@later_nodes)) parsed (Parse_ast.Defs []) in let ast = Process_file.preprocess_ast options ast in - let ast = Initial_check.process_ast ~generate:generate ast in + let ast = Initial_check.process_ast ~generate:(not check) ast in Profile.finish "parsing" t; let t = Profile.start () in let (ast, type_envs) = check_ast type_envs ast in Profile.finish "type checking" t; - let ast = Scattered.descatter ast in - let ast = rewrite_ast type_envs ast in + if !opt_memo_z3 then Constraint.save_digests () else (); - let out_name = match !opt_file_out with - | None when parsed = [] -> "out.sail" - | None -> fst (List.hd parsed) - | Some f -> f ^ ".sail" in + if check then + ("out.sail", ast, type_envs) + else + let ast = Scattered.descatter ast in + let ast = rewrite_ast type_envs ast in - if !opt_memo_z3 then Constraint.save_digests () else (); + let out_name = match !opt_file_out with + | None when parsed = [] -> "out.sail" + | None -> fst (List.hd parsed) + | Some f -> f ^ ".sail" in - (out_name, ast, type_envs) + (out_name, ast, type_envs) let main() = if !opt_print_version then @@ -411,7 +420,8 @@ let main() = (if !(opt_print_c) then let ast_c = rewrite_ast_c type_envs ast in - let ast_c, type_envs = Specialize.specialize ast_c type_envs in + let ast_c, type_envs = Specialize.(specialize typ_ord_specialization ast_c type_envs) in + (* let ast_c, type_envs = Specialize.(specialize' 2 int_specialization ast_c type_envs) in *) (* let ast_c = Spec_analysis.top_sort_defs ast_c in *) Util.opt_warnings := true; C_backend.compile_ast (C_backend.initial_ctx type_envs) (!opt_includes_c) ast_c diff --git a/src/scattered.ml b/src/scattered.ml index de286e3f..92cb3561 100644 --- a/src/scattered.ml +++ b/src/scattered.ml @@ -66,8 +66,8 @@ let rec last_scattered_mapcl id = function | [] -> true (* Nothing cares about these and the AST should be changed *) -let fake_effect_opt l = Effect_opt_aux (Effect_opt_pure, gen_loc l) -let fake_rec_opt l = Rec_aux (Rec_rec, gen_loc l) +let no_effect_opt l = Effect_opt_aux (Effect_opt_none, gen_loc l) +let fake_rec_opt l = Rec_aux (Rec_nonrec, gen_loc l) let no_tannot_opt l = Typ_annot_opt_aux (Typ_annot_opt_none, gen_loc l) @@ -95,7 +95,7 @@ let rec descatter' funcls mapcls = function | Some clauses -> List.rev (funcl :: clauses) | None -> [funcl] in - DEF_fundef (FD_aux (FD_function (fake_rec_opt l, no_tannot_opt l, fake_effect_opt l, clauses), + DEF_fundef (FD_aux (FD_function (fake_rec_opt l, no_tannot_opt l, no_effect_opt l, clauses), (gen_loc l, Type_check.empty_tannot))) :: descatter' funcls mapcls defs diff --git a/src/specialize.ml b/src/specialize.ml index 00357557..591a415a 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -52,11 +52,26 @@ open Ast open Ast_util open Rewriter -let is_typ_ord_uvar = function +let is_typ_ord_arg = function | A_aux (A_typ _, _) -> true | A_aux (A_order _, _) -> true | _ -> false +type specialization = { + is_polymorphic : kinded_id -> bool; + instantiation_filter : kid -> typ_arg -> bool + } + +let typ_ord_specialization = { + is_polymorphic = (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt); + instantiation_filter = (fun _ -> is_typ_ord_arg) + } + +let int_specialization = { + is_polymorphic = is_int_kopt; + instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp _, _) -> true | _ -> false) + } + let rec nexp_simp_typ (Typ_aux (typ_aux, l)) = let typ_aux = match typ_aux with | Typ_id v -> Typ_id v @@ -81,34 +96,43 @@ and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) = (* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of. This part of the typechecker API is a bit ugly. *) -let fix_instantiation instantiation = - let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in +let fix_instantiation spec instantiation = + let instantiation = KBindings.bindings (KBindings.filter spec.instantiation_filter instantiation) in let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation +(* polymorphic_functions returns all functions that are polymorphic + for some set of kinded-identifiers, specified by the is_kopt + predicate. For example, polymorphic_functions is_int_kopt will + return all Int-polymorphic functions. *) let rec polymorphic_functions is_kopt (Defs defs) = match defs with | DEF_spec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ) , _), id, _, externs), _)) :: defs -> - let is_type_polymorphic = List.exists is_kopt (quant_kopts typq) in - if is_type_polymorphic then + let is_polymorphic = List.exists is_kopt (quant_kopts typq) in + if is_polymorphic then IdSet.add id (polymorphic_functions is_kopt (Defs defs)) else polymorphic_functions is_kopt (Defs defs) | _ :: defs -> polymorphic_functions is_kopt (Defs defs) | [] -> IdSet.empty +(* When we specialize a function, we need to generate new name. To do + this we take the instantiation that the new function is specialized + for and turn that into a string in such a way that alpha-equivalent + instantiations always get the same name. We then zencode that + string so it is a valid identifier name, and prepend it to the + previous function name. *) let string_of_instantiation instantiation = let open Type_check in let kid_names = ref KOptMap.empty in let kid_counter = ref 0 in let kid_name kid = try KOptMap.find kid !kid_names with - | Not_found -> begin - let n = string_of_int !kid_counter in - kid_names := KOptMap.add kid n !kid_names; - incr kid_counter; - n - end + | Not_found -> + let n = string_of_int !kid_counter in + kid_names := KOptMap.add kid n !kid_names; + incr kid_counter; + n in (* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *) @@ -121,7 +145,7 @@ let string_of_instantiation instantiation = | Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")" | Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")" | Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")" - | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_nexp nexps ^ ")" + | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_nexp nexps ^ ")" | Nexp_exp n -> "2 ^ " ^ string_of_nexp n | Nexp_neg n -> "- " ^ string_of_nexp n in @@ -132,7 +156,7 @@ let string_of_instantiation instantiation = | Typ_id id -> string_of_id id | Typ_var kid -> kid_name (mk_kopt K_type kid) | Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")" - | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")" + | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")" | Typ_fn (arg_typs, ret_typ, eff) -> "(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ ^ " effect " ^ string_of_effect eff | Typ_bidir (t1, t2) -> @@ -160,9 +184,11 @@ let string_of_instantiation instantiation = kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" + | NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid) + | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")" in - let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in + let string_of_binding (kid, arg) = string_of_kid kid ^ "=>" ^ string_of_typ_arg arg in Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) let id_of_instantiation id instantiation = @@ -179,12 +205,12 @@ let rec variant_generic_typ id (Defs defs) = (* Returns a list of all the instantiations of a function id in an ast. Also works with union constructors, and searches for them in patterns. *) -let rec instantiations_of id ast = +let rec instantiations_of spec id ast = let instantiations = ref [] in let inspect_exp = function | E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 -> - let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in instantiations := instantiation :: !instantiations; exp | exp -> exp @@ -202,7 +228,7 @@ let rec instantiations_of id ast = (variant_generic_typ variant_id ast) typ in - instantiations := fix_instantiation instantiation :: !instantiations; + instantiations := fix_instantiation spec instantiation :: !instantiations; pat | Typ_aux (Typ_id variant_id, _) -> pat | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type") @@ -218,12 +244,12 @@ let rec instantiations_of id ast = !instantiations -let rec rewrite_polymorphic_calls id ast = +let rec rewrite_polymorphic_calls spec id ast = let vs_ids = Initial_check.val_spec_ids ast in let rewrite_e_aux = function | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 -> - let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in let spec_id = id_of_instantiation id instantiation in (* Make sure we only generate specialized calls when we've specialized the valspec. The valspec may not be generated if @@ -278,13 +304,61 @@ and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) = | A_order ord -> KidSet.empty | A_bool _ -> KidSet.empty -let specialize_id_valspec instantiations id ast = +(* Implicit arguments have restrictions that won't hold + post-specialisation, but we can just remove them and turn them into + regular arguments. *) +let rec remove_implicit (Typ_aux (aux, l) as t) = + match aux with + | Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l) + | Typ_tup typs -> Typ_aux (Typ_tup (List.map remove_implicit typs), l) + | Typ_fn (arg_typs, ret_typ, effs) -> Typ_aux (Typ_fn (List.map remove_implicit arg_typs, remove_implicit ret_typ, effs), l) + | Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (remove_implicit typ1, remove_implicit typ2), l) + | Typ_app (Id_aux (Id "implicit", _), args) -> Typ_aux (Typ_app (mk_id "atom", List.map remove_implicit_arg args), l) + | Typ_app (id, args) -> Typ_aux (Typ_app (id, List.map remove_implicit_arg args), l) + | Typ_id id -> Typ_aux (Typ_id id, l) + | Typ_exist (kopts, nc, typ) -> Typ_aux (Typ_exist (kopts, nc, remove_implicit typ), l) + | Typ_var v -> Typ_aux (Typ_var v, l) +and remove_implicit_arg (A_aux (aux, l)) = + match aux with + | A_typ typ -> A_aux (A_typ (remove_implicit typ), l) + | arg -> A_aux (arg, l) + +let kopt_arg = function + | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _) -> arg_nexp (nvar kid) + | KOpt_aux (KOpt_kind (K_aux (K_type,_), kid), _) -> arg_typ (mk_typ (Typ_var kid)) + | _ -> failwith "oh no" + +(* For numeric type arguments we have to be careful not to run into a + situation where we have an instantiation like + + 'n => 'm, 'm => 8 + + and end up re-writing 'n to 8. This function turns an instantition + like the above into two, + + 'n => 'i#m, 'm => 8 and 'i#m => 'm + + so we can do the substitution in two steps. *) +let safe_instantiation instantiation = + let args = + List.map (fun (_, arg) -> kopts_of_typ_arg arg) (KBindings.bindings instantiation) + |> List.fold_left KOptSet.union KOptSet.empty + |> KOptSet.elements + in + List.fold_left (fun (i, r) v -> KBindings.map (fun arg -> subst_kid typ_arg_subst (kopt_kid v) (prepend_kid "i#" (kopt_kid v)) arg) i, + KBindings.add (prepend_kid "i#" (kopt_kid v)) (kopt_arg v) r) + (instantiation, KBindings.empty) args + +let instantiate_constraints instantiation ncs = + List.map (fun c -> List.fold_left (fun c (v, a) -> constraint_subst v a c) c (KBindings.bindings instantiation)) ncs + +let specialize_id_valspec spec instantiations id ast = match split_defs (is_valspec id) ast with - | None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!") + | None -> Reporting.unreachable (id_loc id) __POS__ ("Valspec " ^ string_of_id id ^ " does not exist!") | Some (pre_ast, vs, post_ast) -> let typschm, externs, is_cast, annot = match vs with | DEF_spec (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)) -> typschm, externs, is_cast, annot - | _ -> assert false (* unreachable *) + | _ -> Reporting.unreachable (id_loc id) __POS__ "val-spec is not actually a val-spec" in let TypSchm_aux (TypSchm_ts (typq, typ), _) = typschm in @@ -292,8 +366,9 @@ let specialize_id_valspec instantiations id ast = let spec_ids = ref IdSet.empty in let specialize_instance instantiation = + let safe_instantiation, reverse = safe_instantiation instantiation in (* Replace the polymorphic type variables in the type with their concrete instantiation. *) - let typ = Type_check.subst_unifiers instantiation typ in + let typ = remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) in (* Collect any new type variables introduced by the instantiation *) let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in @@ -302,11 +377,17 @@ let specialize_id_valspec instantiations id ast = (* Remove type variables from the type quantifier. *) let kopts, constraints = quant_split typq in - let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in - let typq = mk_typquant (List.map (mk_qi_id K_type) typ_frees - @ List.map (mk_qi_id K_int) int_frees - @ List.map mk_qi_kopt kopts - @ List.map mk_qi_nc constraints) in + let constraints = instantiate_constraints safe_instantiation constraints in + let constraints = instantiate_constraints reverse constraints in + let kopts = List.filter (fun kopt -> not (spec.is_polymorphic kopt)) kopts in + let typq = + if List.length (typ_frees @ int_frees) = 0 && List.length kopts = 0 then + mk_typquant [] + else + mk_typquant (List.map (mk_qi_id K_type) typ_frees + @ List.map (mk_qi_id K_int) int_frees + @ List.map mk_qi_kopt kopts + @ List.map mk_qi_nc constraints) in let typschm = mk_typschm typq typ in let spec_id = id_of_instantiation id instantiation in @@ -324,8 +405,9 @@ let specialize_id_valspec instantiations id ast = (* When we specialize a function definition we also need to specialize all the types that appear as annotations within the function - body. *) -let specialize_annotations instantiation = + body. Also remove any type-annotation from the fundef itself, + because at this point we have that as a separate valspec.*) +let specialize_annotations instantiation fdef = let open Type_check in let rw_pat = { id_pat_alg with @@ -337,12 +419,21 @@ let specialize_annotations instantiation = lEXP_cast = (fun (typ, lexp) -> LEXP_cast (subst_unifiers instantiation typ, lexp)); pat_alg = rw_pat } in - rewrite_fun { - rewriters_base with - rewrite_exp = (fun _ -> fold_exp rw_exp); - rewrite_pat = (fun _ -> fold_pat rw_pat) - } - + let fdef = + rewrite_fun { + rewriters_base with + rewrite_exp = (fun _ -> fold_exp rw_exp); + rewrite_pat = (fun _ -> fold_pat rw_pat) + } fdef + in + match fdef with + | FD_aux (FD_function (rec_opt, _, eff_opt, funcls), annot) -> + FD_aux (FD_function (rec_opt, + Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown), + eff_opt, + funcls), + annot) + let specialize_id_fundef instantiations id ast = match split_defs (is_fundef id) ast with | None -> ast @@ -380,7 +471,15 @@ let specialize_id_overloads instantiations id (Defs defs) = valspecs are then re-specialized. This process is iterated until the whole spec is specialized. *) -let initial_calls = (IdSet.of_list [mk_id "main"; mk_id "__SetConfig"; mk_id "__ListConfig"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"; mk_id "append_64"]) +let initial_calls = IdSet.of_list + [ mk_id "main"; + mk_id "__SetConfig"; + mk_id "__ListConfig"; + mk_id "execute"; + mk_id "decode"; + mk_id "initialize_registers"; + mk_id "append_64" (* used to construct bitvector literals in C backend *) + ] let remove_unused_valspecs ?(initial_calls=initial_calls) env ast = let calls = ref initial_calls in @@ -424,9 +523,9 @@ let slice_defs env (Defs defs) keep_ids = let defs = List.filter keep defs in remove_unused_valspecs env (Defs defs) ~initial_calls:keep_ids -let specialize_id id ast = - let instantiations = instantiations_of id ast in - let ast = specialize_id_valspec instantiations id ast in +let specialize_id spec id ast = + let instantiations = instantiations_of spec id ast in + let ast = specialize_id_valspec spec instantiations id ast in let ast = specialize_id_fundef instantiations id ast in specialize_id_overloads instantiations id ast @@ -448,21 +547,26 @@ let reorder_typedefs (Defs defs) = let others = filter_typedefs defs in Defs (List.rev !tdefs @ others) -let specialize_ids ids ast = - let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in +let specialize_ids spec ids ast = + let ast = List.fold_left (fun ast id -> specialize_id spec id ast) ast (IdSet.elements ids) in let ast = reorder_typedefs ast in let ast, _ = Type_error.check Type_check.initial_env ast in let ast = - List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids) + List.fold_left (fun ast id -> rewrite_polymorphic_calls spec id ast) ast (IdSet.elements ids) in let ast, env = Type_error.check Type_check.initial_env ast in let ast = remove_unused_valspecs env ast in ast, env -let rec specialize ast env = - let ids = polymorphic_functions (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) ast in - if IdSet.is_empty ids then +let rec specialize' n spec ast env = + if n = 0 then ast, env else - let ast, env = specialize_ids ids ast in - specialize ast env + let ids = polymorphic_functions spec.is_polymorphic ast in + if IdSet.is_empty ids then + ast, env + else + let ast, env = specialize_ids spec ids ast in + specialize' (n - 1) spec ast env + +let specialize = specialize' (-1) diff --git a/src/specialize.mli b/src/specialize.mli index 28029747..93dec239 100644 --- a/src/specialize.mli +++ b/src/specialize.mli @@ -54,10 +54,18 @@ open Ast open Ast_util open Type_check +type specialization + +(** Only specialize Type- and Ord- kinded polymorphism. *) +val typ_ord_specialization : specialization + +(** (experimental) specialise Int-kinded definitions *) +val int_specialization : specialization + (** Returns an IdSet with the function ids that have X-kinded parameters, e.g. val f : forall ('a : X). 'a -> 'a. The first argument specifies what X should be - it should be one of: - [is_nat_kopt], [is_order_kopt], or [is_typ_kopt] from [Ast_util], + [is_int_kopt], [is_order_kopt], or [is_typ_kopt] from [Ast_util], or some combination of those. *) val polymorphic_functions : (kinded_id -> bool) -> 'a defs -> IdSet.t @@ -66,11 +74,15 @@ val polymorphic_functions : (kinded_id -> bool) -> 'a defs -> IdSet.t AST with [Type_check.initial_env]. The env parameter is the environment to return if there is no polymorphism to remove, in which case specialize returns the AST unmodified. *) -val specialize : tannot defs -> Env.t -> tannot defs * Env.t +val specialize : specialization -> tannot defs -> Env.t -> tannot defs * Env.t + +val specialize' : int -> specialization -> tannot defs -> Env.t -> tannot defs * Env.t -val instantiations_of : id -> tannot defs -> typ_arg KBindings.t list +(** return all instantiations of a function id, with the + instantiations filtered according to the specialization. *) +val instantiations_of : specialization -> id -> tannot defs -> typ_arg KBindings.t list val string_of_instantiation : typ_arg KBindings.t -> string -(* Remove all function definitions except for the given set *) +(** Remove all function definitions except for the given set *) val slice_defs : Env.t -> tannot defs -> IdSet.t -> tannot defs diff --git a/src/type_check.ml b/src/type_check.ml index b9f8f323..b43e17ed 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -72,6 +72,10 @@ let opt_no_lexp_bounds_check = ref false We prefer not to do it for latex output but it is otherwise a good idea. *) let opt_expand_valspec = ref true +(* Linearize cases involving power where we would otherwise require + the SMT solver to use non-linear arithmetic. *) +let opt_smt_linearize = ref false + let depth = ref 0 let rec indent n = match n with @@ -246,6 +250,50 @@ and strip_kinded_id_aux = function and strip_kind = function | K_aux (k_aux, _) -> K_aux (k_aux, Parse_ast.Unknown) +let rec typ_nexps (Typ_aux (typ_aux, l)) = + match typ_aux with + | Typ_internal_unknown -> [] + | Typ_id v -> [] + | Typ_var kid -> [] + | Typ_tup typs -> List.concat (List.map typ_nexps typs) + | Typ_app (f, args) -> List.concat (List.map typ_arg_nexps args) + | Typ_exist (kids, nc, typ) -> typ_nexps typ + | Typ_fn (arg_typs, ret_typ, _) -> + List.concat (List.map typ_nexps arg_typs) @ typ_nexps ret_typ + | Typ_bidir (typ1, typ2) -> + typ_nexps typ1 @ typ_nexps typ2 +and typ_arg_nexps (A_aux (typ_arg_aux, l)) = + match typ_arg_aux with + | A_nexp n -> [n] + | A_typ typ -> typ_nexps typ + | A_bool nc -> constraint_nexps nc + | A_order ord -> [] +and constraint_nexps (NC_aux (nc_aux, l)) = + match nc_aux with + | NC_equal (n1, n2) | NC_bounded_ge (n1, n2) | NC_bounded_le (n1, n2) | NC_not_equal (n1, n2) -> + [n1; n2] + | NC_set _ | NC_true | NC_false | NC_var _ -> [] + | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> constraint_nexps nc1 @ constraint_nexps nc2 + | NC_app (_, args) -> List.concat (List.map typ_arg_nexps args) + +(* Return a KidSet containing all the type variables appearing in + nexp, where nexp occurs underneath a Nexp_exp, i.e. 2^nexp *) +let rec nexp_power_variables (Nexp_aux (aux, _)) = + match aux with + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> + KidSet.union (nexp_power_variables n1) (nexp_power_variables n2) + | Nexp_neg n -> + nexp_power_variables n + | Nexp_id _ | Nexp_var _ | Nexp_constant _ -> + KidSet.empty + | Nexp_app (_, ns) -> + List.fold_left KidSet.union KidSet.empty (List.map nexp_power_variables ns) + | Nexp_exp n -> + tyvars_of_nexp n + +let constraint_power_variables nc = + List.fold_left KidSet.union KidSet.empty (List.map nexp_power_variables (constraint_nexps nc)) + let rec name_pat (P_aux (aux, _)) = match aux with | P_id id | P_as (_, id) -> Some ("_" ^ string_of_id id) @@ -261,7 +309,7 @@ let fresh_existential k = let named_existential k = function | Some n -> mk_kopt k (mk_kid n) | None -> fresh_existential k - + let destruct_exist_plain ?name:(name=None) typ = match typ with | Typ_aux (Typ_exist ([kopt], nc, typ), _) -> @@ -397,7 +445,7 @@ module Env : sig val wf_nexp : ?exs:KidSet.t -> t -> nexp -> unit val wf_constraint : ?exs:KidSet.t -> t -> n_constraint -> unit - (* Some of the code in the environment needs to use the Z3 prover, + (* Some of the code in the environment needs to use the smt solver, which is defined below. To break the circularity this would cause (as the prove code depends on the environment), we add a reference to the prover to the initial environment. *) @@ -521,7 +569,7 @@ end = struct let kopts, ncs = quant_split typq in let rec subst_args kopts args = match kopts, args with - | kopt :: kopts, (A_aux (A_nexp _, _) as arg) :: args when is_nat_kopt kopt -> + | kopt :: kopts, (A_aux (A_nexp _, _) as arg) :: args when is_int_kopt kopt -> List.map (constraint_subst (kopt_kid kopt) arg) (subst_args kopts args) | kopt :: kopts, A_aux (A_typ arg, _) :: args when is_typ_kopt kopt -> subst_args kopts args @@ -543,6 +591,10 @@ end = struct match aux with | NC_or (nc1, nc2) -> NC_aux (NC_or (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l) | NC_and (nc1, nc2) -> NC_aux (NC_and (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l) + | NC_equal (n1, n2) -> NC_aux (NC_equal (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) + | NC_not_equal (n1, n2) -> NC_aux (NC_not_equal (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) + | NC_bounded_le (n1, n2) -> NC_aux (NC_bounded_le (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) + | NC_bounded_ge (n1, n2) -> NC_aux (NC_bounded_ge (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) | NC_app (id, args) -> (try begin match Bindings.find id env.typ_synonyms l env args with @@ -550,7 +602,7 @@ end = struct | arg -> typ_error env l ("Expected Bool when expanding synonym " ^ string_of_id id ^ " got " ^ string_of_typ_arg arg) end with Not_found -> NC_aux (NC_app (id, List.map (expand_synonyms_arg env) args), l)) - | NC_true | NC_false | NC_equal _ | NC_not_equal _ | NC_bounded_le _ | NC_bounded_ge _ | NC_var _ | NC_set _ -> nc + | NC_true | NC_false | NC_var _ | NC_set _ -> nc and expand_nexp_synonyms env (Nexp_aux (aux, l) as nexp) = typ_debug ~level:2 (lazy ("Expanding " ^ string_of_nexp nexp)); @@ -935,7 +987,7 @@ end = struct typ_print (lazy (adding ^ "record " ^ string_of_id id)); let rec record_typ_args = function | [] -> [] - | ((QI_aux (QI_id kopt, _)) :: qis) when is_nat_kopt kopt -> + | ((QI_aux (QI_id kopt, _)) :: qis) when is_int_kopt kopt -> mk_typ_arg (A_nexp (nvar (kopt_kid kopt))) :: record_typ_args qis | ((QI_aux (QI_id kopt, _)) :: qis) when is_typ_kopt kopt -> mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt)))) :: record_typ_args qis @@ -1068,7 +1120,7 @@ end = struct let add_typ_var l (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) env = if KBindings.mem v env.typ_vars then begin let n = match KBindings.find_opt v env.shadow_vars with Some n -> n | None -> 0 in - let s_l, s_k = KBindings.find v env.typ_vars in + let s_l, s_k = KBindings.find v env.typ_vars in let s_v = Kid_aux (Var (string_of_kid v ^ "#" ^ string_of_int n), l) in typ_print (lazy (Printf.sprintf "%stype variable (shadowing %s) %s : %s" adding (string_of_kid s_v) (string_of_kid v) (string_of_kind_aux k))); { env with @@ -1087,11 +1139,34 @@ end = struct let add_constraint constr env = wf_constraint env constr; let (NC_aux (nc_aux, l) as constr) = constraint_simp (expand_constraint_synonyms env constr) in - match nc_aux with - | NC_true -> env - | _ -> - typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint constr)); - { env with constraints = constr :: env.constraints } + let power_vars = constraint_power_variables constr in + if KidSet.cardinal power_vars > 1 && !opt_smt_linearize then + typ_error env l ("Cannot add constraint " ^ string_of_n_constraint constr + ^ " where more than two variables appear within an exponential") + else if KidSet.cardinal power_vars = 1 && !opt_smt_linearize then + let v = KidSet.choose power_vars in + let constrs = List.fold_left nc_and nc_true (get_constraints env) in + begin match Constraint.solve_all_smt l (get_typ_vars env) constrs v with + | Some solutions -> + typ_print (lazy (Util.("Linearizing " |> red |> clear) ^ string_of_n_constraint constr + ^ " for " ^ string_of_kid v ^ " in " ^ Util.string_of_list ", " Big_int.to_string solutions)); + let linearized = + List.fold_left + (fun c s -> nc_or c (nc_and (nc_eq (nvar v) (nconstant s)) (constraint_subst v (arg_nexp (nconstant s)) constr))) + nc_false solutions + in + typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint linearized)); + { env with constraints = linearized :: env.constraints } + | None -> + typ_error env l ("Type variable " ^ string_of_kid v + ^ " must have a finite number of solutions to add " ^ string_of_n_constraint constr) + end + else + match nc_aux with + | NC_true -> env + | _ -> + typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint constr)); + { env with constraints = constr :: env.constraints } let get_ret_typ env = env.ret_typ @@ -1282,7 +1357,7 @@ and simp_typ_aux = function would become {('s:Bool) ('r: Bool), nc('r). bool('s & 'r)}, wherein all the redundant boolean variables have been combined into a single one. Making this simplification allows us to avoid - having to pass large numbers of pointless variables to Z3 if we + having to pass large numbers of pointless variables to SMT if we ever bind this existential. *) | Typ_exist (vars, nc, Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool b, _)]), _)) -> let kids = KidSet.of_list (List.map kopt_kid vars) in @@ -1326,16 +1401,23 @@ this is equivalent to which is then a problem we can feed to the constraint solver expecting unsat. *) -let prove_z3 env (NC_aux (_, l) as nc) = +let prove_smt env (NC_aux (_, l) as nc) = let vars = Env.get_typ_vars env in let vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in let ncs = Env.get_constraints env in - match Constraint.call_z3 l vars (List.fold_left nc_and (nc_not nc) ncs) with + match Constraint.call_smt l vars (List.fold_left nc_and (nc_not nc) ncs) with | Constraint.Unsat -> typ_debug (lazy "unsat"); true | Constraint.Sat -> typ_debug (lazy "sat"); false - | Constraint.Unknown -> typ_debug (lazy "unknown"); false - -let solve env (Nexp_aux (_, l) as nexp) = + | Constraint.Unknown -> + (* Work around versions of z3 that are confused by 2^n in + constraints, even when such constraints are irrelevant *) + let ncs' = List.concat (List.map constraint_conj ncs) in + let ncs' = List.filter (fun nc -> KidSet.is_empty (constraint_power_variables nc)) ncs' in + match Constraint.call_smt l vars (List.fold_left nc_and (nc_not nc) ncs') with + | Constraint.Unsat -> typ_debug (lazy "unsat"); true + | Constraint.Sat | Constraint.Unknown -> typ_debug (lazy "sat/unknown"); false + +let solve_unique env (Nexp_aux (_, l) as nexp) = typ_print (lazy (Util.("Solve " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_nexp nexp ^ " = ?")); match nexp with @@ -1345,7 +1427,7 @@ let solve env (Nexp_aux (_, l) as nexp) = let vars = Env.get_typ_vars env in let vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in let constr = List.fold_left nc_and (nc_eq (nvar (mk_kid "solve#")) nexp) (Env.get_constraints env) in - Constraint.solve_z3 l vars constr (mk_kid "solve#") + Constraint.solve_unique_smt l vars constr (mk_kid "solve#") let debug_pos (file, line, _, _) = "(" ^ file ^ "/" ^ string_of_int line ^ ") " @@ -1358,7 +1440,7 @@ let prove pos env nc = else (); match nc_aux with | NC_true -> true - | _ -> prove_z3 env nc + | _ -> prove_smt env nc (**************************************************************************) (* 3. Unification *) @@ -1376,32 +1458,6 @@ let rec nexp_frees ?exs:(exs=KidSet.empty) (Nexp_aux (nexp, l)) = | Nexp_exp n -> nexp_frees ~exs:exs n | Nexp_neg n -> nexp_frees ~exs:exs n -let rec typ_nexps (Typ_aux (typ_aux, l)) = - match typ_aux with - | Typ_internal_unknown -> [] - | Typ_id v -> [] - | Typ_var kid -> [] - | Typ_tup typs -> List.concat (List.map typ_nexps typs) - | Typ_app (f, args) -> List.concat (List.map typ_arg_nexps args) - | Typ_exist (kids, nc, typ) -> typ_nexps typ - | Typ_fn (arg_typs, ret_typ, _) -> - List.concat (List.map typ_nexps arg_typs) @ typ_nexps ret_typ - | Typ_bidir (typ1, typ2) -> - typ_nexps typ1 @ typ_nexps typ2 -and typ_arg_nexps (A_aux (typ_arg_aux, l)) = - match typ_arg_aux with - | A_nexp n -> [n] - | A_typ typ -> typ_nexps typ - | A_bool nc -> constraint_nexps nc - | A_order ord -> [] -and constraint_nexps (NC_aux (nc_aux, l)) = - match nc_aux with - | NC_equal (n1, n2) | NC_bounded_ge (n1, n2) | NC_bounded_le (n1, n2) | NC_not_equal (n1, n2) -> - [n1; n2] - | NC_set _ | NC_true | NC_false | NC_var _ -> [] - | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> constraint_nexps nc1 @ constraint_nexps nc2 - | NC_app (_, args) -> List.concat (List.map typ_arg_nexps args) - let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = match nexp1, nexp2 with | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0 @@ -1692,13 +1748,13 @@ let rec ambiguous_vars (Typ_aux (aux, _)) = match aux with | Typ_app (_, args) -> List.fold_left KidSet.union KidSet.empty (List.map ambiguous_arg_vars args) | _ -> KidSet.empty - + and ambiguous_arg_vars (A_aux (aux, _)) = match aux with | A_bool nc -> ambiguous_nc_vars nc | A_nexp nexp -> ambiguous_nexp_vars nexp | _ -> KidSet.empty - + and ambiguous_nc_vars (NC_aux (aux, _)) = match aux with | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) @@ -1903,9 +1959,17 @@ let rec subtyp l env typ1 typ2 = let env = add_typ_vars l (List.map (mk_kopt K_int) (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2)))) env in let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in if not (kids2 = []) then typ_error env l ("Universally quantified constraint generated: " ^ Util.string_of_list ", " string_of_kid kids2) else (); - let env = Env.add_constraint (nc_eq nexp1 nexp2) env in - if prove __POS__ env nc2 then () - else typ_raise env l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) + let vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) (Env.get_typ_vars env) in + begin match Constraint.call_smt l vars (nc_eq nexp1 nexp2) with + | Constraint.Sat -> + let env = Env.add_constraint (nc_eq nexp1 nexp2) env in + if prove __POS__ env nc2 then + () + else + typ_raise env l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) + | _ -> + typ_error env l ("Constraint " ^ string_of_n_constraint (nc_eq nexp1 nexp2) ^ " is not satisfiable") + end | _, _ -> match typ_aux1, typ_aux2 with | _, Typ_internal_unknown when Env.allow_unknowns env -> () @@ -1966,11 +2030,24 @@ let subtype_check env typ1 typ2 = exception No_simple_rewrite;; +let rec move_to_front p ys = function + | x :: xs when p x -> x :: (ys @ xs) + | x :: xs -> move_to_front p (x :: ys) xs + | [] -> ys + let rec rewrite_sizeof' env (Nexp_aux (aux, l) as nexp) = let mk_exp exp = mk_exp ~loc:l exp in match aux with | Nexp_var v -> + (* Use a simple heuristic to find the most likely local we can + use, and move it to the front of the list. *) + let str = string_of_kid v in + let likely = + try let n = if str.[1] = '_' then 2 else 1 in String.sub str n (String.length str - n) with + | Invalid_argument _ -> str + in let locals = Env.get_locals env |> Bindings.bindings in + let locals = move_to_front (fun local -> likely = string_of_id (fst local)) [] locals in let same_size (local, (_, Typ_aux (aux, _))) = match aux with | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _)]) @@ -4731,7 +4808,7 @@ let mk_synonym typq typ_arg = let kopts = List.map snd kopts in let rec subst_args env l kopts args = match kopts, args with - | kopt :: kopts, A_aux (A_nexp arg, _) :: args when is_nat_kopt kopt -> + | kopt :: kopts, A_aux (A_nexp arg, _) :: args when is_int_kopt kopt -> let typ_arg, ncs = subst_args env l kopts args in typ_arg_subst (kopt_kid kopt) (arg_nexp arg) typ_arg, List.map (constraint_subst (kopt_kid kopt) (arg_nexp arg)) ncs diff --git a/src/type_check.mli b/src/type_check.mli index 4ff52cd9..5ab986d0 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -71,6 +71,10 @@ val opt_no_lexp_bounds_check : bool ref We prefer not to do it for latex output but it is otherwise a good idea. *) val opt_expand_valspec : bool ref +(** Linearize cases involving power where we would otherwise require + the SMT solver to use non-linear arithmetic. *) +val opt_smt_linearize : bool ref + (** {2 Type errors} *) type type_error = @@ -309,9 +313,15 @@ val check_fundef : Env.t -> 'a fundef -> tannot def list * Env.t val check_val_spec : Env.t -> 'a val_spec -> tannot def list * Env.t +(** Attempt to prove a constraint using z3. Returns true if z3 can + prove that the constraint is true, returns false if z3 cannot prove + the constraint true. Note that this does not guarantee that the + constraint is actually false, as the constraint solver is somewhat + untrustworthy. *) val prove : (string * int * int * int) -> Env.t -> n_constraint -> bool -val solve : Env.t -> nexp -> Big_int.num option +(** Returns Some c if there is a unique c such that nexp = c *) +val solve_unique : Env.t -> nexp -> Big_int.num option val canonicalize : Env.t -> typ -> typ diff --git a/test/c/dead_branch.expect b/test/c/dead_branch.expect new file mode 100644 index 00000000..ca6ef09a --- /dev/null +++ b/test/c/dead_branch.expect @@ -0,0 +1,2 @@ +v = 0x5678EF91 +v = 0xABCD12345678EF91 diff --git a/test/c/dead_branch.sail b/test/c/dead_branch.sail new file mode 100644 index 00000000..4d7900eb --- /dev/null +++ b/test/c/dead_branch.sail @@ -0,0 +1,42 @@ +default Order dec + +$include <arith.sail> +$include <vector_dec.sail> + +type xlen : Int = 32 +type xlenbits = bits(xlen) + +register reg : bits(64) + +function read_xlen (arg : bool) -> xlenbits = { + match (arg, sizeof(xlen)) { + (_, 32) => reg[31 .. 0], + (_, 64) => reg, + (_, _) => if sizeof(xlen) == 32 + then reg[31 .. 0] + else reg[63 .. 32] + } +} + +type ylen : Int = 64 +type ylenbits = bits(ylen) + +function read_ylen (arg : bool) -> ylenbits = { + match (arg, sizeof(ylen)) { + (_, 32) => reg[31 .. 0], + (_, 64) => reg, + (_, _) => if sizeof(ylen) == 32 + then reg[31 .. 0] + else reg + } +} + +val main : unit -> unit effect {rreg, wreg} +function main() = { + reg = 0xABCD_1234_5678_EF91; + let v = read_xlen(true); + print_bits("v = ", v); + let v = read_ylen(true); + print_bits("v = ", v); + () +} diff --git a/test/c/encdec.expect b/test/c/encdec.expect new file mode 100644 index 00000000..18fab89a --- /dev/null +++ b/test/c/encdec.expect @@ -0,0 +1,2 @@ +bin = 0x9FFF +bin' = 0x9FFF diff --git a/test/c/encdec.sail b/test/c/encdec.sail new file mode 100644 index 00000000..bac55c8d --- /dev/null +++ b/test/c/encdec.sail @@ -0,0 +1,38 @@ +default Order dec + +$include <prelude.sail> +$include <exception_basic.sail> + +enum pred = { + P_false, + P_true +} + +mapping decenc_p : bits(2) <-> pred = { + 0b00 <-> P_true, + 0b01 <-> P_false +} + +scattered union ast + +val encdec : ast <-> bits(16) + +union clause ast = ABS : (pred, bits(10)) + +mapping clause encdec = + ABS(decenc_p(0b0 @ p), rd @ rs) + <-> 0b10011 @ p : bits(1) @ rd : bits(5) @ rs : bits(5) + +function fetch(_: unit) -> bits(16) = { + 0b10011 @ 0xFF @ 0b111 +} + +val main : unit -> unit effect {barr, eamem, escape, exmem, rmem, rreg, wmv, wreg} +function main () = { + let bin = fetch(); + let ast = encdec(bin); + let bin' = encdec(ast); + assert(bin == bin'); + print_bits("bin = ", bin); + print_bits("bin' = ", bin') +} diff --git a/test/c/run_tests.py b/test/c/run_tests.py index 4927e281..4a02dd78 100755 --- a/test/c/run_tests.py +++ b/test/c/run_tests.py @@ -93,6 +93,7 @@ xml = '<testsuites>\n' xml += test_c('unoptimized C', '', '', True) xml += test_c('optimized C', '-O2', '-O', True) xml += test_c('constant folding', '', '-Oconstant_fold', True) +xml += test_c('monomorphised C', '-O2', '-O -Oconstant_fold -auto_mono', True) xml += test_c('full optimizations', '-O2 -mbmi2 -DINTRINSICS', '-O -Oconstant_fold', True) xml += test_c('address sanitised', '-O2 -fsanitize=undefined', '-O', False) diff --git a/test/typecheck/pass/complex_exist_sat.sail b/test/typecheck/pass/complex_exist_sat.sail new file mode 100644 index 00000000..65c3b6a9 --- /dev/null +++ b/test/typecheck/pass/complex_exist_sat.sail @@ -0,0 +1,7 @@ +val foo : int -> {'q, 'q in {0, 1}. atom(2 * 'q)} + +function foo(x) = 2 + +val bar : int -> {'q, 'q in {0, 1}. atom(2 * 'q)} + +function bar(x) = 0 diff --git a/test/typecheck/pass/complex_exist_sat/v1.expect b/test/typecheck/pass/complex_exist_sat/v1.expect new file mode 100644 index 00000000..b1937f47 --- /dev/null +++ b/test/typecheck/pass/complex_exist_sat/v1.expect @@ -0,0 +1,8 @@ +Type error: +[[96mcomplex_exist_sat/v1.sail[0m]:3:18-19 +3[96m |[0mfunction foo(x) = 3 + [91m |[0m [91m^[0m + [91m |[0m Tried performing type coercion from int(3) to {('q : Int), 'q in {0, 1}. int((2 * 'q))} on 3 + [91m |[0m Coercion failed because: + [91m |[0m Constraint 3 == (2 * 'ex1#) is not satisfiable + [91m |[0m diff --git a/test/typecheck/pass/complex_exist_sat/v1.sail b/test/typecheck/pass/complex_exist_sat/v1.sail new file mode 100644 index 00000000..f36f2dda --- /dev/null +++ b/test/typecheck/pass/complex_exist_sat/v1.sail @@ -0,0 +1,3 @@ +val foo : int -> {'q, 'q in {0, 1}. atom(2 * 'q)} + +function foo(x) = 3
\ No newline at end of file diff --git a/test/typecheck/pass/complex_exist_sat/v2.expect b/test/typecheck/pass/complex_exist_sat/v2.expect new file mode 100644 index 00000000..27af46d2 --- /dev/null +++ b/test/typecheck/pass/complex_exist_sat/v2.expect @@ -0,0 +1,8 @@ +Type error: +[[96mcomplex_exist_sat/v2.sail[0m]:3:18-19 +3[96m |[0mfunction foo(x) = 4 + [91m |[0m [91m^[0m + [91m |[0m Tried performing type coercion from int(4) to {('q : Int), 'q in {0, 1}. int((2 * 'q))} on 4 + [91m |[0m Coercion failed because: + [91m |[0m int(4) is not a subtype of {('q : Int), 'q in {0, 1}. int((2 * 'q))} + [91m |[0m diff --git a/test/typecheck/pass/complex_exist_sat/v2.sail b/test/typecheck/pass/complex_exist_sat/v2.sail new file mode 100644 index 00000000..e3e18e8c --- /dev/null +++ b/test/typecheck/pass/complex_exist_sat/v2.sail @@ -0,0 +1,3 @@ +val foo : int -> {'q, 'q in {0, 1}. atom(2 * 'q)} + +function foo(x) = 4
\ No newline at end of file diff --git a/test/typecheck/pass/int_synonym.sail b/test/typecheck/pass/int_synonym.sail new file mode 100644 index 00000000..33bdaf0c --- /dev/null +++ b/test/typecheck/pass/int_synonym.sail @@ -0,0 +1,18 @@ +/* from prelude */ +default Order dec + +$include <flow.sail> + +type bits ('n : Int) = vector('n, dec, bit) + +type xlen : Int = 64 +type xlen_bytes : Int = 8 +type xlenbits = bits(xlen) + +val "sign_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m) +val EXTS : forall 'n 'm, 'm >= 'n. (implicit('m), bits('n)) -> bits('m) +function EXTS(m, v) = sign_extend(v, m) + +val extend : forall 'n, 'n <= xlen_bytes. (bool, bits(8 * 'n)) -> xlenbits + +function extend(flag, value) = EXTS(value) diff --git a/test/typecheck/pass/pow_32_64.sail b/test/typecheck/pass/pow_32_64.sail new file mode 100644 index 00000000..bb4f207a --- /dev/null +++ b/test/typecheck/pass/pow_32_64.sail @@ -0,0 +1,14 @@ +default Order dec + +$include <prelude.sail> + +$option -smt_linearize + +val bar : forall 'n, 'n <= 2 ^ 64 - 1. int('n) -> unit + +val foo : forall 'n, 'n in {32, 64}. bits('n) -> unit + +function foo(xs) = { + let x = unsigned(xs); + bar(x) +} diff --git a/test/typecheck/run_tests.sh b/test/typecheck/run_tests.sh index ad2592df..e5650646 100755 --- a/test/typecheck/run_tests.sh +++ b/test/typecheck/run_tests.sh @@ -50,9 +50,9 @@ printf "<testsuites>\n" >> $DIR/tests.xml for i in `ls $DIR/pass/ | grep sail`; do - if $SAILDIR/sail -just_check -ddump_tc_ast -dsanity $DIR/pass/$i 2> /dev/null 1> $DIR/rtpass/$i; + if $SAILDIR/sail -no_memo_z3 -just_check -ddump_tc_ast -dsanity $DIR/pass/$i 2> /dev/null 1> $DIR/rtpass/$i; then - if $SAILDIR/sail -just_check -ddump_tc_ast -dmagic_hash -dno_cast -dsanity $DIR/rtpass/$i 2> /dev/null 1> $DIR/rtpass2/$i; + if $SAILDIR/sail -no_memo_z3 -just_check -ddump_tc_ast -dmagic_hash -dno_cast -dsanity $DIR/rtpass/$i 2> /dev/null 1> $DIR/rtpass2/$i; then if diff $DIR/rtpass/$i $DIR/rtpass2/$i; then @@ -71,7 +71,7 @@ do for file in $DIR/pass/${i%.sail}/*.sail; do pushd $DIR/pass > /dev/null; - if $SAILDIR/sail ${i%.sail}/$(basename $file) 2> result; + if $SAILDIR/sail -no_memo_z3 ${i%.sail}/$(basename $file) 2> result; then red "failing variant of $i $(basename $file) passed" "fail" else |
