diff options
| author | Alasdair Armstrong | 2017-07-13 14:25:34 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2017-07-13 14:25:34 +0100 |
| commit | c19b8e2b934149b6670f43d875d773115b08410e (patch) | |
| tree | 65047a852db3ffb1773f59eb2d859884179abaaf | |
| parent | 73e54aeec2febe58424b44c2c8f649b29910f3d9 (diff) | |
Improved type inference for let statements and assignments with type annotated patterns and lexps
Added get_enum to type checker interface
| -rw-r--r-- | lib/prelude.sail | 47 | ||||
| -rw-r--r-- | mips_new_tc/mips_insts.sail | 44 | ||||
| -rw-r--r-- | src/parser.mly | 8 | ||||
| -rw-r--r-- | src/pretty_print_sail.ml | 1 | ||||
| -rw-r--r-- | src/type_check_new.ml | 27 | ||||
| -rw-r--r-- | src/type_check_new.mli | 4 |
6 files changed, 103 insertions, 28 deletions
diff --git a/lib/prelude.sail b/lib/prelude.sail index 05b1ac80..350c6c20 100644 --- a/lib/prelude.sail +++ b/lib/prelude.sail @@ -1,6 +1,8 @@ val cast forall Nat 'n, Nat 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0:2**'m - 1|] effect pure unsigned +val forall Nat 'n, Nat 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0 - (2**('m - 1)):2**('m - 1) - 1|] effect pure signed + val forall Nat 'n, Nat 'm. [|0:'n|] -> vector<'m - 1,'m,dec,bit> effect pure to_vec (* Vector access can't actually be properly polymorphic on vector @@ -37,10 +39,15 @@ val cast forall Type 'a. register<'a> -> 'a effect {rreg} reg_deref (* Bitvector duplication *) val forall Nat 'n. (bit, [:'n:]) -> vector<'n - 1,'n,dec,bit> effect pure duplicate +val (bit, int) -> list<bit> effect pure duplicate_to_list + val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord. (vector<'o,'n,'ord,bit>, [:'m:]) -> vector<'o,'m*'n,'ord,bit> effect pure duplicate_bits -overload (deinfix ^^) [duplicate; duplicate_bits] +val forall Nat 'n, Nat 'o, Order 'ord. + (vector<'o,'n,'ord,bit>, int) -> list<bit> effect pure duplicate_bits_to_list + +overload (deinfix ^^) [duplicate; duplicate_bits; duplicate_to_list; duplicate_bits_to_list] (* Bitvector extension *) val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord. @@ -49,11 +56,14 @@ val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord. val forall Nat 'm, Nat 'p, Order 'ord. list<bit> -> vector<'p, 'm, 'ord, bit> effect pure extz_bl -val cast forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord. +val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord. vector<'o, 'n, 'ord, bit> -> vector<'p, 'm, 'ord, bit> effect pure exts +val forall Nat 'm, Nat 'p, Order 'ord. + list<bit> -> vector<'p, 'm, 'ord, bit> effect pure exts_bl + overload EXTZ [extz; extz_bl] -overload EXTS [exts] +overload EXTS [exts; exts_bl] val forall Type 'a, Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord, 'm >= 'o. vector<'n, 'm, 'ord, 'a> -> vector<'p, 'o, 'ord, 'a> effect pure mask @@ -123,6 +133,27 @@ overload (deinfix -) [ sub_int ] +val forall Nat 'n, Nat 'o, Order 'ord. + (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_svec + +overload (deinfix *_s) [ + add_svec +] + +val forall Nat 'n, Nat 'o, Order 'ord. + (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure xor_vec + +overload (deinfix ^) [ + xor_vec +] + +val forall Nat 'n, Nat 'o, Order 'ord. + (vector<'o, 'n, 'ord, bit>, int) -> vector<'o, 'n, 'ord, bit> effect pure shiftl + +overload (deinfix <<) [ + shiftl +] + (* Boolean operators *) val bool -> bool effect pure bool_not val (bool, bool) -> bool effect pure bool_or @@ -185,6 +216,16 @@ overload (deinfix >) [gt_atom_atom; gt_vec; gt_int] overload (deinfix <=) [lteq_atom_atom; lteq_range_atom; lteq_atom_range; lteq_vec; lteq_int] overload (deinfix <) [lt_atom_atom; lt_vec; lt_int] +val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gteq_svec +val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gt_svec +val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lteq_svec +val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lt_svec + +overload (deinfix <_s) [lt_svec] +overload (deinfix <=_s) [lteq_svec] +overload (deinfix >_s) [gt_svec] +overload (deinfix >=_s) [gteq_svec] + val (int, int) -> int effect pure quotient overload (deinfix quot) [quotient] diff --git a/mips_new_tc/mips_insts.sail b/mips_new_tc/mips_insts.sail index 07d4d841..19176c72 100644 --- a/mips_new_tc/mips_insts.sail +++ b/mips_new_tc/mips_insts.sail @@ -144,13 +144,13 @@ function clause execute (ADDI(rs, rt, imm)) = { (bit[64]) opA := rGPR(rs); if NotWordVal(opA) then - wGPR(rt) := undefined (* XXX could exit instead *) + wGPR(rt) := (bit[64]) undefined (* XXX could exit instead *) else - let (bit[33]) sum33 = (EXTS(opA[31 .. 0]) + EXTS(imm)) in + let sum33 = (bit[33]) (EXTS(opA[31 .. 0])) + (bit[33]) (EXTS(imm)) in if sum33[32] != sum33[31] then (SignalException(Ov)) else - wGPR(rt) := EXTS(sum33[31..0]) + wGPR(rt) := EXTS(sum33[31..0]) } (* ADDU 32-bit add immediate -- reg, reg, reg with possible undefined behaviour *) @@ -165,9 +165,9 @@ function clause execute (ADDU(rs, rt, rd)) = (bit[64]) opA := rGPR(rs); (bit[64]) opB := rGPR(rt); if NotWordVal(opA) | NotWordVal(opB) then - wGPR(rd) := undefined + wGPR(rd) := (bit[64]) undefined else - wGPR(rd) := EXTS(opA[31..0] + opB[31..0]) + wGPR(rd) := (bit[64]) (EXTS(opA[31..0] + opB[31..0])) } @@ -182,9 +182,9 @@ function clause execute (ADDIU(rs, rt, imm)) = { (bit[64]) opA := rGPR(rs); if NotWordVal(opA) then - wGPR(rt) := undefined (* XXX could exit instead *) + wGPR(rt) := (bit[64]) undefined (* XXX could exit instead *) else - wGPR(rt) := EXTS((opA[31 .. 0]) + EXTS(imm)) + wGPR(rt) := (bit[64]) (EXTS((opA[31 .. 0]) + (bit[32]) (EXTS(imm)))) } (**************************************************************************************) @@ -212,7 +212,7 @@ function clause decode (0b000000 : (regno) rs : (regno) rt : (regno) rd : 0b0000 function clause execute (DSUB (rs, rt, rd)) = { - let (bit[65]) temp65 = (EXTS(rGPR(rs)) - EXTS(rGPR(rt))) in + let (bit[65]) temp65 = (bit[65]) (EXTS(rGPR(rs))) - (bit[65]) (EXTS(rGPR(rt))) in { if temp65[64] != temp65[63] then (SignalException(Ov)) @@ -228,12 +228,12 @@ union ast member regregreg SUB function clause decode (0b000000 : (regno) rs : (regno) rt : (regno) rd : 0b00000 : 0b100010) = Some(SUB(rs, rt, rd)) -function clause execute (SUB(rs, rt, rd)) = +function clause execute (SUB(rs, rt, rd)) = { (bit[64]) opA := rGPR(rs); (bit[64]) opB := rGPR(rt); if NotWordVal(opA) | NotWordVal(opB) then - wGPR(rd) := undefined (* XXX could instead *) + wGPR(rd) := (bit[64]) undefined (* XXX could instead *) else let (bit[33]) temp33 = (EXTS(opA[31..0]) - EXTS(opB[31..0])) in if temp33[32] != temp33[31] then @@ -388,7 +388,7 @@ function clause decode (0b000000 : 0b00000 : (regno) rt : (regno) rd : (bit[5]) function clause execute (DSRA (rt, rd, sa)) = { temp := rGPR(rt); - wGPR(rd) := ((temp[63] ^^ sa) : (temp[63 .. sa])) + wGPR(rd) := EXTS((temp[63] ^^ sa) : (temp[63 .. sa])) } (* DSRA32 reg, reg, imm5 *) @@ -400,7 +400,7 @@ function clause execute (DSRA32 (rt, rd, sa)) = { temp := rGPR(rt); sa32 := (0b1 : sa); (* sa+32 *) - wGPR(rd) := ((temp[63] ^^ sa32) : (temp[63 .. sa32])) + wGPR(rd) := EXTS((temp[63] ^^ sa32) : (temp[63 .. sa32])) } (* DSRAV reg, reg, reg *) @@ -411,7 +411,7 @@ function clause execute (DSRAV (rs, rt, rd)) = { temp := rGPR(rt); sa := (rGPR(rs)) [5..0]; - wGPR(rd) := ((temp[63] ^^ sa) : temp[63 .. sa]) + wGPR(rd) := EXTS((temp[63] ^^ sa) : temp[63 .. sa]) } (* DSRL shift right logical - reg, reg, imm5 *) @@ -422,7 +422,7 @@ function clause decode (0b000000 : 0b00000 : (regno) rt : (regno) rd : (bit[5]) function clause execute (DSRL (rt, rd, sa)) = { temp := rGPR(rt); - wGPR(rd) := ((bitzero ^^ sa) : (temp[63 .. sa])) + wGPR(rd) := EXTS((bitzero ^^ sa) : (temp[63 .. sa])) } (* DSRL32 reg, reg, imm5 *) @@ -434,7 +434,7 @@ function clause execute (DSRL32 (rt, rd, sa)) = { temp := rGPR(rt); sa32 := (0b1 : sa); (* sa+32 *) - wGPR(rd) := ((bitzero ^^ sa32) : (temp[63 .. sa32])) + wGPR(rd) := EXTS((bitzero ^^ sa32) : (temp[63 .. sa32])) } (* DSRLV reg, reg, reg *) @@ -446,7 +446,7 @@ function clause execute (DSRLV (rs, rt, rd)) = { temp := rGPR(rt); sa := (rGPR(rs)) [5..0]; - wGPR(rd) := ((bitzero ^^ sa) : temp[63 .. sa]) + wGPR(rd) := EXTS((bitzero ^^ sa) : temp[63 .. sa]) } (**************************************************************************************) @@ -460,7 +460,7 @@ function clause decode (0b000000 : 0b00000 : (regno) rt : (regno) rd : (regno) s Some(SLL(rt, rd, sa)) function clause execute (SLL(rt, rd, sa)) = { - wGPR(rd) := EXTS((rGPR(rt)) [(31-sa)..0] : (bitzero ^^ sa)) + wGPR(rd) := EXTS((rGPR(rt)) [31 - sa .. 0] : (bitzero ^^ sa)) } (* SLLV 32-bit shift left variable *) @@ -471,7 +471,7 @@ function clause decode (0b000000 : (regno) rs : (regno) rt : (regno) rd : 0b0000 function clause execute (SLLV(rs, rt, rd)) = { sa := (rGPR(rs))[4..0]; - wGPR(rd) := EXTS((rGPR(rt)) [(31-sa)..0] : (bitzero ^^ sa)) + wGPR(rd) := EXTS((rGPR(rt)) [31 - sa .. 0] : (bitzero ^^ sa)) } (* SRA 32-bit arithmetic shift right *) @@ -485,7 +485,7 @@ function clause execute (SRA(rt, rd, sa)) = if (NotWordVal(temp)) then wGPR(rd) := undefined else - wGPR(rd) := (temp[31] ^^ (sa+32)) : temp [31..sa] + wGPR(rd) := EXTS((temp[31] ^^ (sa+32)) : temp [31..sa]) } (* SRAV 32-bit arithmetic shift right variable *) @@ -500,7 +500,7 @@ function clause execute (SRAV(rs, rt, rd)) = if (NotWordVal(temp)) then wGPR(rd) := undefined else - wGPR(rd) := (temp[31] ^^ (sa+32)) : temp [31..sa] + wGPR(rd) := EXTS((temp[31] ^^ (sa+32)) : temp [31..sa]) } (* SRL 32-bit shift right *) @@ -652,7 +652,7 @@ function clause execute (MUL(rs, rt, rd)) = { rsVal := rGPR(rs); rtVal := rGPR(rt); - (bit[64]) result := (rsVal[31..0]) *_s (rtVal[31..0]); + (bit[64]) result := EXTS((rsVal[31..0]) *_s (rtVal[31..0])); wGPR(rd) := if (NotWordVal(rsVal) | NotWordVal(rtVal)) then undefined else @@ -675,7 +675,7 @@ function clause execute (MULT(rs, rt)) = (bit[64]) result := if (NotWordVal(rsVal) | NotWordVal(rtVal)) then undefined else - (rsVal[31..0]) *_s (rtVal[31..0]); + EXTS((rsVal[31..0]) *_s (rtVal[31..0])); HI := EXTS(result[63..32]); LO := EXTS(result[31..0]); } diff --git a/src/parser.mly b/src/parser.mly index 9f48067f..8e61a0ac 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -209,6 +209,10 @@ id: { idl (DeIid($3)) } | Lparen Deinfix Lt Rparen { idl (DeIid($3)) } + | Lparen Deinfix GtUnderS Rparen + { idl (DeIid($3)) } + | Lparen Deinfix LtUnderS Rparen + { idl (DeIid($3)) } | Lparen Deinfix Minus Rparen { idl (DeIid("-")) } | Lparen Deinfix MinusUnderS Rparen @@ -243,6 +247,8 @@ id: { idl (DeIid($3)) } | Lparen Deinfix GtEq Rparen { idl (DeIid($3)) } + | Lparen Deinfix GtEqUnderS Rparen + { idl (DeIid($3)) } | Lparen Deinfix GtEqPlus Rparen { idl (DeIid($3)) } | Lparen Deinfix GtGt Rparen @@ -257,6 +263,8 @@ id: { idl (DeIid($3)) } | Lparen Deinfix LtEq Rparen { idl (DeIid($3)) } + | Lparen Deinfix LtEqUnderS Rparen + { idl (DeIid($3)) } | Lparen Deinfix LtLt Rparen { idl (DeIid($3)) } | Lparen Deinfix LtLtLt Rparen diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 9c33c841..6826087a 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -488,6 +488,7 @@ let doc_rec (Rec_aux(r,_)) = match r with let doc_tannot_opt (Typ_annot_opt_aux(t,_)) = match t with | Typ_annot_opt_some(tq,typ) -> doc_typquant tq (doc_typ typ) + | Typ_annot_opt_none -> empty let doc_effects_opt (Effect_opt_aux(e,_)) = match e with | Effect_opt_pure -> string "pure" diff --git a/src/type_check_new.ml b/src/type_check_new.ml index 8232fd6a..5be2cd43 100644 --- a/src/type_check_new.ml +++ b/src/type_check_new.ml @@ -346,6 +346,7 @@ module Env : sig val set_default_order_inc : t -> t val set_default_order_dec : t -> t val add_enum : id -> id list -> t -> t + val get_enum : id -> t -> id list val get_casts : t -> id list val allow_casts : t -> bool val no_casts : t -> t @@ -518,6 +519,11 @@ end = struct { env with enums = Bindings.add id (IdSet.of_list ids) env.enums } end + let get_enum id env = + try IdSet.elements (Bindings.find id env.enums) + with + | Not_found -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist") + let is_record id env = Bindings.mem id env.records let add_record id typq fields env = @@ -1533,19 +1539,24 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ begin match letbind with | LB_val_explicit (typschm, pat, bind) -> assert false + | LB_val_implicit (P_aux (P_typ (ptyp, _), _) as pat, bind) -> + let checked_bind = crule check_exp env bind ptyp in + let tpat, env = bind_pat env pat (typ_of checked_bind) in + annot_exp (E_let (LB_aux (LB_val_implicit (tpat, checked_bind), (let_loc, None)), crule check_exp env exp typ)) typ | LB_val_implicit (pat, bind) -> let inferred_bind = irule infer_exp env bind in let tpat, env = bind_pat env pat (typ_of inferred_bind) in annot_exp (E_let (LB_aux (LB_val_implicit (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ end - | E_app_infix (x, op, y), _ when List.length (Env.get_overloads (deinfix op) env) > 0 -> check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ()))) typ + | E_app_infix (x, op, y), _ when List.length (Env.get_overloads (deinfix op) env) > 0 -> + check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ()))) typ | E_app (f, xs), _ when List.length (Env.get_overloads f env) > 0 -> let rec try_overload = function | [] -> typ_error l ("No valid overloading for " ^ string_of_exp exp) | (f :: fs) -> begin typ_print ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"); try crule check_exp env (E_aux (E_app (f, xs), (l, ()))) typ with - | Type_error (_, m2) -> try_overload fs + | Type_error (_, m) -> typ_print ("Error : " ^ m); try_overload fs end in try_overload (Env.get_overloads f env) @@ -1789,6 +1800,10 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as end | LEXP_memory (f, xs) -> check_exp env (E_aux (E_app (f, xs @ [exp]), (l, ()))) unit_typ, env + | LEXP_cast (typ_annot, v) -> + let checked_exp = crule check_exp env exp typ_annot in + let tlexp, env' = bind_lexp env lexp (typ_of checked_exp) in + annot_assign tlexp checked_exp, env' | _ -> let inferred_exp = irule infer_exp env exp in let tlexp, env' = bind_lexp env lexp (typ_of inferred_exp) in @@ -1963,7 +1978,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | (f :: fs) -> begin typ_print ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"); try irule infer_exp env (E_aux (E_app (f, xs), (l, ()))) with - | Type_error (_, m2) -> try_overload fs + | Type_error (_, m) -> typ_print ("Error: " ^ m); try_overload fs end in try_overload (Env.get_overloads f env) @@ -2043,6 +2058,12 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ = else typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants ^ " not resolved during application of " ^ string_of_id f) end + | (utyps, (typ :: typs)), (uargs, ((n, arg) :: args)) when KidSet.is_empty (typ_frees typ) -> + begin + let carg = crule check_exp env arg typ in + let (iargs, ret_typ') = instantiate quants (utyps, typs) ret_typ (uargs, args) in + ((n, carg) :: iargs, ret_typ') + end | (utyps, (typ :: typs)), (uargs, ((n, arg) :: args)) -> begin typ_debug ("INSTANTIATE: " ^ string_of_exp arg ^ " with " ^ string_of_typ typ ^ " NF " ^ string_of_tnf (normalize_typ env typ)); diff --git a/src/type_check_new.mli b/src/type_check_new.mli index f55ccf3a..cdacf523 100644 --- a/src/type_check_new.mli +++ b/src/type_check_new.mli @@ -64,6 +64,10 @@ module Env : sig val get_regtyp : id -> t -> int * int * (index_range * id) list + (* Return all the identifiers in an enumeration. Throws a type error + if the enumeration doesn't exist. *) + val get_enum : id -> t -> id list + (* Returns true if id is a register type, false otherwise *) val is_regtyp : id -> t -> bool |
