summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair Armstrong2017-07-13 14:25:34 +0100
committerAlasdair Armstrong2017-07-13 14:25:34 +0100
commitc19b8e2b934149b6670f43d875d773115b08410e (patch)
tree65047a852db3ffb1773f59eb2d859884179abaaf
parent73e54aeec2febe58424b44c2c8f649b29910f3d9 (diff)
Improved type inference for let statements and assignments with type annotated patterns and lexps
Added get_enum to type checker interface
-rw-r--r--lib/prelude.sail47
-rw-r--r--mips_new_tc/mips_insts.sail44
-rw-r--r--src/parser.mly8
-rw-r--r--src/pretty_print_sail.ml1
-rw-r--r--src/type_check_new.ml27
-rw-r--r--src/type_check_new.mli4
6 files changed, 103 insertions, 28 deletions
diff --git a/lib/prelude.sail b/lib/prelude.sail
index 05b1ac80..350c6c20 100644
--- a/lib/prelude.sail
+++ b/lib/prelude.sail
@@ -1,6 +1,8 @@
val cast forall Nat 'n, Nat 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0:2**'m - 1|] effect pure unsigned
+val forall Nat 'n, Nat 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0 - (2**('m - 1)):2**('m - 1) - 1|] effect pure signed
+
val forall Nat 'n, Nat 'm. [|0:'n|] -> vector<'m - 1,'m,dec,bit> effect pure to_vec
(* Vector access can't actually be properly polymorphic on vector
@@ -37,10 +39,15 @@ val cast forall Type 'a. register<'a> -> 'a effect {rreg} reg_deref
(* Bitvector duplication *)
val forall Nat 'n. (bit, [:'n:]) -> vector<'n - 1,'n,dec,bit> effect pure duplicate
+val (bit, int) -> list<bit> effect pure duplicate_to_list
+
val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord.
(vector<'o,'n,'ord,bit>, [:'m:]) -> vector<'o,'m*'n,'ord,bit> effect pure duplicate_bits
-overload (deinfix ^^) [duplicate; duplicate_bits]
+val forall Nat 'n, Nat 'o, Order 'ord.
+ (vector<'o,'n,'ord,bit>, int) -> list<bit> effect pure duplicate_bits_to_list
+
+overload (deinfix ^^) [duplicate; duplicate_bits; duplicate_to_list; duplicate_bits_to_list]
(* Bitvector extension *)
val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord.
@@ -49,11 +56,14 @@ val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord.
val forall Nat 'm, Nat 'p, Order 'ord.
list<bit> -> vector<'p, 'm, 'ord, bit> effect pure extz_bl
-val cast forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord.
+val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord.
vector<'o, 'n, 'ord, bit> -> vector<'p, 'm, 'ord, bit> effect pure exts
+val forall Nat 'm, Nat 'p, Order 'ord.
+ list<bit> -> vector<'p, 'm, 'ord, bit> effect pure exts_bl
+
overload EXTZ [extz; extz_bl]
-overload EXTS [exts]
+overload EXTS [exts; exts_bl]
val forall Type 'a, Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord, 'm >= 'o.
vector<'n, 'm, 'ord, 'a> -> vector<'p, 'o, 'ord, 'a> effect pure mask
@@ -123,6 +133,27 @@ overload (deinfix -) [
sub_int
]
+val forall Nat 'n, Nat 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_svec
+
+overload (deinfix *_s) [
+ add_svec
+]
+
+val forall Nat 'n, Nat 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure xor_vec
+
+overload (deinfix ^) [
+ xor_vec
+]
+
+val forall Nat 'n, Nat 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, int) -> vector<'o, 'n, 'ord, bit> effect pure shiftl
+
+overload (deinfix <<) [
+ shiftl
+]
+
(* Boolean operators *)
val bool -> bool effect pure bool_not
val (bool, bool) -> bool effect pure bool_or
@@ -185,6 +216,16 @@ overload (deinfix >) [gt_atom_atom; gt_vec; gt_int]
overload (deinfix <=) [lteq_atom_atom; lteq_range_atom; lteq_atom_range; lteq_vec; lteq_int]
overload (deinfix <) [lt_atom_atom; lt_vec; lt_int]
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gteq_svec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gt_svec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lteq_svec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lt_svec
+
+overload (deinfix <_s) [lt_svec]
+overload (deinfix <=_s) [lteq_svec]
+overload (deinfix >_s) [gt_svec]
+overload (deinfix >=_s) [gteq_svec]
+
val (int, int) -> int effect pure quotient
overload (deinfix quot) [quotient]
diff --git a/mips_new_tc/mips_insts.sail b/mips_new_tc/mips_insts.sail
index 07d4d841..19176c72 100644
--- a/mips_new_tc/mips_insts.sail
+++ b/mips_new_tc/mips_insts.sail
@@ -144,13 +144,13 @@ function clause execute (ADDI(rs, rt, imm)) =
{
(bit[64]) opA := rGPR(rs);
if NotWordVal(opA) then
- wGPR(rt) := undefined (* XXX could exit instead *)
+ wGPR(rt) := (bit[64]) undefined (* XXX could exit instead *)
else
- let (bit[33]) sum33 = (EXTS(opA[31 .. 0]) + EXTS(imm)) in
+ let sum33 = (bit[33]) (EXTS(opA[31 .. 0])) + (bit[33]) (EXTS(imm)) in
if sum33[32] != sum33[31] then
(SignalException(Ov))
else
- wGPR(rt) := EXTS(sum33[31..0])
+ wGPR(rt) := EXTS(sum33[31..0])
}
(* ADDU 32-bit add immediate -- reg, reg, reg with possible undefined behaviour *)
@@ -165,9 +165,9 @@ function clause execute (ADDU(rs, rt, rd)) =
(bit[64]) opA := rGPR(rs);
(bit[64]) opB := rGPR(rt);
if NotWordVal(opA) | NotWordVal(opB) then
- wGPR(rd) := undefined
+ wGPR(rd) := (bit[64]) undefined
else
- wGPR(rd) := EXTS(opA[31..0] + opB[31..0])
+ wGPR(rd) := (bit[64]) (EXTS(opA[31..0] + opB[31..0]))
}
@@ -182,9 +182,9 @@ function clause execute (ADDIU(rs, rt, imm)) =
{
(bit[64]) opA := rGPR(rs);
if NotWordVal(opA) then
- wGPR(rt) := undefined (* XXX could exit instead *)
+ wGPR(rt) := (bit[64]) undefined (* XXX could exit instead *)
else
- wGPR(rt) := EXTS((opA[31 .. 0]) + EXTS(imm))
+ wGPR(rt) := (bit[64]) (EXTS((opA[31 .. 0]) + (bit[32]) (EXTS(imm))))
}
(**************************************************************************************)
@@ -212,7 +212,7 @@ function clause decode (0b000000 : (regno) rs : (regno) rt : (regno) rd : 0b0000
function clause execute (DSUB (rs, rt, rd)) =
{
- let (bit[65]) temp65 = (EXTS(rGPR(rs)) - EXTS(rGPR(rt))) in
+ let (bit[65]) temp65 = (bit[65]) (EXTS(rGPR(rs))) - (bit[65]) (EXTS(rGPR(rt))) in
{
if temp65[64] != temp65[63] then
(SignalException(Ov))
@@ -228,12 +228,12 @@ union ast member regregreg SUB
function clause decode (0b000000 : (regno) rs : (regno) rt : (regno) rd : 0b00000 : 0b100010) =
Some(SUB(rs, rt, rd))
-function clause execute (SUB(rs, rt, rd)) =
+function clause execute (SUB(rs, rt, rd)) =
{
(bit[64]) opA := rGPR(rs);
(bit[64]) opB := rGPR(rt);
if NotWordVal(opA) | NotWordVal(opB) then
- wGPR(rd) := undefined (* XXX could instead *)
+ wGPR(rd) := (bit[64]) undefined (* XXX could instead *)
else
let (bit[33]) temp33 = (EXTS(opA[31..0]) - EXTS(opB[31..0])) in
if temp33[32] != temp33[31] then
@@ -388,7 +388,7 @@ function clause decode (0b000000 : 0b00000 : (regno) rt : (regno) rd : (bit[5])
function clause execute (DSRA (rt, rd, sa)) =
{
temp := rGPR(rt);
- wGPR(rd) := ((temp[63] ^^ sa) : (temp[63 .. sa]))
+ wGPR(rd) := EXTS((temp[63] ^^ sa) : (temp[63 .. sa]))
}
(* DSRA32 reg, reg, imm5 *)
@@ -400,7 +400,7 @@ function clause execute (DSRA32 (rt, rd, sa)) =
{
temp := rGPR(rt);
sa32 := (0b1 : sa); (* sa+32 *)
- wGPR(rd) := ((temp[63] ^^ sa32) : (temp[63 .. sa32]))
+ wGPR(rd) := EXTS((temp[63] ^^ sa32) : (temp[63 .. sa32]))
}
(* DSRAV reg, reg, reg *)
@@ -411,7 +411,7 @@ function clause execute (DSRAV (rs, rt, rd)) =
{
temp := rGPR(rt);
sa := (rGPR(rs)) [5..0];
- wGPR(rd) := ((temp[63] ^^ sa) : temp[63 .. sa])
+ wGPR(rd) := EXTS((temp[63] ^^ sa) : temp[63 .. sa])
}
(* DSRL shift right logical - reg, reg, imm5 *)
@@ -422,7 +422,7 @@ function clause decode (0b000000 : 0b00000 : (regno) rt : (regno) rd : (bit[5])
function clause execute (DSRL (rt, rd, sa)) =
{
temp := rGPR(rt);
- wGPR(rd) := ((bitzero ^^ sa) : (temp[63 .. sa]))
+ wGPR(rd) := EXTS((bitzero ^^ sa) : (temp[63 .. sa]))
}
(* DSRL32 reg, reg, imm5 *)
@@ -434,7 +434,7 @@ function clause execute (DSRL32 (rt, rd, sa)) =
{
temp := rGPR(rt);
sa32 := (0b1 : sa); (* sa+32 *)
- wGPR(rd) := ((bitzero ^^ sa32) : (temp[63 .. sa32]))
+ wGPR(rd) := EXTS((bitzero ^^ sa32) : (temp[63 .. sa32]))
}
(* DSRLV reg, reg, reg *)
@@ -446,7 +446,7 @@ function clause execute (DSRLV (rs, rt, rd)) =
{
temp := rGPR(rt);
sa := (rGPR(rs)) [5..0];
- wGPR(rd) := ((bitzero ^^ sa) : temp[63 .. sa])
+ wGPR(rd) := EXTS((bitzero ^^ sa) : temp[63 .. sa])
}
(**************************************************************************************)
@@ -460,7 +460,7 @@ function clause decode (0b000000 : 0b00000 : (regno) rt : (regno) rd : (regno) s
Some(SLL(rt, rd, sa))
function clause execute (SLL(rt, rd, sa)) =
{
- wGPR(rd) := EXTS((rGPR(rt)) [(31-sa)..0] : (bitzero ^^ sa))
+ wGPR(rd) := EXTS((rGPR(rt)) [31 - sa .. 0] : (bitzero ^^ sa))
}
(* SLLV 32-bit shift left variable *)
@@ -471,7 +471,7 @@ function clause decode (0b000000 : (regno) rs : (regno) rt : (regno) rd : 0b0000
function clause execute (SLLV(rs, rt, rd)) =
{
sa := (rGPR(rs))[4..0];
- wGPR(rd) := EXTS((rGPR(rt)) [(31-sa)..0] : (bitzero ^^ sa))
+ wGPR(rd) := EXTS((rGPR(rt)) [31 - sa .. 0] : (bitzero ^^ sa))
}
(* SRA 32-bit arithmetic shift right *)
@@ -485,7 +485,7 @@ function clause execute (SRA(rt, rd, sa)) =
if (NotWordVal(temp)) then
wGPR(rd) := undefined
else
- wGPR(rd) := (temp[31] ^^ (sa+32)) : temp [31..sa]
+ wGPR(rd) := EXTS((temp[31] ^^ (sa+32)) : temp [31..sa])
}
(* SRAV 32-bit arithmetic shift right variable *)
@@ -500,7 +500,7 @@ function clause execute (SRAV(rs, rt, rd)) =
if (NotWordVal(temp)) then
wGPR(rd) := undefined
else
- wGPR(rd) := (temp[31] ^^ (sa+32)) : temp [31..sa]
+ wGPR(rd) := EXTS((temp[31] ^^ (sa+32)) : temp [31..sa])
}
(* SRL 32-bit shift right *)
@@ -652,7 +652,7 @@ function clause execute (MUL(rs, rt, rd)) =
{
rsVal := rGPR(rs);
rtVal := rGPR(rt);
- (bit[64]) result := (rsVal[31..0]) *_s (rtVal[31..0]);
+ (bit[64]) result := EXTS((rsVal[31..0]) *_s (rtVal[31..0]));
wGPR(rd) := if (NotWordVal(rsVal) | NotWordVal(rtVal)) then
undefined
else
@@ -675,7 +675,7 @@ function clause execute (MULT(rs, rt)) =
(bit[64]) result := if (NotWordVal(rsVal) | NotWordVal(rtVal)) then
undefined
else
- (rsVal[31..0]) *_s (rtVal[31..0]);
+ EXTS((rsVal[31..0]) *_s (rtVal[31..0]));
HI := EXTS(result[63..32]);
LO := EXTS(result[31..0]);
}
diff --git a/src/parser.mly b/src/parser.mly
index 9f48067f..8e61a0ac 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -209,6 +209,10 @@ id:
{ idl (DeIid($3)) }
| Lparen Deinfix Lt Rparen
{ idl (DeIid($3)) }
+ | Lparen Deinfix GtUnderS Rparen
+ { idl (DeIid($3)) }
+ | Lparen Deinfix LtUnderS Rparen
+ { idl (DeIid($3)) }
| Lparen Deinfix Minus Rparen
{ idl (DeIid("-")) }
| Lparen Deinfix MinusUnderS Rparen
@@ -243,6 +247,8 @@ id:
{ idl (DeIid($3)) }
| Lparen Deinfix GtEq Rparen
{ idl (DeIid($3)) }
+ | Lparen Deinfix GtEqUnderS Rparen
+ { idl (DeIid($3)) }
| Lparen Deinfix GtEqPlus Rparen
{ idl (DeIid($3)) }
| Lparen Deinfix GtGt Rparen
@@ -257,6 +263,8 @@ id:
{ idl (DeIid($3)) }
| Lparen Deinfix LtEq Rparen
{ idl (DeIid($3)) }
+ | Lparen Deinfix LtEqUnderS Rparen
+ { idl (DeIid($3)) }
| Lparen Deinfix LtLt Rparen
{ idl (DeIid($3)) }
| Lparen Deinfix LtLtLt Rparen
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 9c33c841..6826087a 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -488,6 +488,7 @@ let doc_rec (Rec_aux(r,_)) = match r with
let doc_tannot_opt (Typ_annot_opt_aux(t,_)) = match t with
| Typ_annot_opt_some(tq,typ) -> doc_typquant tq (doc_typ typ)
+ | Typ_annot_opt_none -> empty
let doc_effects_opt (Effect_opt_aux(e,_)) = match e with
| Effect_opt_pure -> string "pure"
diff --git a/src/type_check_new.ml b/src/type_check_new.ml
index 8232fd6a..5be2cd43 100644
--- a/src/type_check_new.ml
+++ b/src/type_check_new.ml
@@ -346,6 +346,7 @@ module Env : sig
val set_default_order_inc : t -> t
val set_default_order_dec : t -> t
val add_enum : id -> id list -> t -> t
+ val get_enum : id -> t -> id list
val get_casts : t -> id list
val allow_casts : t -> bool
val no_casts : t -> t
@@ -518,6 +519,11 @@ end = struct
{ env with enums = Bindings.add id (IdSet.of_list ids) env.enums }
end
+ let get_enum id env =
+ try IdSet.elements (Bindings.find id env.enums)
+ with
+ | Not_found -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist")
+
let is_record id env = Bindings.mem id env.records
let add_record id typq fields env =
@@ -1533,19 +1539,24 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
begin
match letbind with
| LB_val_explicit (typschm, pat, bind) -> assert false
+ | LB_val_implicit (P_aux (P_typ (ptyp, _), _) as pat, bind) ->
+ let checked_bind = crule check_exp env bind ptyp in
+ let tpat, env = bind_pat env pat (typ_of checked_bind) in
+ annot_exp (E_let (LB_aux (LB_val_implicit (tpat, checked_bind), (let_loc, None)), crule check_exp env exp typ)) typ
| LB_val_implicit (pat, bind) ->
let inferred_bind = irule infer_exp env bind in
let tpat, env = bind_pat env pat (typ_of inferred_bind) in
annot_exp (E_let (LB_aux (LB_val_implicit (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ
end
- | E_app_infix (x, op, y), _ when List.length (Env.get_overloads (deinfix op) env) > 0 -> check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ()))) typ
+ | E_app_infix (x, op, y), _ when List.length (Env.get_overloads (deinfix op) env) > 0 ->
+ check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ()))) typ
| E_app (f, xs), _ when List.length (Env.get_overloads f env) > 0 ->
let rec try_overload = function
| [] -> typ_error l ("No valid overloading for " ^ string_of_exp exp)
| (f :: fs) -> begin
typ_print ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")");
try crule check_exp env (E_aux (E_app (f, xs), (l, ()))) typ with
- | Type_error (_, m2) -> try_overload fs
+ | Type_error (_, m) -> typ_print ("Error : " ^ m); try_overload fs
end
in
try_overload (Env.get_overloads f env)
@@ -1789,6 +1800,10 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as
end
| LEXP_memory (f, xs) ->
check_exp env (E_aux (E_app (f, xs @ [exp]), (l, ()))) unit_typ, env
+ | LEXP_cast (typ_annot, v) ->
+ let checked_exp = crule check_exp env exp typ_annot in
+ let tlexp, env' = bind_lexp env lexp (typ_of checked_exp) in
+ annot_assign tlexp checked_exp, env'
| _ ->
let inferred_exp = irule infer_exp env exp in
let tlexp, env' = bind_lexp env lexp (typ_of inferred_exp) in
@@ -1963,7 +1978,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
| (f :: fs) -> begin
typ_print ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")");
try irule infer_exp env (E_aux (E_app (f, xs), (l, ()))) with
- | Type_error (_, m2) -> try_overload fs
+ | Type_error (_, m) -> typ_print ("Error: " ^ m); try_overload fs
end
in
try_overload (Env.get_overloads f env)
@@ -2043,6 +2058,12 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
else typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants
^ " not resolved during application of " ^ string_of_id f)
end
+ | (utyps, (typ :: typs)), (uargs, ((n, arg) :: args)) when KidSet.is_empty (typ_frees typ) ->
+ begin
+ let carg = crule check_exp env arg typ in
+ let (iargs, ret_typ') = instantiate quants (utyps, typs) ret_typ (uargs, args) in
+ ((n, carg) :: iargs, ret_typ')
+ end
| (utyps, (typ :: typs)), (uargs, ((n, arg) :: args)) ->
begin
typ_debug ("INSTANTIATE: " ^ string_of_exp arg ^ " with " ^ string_of_typ typ ^ " NF " ^ string_of_tnf (normalize_typ env typ));
diff --git a/src/type_check_new.mli b/src/type_check_new.mli
index f55ccf3a..cdacf523 100644
--- a/src/type_check_new.mli
+++ b/src/type_check_new.mli
@@ -64,6 +64,10 @@ module Env : sig
val get_regtyp : id -> t -> int * int * (index_range * id) list
+ (* Return all the identifiers in an enumeration. Throws a type error
+ if the enumeration doesn't exist. *)
+ val get_enum : id -> t -> id list
+
(* Returns true if id is a register type, false otherwise *)
val is_regtyp : id -> t -> bool