diff options
| author | Alasdair Armstrong | 2017-07-28 15:39:52 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2017-07-28 15:39:52 +0100 |
| commit | 3c18efc6153c340517d7b229fe64b38e4d3e5f33 (patch) | |
| tree | 34f7ed3cce7bf6a3b35b94e117e0c6690ae73399 | |
| parent | 34c27ada18e9e36a0224e2ff9999559ed2899157 (diff) | |
| parent | f951a1712fe88eadc812643175ea8f3d31a558cf (diff) | |
Merge remote-tracking branch 'origin/sail_new_tc' into experiments
| -rw-r--r-- | mips/mips_extras.lem | 12 | ||||
| -rw-r--r-- | risc-v/risc-v.sail | 2173 | ||||
| -rw-r--r-- | src/ast.ml | 2 | ||||
| -rw-r--r-- | src/ast_util.ml | 2 | ||||
| -rw-r--r-- | src/initial_check.ml | 18 | ||||
| -rw-r--r-- | src/initial_check.mli | 4 | ||||
| -rw-r--r-- | src/monomorphise.ml | 55 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 34 | ||||
| -rw-r--r-- | src/pretty_print_lem_ast.ml | 18 | ||||
| -rw-r--r-- | src/pretty_print_ocaml.ml | 21 | ||||
| -rw-r--r-- | src/process_file.ml | 34 | ||||
| -rw-r--r-- | src/process_file.mli | 5 | ||||
| -rw-r--r-- | src/rewriter.ml | 578 | ||||
| -rw-r--r-- | src/rewriter.mli | 4 | ||||
| -rw-r--r-- | src/sail.ml | 2 | ||||
| -rw-r--r-- | src/spec_analysis.ml | 5 | ||||
| -rw-r--r-- | src/type_check.ml | 94 | ||||
| -rw-r--r-- | src/type_check.mli | 4 | ||||
| -rw-r--r-- | test/typecheck/fail/overlap_field_wreg.sail | 13 | ||||
| -rw-r--r-- | test/typecheck/pass/add_real.sail | 5 | ||||
| -rw-r--r-- | test/typecheck/pass/add_vec_exts_no_annot.sail | 19 | ||||
| -rw-r--r-- | test/typecheck/pass/add_vec_exts_no_annot_overload.sail | 19 | ||||
| -rw-r--r-- | test/typecheck/pass/overlap_field.sail | 13 |
23 files changed, 852 insertions, 2282 deletions
diff --git a/mips/mips_extras.lem b/mips/mips_extras.lem index 99487f49..1fbba038 100644 --- a/mips/mips_extras.lem +++ b/mips/mips_extras.lem @@ -7,13 +7,13 @@ import Set_extra let memory_parameter_transformer mode v = match v with - | Interp.V_tuple [location;length] -> + | Interp_ast.V_tuple [location;length] -> let (v,loc_regs) = extern_with_track mode extern_vector_value location in match length with - | Interp.V_lit (L_aux (L_num len) _) -> + | Interp_ast.V_lit (L_aux (L_num len) _) -> (v,(natFromInteger len),loc_regs) - | Interp.V_track (Interp.V_lit (L_aux (L_num len) _)) size_regs -> + | Interp_ast.V_track (Interp_ast.V_lit (L_aux (L_num len) _)) size_regs -> match loc_regs with | Nothing -> (v,(natFromInteger len),Just (List.map (fun r -> extern_reg r Nothing) (Set_extra.toList size_regs))) | Just loc_regs -> (v,(natFromInteger len),Just (loc_regs++(List.map (fun r -> extern_reg r Nothing) (Set_extra.toList size_regs)))) @@ -25,7 +25,7 @@ let memory_parameter_transformer mode v = let memory_parameter_transformer_option_address _mode v = match v with - | Interp.V_tuple [location;_] -> + | Interp_ast.V_tuple [location;_] -> Just (extern_vector_value location) | _ -> Assert_extra.failwith ("memory_parameter_transformer_option_address: expected 'V_tuple [_;_]' given " ^ (Interp.string_of_value v)) end @@ -54,7 +54,7 @@ let memory_vals : memory_write_vals = ("MEMval_conditional", (MV memory_parameter_transformer_option_address (Just (fun (IState interp context) b -> - let bit = Interp.V_lit (L_aux (if b then L_one else L_zero) Interp_ast.Unknown) in + let bit = Interp_ast.V_lit (L_aux (if b then L_one else L_zero) Interp_ast.Unknown) in (IState (Interp.add_answer_to_stack interp bit) context))))); ] @@ -64,7 +64,7 @@ let memory_vals_tagged : memory_write_vals_tagged = ("MEMval_tag_conditional", (MVT memory_parameter_transformer_option_address (Just (fun (IState interp context) b -> - let bit = Interp.V_lit (L_aux (if b then L_one else L_zero) Interp_ast.Unknown) in + let bit = Interp_ast.V_lit (L_aux (if b then L_one else L_zero) Interp_ast.Unknown) in (IState (Interp.add_answer_to_stack interp bit) context))))); ] diff --git a/risc-v/risc-v.sail b/risc-v/risc-v.sail index 92e8da51..edf95c62 100644 --- a/risc-v/risc-v.sail +++ b/risc-v/risc-v.sail @@ -1,1966 +1,263 @@ default Order dec -register (bit[64]) x0 -register (bit[64]) x1 -register (bit[64]) x2 -register (bit[64]) x3 -register (bit[64]) x4 -register (bit[64]) x5 -register (bit[64]) x6 -register (bit[64]) x7 -register (bit[64]) x8 -register (bit[64]) x9 -register (bit[64]) x10 -register (bit[64]) x11 -register (bit[64]) x12 -register (bit[64]) x13 -register (bit[64]) x14 -register (bit[64]) x15 -register (bit[64]) x16 -register (bit[64]) x17 -register (bit[64]) x18 -register (bit[64]) x19 -register (bit[64]) x20 -register (bit[64]) x21 -register (bit[64]) x22 -register (bit[64]) x23 -register (bit[64]) x24 -register (bit[64]) x25 -register (bit[64]) x26 -register (bit[64]) x27 -register (bit[64]) x28 -register (bit[64]) x29 -register (bit[64]) x30 -register (bit[64]) x31 - -let (vector <0, 32, inc, (register<(bit[64])>) >) x = +typedef regval = bit[64] +typedef regno = bit[5] + +register (regval) x0 +register (regval) x1 +register (regval) x2 +register (regval) x3 +register (regval) x4 +register (regval) x5 +register (regval) x6 +register (regval) x7 +register (regval) x8 +register (regval) x9 +register (regval) x10 +register (regval) x11 +register (regval) x12 +register (regval) x13 +register (regval) x14 +register (regval) x15 +register (regval) x16 +register (regval) x17 +register (regval) x18 +register (regval) x19 +register (regval) x20 +register (regval) x21 +register (regval) x22 +register (regval) x23 +register (regval) x24 +register (regval) x25 +register (regval) x26 +register (regval) x27 +register (regval) x28 +register (regval) x29 +register (regval) x30 +register (regval) x31 + +register (bit[64]) PC +register (bit[64]) nextPC + +let (vector <0, 32, inc, (register<(regval)>)>) GPRs = [x0, x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18,x19,x20,x21,x22,x23,x24,x25,x26,x27,x28,x29,x30,x31] -typedef arithr = enumerate { ADD; SLL; SLT; SLTU; XOR; SRL; OR; AND } - - -scattered function unit execute -scattered typedef ast = const union - -val bit[32] -> ast effect pure decode - -scattered function ast decode - -union ast member (bit[5], bit[5], arithr, bit[5]) Arithr1 - -function clause decode (0b0000000 : -(bit[5]) src2 : -(bit[5]) src1 : -(bit[3]) arithr_op : -(bit[5]) dest : -0b0110011 ) = -Arithr1(src2,src1, (([|7|]) arithr_op), dest) - - -function clause execute (Arithr1(src2,src1, arithr_op, dest)) = - { - switch arithr_op { - case ADD -> x[dest] := x[src1] + x[src2] -(* case SLL -> - case SLT -> - case SLTU -> - case XOR -> - case SRL -> - case OR -> - case AND -> *) - } -} - -end ast -end decode -end execute - -(*val extern forall Nat 'm, Nat 'n. (implicit<'m>,bit['n]) -> bit['m] effect pure EXTS - -(* XXX binary coded decimal *) -function forall Type 'a . 'a DEC_TO_BCD ( x ) = x -function forall Type 'a . 'a BCD_TO_DEC ( x ) = x -(* XXX carry out *) -function forall Nat 'a . bit carry_out ( (bit['a]) _,carry ) = carry -(* XXX Storage control *) -function forall Type 'a . 'a real_addr ( x ) = x -(* XXX For stvxl and lvxl - what does that do? *) -function forall Type 'a . unit mark_as_not_likely_to_be_needed_again_anytime_soon ( x ) = () - -(* XXX *) -val extern forall Nat 'k, Nat 'r, - 0 <= 'k, 'k <= 64, 'r + 'k = 64. - (bit[64], [|'k|]) -> [|0:'r|] effect pure countLeadingZeroes - -function forall Nat 'n, Nat 'm . - bit['m] EXTS_EXPLICIT((bit['n]) v, ([|'m|]) m) = - (v[0] ^^ (m - length(v))) : v - -val forall Nat 'n, Nat 'm, 0 <= 'n, 'n <= 'm, 'm <= 63 . - ([|'n|],[|'m|]) -> bit[64] - effect pure - MASK - -function (bit[64]) MASK(start, stop) = { - (bit[64]) mask_temp := 0; - if(start > stop) then { - mask_temp[start .. 63] := bitone ^^ (64 - start); - mask_temp[0 .. stop] := bitone ^^ (stop + 1); - } else { - mask_temp[start .. stop ] := bitone ^^ (stop - start + 1); - }; - mask_temp; -} - -val forall Nat 'n, 0 <= 'n, 'n <= 63 . - (bit[64], [|'n|]) -> bit[64] effect pure ROTL - -function (bit[64]) ROTL(v, n) = v[n .. 63] : v[0 .. (n-1)] - -(* Branch facility registers *) - -typedef cr = register bits [ 32 : 63 ] { - 32 .. 35 : CR0; - 32 : LT; 33 : GT; 34 : EQ; 35 : SO; - 36 .. 39 : CR1; - 36 : FX; 37 : FEX; 38 : VX; 39 : OX; - 40 .. 43 : CR2; - 44 .. 47 : CR3; - 48 .. 51 : CR4; - 52 .. 55 : CR5; - 56 .. 59 : CR6; - (* name clashing - do we need hierarchical naming for fields, or do - we just don't care? LT, GT, etc. are not used in the code anyway. - 56 : LT; 57 : GT; 58 : EQ; 59 : SO; - *) - 60 .. 63 : CR7; -} -register (cr) CR - -register (bit[64]) CTR -register (bit[64]) LR - -typedef xer = register bits [ 0 : 63 ] { - 32 : SO; - 33 : OV; - 34 : CA; -} -register (xer) XER - -register alias CA = XER.CA - -(* Fixed-point registers *) - -register (bit[64]) GPR0 -register (bit[64]) GPR1 -register (bit[64]) GPR2 -register (bit[64]) GPR3 -register (bit[64]) GPR4 -register (bit[64]) GPR5 -register (bit[64]) GPR6 -register (bit[64]) GPR7 -register (bit[64]) GPR8 -register (bit[64]) GPR9 -register (bit[64]) GPR10 -register (bit[64]) GPR11 -register (bit[64]) GPR12 -register (bit[64]) GPR13 -register (bit[64]) GPR14 -register (bit[64]) GPR15 -register (bit[64]) GPR16 -register (bit[64]) GPR17 -register (bit[64]) GPR18 -register (bit[64]) GPR19 -register (bit[64]) GPR20 -register (bit[64]) GPR21 -register (bit[64]) GPR22 -register (bit[64]) GPR23 -register (bit[64]) GPR24 -register (bit[64]) GPR25 -register (bit[64]) GPR26 -register (bit[64]) GPR27 -register (bit[64]) GPR28 -register (bit[64]) GPR29 -register (bit[64]) GPR30 -register (bit[64]) GPR31 - -let (vector <0, 32, inc, (register<(bit[64])>) >) GPR = - [ GPR0, GPR1, GPR2, GPR3, GPR4, GPR5, GPR6, GPR7, GPR8, GPR9, GPR10, - GPR11, GPR12, GPR13, GPR14, GPR15, GPR16, GPR17, GPR18, GPR19, GPR20, - GPR21, GPR22, GPR23, GPR24, GPR25, GPR26, GPR27, GPR28, GPR29, GPR30, GPR31 - ] - -register (bit[32:63]) VRSAVE - -register (bit[64]) SPRG4 -register (bit[64]) SPRG5 -register (bit[64]) SPRG6 -register (bit[64]) SPRG7 - -(* XXX bogus, length should be 1024 with many more values - cf. mfspr - definition - eg. SPRG4 to SPRG7 are at offsets 260 to 263, VRSAVE is 256, - etc. *) -let (vector <0, 10, inc, (register<(bit[64])>) >) SPR = - [ undefined, XER, undefined, undefined, - undefined, undefined, undefined, undefined, - LR, CTR - ] - -(* XXX DCR is implementation-dependent; also, some DCR are only 32 bits - instead of 64, and mtdcrux/mfdcrux do special tricks in that case, not - shown in pseudo-code. We just define two dummy DCR here, using sparse - vector definition. *) -register (vector <0, 64, inc, bit>) DCR0 -register (vector <0, 64, inc, bit>) DCR1 -let (vector <0, 2** 64, inc, (register<(vector<0, 64, inc, bit>)>) >) DCR = - [ 0=DCR0, 1=DCR1 ] - -(* Floating-point registers *) - -register (bit[64]) FPR0 -register (bit[64]) FPR1 -register (bit[64]) FPR2 -register (bit[64]) FPR3 -register (bit[64]) FPR4 -register (bit[64]) FPR5 -register (bit[64]) FPR6 -register (bit[64]) FPR7 -register (bit[64]) FPR8 -register (bit[64]) FPR9 -register (bit[64]) FPR10 -register (bit[64]) FPR11 -register (bit[64]) FPR12 -register (bit[64]) FPR13 -register (bit[64]) FPR14 -register (bit[64]) FPR15 -register (bit[64]) FPR16 -register (bit[64]) FPR17 -register (bit[64]) FPR18 -register (bit[64]) FPR19 -register (bit[64]) FPR20 -register (bit[64]) FPR21 -register (bit[64]) FPR22 -register (bit[64]) FPR23 -register (bit[64]) FPR24 -register (bit[64]) FPR25 -register (bit[64]) FPR26 -register (bit[64]) FPR27 -register (bit[64]) FPR28 -register (bit[64]) FPR29 -register (bit[64]) FPR30 -register (bit[64]) FPR31 - -let (vector <0, 32, inc, (register<(bit[64])>) >) FPR = - [ FPR0, FPR1, FPR2, FPR3, FPR4, FPR5, FPR6, FPR7, FPR8, FPR9, FPR10, - FPR11, FPR12, FPR13, FPR14, FPR15, FPR16, FPR17, FPR18, FPR19, FPR20, - FPR21, FPR22, FPR23, FPR24, FPR25, FPR26, FPR27, FPR28, FPR29, FPR30, FPR31 - ] - -typedef fpscr = register bits [ 0 : 63 ] { - 32 : FX; - 33 : FEX; - 34 : VX; - 35 : OX; - 36 : UX; - 37 : ZX; - 38 : XX; - 39 : VXSNAN; - 40 : VXISI; - 41 : VXIDI; - 42 : VXZDZ; - 43 : VXIMZ; - 44 : VXVC; - 45 : FR; - 46 : FI; - 47 .. 51 : FPRF; - 47 : C; - 48 .. 51 : FPCC; - 48 : FL; 49 : FG; 50 : FE; 51 : FU; - 53 : VXSOFT; - 54 : VXSQRT; - 55 : VXCVI; - 56 : VE; - 57 : OE; - 58 : UE; - 59 : ZE; - 60 : XE; - 61 : NI; - 62 .. 63 : RN; -} -register (fpscr) FPSCR +function (regval) rGPR ((regno) r) = + if (r == 0) then + 0 + else + GPRs[r] -(* Pair-wise access to FPR registers *) +function unit wGPR((regno) r, (regval) v) = + if (r != 0) then + GPRs[r] := v -register alias FPRp0 = FPR0 : FPR1 -register alias FPRp2 = FPR2 : FPR3 -register alias FPRp4 = FPR4 : FPR5 -register alias FPRp6 = FPR6 : FPR7 -register alias FPRp8 = FPR8 : FPR9 -register alias FPRp10 = FPR10 : FPR11 -register alias FPRp12 = FPR12 : FPR13 -register alias FPRp14 = FPR14 : FPR15 -register alias FPRp16 = FPR16 : FPR17 -register alias FPRp18 = FPR18 : FPR19 -register alias FPRp20 = FPR20 : FPR21 -register alias FPRp22 = FPR22 : FPR23 -register alias FPRp24 = FPR24 : FPR25 -register alias FPRp26 = FPR26 : FPR27 -register alias FPRp28 = FPR28 : FPR29 -register alias FPRp30 = FPR30 : FPR31 - -let (vector <0, 32, inc, (register<(bit[128])>)>) FPRp = - [ 0 = FPRp0, 2 = FPRp2, 4 = FPRp4, 6 = FPRp6, 8 = FPRp8, 10 = FPRp10, - 12 = FPRp12, 14 = FPRp14, 16 = FPRp16, 18 = FPRp18, 20 = FPRp20, 22 = - FPRp22, 24 = FPRp24, 26 = FPRp26, 28 = FPRp28, 30 = FPRp30 ] - - -(* XXX *) -val bit[32] -> bit[64] effect pure DOUBLE -val bit[64] -> bit[32] effect { undef } SINGLE - -function bit[64] DOUBLE word = { - (bit[64]) temp := 0; - if word[1..8] > 0 & word[1..8] < 255 - then { - temp[0..1] := word[0..1]; - temp[2] := ~(word[1]); - temp[3] := ~(word[1]); - temp[4] := ~(word[1]); - temp[5..63] := word[2..31] : 0b00000000000000000000000000000; - } else if word[1..8] == 0 & word[9..31] != 0 - then { - sign := word[0]; - exp := 0-126; - (bit[53]) frac := 0b0 : word[9..31] : 0b00000000000000000000000000000; - foreach (i from 0 to 52) { - if frac[0] == 0 - then { frac[0..52] := frac[1..52] : 0b0; - exp := exp -1; } - else () - }; - temp[0] := sign; - temp[1..11] := (bit[10]) exp + 1023; - temp[12..63] := frac[1..52]; - } else { - temp[0..1] := word[0..1]; - temp[2] := word[1]; - temp[3] := word[1]; - temp[4] := word[1]; - temp[5..63] := word[2..31] : 0b00000000000000000000000000000; - }; - temp -} - -function bit[32] SINGLE ((bit[64]) frs) = { - (bit[32]) word := 0; - if (frs[1..11] > 896) | (frs[1..63] == 0) - then { word[0..1] := frs[0..1]; - word[2..31] := frs[5..34]; } - else if (874 <= frs[1..11]) & (frs[1..11] <= 896) - then { - sign := frs[0]; - (bit[10]) exp := frs[1..11] - 1023; - (bit[53]) frac := 0b1 : frs[12..63]; - foreach (i from 0 to 53) { - if exp < (0-126) - then { frac[0..52] := 0b0 : frac[0..51]; - exp := exp + 1; } - else ()}; - } else word := undefined; - word -} - -(* Vector registers *) - -register (bit[128]) VR0 -register (bit[128]) VR1 -register (bit[128]) VR2 -register (bit[128]) VR3 -register (bit[128]) VR4 -register (bit[128]) VR5 -register (bit[128]) VR6 -register (bit[128]) VR7 -register (bit[128]) VR8 -register (bit[128]) VR9 -register (bit[128]) VR10 -register (bit[128]) VR11 -register (bit[128]) VR12 -register (bit[128]) VR13 -register (bit[128]) VR14 -register (bit[128]) VR15 -register (bit[128]) VR16 -register (bit[128]) VR17 -register (bit[128]) VR18 -register (bit[128]) VR19 -register (bit[128]) VR20 -register (bit[128]) VR21 -register (bit[128]) VR22 -register (bit[128]) VR23 -register (bit[128]) VR24 -register (bit[128]) VR25 -register (bit[128]) VR26 -register (bit[128]) VR27 -register (bit[128]) VR28 -register (bit[128]) VR29 -register (bit[128]) VR30 -register (bit[128]) VR31 - -let (vector <0, 32, inc, (register<(bit[128])>) >) VR = - [ VR0, VR1, VR2, VR3, VR4, VR5, VR6, VR7, VR8, VR9, VR10, - VR11, VR12, VR13, VR14, VR15, VR16, VR17, VR18, VR19, VR20, - VR21, VR22, VR23, VR24, VR25, VR26, VR27, VR28, VR29, VR30, VR31 - ] - -typedef vscr = register bits [ 96 : 127 ] { - 111 : NJ; - 127 : SAT; -} -register (vscr) VSCR - -(* XXX extend with zeroes -- the resulting size in completely unknown and depends of context *) -val extern forall Nat 'n, Nat 'm. bit['n] -> bit['m] effect pure EXTZ - -(* Chop has a very weird definition where the resulting size depends of - context, but in practice it is used with the following definition everywhere, - except in vaddcuw which probably needs to be patched accordingly. *) -val forall Nat 'n, Nat 'm, 'm <= 'n. (bit['n], [|'m|]) -> bit['m] effect pure Chop -function forall Nat 'm. (bit['m]) Chop(x, y) = x[0..y] - -val forall Nat 'n, Nat 'm, Nat 'k, 'n <= 0, 0 <= 'm. - (implicit<'k>, int, [|'n|], [|'m|]) -> bit['k] effect { wreg } Clamp - -function forall Nat 'n, Nat 'm, Nat 'k, 'n <= 0, 0 <= 'm. (bit['k]) -Clamp((int) x, ([|'n|]) y, ([|'m|]) z) = { - ([|'n:'m|]) result := 0; - if (x<y) then { - result := y; - VSCR.SAT := 1; - } else if (x > z) then { - result := z; - VSCR.SAT := 1; - } else { - result := x; - }; - (bit['k]) result; -} - -(* XXX *) -val extern bit[32] -> bit[32] effect pure RoundToSPIntCeil -val extern bit[32] -> bit[32] effect pure RoundToSPIntFloor -val extern bit[32] -> bit[32] effect pure RoundToSPIntNear -val extern bit[32] -> bit[32] effect pure RoundToSPIntTrunc -val extern bit[32] -> bit[32] effect pure RoundToNearSP -val extern bit[32] -> bit[32] effect pure ReciprocalEstimateSP -val extern bit[32] -> bit[32] effect pure ReciprocalSquareRootEstimateSP -val extern bit[32] -> bit[32] effect pure LogBase2EstimateSP -val extern bit[32] -> bit[32] effect pure Power2EstimateSP -val extern (bit[32], bit[5]) -> bit[32] effect pure ConvertSPtoSXWsaturate -val extern (bit[32], bit[5]) -> bit[32] effect pure ConvertSPtoUXWsaturate - - -register (bit[64]) NIA (* next instruction address *) -register (bit[64]) CIA (* current instruction address *) - - -val extern forall Nat 'n. ( bit[64] , [|'n|] , bit[8*'n]) -> unit effect { wmem } MEMw val extern forall Nat 'n. ( bit[64] , [|'n|] ) -> (bit[8 * 'n]) effect { rmem } MEMr -val extern forall Nat 'n. ( bit[64] , [|'n|] ) -> (bit[8 * 'n]) effect { rmem } MEMr_reserve -val extern forall Nat 'n. ( bit[64] , [|'n|] , bit[8*'n]) -> bool effect { wmem } MEMw_conditional +val extern forall Nat 'n. ( bit[64] , [|'n|]) -> unit effect { eamem } MEMea +val extern forall Nat 'n. ( bit[64] , [|'n|] , bit[8*'n]) -> unit effect { wmv } MEMval -val extern unit -> unit effect { barr } I_Sync -val extern unit -> unit effect { barr } H_Sync (*corresponds to Sync in barrier kinds*) -val extern unit -> unit effect { barr } LW_Sync -val extern unit -> unit effect { barr } EIEIO_Sync +(* Ideally these would be sail builtin *) +function (bit[64]) shift_right_arith64 ((bit[64]) v, (bit[6]) shift) = + let (bit[128]) v128 = EXTS(v) in + (v128 >> shift)[63..0] -val forall Nat 'n, Nat 'm, 'n *8 = 'm. (implicit<'m>,(bit['m])) -> (bit['m]) effect pure byte_reverse -function forall Nat 'n, Nat 'm, 'n*8 = 'm. (bit['m]) effect pure byte_reverse((bit['m]) input) = { - (bit['m]) output := 0; - j := length(input); - foreach (i from 0 to (length(input)) by 8) { - output[i..i+7] := input[j-7 ..j]; - j := j-8; }; - output -} +function (bit[32]) shift_right_arith32 ((bit[32]) v, (bit[5]) shift) = + let (bit[64]) v64 = EXTS(v) in + (v64 >> shift)[31..0] -(* XXX effect for trap? *) -val extern unit -> unit effect pure trap - -register (bit[1]) mode64bit -register (bit[1]) bigendianmode - -val (bit[64],bit) -> unit effect {rreg,wreg} set_overflow_cr0 -function (unit) set_overflow_cr0(target_register,new_xer_so) = { - (if mode64bit - then m := 0 - else m := 32); - (bit[64]) zero := 0; - (if target_register[m..63] <_s zero[m..63] - then c := 0b100 - else if target_register[m..63] >_s zero[m..63] - then c := 0b010 - else c := 0b001); - CR.CR0 := c:[new_xer_so] -} - -function (unit) set_SO_OV(overflow) = { - XER.OV := overflow; - XER.SO := (XER.SO | overflow); -} +typedef uop = enumerate {LUI; AUIPC} (* upper immediate ops *) +typedef bop = enumerate {BEQ; BNE; BLT; BGE; BLTU; BGEU} (* branch ops *) +typedef iop = enumerate {ADDI; SLTI; SLTIU; XORI; ORI; ANDI} (* immediate ops *) +typedef sop = enumerate {SLLI; SRLI; SRAI} (* shift ops *) +typedef rop = enumerate {ADD; SUB; SLL; SLT; SLTU; XOR; SRL; SRA; OR; AND} (* reg-reg ops *) +typedef ropw = enumerate {ADDW; SUBW; SLLW; SRLW; SRAW} (* reg-reg 32-bit ops *) scattered function unit execute scattered typedef ast = const union -val bit[32] -> ast effect pure decode - -scattered function ast decode - -union ast member (bit[24], bit, bit) B - -function clause decode (0b010010 : -(bit[24]) LI : -[AA] : -[LK] as instr) = - B (LI,AA,LK) - -function clause execute (B (LI, AA, LK)) = - { - if AA then NIA := EXTS (LI : 0b00) else NIA := CIA + EXTS (LI : 0b00); - if LK then LR := CIA + 4 else () - } - -union ast member (bit[5], bit[5], bit[14], bit, bit) Bc - -function clause decode (0b010000 : -(bit[5]) BO : -(bit[5]) BI : -(bit[14]) BD : -[AA] : -[LK] as instr) = - Bc (BO,BI,BD,AA,LK) - -function clause execute (Bc (BO, BI, BD, AA, LK)) = - { - if mode64bit then M := 0 else M := 32; - ctr_temp := CTR; - if ~ (BO[2]) - then { - ctr_temp := ctr_temp - 1; - CTR := ctr_temp - } - else (); - ctr_ok := (BO[2] | ~ (ctr_temp[M .. 63] == 0) ^ BO[3]); - cond_ok := (BO[0] | CR[BI + 32] ^ ~ (BO[1])); - if ctr_ok & cond_ok - then if AA then NIA := EXTS (BD : 0b00) else NIA := CIA + EXTS (BD : 0b00) - else (); - if LK then LR := CIA + 4 else () - } - -union ast member (bit[5], bit[5], bit[2], bit) Bclr - -function clause decode (0b010011 : -(bit[5]) BO : -(bit[5]) BI : -(bit[3]) _ : -(bit[2]) BH : -0b0000010000 : -[LK] as instr) = - Bclr (BO,BI,BH,LK) - -function clause execute (Bclr (BO, BI, BH, LK)) = - { - if mode64bit then M := 0 else M := 32; - ctr_temp := CTR; - if ~ (BO[2]) - then { - ctr_temp := ctr_temp - 1; - CTR := ctr_temp - } - else (); - ctr_ok := (BO[2] | ~ (ctr_temp[M .. 63] == 0) ^ BO[3]); - cond_ok := (BO[0] | CR[BI + 32] ^ ~ (BO[1])); - if ctr_ok & cond_ok then NIA := LR[0 .. 61] : 0b00 else (); - if LK then LR := CIA + 4 else () - } - -union ast member (bit[5], bit[5], bit[2], bit) Bcctr - -function clause decode (0b010011 : -(bit[5]) BO : -(bit[5]) BI : -(bit[3]) _ : -(bit[2]) BH : -0b1000010000 : -[LK] as instr) = - Bcctr (BO,BI,BH,LK) - -function clause execute (Bcctr (BO, BI, BH, LK)) = - { - cond_ok := (BO[0] | CR[BI + 32] ^ ~ (BO[1])); - if cond_ok then NIA := CTR[0 .. 61] : 0b00 else (); - if LK then LR := CIA + 4 else () - } - -union ast member (bit[7]) Sc - -function clause decode (0b010001 : -(bit[5]) _ : -(bit[5]) _ : -(bit[4]) _ : -(bit[7]) LEV : -(bit[3]) _ : -0b1 : -(bit[1]) _ as instr) = - Sc (LEV) - -function clause execute (Sc (LEV)) = () - -union ast member (bit[5], bit[5], bit[16]) Lbzu - -function clause decode (0b100011 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lbzu (RT,RA,D) - -function clause execute (Lbzu (RT, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - GPR[RT] := - 0b00000000000000000000000000000000000000000000000000000000 : MEMr (EA,1); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Lbzux - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -0b0001110111 : -(bit[1]) _ as instr) = - Lbzux (RT,RA,RB) - -function clause execute (Lbzux (RT, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - GPR[RT] := - 0b00000000000000000000000000000000000000000000000000000000 : MEMr (EA,1); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Lhzu - -function clause decode (0b101001 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lhzu (RT,RA,D) - -function clause execute (Lhzu (RT, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - GPR[RT] := 0b000000000000000000000000000000000000000000000000 : MEMr (EA,2); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Lhzux - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -0b0100110111 : -(bit[1]) _ as instr) = - Lhzux (RT,RA,RB) - -function clause execute (Lhzux (RT, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - GPR[RT] := 0b000000000000000000000000000000000000000000000000 : MEMr (EA,2); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Lhau - -function clause decode (0b101011 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lhau (RT,RA,D) - -function clause execute (Lhau (RT, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - GPR[RT] := EXTS (MEMr (EA,2)); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Lhaux - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -0b0101110111 : -(bit[1]) _ as instr) = - Lhaux (RT,RA,RB) - -function clause execute (Lhaux (RT, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - GPR[RT] := EXTS (MEMr (EA,2)); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Lwz - -function clause decode (0b100000 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lwz (RT,RA,D) - -function clause execute (Lwz (RT, RA, D)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (D); - GPR[RT] := 0b00000000000000000000000000000000 : MEMr (EA,4) - } - -union ast member (bit[5], bit[5], bit[16]) Lwzu - -function clause decode (0b100001 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lwzu (RT,RA,D) - -function clause execute (Lwzu (RT, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - GPR[RT] := 0b00000000000000000000000000000000 : MEMr (EA,4); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Lwzux - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -0b0000110111 : -(bit[1]) _ as instr) = - Lwzux (RT,RA,RB) - -function clause execute (Lwzux (RT, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - GPR[RT] := 0b00000000000000000000000000000000 : MEMr (EA,4); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Lwaux - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -0b0101110101 : -(bit[1]) _ as instr) = - Lwaux (RT,RA,RB) - -function clause execute (Lwaux (RT, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - GPR[RT] := EXTS (MEMr (EA,4)); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[14]) Ld - -function clause decode (0b111010 : -(bit[5]) RT : -(bit[5]) RA : -(bit[14]) DS : -0b00 as instr) = - Ld (RT,RA,DS) - -function clause execute (Ld (RT, RA, DS)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (DS : 0b00); - GPR[RT] := MEMr (EA,8) - } - -union ast member (bit[5], bit[5], bit[14]) Ldu - -function clause decode (0b111010 : -(bit[5]) RT : -(bit[5]) RA : -(bit[14]) DS : -0b01 as instr) = - Ldu (RT,RA,DS) - -function clause execute (Ldu (RT, RA, DS)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (DS : 0b00); - GPR[RT] := MEMr (EA,8); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Ldux - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -0b0000110101 : -(bit[1]) _ as instr) = - Ldux (RT,RA,RB) - -function clause execute (Ldux (RT, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - GPR[RT] := MEMr (EA,8); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Stbu - -function clause decode (0b100111 : -(bit[5]) RS : -(bit[5]) RA : -(bit[16]) D as instr) = - Stbu (RS,RA,D) - -function clause execute (Stbu (RS, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - MEMw(EA,1) := (GPR[RS])[56 .. 63]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Stbux - -function clause decode (0b011111 : -(bit[5]) RS : -(bit[5]) RA : -(bit[5]) RB : -0b0011110111 : -(bit[1]) _ as instr) = - Stbux (RS,RA,RB) - -function clause execute (Stbux (RS, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - MEMw(EA,1) := (GPR[RS])[56 .. 63]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Sthu - -function clause decode (0b101101 : -(bit[5]) RS : -(bit[5]) RA : -(bit[16]) D as instr) = - Sthu (RS,RA,D) - -function clause execute (Sthu (RS, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - MEMw(EA,2) := (GPR[RS])[48 .. 63]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Sthux - -function clause decode (0b011111 : -(bit[5]) RS : -(bit[5]) RA : -(bit[5]) RB : -0b0110110111 : -(bit[1]) _ as instr) = - Sthux (RS,RA,RB) - -function clause execute (Sthux (RS, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - MEMw(EA,2) := (GPR[RS])[48 .. 63]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Stw - -function clause decode (0b100100 : -(bit[5]) RS : -(bit[5]) RA : -(bit[16]) D as instr) = - Stw (RS,RA,D) - -function clause execute (Stw (RS, RA, D)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (D); - MEMw(EA,4) := (GPR[RS])[32 .. 63] - } - -union ast member (bit[5], bit[5], bit[16]) Stwu - -function clause decode (0b100101 : -(bit[5]) RS : -(bit[5]) RA : -(bit[16]) D as instr) = - Stwu (RS,RA,D) - -function clause execute (Stwu (RS, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - MEMw(EA,4) := (GPR[RS])[32 .. 63]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Stwux - -function clause decode (0b011111 : -(bit[5]) RS : -(bit[5]) RA : -(bit[5]) RB : -0b0010110111 : -(bit[1]) _ as instr) = - Stwux (RS,RA,RB) - -function clause execute (Stwux (RS, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - MEMw(EA,4) := (GPR[RS])[32 .. 63]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[14]) Std - -function clause decode (0b111110 : -(bit[5]) RS : -(bit[5]) RA : -(bit[14]) DS : -0b00 as instr) = - Std (RS,RA,DS) - -function clause execute (Std (RS, RA, DS)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (DS : 0b00); - MEMw(EA,8) := GPR[RS] - } - -union ast member (bit[5], bit[5], bit[14]) Stdu - -function clause decode (0b111110 : -(bit[5]) RS : -(bit[5]) RA : -(bit[14]) DS : -0b01 as instr) = - Stdu (RS,RA,DS) - -function clause execute (Stdu (RS, RA, DS)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (DS : 0b00); - MEMw(EA,8) := GPR[RS]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Stdux - -function clause decode (0b011111 : -(bit[5]) RS : -(bit[5]) RA : -(bit[5]) RB : -0b0010110101 : -(bit[1]) _ as instr) = - Stdux (RS,RA,RB) - -function clause execute (Stdux (RS, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - MEMw(EA,8) := GPR[RS]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[12], bit[4]) Lq +val bit[32] -> option<ast> effect pure decode -function clause decode (0b111000 : -(bit[5]) RTp : -(bit[5]) RA : -(bit[12]) DQ : -(bit[4]) PT as instr) = - Lq (RTp,RA,DQ,PT) +scattered function option<ast> decode -function clause execute (Lq (RTp, RA, DQ, PT)) = - { - (bit[64]) EA := 0; - (bit[64]) b := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (DQ : 0b0000); - (bit[128]) mem := MEMr (EA,16); - if bigendianmode - then { - GPR[RTp] := mem[0 .. 63]; - GPR[RTp + 1] := mem[64 .. 127] - } - else { - (bit[128]) bytereverse := byte_reverse (mem); - GPR[RTp] := bytereverse[0 .. 63]; - GPR[RTp + 1] := bytereverse[64 .. 127] - } - } - -union ast member (bit[5], bit[5], bit[14]) Stq - -function clause decode (0b111110 : -(bit[5]) RSp : -(bit[5]) RA : -(bit[14]) DS : -0b10 as instr) = - Stq (RSp,RA,DS) - -function clause execute (Stq (RSp, RA, DS)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - { - (bit[64]) EA := 0; - (bit[64]) b := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (DS : 0b00); - (bit[128]) mem := 0; - mem[0..63] := GPR[RSp]; - mem[64..127] := GPR[RSp + 1]; - if ~ (bigendianmode) then mem := byte_reverse (mem) else (); - MEMw(EA,16) := mem - }; - EA := b + EXTS (DS : 0b00); - MEMw(EA,8) := RSp - } +union ast member ((bit[20]), regno, uop) UTYPE -union ast member (bit[5], bit[5], bit[16]) Lmw +function clause decode ((bit[20]) imm : (regno) rd : 0b0110111) = Some(UTYPE(imm, rd, LUI)) +function clause decode ((bit[20]) imm : (regno) rd : 0b0010111) = Some(UTYPE(imm, rd, AUIPC)) -function clause decode (0b101110 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lmw (RT,RA,D) +function clause execute (UTYPE(imm, rd, op)) = + let (bit[64]) off = EXTS(imm : 0x000) in + let ret = switch (op) { + case LUI -> off + case AUIPC -> PC + off + } in + wGPR(rd, ret) -function clause execute (Lmw (RT, RA, D)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (D); - size := ([|32|]) (32 - RT) * 4; - buffer := MEMr (EA,size); - i := 0; - foreach (r from RT to 31 by 1 in inc) - { - GPR[r] := 0b00000000000000000000000000000000 : buffer[i .. i + 31]; - i := i + 32 +union ast member ((bit[20]), regno) JAL +function clause decode ((bit[20]) imm : (regno) rd : 0b1101111) = Some (JAL(imm, rd)) +function clause execute (JAL(imm, rd)) = + let (bit[64]) offset = EXTS(imm[19] : imm[7..0] : imm[8] : imm[18..13] : imm[12..9] : 0b0) in { + nextPC := PC + offset; + wGPR(rd, PC + 4); } - } -union ast member (bit[5], bit[5], bit[5]) Lswi - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) NB : -0b1001010101 : -(bit[1]) _ as instr) = - Lswi (RT,RA,NB) - -function clause execute (Lswi (RT, RA, NB)) = - { - (bit[64]) EA := 0; - if RA == 0 then EA := 0 else EA := GPR[RA]; - ([|32|]) r := 0; - r := RT - 1; - ([|32|]) size := if NB == 0 then 32 else NB; - (bit[256]) membuffer := MEMr (EA,size); - j := 0; - i := 32; - foreach (n from (if NB == 0 then 32 else NB) to 1 by 1 in dec) - { - if i == 32 - then { - r := ([|32|]) (r + 1) mod 32; - GPR[r] := 0 - } - else (); - (GPR[r])[i..i + 7] := membuffer[j .. j + 7]; - j := j + 8; - i := i + 8; - if i == 64 then i := 32 else (); - EA := EA + 1 +union ast member((bit[12]), regno, regno) JALR +function clause decode ((bit[12]) imm : (regno) rs1 : 0b000 : (regno) rd : 0b1100111) = + Some(JALR(imm, rs1, rd)) +function clause execute (JALR(imm, rs1, rd)) = + let (bit[64]) newPC = rGPR(rs1) + EXTS(imm) in { + nextPC := newPC[63..1] : 0b0; + wGPR(rd, PC + 4); } - } - -union ast member (bit[5], bit[5], bit[16]) Addi - -function clause decode (0b001110 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) SI as instr) = - Addi (RT,RA,SI) - -function clause execute (Addi (RT, RA, SI)) = - if RA == 0 then GPR[RT] := EXTS (SI) else GPR[RT] := GPR[RA] + EXTS (SI) - -union ast member (bit[5], bit[5], bit[16]) Addis - -function clause decode (0b001111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) SI as instr) = - Addis (RT,RA,SI) - -function clause execute (Addis (RT, RA, SI)) = - if RA == 0 - then GPR[RT] := EXTS (SI : 0b0000000000000000) - else GPR[RT] := GPR[RA] + EXTS (SI : 0b0000000000000000) - -union ast member (bit[5], bit[5], bit[5], bit, bit) Add - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -[OE] : -0b100001010 : -[Rc] as instr) = - Add (RT,RA,RB,OE,Rc) - -function clause execute (Add (RT, RA, RB, OE, Rc)) = - let (temp, overflow, _) = (GPR[RA] +_s GPR[RB]) in - { - GPR[RT] := temp; - if Rc - then { - xer_so := XER.SO; - if OE & overflow then xer_so := overflow else (); - set_overflow_cr0 (temp,xer_so) - } - else (); - if OE then set_SO_OV (overflow) else () - } - -union ast member (bit[5], bit[5], bit[5], bit, bit) Subf - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -[OE] : -0b000101000 : -[Rc] as instr) = - Subf (RT,RA,RB,OE,Rc) -function clause execute (Subf (RT, RA, RB, OE, Rc)) = - let (t, overflow, _) = (~ (GPR[RA]) +_s GPR[RB]) in - { - (bit[64]) temp := t + 1; - GPR[RT] := temp; - if Rc - then { - xer_so := XER.SO; - if OE & overflow then xer_so := overflow else (); - set_overflow_cr0 (temp,xer_so) - } - else (); - if OE then set_SO_OV (overflow) else () +union ast member ((bit[12]), regno, regno, bop) BTYPE +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b000 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BEQ)) +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b001 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BNE)) +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b100 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BLT)) +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b101 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BGE)) +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b110 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BLTU)) +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b111 : (bit[5]) imm5 : 0b1100011) = Some(BTYPE(imm7[6] : imm5[0] : imm7[5..0] : imm5[4..1], rs2, rs1, BGEU)) + +function clause execute (BTYPE(imm, rs2, rs1, op)) = + let rs1_val = rGPR(rs1) in + let rs2_val = rGPR(rs2) in + let taken = switch(op) { + case BEQ -> rs1_val == rs2_val + case BNE -> rs1_val != rs2_val + case BLT -> rs1_val <_s rs2_val + case BGE -> rs1_val >=_s rs2_val + case BLTU -> rs1_val <_u rs2_val + case BGEU -> rs1_val >= rs2_val (* XXX is this signed or unsigned? *) + } in + if (taken) then + nextPC := PC + EXTS(imm : 0b0) + +union ast member ((bit[12]), regno, regno, iop) ITYPE +function clause decode ((bit[12]) imm : (regno) rs1 : 0b000 : (regno) rd : 0b0010011) = Some(ITYPE(imm, rs1, rd, ADDI)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b010 : (regno) rd : 0b0010011) = Some(ITYPE(imm, rs1, rd, SLTI)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b011 : (regno) rd : 0b0010011) = Some(ITYPE(imm, rs1, rd, SLTIU)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b100 : (regno) rd : 0b0010011) = Some(ITYPE(imm, rs1, rd, XORI)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b110 : (regno) rd : 0b0010011) = Some(ITYPE(imm, rs1, rd, ORI)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b111 : (regno) rd : 0b0010011) = Some(ITYPE(imm, rs1, rd, ANDI)) +function clause execute (ITYPE (imm, rs1, rd, op)) = + let rs1_val = rGPR(rs1) in + let imm64 = (bit[64]) (EXTS(imm)) in + let (bit[64]) result = switch(op) { + case ADDI -> rs1_val + imm64 + case SLTI -> EXTZ(rs1_val <_s imm64) + case SLTIU -> EXTZ(rs1_val <_u imm64) + case XORI -> rs1_val ^ imm64 + case ORI -> rs1_val | imm64 + case ANDI -> rs1_val & imm64 + } in + wGPR(rd, result) + +union ast member ((bit[6]), regno, regno, sop) SHIFTIOP +function clause decode (0b000000 : (bit[6]) shamt : (regno) rs1 : 0b001 : (regno) rd : 0b0010011) = Some(SHIFTIOP(shamt, rs1, rd, SLLI)) +function clause decode (0b000000 : (bit[6]) shamt : (regno) rs1 : 0b101 : (regno) rd : 0b0010011) = Some(SHIFTIOP(shamt, rs1, rd, SRLI)) +function clause decode (0b010000 : (bit[6]) shamt : (regno) rs1 : 0b101 : (regno) rd : 0b0010011) = Some(SHIFTIOP(shamt, rs1, rd, SRAI)) +function clause execute (SHIFTIOP(shamt, rs1, rd, op)) = + let rs1_val = rGPR(rs1) in + let result = switch(op) { + case SLLI -> rs1_val >> shamt + case SRLI -> rs1_val << shamt + case SRAI -> shift_right_arith64(rs1_val, shamt) + } in + wGPR(rd, result) + +union ast member (regno, regno, regno, rop) RTYPE +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b000 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, ADD)) +function clause decode (0b0100000 : (regno) rs2 : (regno) rs1 : 0b000 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, SUB)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b001 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, SLL)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b010 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, SLT)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b011 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, SLTU)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b100 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, XOR)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b101 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, SRL)) +function clause decode (0b0100000 : (regno) rs2 : (regno) rs1 : 0b101 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, SRA)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b110 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, OR)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b111 : (regno) rd : 0b0110011) = Some(RTYPE(rs2, rs1, rd, AND)) +function clause execute (RTYPE(rs2, rs1, rd, op)) = + let rs1_val = rGPR(rs1) in + let rs2_val = rGPR(rs2) in + let (bit[64]) result = switch(op) { + case ADD -> rs1_val + rs2_val + case SUB -> rs1_val - rs2_val + case SLL -> rs1_val << (rs2_val[5..0]) + case SLT -> EXTZ(rs1_val <_s rs2_val) + case SLTU -> EXTZ(rs1_val <_u rs2_val) + case XOR -> rs1_val ^ rs2_val + case SRL -> rs1_val >> (rs2_val[5..0]) + case SRA -> shift_right_arith64(rs1_val, rs2_val[5..0]) + case OR -> rs1_val | rs2_val + case AND -> rs1_val & rs2_val + } in + wGPR(rd, result) + +union ast member ((bit[12]), regno, regno, bool, [|8|]) LOAD +function clause decode ((bit[12]) imm : (regno) rs1 : 0b000 : (regno) rd : 0b0000011) = Some(LOAD(imm, rs1, rd, false, 1)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b001 : (regno) rd : 0b0000011) = Some(LOAD(imm, rs1, rd, false, 2)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b010 : (regno) rd : 0b0000011) = Some(LOAD(imm, rs1, rd, false, 4)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b011 : (regno) rd : 0b0000011) = Some(LOAD(imm, rs1, rd, false, 8)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b100 : (regno) rd : 0b0000011) = Some(LOAD(imm, rs1, rd, true, 1)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b101 : (regno) rd : 0b0000011) = Some(LOAD(imm, rs1, rd, true, 2)) +function clause decode ((bit[12]) imm : (regno) rs1 : 0b110 : (regno) rd : 0b0000011) = Some(LOAD(imm, rs1, rd, true, 4)) +function clause execute(LOAD(imm, rs1, rd, unsigned, width)) = + let (bit[64]) addr = rGPR(rs1) + EXTS(imm) in + let (bit[64]) result = if unsigned then + EXTZ(MEMr(addr, width)) + else + EXTS(MEMr(addr, width)) in + wGPR(rd, result) + +union ast member ((bit[12]), regno, regno, [|8|]) STORE +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b000 : (bit[5]) imm5 : 0b0100011) = Some(STORE(imm7 : imm5, rs2, rs1, 1)) +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b001 : (bit[5]) imm5 : 0b0100011) = Some(STORE(imm7 : imm5, rs2, rs1, 2)) +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b010 : (bit[5]) imm5 : 0b0100011) = Some(STORE(imm7 : imm5, rs2, rs1, 4)) +function clause decode ((bit[7]) imm7 : (regno) rs2 : (regno) rs1 : 0b011 : (bit[5]) imm5 : 0b0100011) = Some(STORE(imm7 : imm5, rs2, rs1, 8)) +function clause execute (STORE(imm, rs2, rs1, width)) = + let (bit[64]) addr = rGPR(rs1) + EXTS(imm) in { + MEMea(addr, width); + MEMval(addr, width, rGPR(rs2)); } -union ast member (bit[5], bit[5], bit[16]) Addic - -function clause decode (0b001100 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) SI as instr) = - Addic (RT,RA,SI) - -function clause execute (Addic (RT, RA, SI)) = - let (temp, _, carry) = (GPR[RA] +_s EXTS (SI)) in - { - GPR[RT] := temp; - CA := carry - } - -union ast member (bit[5], bit[5], bit[16]) AddicDot - -function clause decode (0b001101 : -(bit[5]) RT : -(bit[5]) RA : -(bit[16]) SI as instr) = - AddicDot (RT,RA,SI) - -function clause execute (AddicDot (RT, RA, SI)) = - let (temp, overflow, carry) = (GPR[RA] +_s EXTS (SI)) in - { - GPR[RT] := temp; - CA := carry; - set_overflow_cr0 (temp,overflow) - } - -union ast member (bit[5], bit[5], bit, bit) Neg - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) _ : -[OE] : -0b001101000 : -[Rc] as instr) = - Neg (RT,RA,OE,Rc) - -function clause execute (Neg (RT, RA, OE, Rc)) = - let (temp, overflow, _) = (~ (GPR[RA]) +_s (bit) 1) in - { - GPR[RT] := temp; - if Rc then set_overflow_cr0 (temp,XER.SO) else () - } - -union ast member (bit[5], bit[5], bit[5], bit, bit) Mullw - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[5]) RA : -(bit[5]) RB : -[OE] : -0b011101011 : -[Rc] as instr) = - Mullw (RT,RA,RB,OE,Rc) - -function clause execute (Mullw (RT, RA, RB, OE, Rc)) = - let (prod, overflow, _) = ((GPR[RA])[32 .. 63] *_s (GPR[RB])[32 .. 63]) in - { - GPR[RT] := prod; - if Rc - then { - xer_so := XER.SO; - if OE & overflow then xer_so := overflow else (); - set_overflow_cr0 (prod,xer_so) - } - else (); - if OE then set_SO_OV (overflow) else () - } - -union ast member (bit[3], bit, bit[5], bit[16]) Cmpi - -function clause decode (0b001011 : -(bit[3]) BF : -(bit[1]) _ : -[L] : -(bit[5]) RA : -(bit[16]) SI as instr) = - Cmpi (BF,L,RA,SI) - -function clause execute (Cmpi (BF, L, RA, SI)) = - { - (bit[64]) a := 0; - if L == 0 then a := EXTS ((GPR[RA])[32 .. 63]) else a := GPR[RA]; - if a < EXTS (SI) - then c := 0b100 - else if a > EXTS (SI) then c := 0b010 else c := 0b001; - CR[4 * BF + 32..4 * BF + 35] := c : [XER.SO] - } - -union ast member (bit[3], bit, bit[5], bit[5]) Cmp - -function clause decode (0b011111 : -(bit[3]) BF : -(bit[1]) _ : -[L] : -(bit[5]) RA : -(bit[5]) RB : -0b0000000000 : -(bit[1]) _ as instr) = - Cmp (BF,L,RA,RB) - -function clause execute (Cmp (BF, L, RA, RB)) = - { - (bit[64]) a := 0; - (bit[64]) b := 0; - if L == 0 - then { - a := EXTS ((GPR[RA])[32 .. 63]); - b := EXTS ((GPR[RB])[32 .. 63]) - } - else { - a := GPR[RA]; - b := GPR[RB] - }; - if a < b then c := 0b100 else if a > b then c := 0b010 else c := 0b001; - CR[4 * BF + 32..4 * BF + 35] := c : [XER.SO] - } - -union ast member (bit[5], bit[5], bit[16]) Ori - -function clause decode (0b011000 : -(bit[5]) RS : -(bit[5]) RA : -(bit[16]) UI as instr) = - Ori (RS,RA,UI) - -function clause execute (Ori (RS, RA, UI)) = - GPR[RA] := (GPR[RS] | 0b000000000000000000000000000000000000000000000000 : UI) - -union ast member (bit[5], bit[5], bit[16]) Oris - -function clause decode (0b011001 : -(bit[5]) RS : -(bit[5]) RA : -(bit[16]) UI as instr) = - Oris (RS,RA,UI) - -function clause execute (Oris (RS, RA, UI)) = - GPR[RA] := - (GPR[RS] | 0b00000000000000000000000000000000 : UI : 0b0000000000000000) - -union ast member (bit[5], bit[5], bit[16]) Xori - -function clause decode (0b011010 : -(bit[5]) RS : -(bit[5]) RA : -(bit[16]) UI as instr) = - Xori (RS,RA,UI) - -function clause execute (Xori (RS, RA, UI)) = - GPR[RA] := GPR[RS] ^ 0b000000000000000000000000000000000000000000000000 : UI - -union ast member (bit[5], bit[5], bit[5], bit) Or - -function clause decode (0b011111 : -(bit[5]) RS : -(bit[5]) RA : -(bit[5]) RB : -0b0110111100 : -[Rc] as instr) = - Or (RS,RA,RB,Rc) - -function clause execute (Or (RS, RA, RB, Rc)) = - { - (bit[64]) temp := (GPR[RS] | GPR[RB]); - GPR[RA] := temp; - if Rc then set_overflow_cr0 (temp,XER.SO) else () - } - -union ast member (bit[5], bit[5], bit) Extsw - -function clause decode (0b011111 : -(bit[5]) RS : -(bit[5]) RA : -(bit[5]) _ : -0b1111011010 : -[Rc] as instr) = - Extsw (RS,RA,Rc) - -function clause execute (Extsw (RS, RA, Rc)) = - { - s := (GPR[RS])[32]; - (bit[64]) temp := 0; - temp := (GPR[RS])[32 .. 63]; - temp := s ^^ 32; - if Rc then set_overflow_cr0 (temp,XER.SO) else (); - GPR[RA] := temp - } - -union ast member (bit[5], bit[5], bit[6], bit[6], bit) Rldicr - -function clause decode (0b011110 : -(bit[5]) RS : -(bit[5]) RA : -(bit[5]) _ : -(bit[6]) me : -0b001 : -(bit[1]) _ : -[Rc] as instr) = - Rldicr (RS,RA,instr[16 .. 20] : instr[30 .. 30],me,Rc) - -function clause execute (Rldicr (RS, RA, sh, me, Rc)) = - { - n := [sh[5]] : sh[0 .. 4]; - r := ROTL (GPR[RS],n); - e := [me[5]] : me[0 .. 4]; - m := MASK (0,e); - (bit[64]) temp := (r & m); - GPR[RA] := temp; - if Rc then set_overflow_cr0 (temp,XER.SO) else () - } - -union ast member (bit[5], bit[10]) Mtspr - -function clause decode (0b011111 : -(bit[5]) RS : -(bit[10]) spr : -0b0111010011 : -(bit[1]) _ as instr) = - Mtspr (RS,spr) - -function clause execute (Mtspr (RS, spr)) = - { - n := spr[5 .. 9] : spr[0 .. 4]; - if n == 13 - then trap () - else if length (SPR[n]) == 64 - then SPR[n] := GPR[RS] - else SPR[n] := (GPR[RS])[32 .. 63] - } - -union ast member (bit[5], bit[10]) Mfspr - -function clause decode (0b011111 : -(bit[5]) RT : -(bit[10]) spr : -0b0101010011 : -(bit[1]) _ as instr) = - Mfspr (RT,spr) - -function clause execute (Mfspr (RT, spr)) = - { - n := spr[5 .. 9] : spr[0 .. 4]; - if length (SPR[n]) == 64 - then GPR[RT] := SPR[n] - else GPR[RT] := 0b00000000000000000000000000000000 : SPR[n] - } - -union ast member (bit[5], bit[8]) Mtcrf - -function clause decode (0b011111 : -(bit[5]) RS : -0b0 : -(bit[8]) FXM : -(bit[1]) _ : -0b0010010000 : -(bit[1]) _ as instr) = - Mtcrf (RS,FXM) - -function clause execute (Mtcrf (RS, FXM)) = - { - mask := - (FXM[0] ^^ 4) : - (FXM[1] ^^ 4) : - (FXM[2] ^^ 4) : - (FXM[3] ^^ 4) : - (FXM[4] ^^ 4) : (FXM[5] ^^ 4) : (FXM[6] ^^ 4) : (FXM[7] ^^ 4); - CR := - ((bit[32]) ((GPR[RS])[32 .. 63] & mask) | - (bit[32]) (CR & ~ ((bit[32]) mask))) - } - -union ast member (bit[5]) Mfcr - -function clause decode (0b011111 : -(bit[5]) RT : -0b0 : -(bit[9]) _ : -0b0000010011 : -(bit[1]) _ as instr) = - Mfcr (RT) - -function clause execute (Mfcr (RT)) = - GPR[RT] := 0b00000000000000000000000000000000 : CR - -union ast member (bit[5], bit[5], bit[16]) Lfsu - -function clause decode (0b110001 : -(bit[5]) FRT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lfsu (FRT,RA,D) - -function clause execute (Lfsu (FRT, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - FPR[FRT] := DOUBLE (MEMr (EA,4)); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Lfsux - -function clause decode (0b011111 : -(bit[5]) FRT : -(bit[5]) RA : -(bit[5]) RB : -0b1000110111 : -(bit[1]) _ as instr) = - Lfsux (FRT,RA,RB) - -function clause execute (Lfsux (FRT, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - FPR[FRT] := DOUBLE (MEMr (EA,4)); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Lfd - -function clause decode (0b110010 : -(bit[5]) FRT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lfd (FRT,RA,D) - -function clause execute (Lfd (FRT, RA, D)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (D); - FPR[FRT] := MEMr (EA,8) - } - -union ast member (bit[5], bit[5], bit[16]) Lfdu - -function clause decode (0b110011 : -(bit[5]) FRT : -(bit[5]) RA : -(bit[16]) D as instr) = - Lfdu (FRT,RA,D) - -function clause execute (Lfdu (FRT, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - FPR[FRT] := MEMr (EA,8); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Lfdux - -function clause decode (0b011111 : -(bit[5]) FRT : -(bit[5]) RA : -(bit[5]) RB : -0b1001110111 : -(bit[1]) _ as instr) = - Lfdux (FRT,RA,RB) - -function clause execute (Lfdux (FRT, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - FPR[FRT] := MEMr (EA,8); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Stfsu - -function clause decode (0b110101 : -(bit[5]) FRS : -(bit[5]) RA : -(bit[16]) D as instr) = - Stfsu (FRS,RA,D) - -function clause execute (Stfsu (FRS, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - MEMw(EA,4) := SINGLE (FPR[FRS]); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Stfsux - -function clause decode (0b011111 : -(bit[5]) FRS : -(bit[5]) RA : -(bit[5]) RB : -0b1010110111 : -(bit[1]) _ as instr) = - Stfsux (FRS,RA,RB) - -function clause execute (Stfsux (FRS, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - MEMw(EA,4) := SINGLE (FPR[FRS]); - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[16]) Stfd - -function clause decode (0b110110 : -(bit[5]) FRS : -(bit[5]) RA : -(bit[16]) D as instr) = - Stfd (FRS,RA,D) - -function clause execute (Stfd (FRS, RA, D)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (D); - MEMw(EA,8) := FPR[FRS] - } - -union ast member (bit[5], bit[5], bit[16]) Stfdu - -function clause decode (0b110111 : -(bit[5]) FRS : -(bit[5]) RA : -(bit[16]) D as instr) = - Stfdu (FRS,RA,D) - -function clause execute (Stfdu (FRS, RA, D)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + EXTS (D); - MEMw(EA,8) := FPR[FRS]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[5]) Stfdux - -function clause decode (0b011111 : -(bit[5]) FRS : -(bit[5]) RA : -(bit[5]) RB : -0b1011110111 : -(bit[1]) _ as instr) = - Stfdux (FRS,RA,RB) - -function clause execute (Stfdux (FRS, RA, RB)) = - { - (bit[64]) EA := 0; - EA := GPR[RA] + GPR[RB]; - MEMw(EA,8) := FPR[FRS]; - GPR[RA] := EA - } - -union ast member (bit[5], bit[5], bit[14]) Lfdp - -function clause decode (0b111001 : -(bit[5]) FRTp : -(bit[5]) RA : -(bit[14]) DS : -0b0 : -0b0 as instr) = - Lfdp (FRTp,RA,DS) - -function clause execute (Lfdp (FRTp, RA, DS)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (DS : 0b00); - FPRp[FRTp] := MEMr (EA,16) - } - -union ast member (bit[5], bit[5], bit[5]) Lfdpx - -function clause decode (0b011111 : -(bit[5]) FRTp : -(bit[5]) RA : -(bit[5]) RB : -0b1100010111 : -(bit[1]) _ as instr) = - Lfdpx (FRTp,RA,RB) - -function clause execute (Lfdpx (FRTp, RA, RB)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + GPR[RB]; - FPRp[FRTp] := MEMr (EA,16) - } - -union ast member (bit[5], bit[5], bit[14]) Stfdp - -function clause decode (0b111101 : -(bit[5]) FRSp : -(bit[5]) RA : -(bit[14]) DS : -0b0 : -0b0 as instr) = - Stfdp (FRSp,RA,DS) - -function clause execute (Stfdp (FRSp, RA, DS)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + EXTS (DS : 0b00); - MEMw(EA,16) := FPRp[FRSp] - } - -union ast member (bit[5], bit[5], bit[5]) Stfdpx - -function clause decode (0b011111 : -(bit[5]) FRSp : -(bit[5]) RA : -(bit[5]) RB : -0b1110010111 : -(bit[1]) _ as instr) = - Stfdpx (FRSp,RA,RB) - -function clause execute (Stfdpx (FRSp, RA, RB)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + GPR[RB]; - MEMw(EA,16) := FPRp[FRSp] - } - -union ast member (bit[5], bit) Mffs - -function clause decode (0b111111 : -(bit[5]) FRT : -(bit[5]) _ : -(bit[5]) _ : -0b1001000111 : -[Rc] as instr) = - Mffs (FRT,Rc) - -function clause execute (Mffs (FRT, Rc)) = () - -union ast member (bit[3], bit[3]) Mcrfs - -function clause decode (0b111111 : -(bit[3]) BF : -(bit[2]) _ : -(bit[3]) BFA : -(bit[2]) _ : -(bit[5]) _ : -0b0001000000 : -(bit[1]) _ as instr) = - Mcrfs (BF,BFA) - -function clause execute (Mcrfs (BF, BFA)) = () - -union ast member (bit[5], bit[5], bit[5]) Lvx - -function clause decode (0b011111 : -(bit[5]) VRT : -(bit[5]) RA : -(bit[5]) RB : -0b0001100111 : -(bit[1]) _ as instr) = - Lvx (VRT,RA,RB) - -function clause execute (Lvx (VRT, RA, RB)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + GPR[RB]; - VR[VRT] := - MEMr - (EA & 0b1111111111111111111111111111111111111111111111111111111111110000,16) - } - -union ast member (bit[5], bit[5], bit[5]) Stvx - -function clause decode (0b011111 : -(bit[5]) VRS : -(bit[5]) RA : -(bit[5]) RB : -0b0011100111 : -(bit[1]) _ as instr) = - Stvx (VRS,RA,RB) - -function clause execute (Stvx (VRS, RA, RB)) = - { - (bit[64]) b := 0; - (bit[64]) EA := 0; - if RA == 0 then b := 0 else b := GPR[RA]; - EA := b + GPR[RB]; - MEMw(EA & 0b1111111111111111111111111111111111111111111111111111111111110000,16) := - VR[VRS] - } - -union ast member (bit[5]) Mtvscr - -function clause decode (0b000100 : -(bit[10]) _ : -(bit[5]) VRB : -0b11001000100 as instr) = - Mtvscr (VRB) - -function clause execute (Mtvscr (VRB)) = VSCR := (VR[VRB])[96 .. 127] - -union ast member (bit[5]) Mfvscr - -function clause decode (0b000100 : -(bit[5]) VRT : -(bit[10]) _ : -0b11000000100 as instr) = - Mfvscr (VRT) - -function clause execute (Mfvscr (VRT)) = - VR[VRT] := - 0b000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 : - VSCR - -union ast member (bit[2]) Sync - -function clause decode (0b011111 : -(bit[3]) _ : -(bit[2]) L : -(bit[5]) _ : -(bit[5]) _ : -0b1001010110 : -(bit[1]) _ as instr) = - Sync (L) - -function clause execute (Sync (L)) = - switch L { case 0b00 -> { H_Sync (()) } case 0b01 -> { LW_Sync (()) } } - -union ast member (bit[5]) Mbar - -function clause decode (0b011111 : -(bit[5]) MO : -(bit[5]) _ : -(bit[5]) _ : -0b1101010110 : -(bit[1]) _ as instr) = - Mbar (MO) - -function clause execute (Mbar (MO)) = () - - -typedef decode_failure = enumerate { no_matching_pattern; unsupported_instruction; illegal_instruction } - -function clause decode _ = exit no_matching_pattern +union ast member ((bit[12]), regno, regno) ADDIW +function clause decode ((bit[12]) imm : (regno) rs1 : 0b000 : (regno) rd : 0b0011011) = Some(ADDIW(imm, rs1, rd)) +function clause execute (ADDIW(imm, rs1, rd)) = + let (bit[64]) imm64 = EXTS(imm) in + let (bit[64]) result64 = imm64 + rGPR(rs1) in + let (bit[64]) result32 = EXTS(result64[31..0]) in + wGPR(rd, result32) + +union ast member ((bit[5]), regno, regno, sop) SHIFTW +function clause decode (0b0000000 : (bit[5]) shamt : (regno) rs1 : 0b001 : (regno) rd : 0b0011011) = Some(SHIFTW(shamt, rs1, rd, SLLI)) +function clause decode (0b0000000 : (bit[5]) shamt : (regno) rs1 : 0b101 : (regno) rd : 0b0011011) = Some(SHIFTW(shamt, rs1, rd, SRLI)) +function clause decode (0b0100000 : (bit[5]) shamt : (regno) rs1 : 0b101 : (regno) rd : 0b0011011) = Some(SHIFTW(shamt, rs1, rd, SRAI)) +function clause execute (SHIFTW(shamt, rs1, rd, op)) = + let rs1_val = (rGPR(rs1))[31..0] in + let result = switch(op) { + case SLLI -> rs1_val >> shamt + case SRLI -> rs1_val << shamt + case SRAI -> shift_right_arith32(rs1_val, shamt) + } in + wGPR(rd, EXTS(result)) + +union ast member (regno, regno, regno, ropw) RTYPEW +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b000 : (regno) rd : 0b0111011) = Some(RTYPEW(rs2, rs1, rd, ADDW)) +function clause decode (0b0100000 : (regno) rs2 : (regno) rs1 : 0b000 : (regno) rd : 0b0111011) = Some(RTYPEW(rs2, rs1, rd, SUBW)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b001 : (regno) rd : 0b0111011) = Some(RTYPEW(rs2, rs1, rd, SLLW)) +function clause decode (0b0000000 : (regno) rs2 : (regno) rs1 : 0b101 : (regno) rd : 0b0111011) = Some(RTYPEW(rs2, rs1, rd, SRLW)) +function clause decode (0b0100000 : (regno) rs2 : (regno) rs1 : 0b101 : (regno) rd : 0b0111011) = Some(RTYPEW(rs2, rs1, rd, SRAW)) +function clause execute (RTYPEW(rs2, rs1, rd, op)) = + let rs1_val = (rGPR(rs1))[31..0] in + let rs2_val = (rGPR(rs2))[31..0] in + let (bit[32]) result = switch(op) { + case ADDW -> rs1_val + rs2_val + case SUBW -> rs1_val - rs2_val + case SLLW -> rs1_val << (rs2_val[4..0]) + case SRLW -> rs1_val >> (rs2_val[4..0]) + case SRAW -> shift_right_arith32(rs1_val, rs2_val[4..0]) + } in + wGPR(rd, EXTS(result)) +end ast end decode end execute -end ast - -val ast -> ast effect pure supported_instructions -function ast supported_instructions ((ast) instr) = { - switch instr { - case (Mbar(_)) -> exit unsupported_instruction - case (Sync(0b10)) -> exit unsupported_instruction - case (Sync(0b11)) -> exit unsupported_instruction - case _ -> instr - } -} - -val ast -> bit effect pure illegal_instructions_pred -function bit illegal_instructions_pred ((ast) instr) = { - switch instr { - case (Bcctr(BO,BI,BH,LK)) -> ~(BO[2]) - case (Lbzu(RT,RA,D)) -> (RA == 0) | (RA == RT) - case (Lbzux(RT,RA,_)) ->(RA == 0) | (RA == RT) - case (Lhzu(RT,RA,D)) -> (RA == 0) | (RA == RT) - case (Lhzux(RT,RA,RB)) -> (RA == 0) | (RA == RT) - case (Lhau(RT,RA,D)) -> (RA == 0) | (RA == RT) - case (Lhaux(RT,RA,RB)) -> (RA == 0) | (RA == RT) - case (Lwzu(RA,RT,D)) -> (RA == 0) | (RA == RT) - case (Lwzux(RT,RA,RB)) -> (RA == 0) | (RA == RT) - case (Lwaux(RA,RT,RB)) -> (RA == 0) | (RA == RT) - case (Ldu(RT,RA,DS)) -> (RA == 0) | (RA == RT) - case (Ldux(RT,RA,RB)) -> (RA == 0) | (RA == RT) - case (Stbu(RS,RA,D)) -> (RA == 0) - case (Stbux(RS,RA,RB)) -> (RA == 0) - case (Sthu(RS,RA,RB)) -> (RA == 0) - case (Sthux(RS,RA,RB)) -> (RA == 0) - case (Stwu(RS,RA,D)) -> (RA == 0) - case (Stwux(RS,RA,RB)) -> (RA == 0) - case (Stdu(RS,RA,DS)) -> (RA == 0) - case (Stdux(RS,RA,RB)) -> (RA == 0) - case (Lmw(RT,RA,D)) -> (RA == 0) | ((RT <= RA) & (RA <= 31)) - case (Lswi(RT,RA,NB)) -> - let (([|32|]) n) = (if ~(NB == 0) then NB else 32) in - let ceil = - (if (n mod 4) == 0 - then n quot 4 else (n quot 4) + 1) in - (RT <= RA) & (RA <= ((bit[5]) (((bit[5]) (RT + ceil)) -1))) - (* Can't read XER at the time meant, so will need to rethink *) - (* case (Lswx(RT,RA,RB)) -> - let (([|32|]) n) = (XER[57..63]) in - let ceil = - (if (n mod 4 == 0) - then n quot 4 else (n quot 4) + 1) in - let ((bit[5]) upper_bound) = (RT + ceil) in - (RT <= RA & RA <= upper_bound) | - (RT <= RB & RB <= upper_bound) | - (RT == RA) | (RT == RB)*) - case (Lfsu(FRT,RA,D)) -> (RA == 0) - case (Lfsux(FRT,RA,RB)) -> (RA == 0) - case (Lfdu(FRT,RA,D)) -> (RA == 0) - case (Lfdux(FRT,RA,RB)) -> (RA == 0) - case (Stfsu(FRS,RA,D)) -> (RA == 0) - case (Stfsux(FRS,RA,RB)) -> (RA == 0) - case (Stfdu(FRS,D,RA)) -> (RA == 0) - case (Stfdux(FRS,RA,RB)) -> (RA == 0) - case (Lfdp(FRTp,RA,DS)) -> (FRTp mod 2 == 1) - case (Stfdp(FRSp,RA,DS)) -> (FRSp mod 2 == 1) - case (Lfdpx(FRTp,RA,RB)) -> (FRTp mod 2 == 1) - case (Stfdpx(FRSp,RA,RB)) -> (FRSp mod 2 == 1) - case (Lq(RTp,RA,DQ,Pt)) -> ((RTp mod 2 ==1) | RTp == RA) - case (Stq(RSp,RA,RS)) -> (RSp mod 2 == 1) - case (Mtspr(RS, spr)) -> - ~ ((spr == 1) | (spr == 8) | (spr == 9) | (spr == 256) | - (spr == 512) | (spr == 896) | (spr == 898)) -(*One of these causes a stack overflow error, don't want to debug why now*) - (*case (Mfspr(RT, spr)) -> - ~ ((spr == 1) | (spr == 8) | (spr == 9) | (spr == 136) | - (spr == 256) | (spr == 259) | (spr == 260) | (spr == 261) | - (spr == 262) | (spr == 263) | (spr == 268) | (spr == 268) | - (spr == 269) | (spr == 512) | (spr == 526) | (spr == 526) | - (spr == 527) | (spr == 896) | (spr == 898)) - case (Se_illegal) -> true - case (E_lhau(RT,RA,D8)) -> (RA == 0 | RA == RT) - case (E_Lhzu(RT,RA,D8)) -> (RA == 0 | RA == RT) - case (E_lwzu(RT,RA,D8)) -> (RA == 0 | RA == RT) - case (E_stbu(RS,RA,D8)) -> (RA == 0) - case (E_sthu(RS,RA,D8)) -> (RA == 0) - case (E_stwu(RS,RA,D8)) -> (RA == 0) - case (E_lmw(RT,RA,D8)) -> (RT <= RA & RA <= 31)*) - case _ -> false - } -} - -val ast -> ast effect pure illegal_instructions -function ast illegal_instructions ((ast) instr) = - if (illegal_instructions_pred ((ast) instr)) - then exit illegal_instruction else instr - -(* fetch-decode-execute *) -function unit fde () = { - NIA := CIA + 4; - instr := decode(MEMr(CIA, 4)); - instr := supported_instructions(instr); - execute(instr); - CIA := NIA; -}*) @@ -175,6 +175,8 @@ n_constraint_aux = (* constraint over kind $_$ *) | NC_nat_set_bounded of kid * (int) list | NC_or of n_constraint * n_constraint | NC_and of n_constraint * n_constraint + | NC_true + | NC_false and n_constraint = diff --git a/src/ast_util.ml b/src/ast_util.ml index 67eedf72..2109175f 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -232,6 +232,8 @@ and string_of_n_constraint = function "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_nat_set_bounded (kid, ns), _) -> string_of_kid kid ^ " IN {" ^ string_of_list ", " string_of_int ns ^ "}" + | NC_aux (NC_true, _) -> "true" + | NC_aux (NC_false, _) -> "false" let string_of_quant_item_aux = function | QI_id (KOpt_aux (KOpt_none kid, _)) -> string_of_kid kid diff --git a/src/initial_check.ml b/src/initial_check.ml index 0e68ad81..b831e288 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -112,7 +112,7 @@ let typ_error l msg opt_id opt_var opt_kind = | None,Some(v),None -> ": " ^ (var_to_string v) | None,None,Some(kind) -> " " ^ (kind_to_string kind) | _ -> ""))) - + let to_ast_id (Parse_ast.Id_aux(id,l)) = Id_aux( (match id with | Parse_ast.Id(x) -> Id(x) @@ -144,16 +144,8 @@ let rec to_ast_typ (k_env : kind Envmap.t) (def_ord : order) (t: Parse_ast.atyp) match t with | Parse_ast.ATyp_aux(t,l) -> Typ_aux( (match t with - | Parse_ast.ATyp_id(id) -> - let id = to_ast_id id in - let mk = Envmap.apply k_env (id_to_string id) in - (match mk with - | Some(k) -> (match k.k with - | K_Typ -> Typ_id id - | K_infer -> k.k <- K_Typ; Typ_id id - | _ -> typ_error l "Required an identifier with kind Type, encountered " (Some id) None (Some k)) - | None -> typ_error l "Encountered an unbound type identifier" (Some id) None None) - | Parse_ast.ATyp_var(v) -> + | Parse_ast.ATyp_id(id) -> Typ_id (to_ast_id id) + | Parse_ast.ATyp_var(v) -> let v = to_ast_var v in let mk = Envmap.apply k_env (var_to_string v) in (match mk with @@ -1010,6 +1002,6 @@ let initial_kind_env = ("implicit", {k = K_Lam( [{k = K_Nat}], {k=K_Typ})} ); ] -let process_ast defs = - let (ast, _, _) = to_ast Nameset.empty initial_kind_env (Ast.Ord_aux(Ast.Ord_inc,Parse_ast.Unknown)) defs in +let process_ast order defs = + let (ast, _, _) = to_ast Nameset.empty initial_kind_env order defs in ast diff --git a/src/initial_check.mli b/src/initial_check.mli index 063a0131..ed4eb0bf 100644 --- a/src/initial_check.mli +++ b/src/initial_check.mli @@ -42,7 +42,5 @@ open Ast -val process_ast : Parse_ast.defs -> unit defs +val process_ast : order -> Parse_ast.defs -> unit defs - - diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 7bfc3a3d..63be60b2 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -77,9 +77,10 @@ let make_vectors sz = -(* Based on current type checker's behaviour *) let pat_id_is_variable env id = match Env.lookup_id id env with + (* Unbound is returned for both variables and constructors which take + arguments, but the latter only don't appear in a P_id *) | Unbound (* Shadowing of immutable locals is allowed; mutable locals and registers are rejected by the type checker, so don't matter *) @@ -90,21 +91,24 @@ let pat_id_is_variable env id = | Union _ -> false - let rec is_value (E_aux (e,(l,annot))) = + let is_constructor id = + match annot with + | None -> + (Reporting_basic.print_err false true l "Monomorphisation" + ("Missing type information for identifier " ^ string_of_id id); + false) (* Be conservative if we have no info *) + | Some (env,_,_) -> + Env.is_union_constructor id env || + (match Env.lookup_id id env with + | Enum _ | Union _ -> true + | Unbound | Local _ | Register _ -> false) + in match e with - | E_id id -> - (match annot with - | None -> - (Reporting_basic.print_err false true l "Monomorphisation" - ("Missing type information for identifier " ^ string_of_id id); - false) (* Be conservative if we have no info *) - | Some (env,_,_) -> - match Env.lookup_id id env with - | Enum _ | Union _ -> true - | Unbound | Local _ | Register _ -> false) + | E_id id -> is_constructor id | E_lit _ -> true | E_tuple es -> List.for_all is_value es + | E_app (id,es) -> is_constructor id && List.for_all is_value es (* TODO: more? *) | _ -> false @@ -294,6 +298,7 @@ let nexp_subst_fns substs refinements = | E_lit _ | E_comment _ -> re e | E_sizeof ne -> re (E_sizeof ne) (* TODO: does this need done? does it appear in type checked code? *) + | E_constraint _ -> re e (* TODO: actual substitution if necessary *) | E_internal_exp (l,annot) -> re (E_internal_exp (l, (*s_tannot*) annot)) | E_sizeof_internal (l,annot) -> re (E_sizeof_internal (l, (*s_tannot*) annot)) | E_internal_exp_user ((l1,annot1),(l2,annot2)) -> @@ -308,12 +313,15 @@ let nexp_subst_fns substs refinements = | _ -> E_aux (E_tuple es',(l,None)) in let id' = - match Env.lookup_id id (fst (env_typ_expected l annot)) with - | Union (qs,Typ_aux (Typ_fn(inty,outty,_),_)) -> - (match refine_constructor refinements id substs arg inty with - | None -> id - | Some id' -> id') - | _ -> id + let env,_ = env_typ_expected l annot in + if Env.is_union_constructor id env then + let (qs,ty) = Env.get_val_spec id env in + match ty with (Typ_aux (Typ_fn(inty,outty,_),_)) -> + (match refine_constructor refinements id substs arg inty with + | None -> id + | Some id' -> id') + | _ -> id + else id in re (E_app (id',es')) | E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2)) | E_tuple es -> re (E_tuple (List.map s_exp es)) @@ -395,6 +403,7 @@ let bindings_from_pat p = -> List.concat (List.map aux_pat ps) | P_record (fps,_) -> List.concat (List.map aux_fpat fps) | P_vector_indexed ips -> List.concat (List.map (fun (_,p) -> aux_pat p) ips) + | P_cons (p1,p2) -> aux_pat p1 @ aux_pat p2 and aux_fpat (FP_aux (FP_Fpat (_,p), _)) = aux_pat p in aux_pat p @@ -577,6 +586,7 @@ let split_defs splits defs = | E_sizeof_internal _ | E_internal_exp_user _ | E_comment _ + | E_constraint _ -> exp | E_cast (t,e') -> re (E_cast (t, const_prop_exp substs e')) | E_app (id,es) -> @@ -841,6 +851,10 @@ let split_defs splits defs = relist spl (fun ps -> P_tup ps) ps | P_list ps -> relist spl (fun ps -> P_list ps) ps + | P_cons (p1,p2) -> + match re (fun p' -> P_cons (p',p2)) p1 with + | Some r -> Some r + | None -> re (fun p' -> P_cons (p1,p')) p2 in spl p in @@ -861,8 +875,8 @@ let split_defs splits defs = let (_,variants) = List.find (fun (id',_) -> Id.compare id id' = 0) refinements in let env,_ = env_typ_expected l tannot in let constr_out_typ = - match Env.lookup_id id env with - | Union (qs,Typ_aux (Typ_fn(_,outt,_),_)) -> outt + match Env.get_val_spec id env with + | (qs,Typ_aux (Typ_fn(_,outt,_),_)) -> outt | _ -> raise (Reporting_basic.err_general l ("Constructor " ^ string_of_id id ^ " is not a construtor!")) in @@ -909,6 +923,7 @@ let split_defs splits defs = | E_sizeof_internal _ | E_internal_exp_user _ | E_comment _ + | E_constraint _ -> ea | E_cast (t,e') -> re (E_cast (t, map_exp e')) | E_app (id,es) -> re (E_app (id,List.map map_exp es)) diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 95ddc580..7adccfdf 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -239,7 +239,8 @@ let doc_lit_lem in_pat (L_aux(lit,l)) a = | Typ_id (Id_aux (Id "string", _)) -> "\"\"" | _ -> "(failwith \"undefined value of unsupported type\")") | _ -> "(failwith \"undefined value of unsupported type\")") - | L_string s -> "\"" ^ s ^ "\"") + | L_string s -> "\"" ^ s ^ "\"" + | L_real s -> s (* TODO What's the Lem syntax for reals? *)) (* typ_doc is the doc for the type being quantified *) @@ -257,16 +258,10 @@ let is_ctor env id = match Env.lookup_id id env with *) let rec doc_pat_lem regtypes apat_needed (P_aux (p,(l,annot)) as pa) = match p with | P_app(id, ((_ :: _) as pats)) -> - (match annot with - | Some (env, _, _) when (is_ctor env id) -> - let ppp = doc_unop (doc_id_lem_ctor id) - (parens (separate_map comma (doc_pat_lem regtypes true) pats)) in - if apat_needed then parens ppp else ppp - | _ -> empty) - | P_app(id,[]) -> - (match annot with - | Some (env, _, _) when (is_ctor env id) -> doc_id_lem_ctor id - | _ -> empty) + let ppp = doc_unop (doc_id_lem_ctor id) + (parens (separate_map comma (doc_pat_lem regtypes true) pats)) in + if apat_needed then parens ppp else ppp + | P_app(id,[]) -> doc_id_lem_ctor id | P_lit lit -> doc_lit_lem true lit annot | P_wild -> underscore | P_id id -> @@ -281,15 +276,14 @@ let rec doc_pat_lem regtypes apat_needed (P_aux (p,(l,annot)) as pa) = match p w [string "Vector";brackets (separate_map semi (doc_pat_lem regtypes true) pats);underscore;underscore] in if apat_needed then parens ppp else ppp | P_vector_concat pats -> - let ppp = - (separate space) - [string "Vector";parens (separate_map (string "::") (doc_pat_lem regtypes true) pats);underscore;underscore] in - if apat_needed then parens ppp else ppp + raise (Reporting_basic.err_unreachable l + "vector concatenation patterns should have been removed before pretty-printing") | P_tup pats -> (match pats with | [p] -> doc_pat_lem regtypes apat_needed p | _ -> parens (separate_map comma_sp (doc_pat_lem regtypes false) pats)) - | P_list pats -> brackets (separate_map semi (doc_pat_lem regtypes false) pats) (*Never seen but easy in lem*) + | P_list pats -> brackets (separate_map semi (doc_pat_lem regtypes false) pats) (*Never seen but easy in lem*) + | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem regtypes true p) (doc_pat_lem regtypes true p') | P_record (_,_) | P_vector_indexed _ -> empty (* TODO *) let rec contains_bitvector_typ (Typ_aux (t,_) as typ) = match t with @@ -926,7 +920,7 @@ let doc_exp_lem, doc_let_lem = | E_return _ -> raise (Reporting_basic.err_todo l "pretty-printing early return statements to Lem not yet supported") - | E_comment _ | E_comment_struc _ -> empty + | E_constraint _ | E_comment _ | E_comment_struc _ -> empty | E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ -> raise (Reporting_basic.err_unreachable l "unsupported internal expression encountered while pretty-printing") @@ -944,9 +938,13 @@ let doc_exp_lem, doc_let_lem = else doc_id_lem id in group (doc_op equals fname (top_exp regtypes true e)) - and doc_case regtypes (Pat_aux(Pat_exp(pat,e),_)) = + and doc_case regtypes = function + | Pat_aux(Pat_exp(pat,e),_) -> group (prefix 3 1 (separate space [pipe; doc_pat_lem regtypes false pat;arrow]) (group (top_exp regtypes false e))) + | Pat_aux(Pat_when(_,_,_),(l,_)) -> + raise (Reporting_basic.err_unreachable l + "guarded pattern expression should have been rewritten before pretty-printing") and doc_lexp_deref_lem regtypes ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with | LEXP_field (le,id) -> diff --git a/src/pretty_print_lem_ast.ml b/src/pretty_print_lem_ast.ml index 6809826a..0875aee7 100644 --- a/src/pretty_print_lem_ast.ml +++ b/src/pretty_print_lem_ast.ml @@ -219,12 +219,15 @@ let pp_lem_ord ppf o = base ppf (pp_format_ord_lem o) let pp_lem_effects ppf e = base ppf (pp_format_effects_lem e) let pp_lem_beffect ppf be = base ppf (pp_format_base_effect_lem be) -let pp_format_nexp_constraint_lem (NC_aux(nc,l)) = +let rec pp_format_nexp_constraint_lem (NC_aux(nc,l)) = "(NC_aux " ^ (match nc with | NC_fixed(n1,n2) -> "(NC_fixed " ^ pp_format_nexp_lem n1 ^ " " ^ pp_format_nexp_lem n2 ^ ")" | NC_bounded_ge(n1,n2) -> "(NC_bounded_ge " ^ pp_format_nexp_lem n1 ^ " " ^ pp_format_nexp_lem n2 ^ ")" | NC_bounded_le(n1,n2) -> "(NC_bounded_le " ^ pp_format_nexp_lem n1 ^ " " ^ pp_format_nexp_lem n2 ^ ")" + | NC_not_equal(n1,n2) -> "(NC_not_equal " ^ pp_format_nexp_lem n1 ^ " " ^ pp_format_nexp_lem n2 ^ ")" + | NC_or(nc1,nc2) -> "(NC_or " ^ pp_format_nexp_constraint_lem nc1 ^ " " ^ pp_format_nexp_constraint_lem nc2 ^ ")" + | NC_and(nc1,nc2) -> "(NC_and " ^ pp_format_nexp_constraint_lem nc1 ^ " " ^ pp_format_nexp_constraint_lem nc2 ^ ")" | NC_nat_set_bounded(id,bounds) -> "(NC_nat_set_bounded " ^ pp_format_var_lem id ^ " [" ^ @@ -278,7 +281,8 @@ let pp_format_lit_lem (L_aux(lit,l)) = | L_hex(n) -> "(L_hex \"" ^ n ^ "\")" | L_bin(n) -> "(L_bin \"" ^ n ^ "\")" | L_undef -> "L_undef" - | L_string(s) -> "(L_string \"" ^ s ^ "\")") ^ " " ^ + | L_string(s) -> "(L_string \"" ^ s ^ "\")" + | L_real(s) -> "(L_real \"" ^ s ^ "\")") ^ " " ^ (pp_format_l_lem l) ^ ")" let pp_lem_lit ppf l = base ppf (pp_format_lit_lem l) @@ -336,7 +340,8 @@ let rec pp_format_pat_lem (P_aux(p,(l,annot))) = "(P_vector_indexed [" ^ list_format "; " (fun (i,p) -> Printf.sprintf "(%d, %s)" i (pp_format_pat_lem p)) ipats ^ "])" | P_vector_concat(pats) -> "(P_vector_concat [" ^ list_format "; " pp_format_pat_lem pats ^ "])" | P_tup(pats) -> "(P_tup [" ^ (list_format "; " pp_format_pat_lem pats) ^ "])" - | P_list(pats) -> "(P_list [" ^ (list_format "; " pp_format_pat_lem pats) ^ "])") ^ + | P_list(pats) -> "(P_list [" ^ (list_format "; " pp_format_pat_lem pats) ^ "])" + | P_cons(pat,pat') -> "(P_cons " ^ pp_format_pat_lem pat ^ " " ^ pp_format_pat_lem pat' ^ ")") ^ " (" ^ pp_format_l_lem l ^ ", " ^ pp_format_annot annot ^ "))" let pp_lem_pat ppf p = base ppf (pp_format_pat_lem p) @@ -426,6 +431,8 @@ and pp_lem_exp ppf (E_aux(e,(l,annot))) = pp_lem_lexp lexp pp_lem_exp exp pp_lem_l l pp_annot annot | E_sizeof nexp -> fprintf ppf "@[<0>(E_aux (E_sizeof %a) (%a, %a))@]" pp_lem_nexp nexp pp_lem_l l pp_annot annot + | E_constraint nc -> + fprintf ppf "@[<0>(E_aux (E_constraint %a) (%a, %a))@]" pp_lem_nexp_constraint nc pp_lem_l l pp_annot annot | E_exit exp -> fprintf ppf "@[<0>(E_aux (E_exit %a) (%a, %a))@]" pp_lem_exp exp pp_lem_l l pp_annot annot | E_return exp -> @@ -476,8 +483,11 @@ and pp_lem_fexp ppf (FE_aux(FE_Fexp(id,exp),(l,annot))) = fprintf ppf "@[<1>(FE_aux (FE_Fexp %a %a) (%a, %a))@]" pp_lem_id id pp_lem_exp exp pp_lem_l l pp_annot annot and pp_semi_lem_fexp ppf fexp = fprintf ppf "@[<1>%a %a@]" pp_lem_fexp fexp kwd ";" -and pp_lem_case ppf (Pat_aux(Pat_exp(pat,exp),(l,annot))) = +and pp_lem_case ppf = function +| Pat_aux(Pat_exp(pat,exp),(l,annot)) -> fprintf ppf "@[<1>(Pat_aux (Pat_exp %a@ %a) (%a, %a))@]" pp_lem_pat pat pp_lem_exp exp pp_lem_l l pp_annot annot +| Pat_aux(Pat_when(pat,guard,exp),(l,annot)) -> + fprintf ppf "@[<1>(Pat_aux (Pat_exp %a@ %a %a) (%a, %a))@]" pp_lem_pat pat pp_lem_exp guard pp_lem_exp exp pp_lem_l l pp_annot annot and pp_semi_lem_case ppf case = fprintf ppf "@[<1>%a %a@]" pp_lem_case case kwd ";" and pp_lem_lexp ppf (LEXP_aux(lexp,(l,annot))) = diff --git a/src/pretty_print_ocaml.ml b/src/pretty_print_ocaml.ml index 652b0ce9..4f2c3ab0 100644 --- a/src/pretty_print_ocaml.ml +++ b/src/pretty_print_ocaml.ml @@ -140,7 +140,8 @@ let doc_lit_ocaml in_pat (L_aux(l,_)) = | L_hex n -> "(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*) | L_bin n -> "(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*) | L_undef -> "(failwith \"undef literal not supported\")" (* XXX Undef vectors get handled with to_vec_undef. We could support undef bit but would need to check type. For the moment treat as runtime error. *) - | L_string s -> "\"" ^ s ^ "\"") + | L_string s -> "\"" ^ s ^ "\"" + | L_real s -> s) (* typ_doc is the doc for the type being quantified *) let doc_typquant_ocaml (TypQ_aux(tq,_)) typ_doc = typ_doc @@ -170,7 +171,7 @@ let doc_pat_ocaml = | P_wild -> underscore | P_id id -> doc_id_ocaml id | P_as(p,id) -> parens (separate space [pat p; string "as"; doc_id_ocaml id]) - | P_typ(typ,p) -> doc_op colon (pat p) (doc_typ_ocaml typ) + | P_typ(typ,p) -> parens (doc_op colon (pat p) (doc_typ_ocaml typ)) | P_app(id,[]) -> (match annot with | Some (env, typ, eff) -> @@ -196,6 +197,7 @@ let doc_pat_ocaml = | None -> non_bit_print()) | P_tup pats -> parens (separate_map comma_sp pat pats) | P_list pats -> brackets (separate_map semi pat pats) (*Never seen but easy in ocaml*) + | P_cons (p,p') -> doc_op (string "::") (pat p) (pat p') | P_record _ -> raise (Reporting_basic.err_unreachable l "unhandled record pattern") | P_vector_indexed _ -> raise (Reporting_basic.err_unreachable l "unhandled vector_indexed pattern") | P_vector_concat _ -> raise (Reporting_basic.err_unreachable l "unhandled vector_concat pattern") @@ -467,6 +469,13 @@ let doc_exp_ocaml, doc_let_ocaml = separate space [string "return"; exp e1;] | E_assert (e1, e2) -> (string "assert") ^^ parens ((string "to_bool") ^^ space ^^ exp e1) (* XXX drops e2 *) + | E_sizeof _ -> raise (Reporting_basic.err_unreachable l + "E_sizeof should have been rewritten before pretty-printing") + | E_constraint _ -> empty + | E_sizeof_internal _ | E_internal_exp_user (_, _) | E_internal_cast (_, _) + | E_internal_exp _ -> raise (Reporting_basic.err_unreachable l + "internal expression should have been rewritten before pretty-printing") + | E_comment _ | E_comment_struc _ -> empty (* TODO Should we output comments? *) and let_exp (LB_aux(lb,_)) = match lb with | LB_val_explicit(ts,pat,e) -> prefix 2 1 @@ -479,8 +488,14 @@ let doc_exp_ocaml, doc_let_ocaml = and doc_fexp (FE_aux(FE_Fexp(id,e),_)) = doc_op equals (doc_id_ocaml id) (top_exp false e) - and doc_case (Pat_aux(Pat_exp(pat,e),_)) = + and doc_case = function + | (Pat_aux(Pat_exp(pat,e),_)) -> doc_op arrow (separate space [pipe; doc_pat_ocaml pat]) (group (top_exp false e)) + | (Pat_aux(Pat_when(pat,guard,e),_)) -> + doc_op arrow + (separate space [pipe; + doc_op (string "when") (doc_pat_ocaml pat) (top_exp false guard)]) + (group (top_exp false e)) and doc_lexp_ocaml top_call ((LEXP_aux(lexp,(l,annot))) as le) = let exp = top_exp false in diff --git a/src/process_file.ml b/src/process_file.ml index 0601bfab..c9a4f178 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -45,16 +45,17 @@ type out_type = | Lem_out of string option | Ocaml_out of string option -let get_lexbuf fn = - let lexbuf = Lexing.from_channel (open_in fn) in - lexbuf.Lexing.lex_curr_p <- { Lexing.pos_fname = fn; - Lexing.pos_lnum = 1; - Lexing.pos_bol = 0; - Lexing.pos_cnum = 0; }; - lexbuf +let get_lexbuf f = + let in_chan = open_in f in + let lexbuf = Lexing.from_channel in_chan in + lexbuf.Lexing.lex_curr_p <- { Lexing.pos_fname = f; + Lexing.pos_lnum = 1; + Lexing.pos_bol = 0; + Lexing.pos_cnum = 0; }; + lexbuf, in_chan let parse_file (f : string) : Parse_ast.defs = - let scanbuf = get_lexbuf f in + let scanbuf, in_chan = get_lexbuf f in let type_names = try Pre_parser.file Pre_lexer.token scanbuf @@ -67,25 +68,26 @@ let parse_file (f : string) : Parse_ast.defs = | Lexer.LexError(s,p) -> raise (Reporting_basic.Fatal_error (Reporting_basic.Err_lex (p, s))) in let () = Lexer.custom_type_names := !Lexer.custom_type_names @ type_names in - let lexbuf = get_lexbuf f in + close_in in_chan; + let lexbuf, in_chan = get_lexbuf f in try - Parser.file Lexer.token lexbuf + let ast = Parser.file Lexer.token lexbuf in + close_in in_chan; ast with | Parsing.Parse_error -> let pos = Lexing.lexeme_start_p lexbuf in - raise (Reporting_basic.Fatal_error (Reporting_basic.Err_syntax (pos, "main"))) + raise (Reporting_basic.Fatal_error (Reporting_basic.Err_syntax (pos, "main"))) | Parse_ast.Parse_error_locn(l,m) -> raise (Reporting_basic.Fatal_error (Reporting_basic.Err_syntax_locn (l, m))) | Lexer.LexError(s,p) -> raise (Reporting_basic.Fatal_error (Reporting_basic.Err_lex (p, s))) +let convert_ast (order : Ast.order) (defs : Parse_ast.defs) : unit Ast.defs = Initial_check.process_ast order defs -(*Should add a flag to say whether we want to consider Oinc or Odec the default order *) -let convert_ast (defs : Parse_ast.defs) : unit Ast.defs = Initial_check.process_ast defs +let load_file_no_check order f = convert_ast order (parse_file f) -let load_file env f = - let ast = parse_file f in - let ast = convert_ast ast in +let load_file order env f = + let ast = convert_ast order (parse_file f) in Type_check.check env ast let opt_new_typecheck = ref false diff --git a/src/process_file.mli b/src/process_file.mli index b15523bb..9907b743 100644 --- a/src/process_file.mli +++ b/src/process_file.mli @@ -41,14 +41,15 @@ (**************************************************************************) val parse_file : string -> Parse_ast.defs -val convert_ast : Parse_ast.defs -> unit Ast.defs +val convert_ast : Ast.order -> Parse_ast.defs -> unit Ast.defs val check_ast: unit Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t val monomorphise_ast : ((string * int) * string) list -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t val rewrite_ast: Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs val rewrite_ast_lem : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs val rewrite_ast_ocaml : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs -val load_file : Type_check.Env.t -> string -> Type_check.tannot Ast.defs * Type_check.Env.t +val load_file_no_check : Ast.order -> string -> unit Ast.defs +val load_file : Ast.order -> Type_check.Env.t -> string -> Type_check.tannot Ast.defs * Type_check.Env.t val opt_new_typecheck : bool ref val opt_just_check : bool ref diff --git a/src/rewriter.ml b/src/rewriter.ml index 166c31f0..0cf25103 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -135,7 +135,7 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with | E_assign (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e) | E_exit e -> effect_of e | E_return e -> effect_of e - | E_sizeof _ | E_sizeof_internal _ -> no_effect + | E_sizeof _ | E_sizeof_internal _ | E_constraint _ -> no_effect | E_assert (c,m) -> no_effect | E_comment _ | E_comment_struc _ -> no_effect | E_internal_cast (_,e) -> effect_of e @@ -157,6 +157,8 @@ let fix_eff_lexp (LEXP_aux (lexp,((l,_) as annot))) = match snd annot with | LEXP_id _ -> no_effect | LEXP_cast _ -> no_effect | LEXP_memory (_,es) -> union_eff_exps es + | LEXP_tup les -> + List.fold_left (fun eff le -> union_effects eff (effect_of_lexp le)) no_effect les | LEXP_vector (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e) | LEXP_vector_range (lexp,e1,e2) -> union_effects (effect_of_lexp lexp) @@ -188,7 +190,8 @@ let fix_eff_opt_default (Def_val_aux (opt_default,((l,_) as annot))) = match snd let fix_eff_pexp (Pat_aux (pexp,((l,_) as annot))) = match snd annot with | Some (env, typ, eff) -> let effsum = union_effects eff (match pexp with - | Pat_exp (_,e) -> effect_of e) in + | Pat_exp (_,e) -> effect_of e + | Pat_when (_,e,e') -> union_effects (effect_of e) (effect_of e')) in Pat_aux (pexp, (l, Some (env, typ, effsum))) | None -> Pat_aux (pexp, (l, None)) @@ -396,11 +399,13 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot))) = (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) -> FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot)))) | E_field(exp,id) -> rewrap (E_field(rewrite exp,id)) - | E_case (exp ,pexps) -> - rewrap (E_case (rewrite exp, - (List.map - (fun (Pat_aux (Pat_exp(p,e),pannot)) -> - Pat_aux (Pat_exp(rewriters.rewrite_pat rewriters p,rewrite e),pannot)) pexps))) + | E_case (exp,pexps) -> + let rewrite_pexp = function + | (Pat_aux (Pat_exp(p, e), pannot)) -> + Pat_aux (Pat_exp(rewriters.rewrite_pat rewriters p, rewrite e), pannot) + | (Pat_aux (Pat_when(p, e, e'), pannot)) -> + Pat_aux (Pat_when(rewriters.rewrite_pat rewriters p, rewrite e, rewrite e'), pannot) in + rewrap (E_case (rewrite exp, List.map rewrite_pexp pexps)) | E_let (letbind,body) -> rewrap (E_let(rewriters.rewrite_let rewriters letbind,rewrite body)) | E_assign (lexp,exp) -> rewrap (E_assign(rewriters.rewrite_lexp rewriters lexp,rewrite exp)) | E_sizeof n -> rewrap (E_sizeof n) @@ -615,6 +620,7 @@ type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg = ; p_vector_concat : 'pat list -> 'pat_aux ; p_tup : 'pat list -> 'pat_aux ; p_list : 'pat list -> 'pat_aux + ; p_cons : 'pat * 'pat -> 'pat_aux ; p_aux : 'pat_aux * 'a annot -> 'pat ; fP_aux : 'fpat_aux * 'a annot -> 'fpat ; fP_Fpat : id * 'pat -> 'fpat_aux @@ -634,6 +640,7 @@ let rec fold_pat_aux (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a pat | P_vector_concat ps -> alg.p_vector_concat (List.map (fold_pat alg) ps) | P_tup ps -> alg.p_tup (List.map (fold_pat alg) ps) | P_list ps -> alg.p_list (List.map (fold_pat alg) ps) + | P_cons (ph,pt) -> alg.p_cons (fold_pat alg ph, fold_pat alg pt) and fold_pat (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a pat -> 'pat = @@ -660,6 +667,7 @@ let id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg = ; p_vector_concat = (fun ps -> P_vector_concat ps) ; p_tup = (fun ps -> P_tup ps) ; p_list = (fun ps -> P_list ps) + ; p_cons = (fun (ph,pt) -> P_cons (ph,pt)) ; p_aux = (fun (pat,annot) -> P_aux (pat,annot)) ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) ; fP_Fpat = (fun (id,pat) -> FP_Fpat (id,pat)) @@ -700,6 +708,8 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, ; e_internal_cast : 'a annot * 'exp -> 'exp_aux ; e_internal_exp : 'a annot -> 'exp_aux ; e_internal_exp_user : 'a annot * 'a annot -> 'exp_aux + ; e_comment : string -> 'exp_aux + ; e_comment_struc : 'exp -> 'exp_aux ; e_internal_let : 'lexp * 'exp * 'exp -> 'exp_aux ; e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux ; e_internal_return : 'exp -> 'exp_aux @@ -720,6 +730,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, ; def_val_dec : 'exp -> 'opt_default_aux ; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default ; pat_exp : 'pat * 'exp -> 'pexp_aux + ; pat_when : 'pat * 'exp * 'exp -> 'pexp_aux ; pat_aux : 'pexp_aux * 'a annot -> 'pexp ; lB_val_explicit : typschm * 'pat * 'exp -> 'letbind_aux ; lB_val_implicit : 'pat * 'exp -> 'letbind_aux @@ -759,12 +770,18 @@ let rec fold_exp_aux alg = function | E_let (letbind,e) -> alg.e_let (fold_letbind alg letbind, fold_exp alg e) | E_assign (lexp,e) -> alg.e_assign (fold_lexp alg lexp, fold_exp alg e) | E_sizeof nexp -> alg.e_sizeof nexp + | E_constraint nc -> raise (Reporting_basic.err_unreachable (Parse_ast.Unknown) + "E_constraint encountered during rewriting") | E_exit e -> alg.e_exit (fold_exp alg e) | E_return e -> alg.e_return (fold_exp alg e) | E_assert(e1,e2) -> alg.e_assert (fold_exp alg e1, fold_exp alg e2) | E_internal_cast (annot,e) -> alg.e_internal_cast (annot, fold_exp alg e) | E_internal_exp annot -> alg.e_internal_exp annot + | E_sizeof_internal a -> raise (Reporting_basic.err_unreachable (Parse_ast.Unknown) + "E_sizeof_internal encountered during rewriting") | E_internal_exp_user (annot1,annot2) -> alg.e_internal_exp_user (annot1,annot2) + | E_comment c -> alg.e_comment c + | E_comment_struc e -> alg.e_comment_struc (fold_exp alg e) | E_internal_let (lexp,e1,e2) -> alg.e_internal_let (fold_lexp alg lexp, fold_exp alg e1, fold_exp alg e2) | E_internal_plet (pat,e1,e2) -> @@ -774,6 +791,7 @@ and fold_exp alg (E_aux (exp_aux,annot)) = alg.e_aux (fold_exp_aux alg exp_aux, and fold_lexp_aux alg = function | LEXP_id id -> alg.lEXP_id id | LEXP_memory (id,es) -> alg.lEXP_memory (id, List.map (fold_exp alg) es) + | LEXP_tup les -> alg.lEXP_tup (List.map (fold_lexp alg) les) | LEXP_cast (typ,id) -> alg.lEXP_cast (typ,id) | LEXP_vector (lexp,e) -> alg.lEXP_vector (fold_lexp alg lexp, fold_exp alg e) | LEXP_vector_range (lexp,e1,e2) -> @@ -790,7 +808,9 @@ and fold_opt_default_aux alg = function | Def_val_dec e -> alg.def_val_dec (fold_exp alg e) and fold_opt_default alg (Def_val_aux (opt_default_aux,annot)) = alg.def_val_aux (fold_opt_default_aux alg opt_default_aux, annot) -and fold_pexp_aux alg (Pat_exp (pat,e)) = alg.pat_exp (fold_pat alg.pat_alg pat, fold_exp alg e) +and fold_pexp_aux alg = function + | Pat_exp (pat,e) -> alg.pat_exp (fold_pat alg.pat_alg pat, fold_exp alg e) + | Pat_when (pat,e,e') -> alg.pat_when (fold_pat alg.pat_alg pat, fold_exp alg e, fold_exp alg e') and fold_pexp alg (Pat_aux (pexp_aux,annot)) = alg.pat_aux (fold_pexp_aux alg pexp_aux, annot) and fold_letbind_aux alg = function | LB_val_explicit (t,pat,e) -> alg.lB_val_explicit (t,fold_pat alg.pat_alg pat, fold_exp alg e) @@ -830,6 +850,8 @@ let id_exp_alg = ; e_internal_cast = (fun (a,e1) -> E_internal_cast (a,e1)) ; e_internal_exp = (fun a -> E_internal_exp a) ; e_internal_exp_user = (fun (a1,a2) -> E_internal_exp_user (a1,a2)) + ; e_comment = (fun c -> E_comment c) + ; e_comment_struc = (fun e -> E_comment_struc e) ; e_internal_let = (fun (lexp, e2, e3) -> E_internal_let (lexp,e2,e3)) ; e_internal_plet = (fun (pat, e1, e2) -> E_internal_plet (pat,e1,e2)) ; e_internal_return = (fun e -> E_internal_return e) @@ -850,6 +872,7 @@ let id_exp_alg = ; def_val_dec = (fun e -> Def_val_dec e) ; def_val_aux = (fun (defval,aux) -> Def_val_aux (defval,aux)) ; pat_exp = (fun (pat,e) -> (Pat_exp (pat,e))) + ; pat_when = (fun (pat,e,e') -> (Pat_when (pat,e,e'))) ; pat_aux = (fun (pexp,a) -> (Pat_aux (pexp,a))) ; lB_val_explicit = (fun (typ,pat,e) -> LB_val_explicit (typ,pat,e)) ; lB_val_implicit = (fun (pat,e) -> LB_val_implicit (pat,e)) @@ -880,6 +903,7 @@ let compute_pat_alg bot join = ; p_vector_concat = split_join (fun ps -> P_vector_concat ps) ; p_tup = split_join (fun ps -> P_tup ps) ; p_list = split_join (fun ps -> P_list ps) + ; p_cons = (fun ((vh,ph),(vt,pt)) -> (join vh vt, P_cons (ph,pt))) ; p_aux = (fun ((v,pat),annot) -> (v, P_aux (pat,annot))) ; fP_aux = (fun ((v,fpat),annot) -> (v, FP_aux (fpat,annot))) ; fP_Fpat = (fun (id,(v,pat)) -> (v, FP_Fpat (id,pat))) @@ -926,6 +950,8 @@ let compute_exp_alg bot join = ; e_internal_cast = (fun (a,(v1,e1)) -> (v1, E_internal_cast (a,e1))) ; e_internal_exp = (fun a -> (bot, E_internal_exp a)) ; e_internal_exp_user = (fun (a1,a2) -> (bot, E_internal_exp_user (a1,a2))) + ; e_comment = (fun c -> (bot, E_comment c)) + ; e_comment_struc = (fun (v,e) -> (bot, E_comment_struc e)) (* ignore value by default, since it is comes from a comment *) ; e_internal_let = (fun ((vl, lexp), (v2,e2), (v3,e3)) -> (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3))) ; e_internal_plet = (fun ((vp,pat), (v1,e1), (v2,e2)) -> @@ -935,7 +961,9 @@ let compute_exp_alg bot join = ; lEXP_id = (fun id -> (bot, LEXP_id id)) ; lEXP_memory = (fun (id,es) -> split_join (fun es -> LEXP_memory (id,es)) es) ; lEXP_cast = (fun (typ,id) -> (bot, LEXP_cast (typ,id))) - ; lEXP_tup = split_join (fun tups -> LEXP_tup tups) + ; lEXP_tup = (fun ls -> + let (vs,ls) = List.split ls in + (join_list vs, LEXP_tup ls)) ; lEXP_vector = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, LEXP_vector (lexp,e2))) ; lEXP_vector_range = (fun ((vl,lexp),(v2,e2),(v3,e3)) -> (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3))) @@ -951,6 +979,7 @@ let compute_exp_alg bot join = ; def_val_dec = (fun (v,e) -> (v, Def_val_dec e)) ; def_val_aux = (fun ((v,defval),aux) -> (v, Def_val_aux (defval,aux))) ; pat_exp = (fun ((vp,pat),(v,e)) -> (join vp v, Pat_exp (pat,e))) + ; pat_when = (fun ((vp,pat),(v,e),(v',e')) -> (join_list [vp;v;v'], Pat_when (pat,e,e'))) ; pat_aux = (fun ((v,pexp),a) -> (v, Pat_aux (pexp,a))) ; lB_val_explicit = (fun (typ,(vp,pat),(v,e)) -> (join vp v, LB_val_explicit (typ,pat,e))) ; lB_val_implicit = (fun ((vp,pat),(v,e)) -> (join vp v, LB_val_implicit (pat,e))) @@ -986,11 +1015,15 @@ let rewrite_sizeof (Defs defs) = when string_of_id atom = "atom" -> [nexp, E_id id] | Typ_app (vector, _) when string_of_id vector = "vector" -> - let (_,len,_,_) = vector_typ_args_of typ_aux in - let exp = E_app - (Id_aux (Id "length", Parse_ast.Generated l), - [E_aux (E_id id, annot)]) in - [len, exp] + let id_length = Id_aux (Id "length", Parse_ast.Generated l) in + (try + (match Env.get_val_spec id_length (env_of_annot annot) with + | _ -> + let (_,len,_,_) = vector_typ_args_of typ_aux in + let exp = E_app (id_length, [E_aux (E_id id, annot)]) in + [len, exp]) + with + | _ -> []) | _ -> []) | _ -> [] in (v @ v', P_aux (pat,annot)))} pat) in @@ -1166,6 +1199,7 @@ let remove_vector_concat_pat pat = ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)) ; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps)) ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)) + ; p_cons = (fun (p,ps) -> P_cons (p false, ps false)) ; p_aux = (fun (pat,((l,_) as annot)) contained_in_p_as -> match pat with @@ -1218,8 +1252,8 @@ let remove_vector_concat_pat pat = (* build a let-expression of the form "let child = root[i..j] in body" *) let letbind_vec typ_opt (rootid,rannot) (child,cannot) (i,j) = let (l,_) = cannot in - let (Id_aux (Id rootname,_)) = rootid in - let (Id_aux (Id childname,_)) = child in + let rootname = string_of_id rootid in + let childname = string_of_id child in let root = E_aux (E_id rootid, rannot) in let index_i = simple_num l i in @@ -1248,38 +1282,29 @@ let remove_vector_concat_pat pat = let rec aux typ_opt (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) = let ctyp = Env.base_typ_of (env_of_annot cannot) (typ_of_annot cannot) in let (_,length,ord,_) = vector_typ_args_of ctyp in - (*)| (_,length,ord,_) ->*) - let (pos',index_j) = match length with - | Nexp_aux (Nexp_constant i,_) -> - if is_order_inc ord then (pos+i, pos+i-1) - else (pos-i, pos-i+1) - | Nexp_aux (_,l) -> - if is_last then (pos,last_idx) - else - raise - (Reporting_basic.err_unreachable - l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) in - (match p with - (* if we see a named vector pattern, remove the name and remember to - declare it later *) - | P_as (P_aux (p,cannot),cname) -> - let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in - (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) - (* if we see a P_id variable, remember to declare it later *) - | P_id cname -> - let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in - (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) - | P_typ (typ, pat) -> aux (Some typ) (pos,pat_acc,decl_acc) (pat, is_last) - (* normal vector patterns are fine *) - | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc) ) - (* non-vector patterns aren't *) - (*)| _ -> - raise - (Reporting_basic.err_unreachable - (fst cannot) - ("unname_vector_concat_elements: Non-vector in vector-concat pattern:" ^ - string_of_typ (typ_of_annot cannot)) - )*) in + let (pos',index_j) = match length with + | Nexp_aux (Nexp_constant i,_) -> + if is_order_inc ord then (pos+i, pos+i-1) + else (pos-i, pos-i+1) + | Nexp_aux (_,l) -> + if is_last then (pos,last_idx) + else + raise + (Reporting_basic.err_unreachable + l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) in + (match p with + (* if we see a named vector pattern, remove the name and remember to + declare it later *) + | P_as (P_aux (p,cannot),cname) -> + let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in + (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) + (* if we see a P_id variable, remember to declare it later *) + | P_id cname -> + let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in + (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) + | P_typ (typ, pat) -> aux (Some typ) (pos,pat_acc,decl_acc) (pat, is_last) + (* normal vector patterns are fine *) + | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc)) in let pats_tagged = tag_last pats in let (_,pats',decls') = List.fold_left (aux None) (start,[],[]) pats_tagged in @@ -1309,6 +1334,7 @@ let remove_vector_concat_pat pat = (P_tup ps,List.flatten decls)) ; p_list = (fun ps -> let (ps,decls) = List.split ps in (P_list ps,List.flatten decls)) + ; p_cons = (fun ((p,decls),(p',decls')) -> (P_cons (p,p'), decls @ decls')) ; p_aux = (fun ((pat,decls),annot) -> p_aux ((pat,decls),annot)) ; fP_aux = (fun ((fpat,decls),annot) -> (FP_aux (fpat,annot),decls)) ; fP_Fpat = (fun (id,(pat,decls)) -> (FP_Fpat (id,pat),decls)) @@ -1417,9 +1443,13 @@ let rewrite_exp_remove_vector_concat_pat rewriters (E_aux (exp,(l,annot)) as ful let rewrite_base = rewrite_exp rewriters in match exp with | E_case (e,ps) -> - let aux (Pat_aux (Pat_exp (pat,body),annot')) = + let aux = function + | (Pat_aux (Pat_exp (pat,body),annot')) -> let (pat,_,decls) = remove_vector_concat_pat pat in - Pat_aux (Pat_exp (pat, decls (rewrite_rec body)),annot') in + Pat_aux (Pat_exp (pat, decls (rewrite_rec body)),annot') + | (Pat_aux (Pat_when (pat,guard,body),annot')) -> + let (pat,_,decls) = remove_vector_concat_pat pat in + Pat_aux (Pat_when (pat, decls (rewrite_rec guard), decls (rewrite_rec body)),annot') in rewrap (E_case (rewrite_rec e, List.map aux ps)) | E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) -> let (pat,_,decls) = remove_vector_concat_pat pat in @@ -1462,6 +1492,177 @@ let rewrite_defs_remove_vector_concat (Defs defs) = | d -> [d] in Defs (List.flatten (List.map rewrite_def defs)) +(* A few helper functions for rewriting guarded pattern clauses. + Used both by the rewriting of P_when and separately by the rewriting of + bitvectors in parameter patterns of function clauses *) + +let remove_wildcards pre (P_aux (_,(l,_)) as pat) = + fold_pat + {id_pat_alg with + p_aux = function + | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot)) + | (p,annot) -> P_aux (p,annot) } + pat + +(* Check if one pattern subsumes the other, and if so, calculate a + substitution of variables that are used in the same position. + TODO: Check somewhere that there are no variable clashes (the same variable + name used in different positions of the patterns) + *) +let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = + let rewrap p = P_aux (p,annot1) in + let subsumes_list s pats1 pats2 = + if List.length pats1 = List.length pats2 + then + let subs = List.map2 s pats1 pats2 in + List.fold_right + (fun p acc -> match p, acc with + | Some subst, Some substs -> Some (subst @ substs) + | _ -> None) + subs (Some []) + else None in + match p1, p2 with + | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) -> + if lit1 = lit2 then Some [] else None + | P_as (pat1,_), _ -> subsumes_pat pat1 pat2 + | _, P_as (pat2,_) -> subsumes_pat pat1 pat2 + | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2 + | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2 + | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) -> + if id1 = id2 then Some [] + else if Env.lookup_id aid1 (env_of_annot annot1) = Unbound && + Env.lookup_id aid2 (env_of_annot annot2) = Unbound + then Some [(id2,id1)] else None + | P_id id1, _ -> + if Env.lookup_id id1 (env_of_annot annot1) = Unbound then Some [] else None + | P_wild, _ -> Some [] + | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) -> + if id1 = id2 then subsumes_list subsumes_pat args1 args2 else None + | P_record (fps1,b1), P_record (fps2,b2) -> + if b1 = b2 then subsumes_list subsumes_fpat fps1 fps2 else None + | P_vector pats1, P_vector pats2 + | P_vector_concat pats1, P_vector_concat pats2 + | P_tup pats1, P_tup pats2 + | P_list pats1, P_list pats2 -> + subsumes_list subsumes_pat pats1 pats2 + | P_list (pat1 :: pats1), P_cons _ -> + subsumes_pat (rewrap (P_cons (pat1, rewrap (P_list pats1)))) pat2 + | P_cons _, P_list (pat2 :: pats2)-> + subsumes_pat pat1 (rewrap (P_cons (pat2, rewrap (P_list pats2)))) + | P_cons (pat1, pats1), P_cons (pat2, pats2) -> + (match subsumes_pat pat1 pat2, subsumes_pat pats1 pats2 with + | Some substs1, Some substs2 -> Some (substs1 @ substs2) + | _ -> None) + | P_vector_indexed ips1, P_vector_indexed ips2 -> + let (is1,ps1) = List.split ips1 in + let (is2,ps2) = List.split ips2 in + if is1 = is2 then subsumes_list subsumes_pat ps1 ps2 else None + | _ -> None +and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_)) = + if id1 = id2 then subsumes_pat pat1 pat2 else None + +let equiv_pats pat1 pat2 = + match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with + | Some _, Some _ -> true + | _, _ -> false + +let subst_id_pat pat (id1,id2) = + let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in + fold_pat {id_pat_alg with p_id = p_id} pat + +let subst_id_exp exp (id1,id2) = + (* TODO Don't substitute bound occurrences inside let expressions etc *) + let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in + fold_exp {id_exp_alg with e_id = e_id} exp + +let rec pat_to_exp (P_aux (pat,(l,annot))) = + let rewrap e = E_aux (e,(l,annot)) in + match pat with + | P_lit lit -> rewrap (E_lit lit) + | P_wild -> raise (Reporting_basic.err_unreachable l + "pat_to_exp given wildcard pattern") + | P_as (pat,id) -> rewrap (E_id id) + | P_typ (_,pat) -> pat_to_exp pat + | P_id id -> rewrap (E_id id) + | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats)) + | P_record (fpats,b) -> + rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot)))) + | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats)) + | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l + "pat_to_exp not implemented for P_vector_concat") + (* We assume that vector concatenation patterns have been transformed + away already *) + | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats)) + | P_list pats -> rewrap (E_list (List.map pat_to_exp pats)) + | P_cons (p,ps) -> rewrap (E_cons (pat_to_exp p, pat_to_exp ps)) + | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l + "pat_to_exp not implemented for P_vector_indexed") (* TODO *) +and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) = + FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot)) + +let case_exp e t cs = + let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in + let ps = List.map pexp cs in + (* let efr = union_effs (List.map effect_of_pexp ps) in *) + fix_eff_exp (E_aux (E_case (e,ps), (get_loc_exp e, Some (env_of e, t, no_effect)))) + +let rewrite_guarded_clauses l cs = + let rec group clauses = + let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in + let rec group_aux current acc = (function + | ((pat,guard,body,annot) as c) :: cs -> + let (current_pat,_,_) = current in + (match subsumes_pat current_pat pat with + | Some substs -> + let pat' = List.fold_left subst_id_pat pat substs in + let guard' = (match guard with + | Some exp -> Some (List.fold_left subst_id_exp exp substs) + | None -> None) in + let body' = List.fold_left subst_id_exp body substs in + let c' = (pat',guard',body',annot) in + group_aux (add_clause current c') acc cs + | None -> + let pat = remove_wildcards "g__" pat in + group_aux (pat,[c],annot) (acc @ [current]) cs) + | [] -> acc @ [current]) in + let groups = match clauses with + | ((pat,guard,body,annot) as c) :: cs -> + group_aux (remove_wildcards "g__" pat, [c], annot) [] cs + | _ -> + raise (Reporting_basic.err_unreachable l + "group given empty list in rewrite_guarded_clauses") in + List.map (fun cs -> if_pexp cs) groups + and if_pexp (pat,cs,annot) = (match cs with + | c :: _ -> + (* fix_eff_pexp (pexp *) + let body = if_exp pat cs in + let pexp = fix_eff_pexp (Pat_aux (Pat_exp (pat,body),annot)) in + let (Pat_aux (_,annot)) = pexp in + (pat, body, annot) + | [] -> + raise (Reporting_basic.err_unreachable l + "if_pexp given empty list in rewrite_guarded_clauses")) + and if_exp current_pat = (function + | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs -> + (match guard with + | Some exp -> + let else_exp = + if equiv_pats current_pat pat' + then if_exp current_pat (c' :: cs) + else case_exp (pat_to_exp current_pat) (typ_of body') (group (c' :: cs)) in + fix_eff_exp (E_aux (E_if (exp,body,else_exp), simple_annot (fst annot) (typ_of body))) + | None -> body) + | [(pat,guard,body,annot)] -> body + | [] -> + raise (Reporting_basic.err_unreachable l + "if_exp given empty list in rewrite_guarded_clauses")) in + group cs + +let bitwise_and_exp exp1 exp2 = + let (E_aux (_,(l,_))) = exp1 in + let andid = Id_aux (Id "bool_and", Parse_ast.Generated l) in + E_aux (E_app(andid,[exp1;exp2]), simple_annot l bool_typ) + let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with | P_lit _ | P_wild | P_id _ -> false | P_as (pat,_) | P_typ (_,pat) -> contains_bitvector_pat pat @@ -1470,9 +1671,16 @@ let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with is_bitvector_typ typ | P_app (_,pats) | P_tup pats | P_list pats -> List.exists contains_bitvector_pat pats +| P_cons (p,ps) -> contains_bitvector_pat p || contains_bitvector_pat ps | P_record (fpats,_) -> List.exists (fun (FP_aux (FP_Fpat (_,pat),_)) -> contains_bitvector_pat pat) fpats +let contains_bitvector_pexp = function +| Pat_aux (Pat_exp (pat,_),_) | Pat_aux (Pat_when (pat,_,_),_) -> + contains_bitvector_pat pat + +(* Rewrite bitvector patterns to guarded patterns *) + let remove_bitvector_pat pat = (* first introduce names for bitvector patterns *) @@ -1489,6 +1697,7 @@ let remove_bitvector_pat pat = ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)) ; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps)) ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)) + ; p_cons = (fun (p,ps) -> P_cons (p false, ps false)) ; p_aux = (fun (pat,annot) contained_in_p_as -> let env = env_of_annot annot in @@ -1557,14 +1766,8 @@ let remove_bitvector_pat pat = E_aux (E_let (letbind,body), (Parse_ast.Generated l, bannot))) in (letexp, letbind) in - (* Helper functions for composing guards *) - let bitwise_and exp1 exp2 = - let (E_aux (_,(l,_))) = exp1 in - let andid = Id_aux (Id "bool_and", Parse_ast.Generated l) in - E_aux (E_app(andid,[exp1;exp2]), simple_annot l bool_typ) in - let compose_guards guards = - List.fold_right (Util.option_binop bitwise_and) guards None in + List.fold_right (Util.option_binop bitwise_and_exp) guards None in let flatten_guards_decls gd = let (guards,decls,letbinds) = Util.split3 gd in @@ -1651,6 +1854,8 @@ let remove_bitvector_pat pat = (P_tup ps, flatten_guards_decls gdls)) ; p_list = (fun ps -> let (ps,gdls) = List.split ps in (P_list ps, flatten_guards_decls gdls)) + ; p_cons = (fun ((p,gdls),(p',gdls')) -> + (P_cons (p,p'), flatten_guards_decls [gdls;gdls'])) ; p_aux = (fun ((pat,gdls),annot) -> let env = env_of_annot annot in let t = Env.base_typ_of env (typ_of_annot annot) in @@ -1665,183 +1870,27 @@ let remove_bitvector_pat pat = } in fold_pat guard_bitvector_pat pat -let remove_wildcards pre (P_aux (_,(l,_)) as pat) = - fold_pat - {id_pat_alg with - p_aux = function - | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot)) - | (p,annot) -> P_aux (p,annot) } - pat - -(* Check if one pattern subsumes the other, and if so, calculate a - substitution of variables that are used in the same position. - TODO: Check somewhere that there are no variable clashes (the same variable - name used in different positions of the patterns) - *) -let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = - let rewrap p = P_aux (p,annot1) in - let subsumes_list s pats1 pats2 = - if List.length pats1 = List.length pats2 - then - let subs = List.map2 s pats1 pats2 in - List.fold_right - (fun p acc -> match p, acc with - | Some subst, Some substs -> Some (subst @ substs) - | _ -> None) - subs (Some []) - else None in - match p1, p2 with - | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) -> - if lit1 = lit2 then Some [] else None - | P_as (pat1,_), _ -> subsumes_pat pat1 pat2 - | _, P_as (pat2,_) -> subsumes_pat pat1 pat2 - | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2 - | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2 - | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) -> - if id1 = id2 then Some [] - else if Env.lookup_id aid1 (env_of_annot annot1) = Unbound && - Env.lookup_id aid2 (env_of_annot annot2) = Unbound - then Some [(id2,id1)] else None - | P_id id1, _ -> - if Env.lookup_id id1 (env_of_annot annot1) = Unbound then Some [] else None - | P_wild, _ -> Some [] - | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) -> - if id1 = id2 then subsumes_list subsumes_pat args1 args2 else None - | P_record (fps1,b1), P_record (fps2,b2) -> - if b1 = b2 then subsumes_list subsumes_fpat fps1 fps2 else None - | P_vector pats1, P_vector pats2 - | P_vector_concat pats1, P_vector_concat pats2 - | P_tup pats1, P_tup pats2 - | P_list pats1, P_list pats2 -> - subsumes_list subsumes_pat pats1 pats2 - | P_vector_indexed ips1, P_vector_indexed ips2 -> - let (is1,ps1) = List.split ips1 in - let (is2,ps2) = List.split ips2 in - if is1 = is2 then subsumes_list subsumes_pat ps1 ps2 else None - | _ -> None -and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_)) = - if id1 = id2 then subsumes_pat pat1 pat2 else None - -let equiv_pats pat1 pat2 = - match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with - | Some _, Some _ -> true - | _, _ -> false - -let subst_id_pat pat (id1,id2) = - let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in - fold_pat {id_pat_alg with p_id = p_id} pat - -let subst_id_exp exp (id1,id2) = - (* TODO Don't substitute bound occurrences inside let expressions etc *) - let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in - fold_exp {id_exp_alg with e_id = e_id} exp - -let rec pat_to_exp (P_aux (pat,(l,annot))) = - let rewrap e = E_aux (e,(l,annot)) in - match pat with - | P_lit lit -> rewrap (E_lit lit) - | P_wild -> raise (Reporting_basic.err_unreachable l - "pat_to_exp given wildcard pattern") - | P_as (pat,id) -> rewrap (E_id id) - | P_typ (_,pat) -> pat_to_exp pat - | P_id id -> rewrap (E_id id) - | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats)) - | P_record (fpats,b) -> - rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot)))) - | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats)) - | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l - "pat_to_exp not implemented for P_vector_concat") - (* We assume that vector concatenation patterns have been transformed - away already *) - | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats)) - | P_list pats -> rewrap (E_list (List.map pat_to_exp pats)) - | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l - "pat_to_exp not implemented for P_vector_indexed") (* TODO *) -and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) = - FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot)) - -let case_exp e t cs = - let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in - let ps = List.map pexp cs in - (* let efr = union_effs (List.map effect_of_pexp ps) in *) - fix_eff_exp (E_aux (E_case (e,ps), (get_loc_exp e, Some (env_of e, t, no_effect)))) - -let rewrite_guarded_clauses l cs = - let rec group clauses = - let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in - let rec group_aux current acc = (function - | ((pat,guard,body,annot) as c) :: cs -> - let (current_pat,_,_) = current in - (match subsumes_pat current_pat pat with - | Some substs -> - let pat' = List.fold_left subst_id_pat pat substs in - let guard' = (match guard with - | Some exp -> Some (List.fold_left subst_id_exp exp substs) - | None -> None) in - let body' = List.fold_left subst_id_exp body substs in - let c' = (pat',guard',body',annot) in - group_aux (add_clause current c') acc cs - | None -> - let pat = remove_wildcards "g__" pat in - group_aux (pat,[c],annot) (acc @ [current]) cs) - | [] -> acc @ [current]) in - let groups = match clauses with - | ((pat,guard,body,annot) as c) :: cs -> - group_aux (remove_wildcards "g__" pat, [c], annot) [] cs - | _ -> - raise (Reporting_basic.err_unreachable l - "group given empty list in rewrite_guarded_clauses") in - List.map (fun cs -> if_pexp cs) groups - and if_pexp (pat,cs,annot) = (match cs with - | c :: _ -> - (* fix_eff_pexp (pexp *) - let body = if_exp pat cs in - let pexp = fix_eff_pexp (Pat_aux (Pat_exp (pat,body),annot)) in - let (Pat_aux (Pat_exp (_,_),annot)) = pexp in - (pat, body, annot) - | [] -> - raise (Reporting_basic.err_unreachable l - "if_pexp given empty list in rewrite_guarded_clauses")) - and if_exp current_pat = (function - | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs -> - (match guard with - | Some exp -> - let else_exp = - if equiv_pats current_pat pat' - then if_exp current_pat (c' :: cs) - else case_exp (pat_to_exp current_pat) (typ_of body') (group (c' :: cs)) in - fix_eff_exp (E_aux (E_if (exp,body,else_exp), simple_annot (fst annot) (typ_of body))) - | None -> body) - | [(pat,guard,body,annot)] -> body - | [] -> - raise (Reporting_basic.err_unreachable l - "if_exp given empty list in rewrite_guarded_clauses")) in - group cs - let rewrite_exp_remove_bitvector_pat rewriters (E_aux (exp,(l,annot)) as full_exp) = let rewrap e = E_aux (e,(l,annot)) in let rewrite_rec = rewriters.rewrite_exp rewriters in let rewrite_base = rewrite_exp rewriters in match exp with | E_case (e,ps) - when List.exists (fun (Pat_aux (Pat_exp (pat,_),_)) -> contains_bitvector_pat pat) ps -> - let clause (Pat_aux (Pat_exp (pat,body),annot')) = - let (pat',(guard,decls,_)) = remove_bitvector_pat pat in + when List.exists contains_bitvector_pexp ps -> + let rewrite_pexp = function + | Pat_aux (Pat_exp (pat,body),annot') -> + let (pat',(guard',decls,_)) = remove_bitvector_pat pat in let body' = decls (rewrite_rec body) in - (pat',guard,body',annot') in - let clauses = rewrite_guarded_clauses l (List.map clause ps) in - if (effectful e) then - let e = rewrite_rec e in - let (E_aux (_,(el,eannot))) = e in - let pat_e' = fresh_id_pat "p__" (el,eannot) in - let exp_e' = pat_to_exp pat_e' in - (* let fresh = fresh_id "p__" el in - let exp_e' = E_aux (E_id fresh, gen_annot l (get_type e) pure_e) in - let pat_e' = P_aux (P_id fresh, gen_annot l (get_type e) pure_e) in *) - let letbind_e = LB_aux (LB_val_implicit (pat_e',e), (el,eannot)) in - let exp' = case_exp exp_e' (typ_of full_exp) clauses in - rewrap (E_let (letbind_e, exp')) - else case_exp e (typ_of full_exp) clauses + (match guard' with + | Some guard' -> Pat_aux (Pat_when (pat', guard', body'), annot') + | None -> Pat_aux (Pat_exp (pat', body'), annot')) + | Pat_aux (Pat_when (pat,guard,body),annot') -> + let (pat',(guard',decls,_)) = remove_bitvector_pat pat in + let body' = decls (rewrite_rec body) in + (match guard' with + | Some guard' -> Pat_aux (Pat_when (pat', bitwise_and_exp guard guard', body'), annot') + | None -> Pat_aux (Pat_when (pat', guard, body'), annot')) in + rewrap (E_case (e, List.map rewrite_pexp ps)) | E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) -> let (pat,(_,decls,_)) = remove_bitvector_pat pat in rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'), @@ -1891,6 +1940,38 @@ let rewrite_defs_remove_bitvector_pats (Defs defs) = Defs (List.flatten (List.map rewrite_def defs)) +(* Remove pattern guards by rewriting them to if-expressions within the + pattern expression. Shares code with the rewriting of bitvector patterns. *) +let rewrite_exp_guarded_pats rewriters (E_aux (exp,(l,annot)) as full_exp) = + let rewrap e = E_aux (e,(l,annot)) in + let rewrite_rec = rewriters.rewrite_exp rewriters in + let rewrite_base = rewrite_exp rewriters in + let is_guarded_pexp = function + | Pat_aux (Pat_when (_,_,_),_) -> true + | _ -> false in + match exp with + | E_case (e,ps) + when List.exists is_guarded_pexp ps -> + let clause = function + | Pat_aux (Pat_exp (pat, body), annot) -> + (pat, None, rewrite_rec body, annot) + | Pat_aux (Pat_when (pat, guard, body), annot) -> + (pat, Some guard, rewrite_rec body, annot) in + let clauses = rewrite_guarded_clauses l (List.map clause ps) in + if (effectful e) then + let e = rewrite_rec e in + let (E_aux (_,(el,eannot))) = e in + let pat_e' = fresh_id_pat "p__" (el,eannot) in + let exp_e' = pat_to_exp pat_e' in + let letbind_e = LB_aux (LB_val_implicit (pat_e',e), (el,eannot)) in + let exp' = case_exp exp_e' (typ_of full_exp) clauses in + rewrap (E_let (letbind_e, exp')) + else case_exp e (typ_of full_exp) clauses + | _ -> rewrite_base full_exp + +let rewrite_defs_guarded_pats = + rewrite_defs_base { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats } + (*Expects to be called after rewrite_defs; thus the following should not appear: internal_exp of any form lit vectors in patterns or expressions @@ -2145,8 +2226,11 @@ let rewrite_defs_letbind_effects = mapCont n_fexp fexps k and n_pexp (newreturn : bool) (pexp : 'a pexp) (k : 'a pexp -> 'a exp) : 'a exp = - let (Pat_aux (Pat_exp (pat,exp),annot)) = pexp in - k (fix_eff_pexp (Pat_aux (Pat_exp (pat,n_exp_term newreturn exp), annot))) + match pexp with + | Pat_aux (Pat_exp (pat,exp),annot) -> + k (fix_eff_pexp (Pat_aux (Pat_exp (pat,n_exp_term newreturn exp), annot))) + | Pat_aux (Pat_when (pat,guard,exp),annot) -> + k (fix_eff_pexp (Pat_aux (Pat_when (pat,n_exp_term newreturn guard,n_exp_term newreturn exp), annot))) and n_pexpL (newreturn : bool) (pexps : 'a pexp list) (k : 'a pexp list -> 'a exp) : 'a exp = mapCont (n_pexp newreturn) pexps k @@ -2181,6 +2265,9 @@ let rewrite_defs_letbind_effects = | LEXP_memory (id,es) -> n_exp_nameL es (fun es -> k (fix_eff_lexp (LEXP_aux (LEXP_memory (id,es),annot)))) + | LEXP_tup es -> + n_lexpL es (fun es -> + k (fix_eff_lexp (LEXP_aux (LEXP_tup es,annot)))) | LEXP_cast (typ,id) -> k (fix_eff_lexp (LEXP_aux (LEXP_cast (typ,id),annot))) | LEXP_vector (lexp,e) -> @@ -2196,6 +2283,9 @@ let rewrite_defs_letbind_effects = n_lexp lexp (fun lexp -> k (fix_eff_lexp (LEXP_aux (LEXP_field (lexp,id),annot)))) + and n_lexpL (lexps : 'a lexp list) (k : 'a lexp list -> 'a exp) : 'a exp = + mapCont n_lexp lexps k + and n_exp_term (newreturn : bool) (exp : 'a exp) : 'a exp = let (E_aux (_,(l,tannot))) = exp in let exp = @@ -2308,6 +2398,7 @@ let rewrite_defs_letbind_effects = rewrap (E_let (lb,n_exp body k))) | E_sizeof nexp -> k (rewrap (E_sizeof nexp)) + | E_constraint nc -> failwith "E_constraint should have been removed till now" | E_sizeof_internal annot -> k (rewrap (E_sizeof_internal annot)) | E_assign (lexp,exp1) -> @@ -2403,7 +2494,7 @@ let eqidtyp (id1,_) (id2,_) = let name2 = match id2 with Id_aux ((Id name | DeIid name),_) -> name in name1 = name2 -let find_updated_vars exp = +let find_updated_vars (E_aux (_,(l,_)) as exp) = let ( @@ ) (a,b) (a',b') = (a @ a',b @ b') in let lapp2 (l : (('a list * 'b list) list)) : ('a list * 'b list) = List.fold_left @@ -2445,8 +2536,14 @@ let find_updated_vars exp = ; e_internal_cast = (fun (_,e1) -> e1) ; e_internal_exp = (fun _ -> ([],[])) ; e_internal_exp_user = (fun _ -> ([],[])) + ; e_comment = (fun _ -> ([],[])) + ; e_comment_struc = (fun _ -> ([],[])) ; e_internal_let = - (fun (([id],acc),e2,e3) -> + (fun ((ids,acc),e2,e3) -> + let id = match ids with + | [] -> raise (Reporting_basic.err_unreachable l "E_internal_let found not introducing a variable") + | [id] -> id + | _ -> raise (Reporting_basic.err_unreachable l "E_internal_let found introducing more than one variable") in let (xs,ys) = ([id],[]) @@ acc @@ e2 @@ e3 in let ys = List.filter (fun id2 -> not (eqidtyp id id2)) ys in (xs,ys)) @@ -2475,6 +2572,7 @@ let find_updated_vars exp = ; def_val_dec = (fun e -> e) ; def_val_aux = (fun (defval,_) -> defval) ; pat_exp = (fun (_,e) -> e) + ; pat_when = (fun (_,_,e) -> e) ; pat_aux = (fun (pexp,_) -> pexp) ; lB_val_explicit = (fun (_,_,e) -> e) ; lB_val_implicit = (fun (_,e) -> e) @@ -2568,7 +2666,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | false, Ord_aux (Ord_inc,_) -> "foreach_inc" | false, Ord_aux (Ord_dec,_) -> "foreach_dec" | true, Ord_aux (Ord_inc,_) -> "foreachM_inc" - | true, Ord_aux (Ord_dec,_) -> "foreachM_dec" in + | true, Ord_aux (Ord_dec,_) -> "foreachM_dec" + | _ -> raise (Reporting_basic.err_unreachable el + "Could not determine foreach combinator") in let funcl = Id_aux (Id fname,Parse_ast.Generated el) in let loopvar = (* Don't bother with creating a range type annotation, since the @@ -2618,16 +2718,21 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | E_case (e1,ps) -> (* after rewrite_defs_letbind_effects e1 needs no rewriting *) let vars = - let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in + let f acc (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) = + acc @ find_updated_vars e in List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (dedup eqidtyp (List.fold_left f [] ps)) in if vars = [] then - let ps = List.map (fun (Pat_aux (Pat_exp (p,e),a)) -> Pat_aux (Pat_exp (p,rewrite_var_updates e),a)) ps in + let ps = List.map (function + | Pat_aux (Pat_exp (p,e),a) -> + Pat_aux (Pat_exp (p,rewrite_var_updates e),a) + | Pat_aux (Pat_when (p,g,e),a) -> + Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in Same_vars (E_aux (E_case (e1,ps),annot)) else let vartuple = mktup el vars in let typ = - let (Pat_aux (Pat_exp (_,first),_)) = List.hd ps in + let (Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_)) = List.hd ps in typ_of first in let (ps,typ,effs) = let f (acc,typ,effs) (Pat_aux (Pat_exp (p,e),pannot)) = @@ -2856,9 +2961,10 @@ let rewrite_defs_remove_e_assign = let rewrite_defs_lem = top_sort_defs >> + rewrite_sizeof >> rewrite_defs_remove_vector_concat >> rewrite_defs_remove_bitvector_pats >> - rewrite_sizeof >> + rewrite_defs_guarded_pats >> rewrite_defs_exp_lift_assign >> rewrite_defs_remove_blocks >> rewrite_defs_letbind_effects >> diff --git a/src/rewriter.mli b/src/rewriter.mli index b2b0bf5e..473456f6 100644 --- a/src/rewriter.mli +++ b/src/rewriter.mli @@ -73,6 +73,7 @@ type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg = ; p_vector_concat : 'pat list -> 'pat_aux ; p_tup : 'pat list -> 'pat_aux ; p_list : 'pat list -> 'pat_aux + ; p_cons : 'pat * 'pat -> 'pat_aux ; p_aux : 'pat_aux * 'a annot -> 'pat ; fP_aux : 'fpat_aux * 'a annot -> 'fpat ; fP_Fpat : id * 'pat -> 'fpat_aux @@ -117,6 +118,8 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, ; e_internal_cast : 'a annot * 'exp -> 'exp_aux ; e_internal_exp : 'a annot -> 'exp_aux ; e_internal_exp_user : 'a annot * 'a annot -> 'exp_aux + ; e_comment : string -> 'exp_aux + ; e_comment_struc : 'exp -> 'exp_aux ; e_internal_let : 'lexp * 'exp * 'exp -> 'exp_aux ; e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux ; e_internal_return : 'exp -> 'exp_aux @@ -137,6 +140,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, ; def_val_dec : 'exp -> 'opt_default_aux ; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default ; pat_exp : 'pat * 'exp -> 'pexp_aux + ; pat_when : 'pat * 'exp * 'exp -> 'pexp_aux ; pat_aux : 'pexp_aux * 'a annot -> 'pexp ; lB_val_explicit : typschm * 'pat * 'exp -> 'letbind_aux ; lB_val_implicit : 'pat * 'exp -> 'letbind_aux diff --git a/src/sail.ml b/src/sail.ml index 3500b213..c7c14a67 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -132,7 +132,7 @@ let main() = let ast = List.fold_right (fun (_,(Parse_ast.Defs ast_nodes)) (Parse_ast.Defs later_nodes) -> Parse_ast.Defs (ast_nodes@later_nodes)) parsed (Parse_ast.Defs []) in - let ast = convert_ast ast in + let ast = convert_ast Type_check.inc_ord ast in let (ast, type_envs) = check_ast ast in let (ast, type_envs) = diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index 1447ff02..fdd56ecc 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -357,6 +357,11 @@ and fv_of_pes consider_var bound used set pes = let bound_p,us_p = pat_bindings consider_var bound used p in let bound_e,us_e,set_e = fv_of_exp consider_var bound_p us_p set e in fv_of_pes consider_var bound us_e set_e pes + | Pat_aux(Pat_when (p,g,e),_)::pes -> + let bound_p,us_p = pat_bindings consider_var bound used p in + let bound_g,us_g,set_g = fv_of_exp consider_var bound_p us_p set g in + let bound_e,us_e,set_e = fv_of_exp consider_var bound_g us_g set_g e in + fv_of_pes consider_var bound us_e set_e pes and fv_of_let consider_var bound used set (LB_aux(lebind,_)) = match lebind with | LB_val_explicit(typsch,pat,exp) -> diff --git a/src/type_check.ml b/src/type_check.ml index 3c133405..ca9c3618 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -73,6 +73,11 @@ let deinfix = function | Id_aux (Id v, l) -> Id_aux (DeIid v, l) | Id_aux (DeIid v, l) -> Id_aux (DeIid v, l) +let field_name rec_id id = + match rec_id, id with + | Id_aux (Id r, _), Id_aux (Id v, l) -> Id_aux (Id (r ^ "." ^ v), l) + | _, _ -> assert false + let string_of_bind (typquant, typ) = string_of_typquant typquant ^ ". " ^ string_of_typ typ let unaux_nexp (Nexp_aux (nexp, _)) = nexp @@ -133,6 +138,9 @@ let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown) let nc_lt n1 n2 = nc_lteq n1 (nsum n2 (nconstant 1)) let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nconstant 1)) let nc_and nc1 nc2 = mk_nc (NC_and (nc1, nc2)) +let nc_or nc1 nc2 = mk_nc (NC_or (nc1, nc2)) +let nc_true = mk_nc NC_true +let nc_false = mk_nc NC_false let mk_lit l = E_aux (E_lit (L_aux (l, Parse_ast.Unknown)), (Parse_ast.Unknown, ())) @@ -145,6 +153,8 @@ let rec nc_negate (NC_aux (nc, _)) = | NC_not_equal (n1, n2) -> nc_eq n1 n2 | NC_and (n1, n2) -> mk_nc (NC_or (nc_negate n1, nc_negate n2)) | NC_or (n1, n2) -> mk_nc (NC_and (nc_negate n1, nc_negate n2)) + | NC_false -> mk_nc NC_true + | NC_true -> mk_nc NC_false | NC_nat_set_bounded (kid, []) -> typ_error Parse_ast.Unknown "Cannot negate empty nexp set" | NC_nat_set_bounded (kid, [int]) -> nc_neq (nvar kid) (nconstant int) | NC_nat_set_bounded (kid, int :: ints) -> @@ -208,7 +218,6 @@ let is_typ_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), _), _) -> true | _ -> false - (**************************************************************************) (* 1. Substitutions *) (**************************************************************************) @@ -240,6 +249,8 @@ and nc_subst_nexp_aux l sv subst = function else set_nc | NC_or (nc1, nc2) -> NC_or (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2) | NC_and (nc1, nc2) -> NC_and (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2) + | NC_false -> NC_false + | NC_true -> NC_true let rec typ_subst_nexp sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_nexp_aux sv subst typ, l) and typ_subst_nexp_aux sv subst = function @@ -374,7 +385,7 @@ module Env : sig val is_union_constructor : id -> t -> bool val add_record : id -> typquant -> (typ * id) list -> t -> t val is_record : id -> t -> bool - val get_accessor : id -> t -> typquant * typ + val get_accessor : id -> id -> t -> typquant * typ val add_local : id -> mut * typ -> t -> t val add_variant : id -> typquant * type_union list -> t -> t val add_union_id : id -> typquant * typ -> t -> t @@ -613,18 +624,18 @@ end = struct in let fold_accessors accs (typ, fid) = let acc_typ = mk_typ (Typ_fn (rectyp, typ, Effect_aux (Effect_set [], Parse_ast.Unknown))) in - typ_print (indent 1 ^ "Adding accessor " ^ string_of_id fid ^ " :: " ^ string_of_bind (typq, acc_typ)); - Bindings.add fid (typq, acc_typ) accs + typ_print (indent 1 ^ "Adding accessor " ^ string_of_id id ^ "." ^ string_of_id fid ^ " :: " ^ string_of_bind (typq, acc_typ)); + Bindings.add (field_name id fid) (typq, acc_typ) accs in { env with records = Bindings.add id (typq, fields) env.records; accessors = List.fold_left fold_accessors env.accessors fields } end - let get_accessor id env = + let get_accessor rec_id id env = let freshen_bind bind = List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) in - try freshen_bind (Bindings.find id env.accessors) + try freshen_bind (Bindings.find (field_name rec_id id) env.accessors) with - | Not_found -> typ_error (id_loc id) ("No accessor found for " ^ string_of_id id) + | Not_found -> typ_error (id_loc id) ("No accessor found for " ^ string_of_id (field_name rec_id id)) let is_mutable id env = try @@ -776,6 +787,7 @@ end = struct | NC_nat_set_bounded (kid, ints) -> () (* MAYBE: We could demand that ints are all unique here *) | NC_or (nc1, nc2) -> wf_constraint env nc1; wf_constraint env nc2 | NC_and (nc1, nc2) -> wf_constraint env nc1; wf_constraint env nc2 + | NC_true | NC_false -> () let get_constraints env = env.constraints @@ -1045,6 +1057,8 @@ let rec nc_constraint var_of (NC_aux (nc, l)) = (List.map (fun i -> Constraint.eq (nexp_constraint var_of (nvar kid)) (Constraint.constant (big_int_of_int i))) ints) | NC_or (nc1, nc2) -> Constraint.disj (nc_constraint var_of nc1) (nc_constraint var_of nc2) | NC_and (nc1, nc2) -> Constraint.conj (nc_constraint var_of nc1) (nc_constraint var_of nc2) + | NC_false -> Constraint.literal false + | NC_true -> Constraint.literal true let rec nc_constraints var_of ncs = match ncs with @@ -1085,6 +1099,8 @@ let prove env (NC_aux (nc_aux, _) as nc) = | NC_fixed (nexp1, nexp2) when compare_const (fun c1 c2 -> c1 <> c2) (nexp_simp nexp1) (nexp_simp nexp2) -> false | NC_bounded_le (nexp1, nexp2) when compare_const (fun c1 c2 -> c1 > c2) (nexp_simp nexp1) (nexp_simp nexp2) -> false | NC_bounded_ge (nexp1, nexp2) when compare_const (fun c1 c2 -> c1 < c2) (nexp_simp nexp1) (nexp_simp nexp2) -> false + | NC_true -> true + | NC_false -> false | _ -> prove_z3 env nc let rec subtyp_tnf env tnf1 tnf2 = @@ -1600,6 +1616,24 @@ let restrict_range_lower c1 (Typ_aux (typ_aux, l) as typ) = range_typ (nconstant (max c1 c2)) nexp | _ -> typ +exception Not_a_constraint;; + +let rec assert_nexp (E_aux (exp_aux, l)) = + match exp_aux with + | E_sizeof nexp -> nexp + | E_lit (L_aux (L_num n, _)) -> nconstant n + | _ -> raise Not_a_constraint + +let rec assert_constraint (E_aux (exp_aux, l)) = + match exp_aux with + | E_app_infix (x, op, y) when string_of_id op = "|" -> + nc_or (assert_constraint x) (assert_constraint y) + | E_app_infix (x, op, y) when string_of_id op = "&" -> + nc_and (assert_constraint x) (assert_constraint y) + | E_app_infix (x, op, y) when string_of_id op = "==" -> + nc_eq (assert_nexp x) (assert_nexp y) + | _ -> nc_true + type flow_constraint = | Flow_lteq of int | Flow_gteq of int @@ -1725,7 +1759,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | E_block exps, _ -> begin let rec check_block l env exps typ = match exps with - | [] -> typ_error l "Empty block found" + | [] -> typ_equality l env typ unit_typ; [] | [exp] -> [crule check_exp env exp typ] | (E_aux (E_assign (lexp, bind), _) :: exps) -> let texp, env = bind_assignment env lexp bind in @@ -1734,6 +1768,14 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ typ_print ("Adding constraint " ^ string_of_n_constraint nc ^ " for assert"); let inferred_exp = irule infer_exp env exp in inferred_exp :: check_block l (Env.add_constraint nc env) exps typ + | ((E_aux (E_assert (const_expr, assert_msg), _) as exp) :: exps) -> + begin + try + let nc = assert_constraint const_expr in + check_block l (Env.add_constraint nc env) exps typ + with + | Not_a_constraint -> check_block l env exps typ + end | (exp :: exps) -> let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in texp :: check_block l env exps typ @@ -1797,7 +1839,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ begin let (start, len, ord, vtyp) = destructure_vec_typ l env typ in let checked_items = List.map (fun i -> crule check_exp env i vtyp) vec in - match len with + match nexp_simp len with | Nexp_aux (Nexp_constant lenc, _) -> if List.length vec = lenc then annot_exp (E_vector checked_items) typ else typ_error l "List length didn't match" (* FIXME: improve error message *) @@ -1932,10 +1974,17 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) annot_pat (P_list pats) typ, env | _ -> typ_error l "Cannot match list pattern against non-list type" end + | P_tup [] -> + begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" -> + annot_pat (P_tup []) typ, env + | _ -> typ_error l "Cannot match unit pattern against non-unit type" + end | P_tup pats -> begin - match typ_aux with - | Typ_tup typs -> + match Env.expand_synonyms env typ with + | Typ_aux (Typ_tup typs, _) -> let tpats, env = try List.fold_left2 bind_tuple_pat ([], env) pats typs with | Invalid_argument _ -> typ_error l "Tuple pattern and tuple type have different length" @@ -2040,24 +2089,27 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as let infer_flexp = function | LEXP_id v -> begin match Env.lookup_id v env with - | Register typ -> typ, LEXP_id v - | _ -> typ_error l "l-expression field is not a register" + | Register typ -> typ, LEXP_id v, true + | Local (Mutable, typ) -> typ, LEXP_id v, false + | _ -> typ_error l "l-expression field is not a register or a local mutable type" end | LEXP_vector (LEXP_aux (LEXP_id v, _), exp) -> begin (* Check: is this ok if the vector is immutable? *) - let is_immutable, vtyp = match Env.lookup_id v env with + let is_immutable, vtyp, is_register = match Env.lookup_id v env with | Unbound -> typ_error l "Cannot assign to element of unbound vector" | Enum _ -> typ_error l "Cannot vector assign to enumeration element" - | Local (Immutable, vtyp) -> true, vtyp - | Local (Mutable, vtyp) | Register vtyp -> false, vtyp + | Local (Immutable, vtyp) -> true, vtyp, false + | Local (Mutable, vtyp) -> false, vtyp, false + | Register vtyp -> false, vtyp, true in let access = infer_exp (Env.enable_casts env) (E_aux (E_app (mk_id "vector_access", [E_aux (E_id v, (l, ())); exp]), (l, ()))) in let E_aux (E_app (_, [_; inferred_exp]), _) = access in - typ_of access, LEXP_vector (annot_lexp (LEXP_id v) vtyp, inferred_exp) + typ_of access, LEXP_vector (annot_lexp (LEXP_id v) vtyp, inferred_exp), is_register end in - let regtyp, inferred_flexp = infer_flexp flexp in + let regtyp, inferred_flexp, is_register = infer_flexp flexp in + let eff = if is_register then mk_effect [BE_wreg] else no_effect in typ_debug ("REGTYP: " ^ string_of_typ regtyp ^ " / " ^ string_of_typ (Env.expand_synonyms env regtyp)); match Env.expand_synonyms env regtyp with | Typ_aux (Typ_id regtyp_id, _) when Env.is_regtyp regtyp_id env -> @@ -2074,13 +2126,13 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as | _, _ -> typ_error l "Not implemented this register field type yet..." in let checked_exp = crule check_exp env exp vec_typ in - annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp (mk_effect [BE_wreg]), field)) vec_typ) checked_exp, env + annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp eff, field)) vec_typ) checked_exp, env | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> let (typq, Typ_aux (Typ_fn (rectyp_q, field_typ, _), _)) = Env.get_accessor field env in let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q regtyp with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in let field_typ' = subst_unifiers unifiers field_typ in let checked_exp = crule check_exp env exp field_typ' in - annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp (mk_effect [BE_wreg]), field)) field_typ') checked_exp, env + annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp eff, field)) field_typ') checked_exp, env | _ -> typ_error l "Field l-expression has invalid type" end | LEXP_memory (f, xs) -> @@ -2254,7 +2306,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = (* Accessing a field of a record *) | Typ_aux (Typ_id rectyp, _) as typ when Env.is_record rectyp env -> begin - let inferred_acc, _ = infer_funapp' l (Env.no_casts env) field (Env.get_accessor field env) [strip_exp inferred_exp] None in + let inferred_acc, _ = infer_funapp' l (Env.no_casts env) field (Env.get_accessor rectyp field env) [strip_exp inferred_exp] None in match inferred_acc with | E_aux (E_app (field, [inferred_exp]) ,_) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc) | _ -> assert false (* Unreachable *) diff --git a/src/type_check.mli b/src/type_check.mli index 647feaaa..a2b8a10c 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -87,7 +87,7 @@ module Env : sig val is_record : id -> t -> bool - val get_accessor : id -> t -> typquant * typ + val get_accessor : id -> id -> t -> typquant * typ (* If the environment is checking a function, then this will get the expected return type of the function. It's useful for checking or @@ -105,6 +105,8 @@ module Env : sig won't throw any exceptions. *) val lookup_id : id -> t -> lvar + val is_union_constructor : id -> t -> bool + (* Return a fresh kind identifier that doesn't exist in the environment *) val fresh_kid : t -> kid diff --git a/test/typecheck/fail/overlap_field_wreg.sail b/test/typecheck/fail/overlap_field_wreg.sail new file mode 100644 index 00000000..4c4d858d --- /dev/null +++ b/test/typecheck/fail/overlap_field_wreg.sail @@ -0,0 +1,13 @@ + +typedef A = const struct {bool field_A; int shared} +typedef B = const struct {bool field_B; int shared} + +val (bool, int) -> A effect {undef, wreg} makeA + +function makeA (x, y) = +{ + (A) record := undefined; + record.field_A := x; + record.shared := y; + record +} diff --git a/test/typecheck/pass/add_real.sail b/test/typecheck/pass/add_real.sail new file mode 100644 index 00000000..38a9cff3 --- /dev/null +++ b/test/typecheck/pass/add_real.sail @@ -0,0 +1,5 @@ +val (real, real) -> real effect pure add_real + +overload (deinfix +) [add_real] + +let (real) r = 2.2 + 0.2 diff --git a/test/typecheck/pass/add_vec_exts_no_annot.sail b/test/typecheck/pass/add_vec_exts_no_annot.sail new file mode 100644 index 00000000..54aa2d40 --- /dev/null +++ b/test/typecheck/pass/add_vec_exts_no_annot.sail @@ -0,0 +1,19 @@ +default Order dec + +val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord. + vector<'o, 'n, 'ord, bit> -> vector<'p, 'm, 'ord, bit> effect pure exts + +overload EXTS [exts] + +val forall Nat 'n, Nat 'o, Order 'ord. + (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_vec + +overload (deinfix +) [add_vec] + +val (bit[32], bit[32]) -> unit effect pure test + +function test (x, y) = +{ + let (bit[64]) z = add_vec(exts(x), exts(y)) in + () +} diff --git a/test/typecheck/pass/add_vec_exts_no_annot_overload.sail b/test/typecheck/pass/add_vec_exts_no_annot_overload.sail new file mode 100644 index 00000000..01e3bf7c --- /dev/null +++ b/test/typecheck/pass/add_vec_exts_no_annot_overload.sail @@ -0,0 +1,19 @@ +default Order dec + +val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord. + vector<'o, 'n, 'ord, bit> -> vector<'p, 'm, 'ord, bit> effect pure exts + +overload EXTS [exts] + +val forall Nat 'n, Nat 'o, Order 'ord. + (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_vec + +overload (deinfix +) [add_vec] + +val (bit[32], bit[32]) -> unit effect pure test + +function test (x, y) = +{ + let (bit[64]) z = EXTS(x) + EXTS(y) in + () +} diff --git a/test/typecheck/pass/overlap_field.sail b/test/typecheck/pass/overlap_field.sail new file mode 100644 index 00000000..82e685ee --- /dev/null +++ b/test/typecheck/pass/overlap_field.sail @@ -0,0 +1,13 @@ + +typedef A = const struct {bool field_A; int shared} +typedef B = const struct {bool field_B; int shared} + +val (bool, int) -> A effect {undef} makeA + +function makeA (x, y) = +{ + (A) record := undefined; + record.field_A := x; + record.shared := y; + record +} |
