diff options
| -rw-r--r-- | lib/mono_rewrites.sail | 92 | ||||
| -rw-r--r-- | lib/vector_dec.sail | 2 | ||||
| -rw-r--r-- | src/monomorphise.ml | 168 | ||||
| -rw-r--r-- | src/rewrites.ml | 12 | ||||
| -rw-r--r-- | src/type_check.ml | 58 |
5 files changed, 226 insertions, 106 deletions
diff --git a/lib/mono_rewrites.sail b/lib/mono_rewrites.sail index 8259ec47..90d74149 100644 --- a/lib/mono_rewrites.sail +++ b/lib/mono_rewrites.sail @@ -19,13 +19,9 @@ overload operator >> = {shiftright} val arith_shiftright = "arith_shiftr" : forall 'n ('ord : Order). (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure -val "extz_vec" : forall 'n 'm. (atom('m),vector('n, dec, bit)) -> vector('m, dec, bit) effect pure -val extzv : forall 'n 'm. vector('n, dec, bit) -> vector('m, dec, bit) effect pure -function extzv(v) = extz_vec(sizeof('m),v) +val extzv = "extz_vec" : forall 'n 'm. (implicit('m), vector('n, dec, bit)) -> vector('m, dec, bit) effect pure -val "exts_vec" : forall 'n 'm. (atom('m),vector('n, dec, bit)) -> vector('m, dec, bit) effect pure -val extsv : forall 'n 'm. vector('n, dec, bit) -> vector('m, dec, bit) effect pure -function extsv(v) = exts_vec(sizeof('m),v) +val extsv = "exts_vec" : forall 'n 'm. (implicit('m), vector('n, dec, bit)) -> vector('m, dec, bit) effect pure /* This is generated internally to deal with case splits which reveal the size of a bitvector */ @@ -34,9 +30,9 @@ val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect p /* Definitions for the rewrites */ -val slice_mask : forall 'n, 'n >= 0. (int, int) -> bits('n) effect pure -function slice_mask(i,l) = - let one : bits('n) = extzv(0b1) in +val slice_mask : forall 'n, 'n >= 0. (implicit('n), int, int) -> bits('n) effect pure +function slice_mask(n,i,l) = + let one : bits('n) = extzv(n, 0b1) in ((one << l) - one) << i val is_zero_subrange : forall 'n, 'n >= 0. @@ -46,6 +42,13 @@ function is_zero_subrange (xs, i, j) = { (xs & slice_mask(j, i-j+1)) == extzv(0b0) } +val is_zeros_slice : forall 'n, 'n >= 0. + (bits('n), int, int) -> bool effect pure + +function is_zeros_slice (xs, i, l) = { + (xs & slice_mask(i, l)) == extzv(0b0) +} + val is_ones_subrange : forall 'n, 'n >= 0. (bits('n), int, int) -> bool effect pure @@ -54,21 +57,29 @@ function is_ones_subrange (xs, i, j) = { (xs & m) == m } +val is_ones_slice : forall 'n, 'n >= 0. + (bits('n), int, int) -> bool effect pure + +function is_ones_slice (xs, i, j) = { + let m : bits('n) = slice_mask(i,j) in + (xs & m) == m +} + val slice_slice_concat : forall 'n 'm 'r, 'n >= 0 & 'm >= 0 & 'r >= 0. - (bits('n), int, int, bits('m), int, int) -> bits('r) effect pure + (implicit('r), bits('n), int, int, bits('m), int, int) -> bits('r) effect pure -function slice_slice_concat (xs, i, l, ys, i', l') = { +function slice_slice_concat (r, xs, i, l, ys, i', l') = { let xs = (xs & slice_mask(i,l)) >> i in let ys = (ys & slice_mask(i',l')) >> i' in - extzv(xs) << l' | extzv(ys) + extzv(r, xs) << l' | extzv(r, ys) } -val slice_zeros_concat : forall 'n 'p 'q 'r, 'r == 'p + 'q & 'n >= 0 & /*'p >= 0 & 'q >= 0 &*/ 'r >= 0. - (bits('n), int, atom('p), atom('q)) -> bits('r) effect pure +val slice_zeros_concat : forall 'n 'p 'q, 'n >= 0 & 'p + 'q >= 0. + (bits('n), int, atom('p), atom('q)) -> bits('p + 'q) effect pure function slice_zeros_concat (xs, i, l, l') = { let xs = (xs & slice_mask(i,l)) >> i in - extzv(xs) << l' + extzv(l + l', xs) << l' } /* Assumes initial vectors are of equal size */ @@ -83,44 +94,59 @@ function subrange_subrange_eq (xs, i, j, ys, i', j') = { } val subrange_subrange_concat : forall 'n 'o 'p 'm 'q 'r 's, 's >= 0 & 'n >= 0 & 'm >= 0. - (bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s) effect pure + (implicit('s), bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s) effect pure -function subrange_subrange_concat (xs, i, j, ys, i', j') = { +function subrange_subrange_concat (s, xs, i, j, ys, i', j') = { let xs = (xs & slice_mask(j,i-j+1)) >> j in let ys = (ys & slice_mask(j',i'-j'+1)) >> j' in - extzv(xs) << (i' - j' + 1) | extzv(ys) + extzv(s, xs) << (i' - j' + 1) | extzv(s, ys) } val place_subrange : forall 'n 'm, 'n >= 0 & 'm >= 0. - (bits('n), int, int, int) -> bits('m) effect pure + (implicit('m), bits('n), int, int, int) -> bits('m) effect pure -function place_subrange(xs,i,j,shift) = { +function place_subrange(m,xs,i,j,shift) = { let xs = (xs & slice_mask(j,i-j+1)) >> j in - extzv(xs) << shift + extzv(m, xs) << shift } val place_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. - (bits('n), int, int, int) -> bits('m) effect pure + (implicit('m), bits('n), int, int, int) -> bits('m) effect pure -function place_slice(xs,i,l,shift) = { +function place_slice(m,xs,i,l,shift) = { let xs = (xs & slice_mask(i,l)) >> i in - extzv(xs) << shift + extzv(m, xs) << shift +} + +val set_slice_zeros : forall 'n, 'n >= 0. + (atom('n), int, bits('n), int) -> bits('n) effect pure + +function set_slice_zeros(n, i, xs, l) = { + let ys : bits('n) = slice_mask(n, i, l) in + xs & ~(ys) } val zext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. - (bits('n), int, int) -> bits('m) effect pure + (implicit('m), bits('n), int, int) -> bits('m) effect pure -function zext_slice(xs,i,l) = { +function zext_slice(m,xs,i,l) = { let xs = (xs & slice_mask(i,l)) >> i in - extzv(xs) + extzv(m, xs) } val sext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. - (bits('n), int, int) -> bits('m) effect pure + (implicit('m), bits('n), int, int) -> bits('m) effect pure -function sext_slice(xs,i,l) = { +function sext_slice(m,xs,i,l) = { let xs = arith_shiftright(((xs & slice_mask(i,l)) << ('n - i - l)), 'n - l) in - extsv(xs) + extsv(m, xs) +} + +val place_slice_signed : forall 'n 'm, 'n >= 0 & 'm >= 0. + (implicit('m), bits('n), int, int, int) -> bits('m) effect pure + +function place_slice_signed(m,xs,i,l,shift) = { + sext_slice(m, xs, i, l) << shift } /* This has different names in the aarch64 prelude (UInt) and the other @@ -150,9 +176,9 @@ function unsigned_subrange(xs,i,j) = { } -val zext_ones : forall 'n, 'n >= 0. int -> bits('n) effect pure +val zext_ones : forall 'n, 'n >= 0. (implicit('n), int) -> bits('n) effect pure -function zext_ones(m) = { +function zext_ones(n, m) = { let v : bits('n) = extsv(0b1) in - v >> ('n - m) + v >> (n - m) } diff --git a/lib/vector_dec.sail b/lib/vector_dec.sail index 166db243..8c6426d4 100644 --- a/lib/vector_dec.sail +++ b/lib/vector_dec.sail @@ -166,6 +166,6 @@ val signed = { _: "sint" } : forall 'n, 'n > 0. bits('n) -> range(- (2 ^ ('n - 1)), 2 ^ ('n - 1) - 1) -overload __size = {length} +overload __size = {__id, length} $endif diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 2b6da219..8c52fce1 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -2263,9 +2263,12 @@ let replace_with_the_value bound_nexps (E_aux (_,(l,_)) as exp) = prove __POS__ env (NC_aux (NC_equal (size,nexp), Parse_ast.Unknown)) in if is_nexp_constant size then size else - match List.find is_equal bound_nexps with - | nexp -> nexp - | exception Not_found -> size + match solve env size with + | Some n -> nconstant n + | None -> + match List.find is_equal bound_nexps with + | nexp -> nexp + | exception Not_found -> size in let mk_exp nexp l l' = let nexp = replace_size nexp in @@ -2419,15 +2422,15 @@ in *) | Some exp -> Some (fold_exp { id_exp_alg with e_app = rewrite_e_app } exp) in FCL_aux (FCL_Funcl (id,construct_pexp (pat,guard,body,(pl,empty_tannot))),(l,empty_tannot)) in - let rewrite_letbind lb = - let rewrite_e_app (id,args) = - match Bindings.find id fn_sizes with - | to_change,_ -> - let args' = mapat (replace_with_the_value []) to_change args in - E_app (id,args') - | exception Not_found -> E_app (id,args) - in fold_letbind { id_exp_alg with e_app = rewrite_e_app } lb + let rewrite_e_app (id,args) = + match Bindings.find id fn_sizes with + | to_change,_ -> + let args' = mapat (replace_with_the_value []) to_change args in + E_app (id,args') + | exception Not_found -> E_app (id,args) in + let rewrite_letbind = fold_letbind { id_exp_alg with e_app = rewrite_e_app } in + let rewrite_exp = fold_exp { id_exp_alg with e_app = rewrite_e_app } in let rewrite_def = function | DEF_fundef (FD_aux (FD_function (recopt,tannopt,effopt,funcls),(l,_))) -> (* TODO rewrite tannopt? *) @@ -2449,6 +2452,8 @@ in *) | _ -> spec | exception Not_found -> spec end + | DEF_reg_dec (DEC_aux (DEC_config (id, typ, exp), a)) -> + DEF_reg_dec (DEC_aux (DEC_config (id, typ, rewrite_exp exp), a)) | def -> def in (* @@ -3671,14 +3676,18 @@ let is_constant_vec_typ env typ = (* We have to add casts in here with appropriate length information so that the type checker knows the expected return types. *) -let rewrite_app env typ (id,args) = +let rec rewrite_app env typ (id,args) = let is_append = is_id env (Id "append") in + let is_subrange = is_id env (Id "vector_subrange") in + let is_slice = is_id env (Id "slice") in + let is_zeros = is_id env (Id "Zeros") in let is_zero_extend = - is_id env (Id "Extend") id || is_id env (Id "ZeroExtend") id || + is_id env (Id "ZeroExtend") id || is_id env (Id "zero_extend") id || is_id env (Id "sail_zero_extend") id || is_id env (Id "mips_zero_extend") id in - let try_cast_to_typ (E_aux (e,_) as exp) = + let mk_exp e = E_aux (e, (Unknown, empty_tannot)) in + let try_cast_to_typ (E_aux (e,(l, _)) as exp) = let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in match size with | Nexp_aux (Nexp_constant _,_) -> E_cast (typ,exp) @@ -3688,9 +3697,6 @@ let rewrite_app env typ (id,args) = in let rewrap e = E_aux (e, (Unknown, empty_tannot)) in if is_append id then - let is_subrange = is_id env (Id "vector_subrange") in - let is_slice = is_id env (Id "slice") in - let is_zeros = is_id env (Id "Zeros") in match args with (* (known-size-vector @ variable-vector) @ variable-vector *) | [E_aux (E_app (append, @@ -3750,6 +3756,14 @@ let rewrite_app env typ (id,args) = (Unknown,empty_tannot))]) end + (* variable-slice @ zeros *) + | [E_aux (E_app (slice1, [vector1; start1; len1]),_); + E_aux (E_app (zeros2, [len2]),_)] + when is_slice slice1 && is_zeros zeros2 && + not (is_constant start1 && is_constant len1 && is_constant len2) -> + try_cast_to_typ + (mk_exp (E_app (mk_id "place_slice", [vector1; start1; len1; len2]))) + (* variable-range @ variable-range *) | [E_aux (E_app (subrange1, [vector1; start1; end1]),_); @@ -3803,9 +3817,14 @@ let rewrite_app env typ (id,args) = end | _ -> E_app (id,args) - else if is_id env (Id "eq_vec") id then + else if is_id env (Id "eq_vec") id || is_id env (Id "neq_vec") id then (* variable-range == variable_range *) let is_subrange = is_id env (Id "vector_subrange") in + let wrap e = + if is_id env (Id "neq_vec") id + then E_app (mk_id "not_bool", [mk_exp e]) + else e + in match args with | [E_aux (E_app (subrange1, [vector1; start1; end1]),_); @@ -3813,17 +3832,37 @@ let rewrite_app env typ (id,args) = [vector2; start2; end2]),_)] when is_subrange subrange1 && is_subrange subrange2 && not (is_constant_range (start1, end1) || is_constant_range (start2, end2)) -> - E_app (mk_id "subrange_subrange_eq", - [vector1; start1; end1; vector2; start2; end2]) + wrap (E_app (mk_id "subrange_subrange_eq", + [vector1; start1; end1; vector2; start2; end2])) + | [E_aux (E_app (slice1, + [vector1; len1; start1]),_); + E_aux (E_app (slice2, + [vector2; len2; start2]),_)] + when is_slice slice1 && is_slice slice2 && + not (is_constant len1 && is_constant start1 && is_constant len2 && is_constant start2) -> + let upper start len = + mk_exp (E_app_infix (start, mk_id "+", + mk_exp (E_app_infix (len, mk_id "-", + mk_exp (E_lit (mk_lit (L_num (Big_int.of_int 1)))))))) + in + wrap (E_app (mk_id "subrange_subrange_eq", + [vector1; upper start1 len1; start1; vector2; upper start2 len2; start2])) + | [E_aux (E_app (slice1, [vector1; start1; len1]), _); + E_aux (E_app (zeros2, _), _)] + when is_slice slice1 && is_zeros zeros2 && not (is_constant len1) -> + wrap (E_app (mk_id "is_zeros_slice", [vector1; start1; len1])) | _ -> E_app (id,args) else if is_id env (Id "IsZero") id then match args with | [E_aux (E_app (subrange1, [vector1; start1; end1]),_)] - when is_id env (Id "vector_subrange") subrange1 && + when (is_id env (Id "vector_subrange") subrange1) && not (is_constant_range (start1,end1)) -> - E_app (mk_id "is_zero_subrange", - [vector1; start1; end1]) + E_app (mk_id "is_zero_subrange", [vector1; start1; end1]) + | [E_aux (E_app (slice1, [vector1; start1; len1]),_)] + when (is_slice slice1) && + not (is_constant len1) -> + E_app (mk_id "is_zeros_slice", [vector1; start1; len1]) | _ -> E_app (id,args) else if is_id env (Id "IsOnes") id then @@ -3833,6 +3872,9 @@ let rewrite_app env typ (id,args) = not (is_constant_range (start1,end1)) -> E_app (mk_id "is_ones_subrange", [vector1; start1; end1]) + | [E_aux (E_app (slice1, [vector1; start1; len1]),_)] + when is_slice slice1 && not (is_constant len1) -> + E_app (mk_id "is_ones_slice", [vector1; start1; len1]) | _ -> E_app (id,args) else if is_zero_extend then @@ -3840,52 +3882,59 @@ let rewrite_app env typ (id,args) = let is_slice = is_id env (Id "slice") in let is_zeros = is_id env (Id "Zeros") in let is_ones = is_id env (Id "Ones") in - match args with - | (E_aux (E_app (append1, + let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in + match List.filter (fun arg -> not (is_number (typ_of arg))) args with + | [E_aux (E_app (append1, [E_aux (E_app (subrange1, [vector1; start1; end1]), _); - E_aux (E_app (zeros1, [len1]),_)]),_)):: - ([] | [_;E_aux (E_id (Id_aux (Id "unsigned",_)),_)]) + E_aux (E_app (zeros1, [len1]),_)]),_)] when is_subrange subrange1 && is_zeros zeros1 && is_append append1 - -> try_cast_to_typ (rewrap (E_app (mk_id "place_subrange", [vector1; start1; end1; len1]))) + -> try_cast_to_typ (rewrap (E_app (mk_id "place_subrange", length_arg @ [vector1; start1; end1; len1]))) - | (E_aux (E_app (append1, + | [E_aux (E_app (append1, [E_aux (E_app (slice1, [vector1; start1; length1]), _); - E_aux (E_app (zeros1, [length2]),_)]),_)):: - ([] | [_;E_aux (E_id (Id_aux (Id "unsigned",_)),_)]) + E_aux (E_app (zeros1, [length2]),_)]),_)] when is_slice slice1 && is_zeros zeros1 && is_append append1 - -> try_cast_to_typ (rewrap (E_app (mk_id "place_slice", [vector1; start1; length1; length2]))) + -> try_cast_to_typ (rewrap (E_app (mk_id "place_slice", length_arg @ [vector1; start1; length1; length2]))) (* If we've already rewritten to slice_slice_concat or subrange_subrange_concat, we can just drop the zero extension because those functions can do it themselves *) - | (E_aux (E_cast (_, (E_aux (E_app (Id_aux ((Id "slice_slice_concat" | Id "subrange_subrange_concat"),_) as op, args),_))),_)):: - ([] | [_;E_aux (E_id (Id_aux (Id "unsigned",_)),_)]) - -> try_cast_to_typ (rewrap (E_app (op, args))) + | [E_aux (E_cast (_, (E_aux (E_app (Id_aux ((Id "slice_slice_concat" | Id "subrange_subrange_concat" | Id "place_slice"),_) as op, args),_))),_)] + -> try_cast_to_typ (rewrap (E_app (op, length_arg @ args))) - | (E_aux (E_app (Id_aux ((Id "slice_slice_concat" | Id "subrange_subrange_concat"),_) as op, args),_)):: - ([] | [_;E_aux (E_id (Id_aux (Id "unsigned",_)),_)]) - -> try_cast_to_typ (rewrap (E_app (op, args))) + | [E_aux (E_app (Id_aux ((Id "slice_slice_concat" | Id "subrange_subrange_concat" | Id "place_slice"),_) as op, args),_)] + -> try_cast_to_typ (rewrap (E_app (op, length_arg @ args))) | [E_aux (E_app (slice1, [vector1; start1; length1]),_)] when is_slice slice1 && not (is_constant length1) -> - try_cast_to_typ (rewrap (E_app (mk_id "zext_slice", [vector1; start1; length1]))) + try_cast_to_typ (rewrap (E_app (mk_id "zext_slice", length_arg @ [vector1; start1; length1]))) - | [E_aux (E_app (ones, [len1]),_); - _ (* unnecessary ZeroExtend length *)] - when is_ones ones -> - try_cast_to_typ (rewrap (E_app (mk_id "zext_ones", [len1]))) + | [E_aux (E_app (ones, [len1]),_)] when is_ones ones -> + try_cast_to_typ (rewrap (E_app (mk_id "zext_ones", length_arg @ [len1]))) | _ -> E_app (id,args) else if is_id env (Id "SignExtend") id || is_id env (Id "sign_extend") id then let is_slice = is_id env (Id "slice") in - match args with + let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in + match List.filter (fun arg -> not (is_number (typ_of arg))) args with | [E_aux (E_app (slice1, [vector1; start1; length1]),_)] when is_slice slice1 && not (is_constant length1) -> - E_app (mk_id "sext_slice", [vector1; start1; length1]) + try_cast_to_typ (rewrap (E_app (mk_id "sext_slice", length_arg @ [vector1; start1; length1]))) + + | [E_aux (E_app (append, + [E_aux (E_app (slice1, [vector1; start1; len1]), _); + E_aux (E_app (zeros2, [len2]), _)]), _)] + when is_append append && is_slice slice1 && is_zeros zeros2 && + not (is_constant len1 && is_constant len2) -> + E_app (mk_id "place_slice_signed", length_arg @ [vector1; start1; len1; len2]) + + | [E_aux (E_cast (_, (E_aux (E_app (Id_aux ((Id "place_slice"),_), args),_))),_)] + | [E_aux (E_app (Id_aux ((Id "place_slice"),_), args),_)] + -> try_cast_to_typ (rewrap (E_app (mk_id "place_slice_signed", length_arg @ args))) (* If the original had a length, keep it *) - | [E_aux (E_app (slice1, [vector1; start1; length1]),_);length2] + (* | [E_aux (E_app (slice1, [vector1; start1; length1]),_);length2] when is_slice slice1 && not (is_constant length1) -> begin match Type_check.destruct_atom_nexp (env_of length2) (typ_of length2) with @@ -3895,10 +3944,18 @@ let rewrite_app env typ (id,args) = E_cast (vector_typ nlen order bittyp, E_aux (E_app (mk_id "sext_slice", [vector1; start1; length1]), (Unknown,empty_tannot))) - end + end *) | _ -> E_app (id,args) + else if is_id env (Id "Extend") id then + match args with + | [vector; len; unsigned] -> + let extz = mk_exp (rewrite_app env typ (mk_id "ZeroExtend", [vector; len])) in + let exts = mk_exp (rewrite_app env typ (mk_id "SignExtend", [vector; len])) in + E_if (unsigned, extz, exts) + | _ -> E_app (id, args) + else if is_id env (Id "UInt") id || is_id env (Id "unsigned") id then let is_slice = is_id env (Id "slice") in let is_subrange = is_id env (Id "vector_subrange") in @@ -3912,6 +3969,13 @@ let rewrite_app env typ (id,args) = | _ -> E_app (id,args) + else if is_id env (Id "__SetSlice_bits") id then + match args with + | [len; slice_len; vector; pos; E_aux (E_app (zeros, _), _)] + when is_zeros zeros -> + E_app (mk_id "set_slice_zeros", [len; slice_len; vector; pos]) + | _ -> E_app (id, args) + else E_app (id,args) let rewrite_aux = function @@ -4412,7 +4476,9 @@ let rewrite_toplevel_nexps (Defs defs) = VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs,typ),ts_l),id,ext_opt,is_cast),ann) in Some (id, nexp_map, vs) in - let rewrite_typ_in_body env nexp_map typ = + (* Changing types in the body confuses simple sizeof rewriting, so turn it + off for now *) + (* let rewrite_typ_in_body env nexp_map typ = let rec aux (Typ_aux (t,l) as typ_full) = match t with | Typ_tup typs -> Typ_aux (Typ_tup (List.map aux typs),l) @@ -4468,19 +4534,19 @@ let rewrite_toplevel_nexps (Defs defs) = match Bindings.find id spec_map with | nexp_map -> FCL_aux (FCL_Funcl (id,rewrite_body nexp_map pexp),ann) | exception Not_found -> funcl - in + in *) let rewrite_def spec_map def = match def with | DEF_spec vs -> (match rewrite_valspec vs with | None -> spec_map, def | Some (id, nexp_map, vs) -> Bindings.add id nexp_map spec_map, DEF_spec vs) - | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) -> + (* | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) -> (* Type annotations on function definitions will have been turned into valspecs by type checking, so it should be safe to drop them rather than updating them. *) let tann = Typ_annot_opt_aux (Typ_annot_opt_none,Generated Unknown) in spec_map, - DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann)) + DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann)) *) | _ -> spec_map, def in let _, defs = List.fold_left (fun (spec_map,t) def -> diff --git a/src/rewrites.ml b/src/rewrites.ml index 4b147aee..5cbc3545 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -4012,6 +4012,16 @@ let rewrite_defs_remove_superfluous_letbinds = E_aux (E_internal_return (exp1),e1annot) | _ -> E_aux (exp,annot) end + | E_internal_plet (_, E_aux (E_throw e, a), _) -> E_aux (E_throw e, a) + | E_internal_plet (_, E_aux (E_assert (c, msg), a), _) -> + begin match typ_of c with + | Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]), _) + when prove __POS__ (env_of c) (nc_not nc) -> + (* Drop rest of block after an 'assert(false)' *) + E_aux (E_exit (infer_exp (env_of c) (mk_lit_exp L_unit)), a) + | _ -> + E_aux (exp, annot) + end | _ -> E_aux (exp,annot) in let alg = { id_exp_alg with e_aux = e_aux } in @@ -5064,7 +5074,7 @@ let rewrite_defs_lem = [ (* ("remove_assert", rewrite_defs_remove_assert); *) ("top_sort_defs", top_sort_defs); ("trivial_sizeof", rewrite_trivial_sizeof); - ("sizeof", rewrite_sizeof); + (* ("sizeof", rewrite_sizeof); *) ("early_return", rewrite_defs_early_return); ("fix_val_specs", rewrite_fix_val_specs); (* early_return currently breaks the types *) diff --git a/src/type_check.ml b/src/type_check.ml index 8fca2c7a..7faa0234 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -689,6 +689,8 @@ end = struct typ_error env l "Bidirectional types cannot be the same on both sides" | Typ_bidir (typ1, typ2) -> wf_typ ~exs:exs env typ1; wf_typ ~exs:exs env typ2 | Typ_tup typs -> List.iter (wf_typ ~exs:exs env) typs + | Typ_app (id, [A_aux (A_nexp _, _) as arg]) when string_of_id id = "implicit" -> + wf_typ_arg ~exs:exs env arg | Typ_app (id, args) when bound_typ_id env id -> List.iter (wf_typ_arg ~exs:exs env) args; check_args_typquant id env args (infer_kind env id) @@ -1612,12 +1614,12 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au else begin match nexp_aux2 with | Nexp_sum (n2a, n2b) -> - if nexp_identical n1a n2a - then unify_nexp l env goals n1b n2b + if KidSet.is_empty (nexp_frees n2a) + then unify_nexp l env goals n2b (nminus nexp1 n2a) else - if nexp_identical n1b n2b - then unify_nexp l env goals n1a n2a - else unify_error l "Unification error" + if KidSet.is_empty (nexp_frees n2a) + then unify_nexp l env goals n2a (nminus nexp1 n2b) + else merge_uvars l (unify_nexp l env goals n1a n2a) (unify_nexp l env goals n1b n2b) | _ -> unify_error l ("Both sides of Int expression " ^ string_of_nexp nexp1 ^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2) end @@ -1714,13 +1716,23 @@ let rec ambiguous_vars (Typ_aux (aux, _)) = and ambiguous_arg_vars (A_aux (aux, _)) = match aux with | A_bool nc -> ambiguous_nc_vars nc + | A_nexp nexp -> ambiguous_nexp_vars nexp | _ -> KidSet.empty and ambiguous_nc_vars (NC_aux (aux, _)) = match aux with | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) + | NC_bounded_le (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) + | NC_bounded_ge (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) + | NC_equal (n1, n2) | NC_not_equal (n1, n2) -> + KidSet.union (ambiguous_nexp_vars n1) (ambiguous_nexp_vars n2) | _ -> KidSet.empty - + +and ambiguous_nexp_vars (Nexp_aux (aux, _)) = + match aux with + | Nexp_sum (nexp1, nexp2) -> KidSet.union (tyvars_of_nexp nexp1) (tyvars_of_nexp nexp2) + | _ -> KidSet.empty + (**************************************************************************) (* 3.5. Subtyping with existentials *) (**************************************************************************) @@ -2831,7 +2843,7 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ = try let inferred_cast = irule infer_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) in let ityp, env = bind_existential l None (typ_of inferred_cast) env in - inferred_cast, unify l env goals typ ityp, env + inferred_cast, unify l env (KidSet.diff goals (ambiguous_vars typ)) typ ityp, env with | Type_error (_, _, err) -> try_casts casts | Unification_error (_, err) -> try_casts casts @@ -2841,7 +2853,7 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ = try typ_debug (lazy ("Coercing unification: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); let atyp, env = bind_existential l None (typ_of annotated_exp) env in - annotated_exp, unify l env goals typ atyp, env + annotated_exp, unify l env (KidSet.diff goals (ambiguous_vars typ)) typ atyp, env with | Unification_error (_, m) when Env.allow_casts env -> let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in @@ -3662,15 +3674,21 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = typ_debug (lazy ("Quantifiers " ^ Util.string_of_list ", " string_of_quant_item !quants)); - let implicits, typ_args = - if not (List.length typ_args = List.length xs) then - let typ_args' = List.filter is_not_implicit typ_args in - if not (List.length typ_args' = List.length xs) then - typ_error env l (Printf.sprintf "Function %s applied to %d args, expected %d" (string_of_id f) (List.length xs) (List.length typ_args)) - else - get_implicits typ_args, typ_args' - else - [], List.map implicit_to_int typ_args + let implicits, typ_args, xs = + let typ_args' = List.filter is_not_implicit typ_args in + match xs, typ_args' with + (* Support the case where a function has only implicit arguments; + allow it to be called either as f() or f(i...) *) + | [E_aux (E_lit (L_aux (L_unit, _)), _)], [] -> + get_implicits typ_args, [], [] + | _ -> + if not (List.length typ_args = List.length xs) then + if not (List.length typ_args' = List.length xs) then + typ_error env l (Printf.sprintf "Function %s applied to %d args, expected %d (%d explicit): %s" (string_of_id f) (List.length xs) (List.length typ_args) (List.length typ_args') (String.concat ", " (List.map string_of_typ typ_args))) + else + get_implicits typ_args, typ_args', xs + else + [], List.map implicit_to_int typ_args, xs in let instantiate_quant (v, arg) (QI_aux (aux, l) as qi) = @@ -3734,7 +3752,7 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = let solve_implicit impl = match KBindings.find_opt impl !all_unifiers with | Some (A_aux (A_nexp (Nexp_aux (Nexp_constant c, _)), _)) -> irule infer_exp env (mk_lit_exp (L_num c)) | Some (A_aux (A_nexp n, _)) -> irule infer_exp env (mk_exp (E_sizeof n)) - | _ -> typ_error env l "bad" + | _ -> typ_error env l ("Cannot solve implicit " ^ string_of_kid impl ^ " in " ^ string_of_exp (mk_exp (E_app (f, List.map strip_exp xs)))) in let xs = List.map solve_implicit implicits @ xs in @@ -4448,10 +4466,10 @@ let check_funcl env (FCL_aux (FCL_Funcl (id, pexp), (l, _))) typ = function arguments as like a tuple, and maybe we shouldn't. *) let typed_pexp, prop_eff = - match typ_args with + match List.map implicit_to_int typ_args with | [typ_arg] -> propagate_pexp_effect (check_case env typ_arg (strip_pexp pexp) typ_ret) - | _ -> + | typ_args -> propagate_pexp_effect (check_case env (Typ_aux (Typ_tup typ_args, l)) (strip_pexp pexp) typ_ret) in FCL_aux (FCL_Funcl (id, typed_pexp), (l, mk_expected_tannot env typ prop_eff (Some typ))) |
