summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/rewrites.ml128
-rw-r--r--src/type_check.ml28
-rw-r--r--test/typecheck/pass/execute_decode_hard.sail26
-rw-r--r--test/typecheck/pass/fpthreesimp.sail (renamed from test/typecheck/fpthreesimp.sail)6
4 files changed, 111 insertions, 77 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index d8cb5a5d..bc9792ef 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -2043,60 +2043,74 @@ let rewrite_split_fun_constr_pats fun_name env (Defs defs) =
let clauses, aux_funs =
List.fold_left
(fun (clauses, aux_funs) (FCL_aux (FCL_Funcl (id, pexp), fannot) as clause) ->
- let pat, guard, exp, annot = destruct_pexp pexp in
- match pat with
- | P_aux (P_app (constr_id, args), pannot) ->
- let argstup_typ = tuple_typ (List.map typ_of_pat args) in
- let pannot' = swaptyp argstup_typ pannot in
- let pat' =
- match args with
- | [arg] -> arg
- | _ -> P_aux (P_tup args, pannot')
- in
- let pexp' = construct_pexp (pat', guard, exp, annot) in
- let aux_fun_id = prepend_id (fun_name ^ "_") constr_id in
- let aux_funcl = FCL_aux (FCL_Funcl (aux_fun_id, pexp'), pannot') in
- begin
- try
- let aux_clauses = Bindings.find aux_fun_id aux_funs in
- clauses,
- Bindings.add aux_fun_id (aux_clauses @ [aux_funcl]) aux_funs
- with Not_found ->
- let argpats, argexps = List.split (List.mapi
- (fun idx (P_aux (_,a) as pat) ->
- let id = match pat_var pat with
- | Some id -> id
- | None -> mk_id ("arg" ^ string_of_int idx)
- in
- P_aux (P_id id, a), E_aux (E_id id, a))
- args)
- in
- let pexp = construct_pexp
- (P_aux (P_app (constr_id, argpats), pannot),
- None,
- E_aux (E_app (aux_fun_id, argexps), annot),
- annot)
- in
- clauses @ [FCL_aux (FCL_Funcl (id, pexp), fannot)],
- Bindings.add aux_fun_id [aux_funcl] aux_funs
- end
- | _ -> clauses @ [clause], aux_funs)
+ let pat, guard, exp, annot = destruct_pexp pexp in
+ match pat with
+ | P_aux (P_app (constr_id, args), pannot) ->
+ let ctor_typq, ctor_typ = Env.get_union_id constr_id env in
+ let args = match args with [P_aux (P_tup args, _)] -> args | _ -> args in
+ let argstup_typ = tuple_typ (List.map typ_of_pat args) in
+ let pannot' = swaptyp argstup_typ pannot in
+ let pat' =
+ match args with
+ | [arg] -> arg
+ | _ -> P_aux (P_tup args, pannot')
+ in
+ let pexp' = construct_pexp (pat', guard, exp, annot) in
+ let aux_fun_id = prepend_id (fun_name ^ "_") constr_id in
+ let aux_funcl = FCL_aux (FCL_Funcl (aux_fun_id, pexp'), pannot') in
+ begin
+ try
+ let aux_clauses = Bindings.find aux_fun_id aux_funs in
+ clauses,
+ Bindings.add aux_fun_id (aux_clauses @ [(aux_funcl, ctor_typq, ctor_typ)]) aux_funs
+ with Not_found ->
+ let argpats, argexps = List.split (List.mapi
+ (fun idx (P_aux (_,a) as pat) ->
+ let id = match pat_var pat with
+ | Some id -> id
+ | None -> mk_id ("arg" ^ string_of_int idx)
+ in
+ P_aux (P_id id, a), E_aux (E_id id, a))
+ args)
+ in
+ let pexp = construct_pexp
+ (P_aux (P_app (constr_id, argpats), pannot),
+ None,
+ E_aux (E_app (aux_fun_id, argexps), annot),
+ annot)
+ in
+ clauses @ [FCL_aux (FCL_Funcl (id, pexp), fannot)],
+ Bindings.add aux_fun_id [(aux_funcl, ctor_typq, ctor_typ)] aux_funs
+ end
+ | _ -> clauses @ [clause], aux_funs)
([], Bindings.empty) clauses
in
- let add_aux_def id funcls defs =
- let env, args_typ, ret_typ = match funcls with
- | FCL_aux (FCL_Funcl (_, pexp), _) :: _ ->
+ let add_aux_def id aux_funs defs =
+ let funcls = List.map (fun (fcl, _, _) -> fcl) aux_funs in
+ let env, quants, args_typ, ret_typ = match aux_funs with
+ | (FCL_aux (FCL_Funcl (_, pexp), _), ctor_typq, ctor_typ) :: _ ->
let pat, _, exp, _ = destruct_pexp pexp in
- env_of exp, typ_of_pat pat, typ_of exp
+ let ctor_quants args_typ =
+ List.filter (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ args_typ))
+ (quant_items ctor_typq)
+ in
+ begin match ctor_typ with
+ | Typ_aux (Typ_fn ([Typ_aux (Typ_exist (kopts, nc, args_typ), _)], _, _), _) ->
+ env_of exp, ctor_quants args_typ @ List.map mk_qi_kopt kopts @ [mk_qi_nc nc], args_typ, typ_of exp
+ | Typ_aux (Typ_fn ([args_typ], _, _), _) -> env_of exp, ctor_quants args_typ, args_typ, typ_of exp
+ | _ ->
+ raise (Reporting.err_unreachable l __POS__
+ ("Union constructor has non-function type: " ^ string_of_typ ctor_typ))
+ end
| _ ->
raise (Reporting.err_unreachable l __POS__
- "rewrite_split_fun_constr_pats: empty auxiliary function")
+ "rewrite_split_fun_constr_pats: empty auxiliary function")
in
let eff = List.fold_left
- (fun eff (FCL_aux (FCL_Funcl (_, pexp), _)) ->
- let _, _, exp, _ = destruct_pexp pexp in
- union_effects eff (effect_of exp))
- no_effect funcls
+ (fun eff (FCL_aux (FCL_Funcl (_, pexp), _)) ->
+ let _, _, exp, _ = destruct_pexp pexp in
+ union_effects eff (effect_of exp))
+ no_effect funcls
in
let fun_typ =
(* Because we got the argument type from a pattern we need to
@@ -2107,27 +2121,9 @@ let rewrite_split_fun_constr_pats fun_name env (Defs defs) =
| _ ->
function_typ [args_typ] ret_typ eff
in
- let quant_new_kopts qis =
- let quant_kopts = List.fold_left KOptSet.union KOptSet.empty (List.map kopts_of_quant_item qis) in
- let typ_kopts = kopts_of_typ fun_typ in
- let new_kopts = KOptSet.diff typ_kopts quant_kopts in
- List.map mk_qi_kopt (KOptSet.elements new_kopts)
- in
- let typquant = match typquant with
- | TypQ_aux (TypQ_tq qis, l) ->
- let qis =
- List.filter
- (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ fun_typ))
- qis
- @ quant_new_kopts qis
- in
- TypQ_aux (TypQ_tq qis, l)
- | _ ->
- TypQ_aux (TypQ_tq (List.map mk_qi_kopt (KOptSet.elements (kopts_of_typ fun_typ))), l)
- in
let val_spec =
VS_aux (VS_val_spec
- (mk_typschm typquant fun_typ, id, (fun _ -> None), false),
+ (mk_typschm (mk_typquant quants) fun_typ, id, (fun _ -> None), false),
(Parse_ast.Unknown, empty_tannot))
in
let fundef = FD_aux (FD_function (r_o, t_o, e_o, funcls), fdannot) in
diff --git a/src/type_check.ml b/src/type_check.ml
index c1689a82..603052b5 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -432,6 +432,7 @@ module Env : sig
val get_typ_var_loc : kid -> t -> Ast.l
val get_typ_vars : t -> kind_aux KBindings.t
val get_typ_var_locs : t -> Ast.l KBindings.t
+ val add_typ_var_shadow : l -> kinded_id -> t -> t * kid option
val add_typ_var : l -> kinded_id -> t -> t
val get_ret_typ : t -> typ option
val add_ret_typ : typ -> t -> t
@@ -656,10 +657,9 @@ end = struct
^ " with " ^ Util.string_of_list ", " string_of_n_constraint env.constraints)
let get_typ_synonym id env =
- begin match Bindings.find_opt id env.typ_synonyms with
+ match Bindings.find_opt id env.typ_synonyms with
| Some (typq, arg) -> mk_synonym typq arg
| None -> raise Not_found
- end
let rec expand_constraint_synonyms env (NC_aux (aux, l) as nc) =
typ_debug ~level:2 (lazy ("Expanding " ^ string_of_n_constraint nc));
@@ -1208,7 +1208,7 @@ end = struct
with
| Not_found -> Unbound
- let add_typ_var l (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) env =
+ let add_typ_var_shadow l (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) env =
if KBindings.mem v env.typ_vars then begin
let n = match KBindings.find_opt v env.shadow_vars with Some n -> n | None -> 0 in
let s_l, s_k = KBindings.find v env.typ_vars in
@@ -1218,13 +1218,15 @@ end = struct
constraints = List.map (constraint_subst v (arg_kopt (mk_kopt s_k s_v))) env.constraints;
typ_vars = KBindings.add v (l, k) (KBindings.add s_v (s_l, s_k) env.typ_vars);
shadow_vars = KBindings.add v (n + 1) env.shadow_vars
- }
+ }, Some s_v
end
else begin
typ_print (lazy (adding ^ "type variable " ^ string_of_kid v ^ " : " ^ string_of_kind_aux k));
- { env with typ_vars = KBindings.add v (l, k) env.typ_vars }
+ { env with typ_vars = KBindings.add v (l, k) env.typ_vars }, None
end
+ let add_typ_var l kopt env = fst (add_typ_var_shadow l kopt env)
+
let get_constraints env = env.constraints
let add_constraint constr env =
@@ -3133,6 +3135,8 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
end
| P_app (f, pats) when Env.is_union_constructor f env ->
begin
+ (* Treat Ctor((p, x)) the same as Ctor(p, x) *)
+ let pats = match pats with [P_aux (P_tup pats, _)] -> pats | _ -> pats in
let (typq, ctor_typ) = Env.get_union_id f env in
let quants = quant_items typq in
let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with
@@ -3152,6 +3156,7 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env))
else ();
let ret_typ' = subst_unifiers unifiers ret_typ in
+ let arg_typ', env = bind_existential l None arg_typ' env in
let tpats, env, guards =
try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with
| Invalid_argument _ -> typ_error env l "Union constructor pattern arguments have incorrect length"
@@ -3325,15 +3330,20 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) =
| _ -> typ_error env l ("Couldn't infer type of pattern " ^ string_of_pat pat)
and bind_typ_pat env (TP_aux (typ_pat_aux, l) as typ_pat) (Typ_aux (typ_aux, _) as typ) =
+ typ_print (lazy (Util.("Binding type pattern " |> yellow |> clear) ^ string_of_typ_pat typ_pat ^ " to " ^ string_of_typ typ));
match typ_pat_aux, typ_aux with
| TP_wild, _ -> env
| TP_var kid, _ ->
begin
match typ_nexps typ, typ_constraints typ with
| [nexp], [] ->
- Env.add_constraint (nc_eq (nvar kid) nexp) (Env.add_typ_var l (mk_kopt K_int kid) env)
+ let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_int kid) env in
+ let nexp = match shadow with Some s_v -> nexp_subst kid (arg_nexp (nvar s_v)) nexp | None -> nexp in
+ Env.add_constraint (nc_eq (nvar kid) nexp) env
| [], [nc] ->
- Env.add_constraint (nc_and (nc_or (nc_not nc) (nc_var kid)) (nc_or nc (nc_not (nc_var kid)))) (Env.add_typ_var l (mk_kopt K_bool kid) env)
+ let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_bool kid) env in
+ let nexp = match shadow with Some s_v -> constraint_subst kid (arg_bool (nc_var s_v)) nc | None -> nc in
+ Env.add_constraint (nc_and (nc_or (nc_not nc) (nc_var kid)) (nc_or nc (nc_not (nc_var kid)))) env
| [], [] ->
typ_error env l ("No numeric expressions in " ^ string_of_typ typ ^ " to bind " ^ string_of_kid kid ^ " to")
| _, _ ->
@@ -3346,7 +3356,9 @@ and bind_typ_pat_arg env (TP_aux (typ_pat_aux, l) as typ_pat) (A_aux (typ_arg_au
match typ_pat_aux, typ_arg_aux with
| TP_wild, _ -> env
| TP_var kid, A_nexp nexp ->
- Env.add_constraint (nc_eq (nvar kid) nexp) (Env.add_typ_var l (mk_kopt K_int kid) env)
+ let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_int kid) env in
+ let nexp = match shadow with Some s_v -> nexp_subst kid (arg_nexp (nvar s_v)) nexp | None -> nexp in
+ Env.add_constraint (nc_eq (nvar kid) nexp) env
| _, A_typ typ -> bind_typ_pat env typ_pat typ
| _, A_order _ -> typ_error env l "Cannot bind type pattern against order"
| _, _ -> typ_error env l ("Couldn't bind type argument " ^ string_of_typ_arg typ_arg ^ " with " ^ string_of_typ_pat typ_pat)
diff --git a/test/typecheck/pass/execute_decode_hard.sail b/test/typecheck/pass/execute_decode_hard.sail
new file mode 100644
index 00000000..d5e91b79
--- /dev/null
+++ b/test/typecheck/pass/execute_decode_hard.sail
@@ -0,0 +1,26 @@
+default Order dec
+
+$include <prelude.sail>
+
+union ast('D: Int), 'D in {32, 64, 128} = {
+ Instr1 : {'R, 'R in {32, 64}. (int('R), bits('D))}
+}
+
+val execute : forall 'd, 'd in {32, 64, 128}. ast('d) -> unit
+
+function clause execute(Instr1(r as int('R), d)) = {
+ _prove(constraint('R in {32, 64}));
+ if length(d) == 64 then {
+ let _ = d[r - 1 .. 0];
+ ()
+ }
+}
+
+function clause execute(Instr1((r as int('R), d))) = {
+ _prove(constraint('R in {32, 64}));
+ if length(d) == 64 then {
+ let _ = d[r - 1 .. 0];
+ ()
+ }
+}
+
diff --git a/test/typecheck/fpthreesimp.sail b/test/typecheck/pass/fpthreesimp.sail
index 3f759ba4..d0f44119 100644
--- a/test/typecheck/fpthreesimp.sail
+++ b/test/typecheck/pass/fpthreesimp.sail
@@ -4,11 +4,11 @@ $include <prelude.sail>
val Zeros : forall 'N, 'N >= 0. int('N) -> bits('N)
-type FPExponent ('N : Int) = {'E, ('N = 16 & 'E = 5) | ('N = 32 & 'E = 8) | ('N = 64 & 'E = 11). int('E)}
+type FPExponent ('N : Int) = {'E, ('N == 16 & 'E == 5) | ('N == 32 & 'E == 8) | ('N == 64 & 'E == 11). int('E)}
-val FPThree : forall 'N, 'N in {16, 32, 64}. bits(1) -> bits('N)
+val FPThree : forall 'N, 'N in {16, 32, 64}. (implicit('N), bits(1)) -> bits('N)
-function FPThree(sign) = {
+function FPThree(N, sign) = {
let E : FPExponent('N) = if 'N == 16 then 5 else if 'N == 32 then 8 else 11;
sign @ 0b1 @ Zeros(E - 1) @ 0b1 @ Zeros('N - E - 2)
} \ No newline at end of file