summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/type_check.ml60
-rw-r--r--src/value.ml55
2 files changed, 54 insertions, 61 deletions
diff --git a/src/type_check.ml b/src/type_check.ml
index 7af0ecff..69e1e2c9 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -3145,40 +3145,34 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
typ_error env l (Printf.sprintf "Cannot bind tuple pattern %s against non tuple type %s"
(string_of_pat pat) (string_of_typ typ))
end
- | P_app (f, pats) when Env.is_union_constructor f env ->
- begin
- (* Treat Ctor((p, x)) the same as Ctor(p, x) *)
- let pats = match pats with [P_aux (P_tup pats, _)] -> pats | _ -> pats in
- let (typq, ctor_typ) = Env.get_union_id f env in
- let quants = quant_items typq in
- let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with
- | Typ_tup typs -> typs
- | _ -> [typ]
- in
- match Env.expand_synonyms env ctor_typ with
- | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) ->
- begin
- try
- let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in
- typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ));
- let unifiers = unify l env goals ret_typ typ in
- let arg_typ' = subst_unifiers unifiers arg_typ in
- let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
- if not (List.for_all (solve_quant env) quants') then
- typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env))
- else ();
- let ret_typ' = subst_unifiers unifiers ret_typ in
- let arg_typ', env = bind_existential l None arg_typ' env in
- let tpats, env, guards =
- try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with
- | Invalid_argument _ -> typ_error env l "Union constructor pattern arguments have incorrect length"
- in
- annot_pat (P_app (f, List.rev tpats)) typ, env, guards
- with
- | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against union constructor: " ^ m)
- end
- | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ)
+ | P_app (f, [pat]) when Env.is_union_constructor f env ->
+ let (typq, ctor_typ) = Env.get_union_id f env in
+ let quants = quant_items typq in
+ begin match Env.expand_synonyms env ctor_typ with
+ | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) ->
+ begin
+ try
+ let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in
+ typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ));
+ let unifiers = unify l env goals ret_typ typ in
+ let arg_typ' = subst_unifiers unifiers arg_typ in
+ let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
+ if not (List.for_all (solve_quant env) quants') then
+ typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env))
+ else ();
+ let ret_typ' = subst_unifiers unifiers ret_typ in
+ let arg_typ', env = bind_existential l None arg_typ' env in
+ let tpat, env, guards = bind_pat env pat arg_typ' in
+ annot_pat (P_app (f, [tpat])) typ, env, guards
+ with
+ | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against union constructor: " ^ m)
+ end
+ | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ)
end
+
+ | P_app (f, pats) when Env.is_union_constructor f env ->
+ (* Treat Ctor(x, y) as Ctor((x, y)) *)
+ bind_pat env (mk_pat (P_app (f, [mk_pat (P_tup pats)]))) typ
| P_app (f, pats) when Env.is_mapping f env ->
begin
diff --git a/src/value.ml b/src/value.ml
index 8f8e651a..d1b945a7 100644
--- a/src/value.ml
+++ b/src/value.ml
@@ -93,6 +93,33 @@ type value =
with a direct register read. *)
| V_attempted_read of string
+let coerce_bit = function
+ | V_bit b -> b
+ | _ -> assert false
+
+let is_bit = function
+ | V_bit _ -> true
+ | _ -> false
+
+let rec string_of_value = function
+ | V_vector vs when List.for_all is_bit vs -> Sail_lib.string_of_bits (List.map coerce_bit vs)
+ | V_vector vs -> "[" ^ Util.string_of_list ", " string_of_value vs ^ "]"
+ | V_bool true -> "true"
+ | V_bool false -> "false"
+ | V_bit Sail_lib.B0 -> "bitzero"
+ | V_bit Sail_lib.B1 -> "bitone"
+ | V_int n -> Big_int.to_string n
+ | V_tuple vals -> "(" ^ Util.string_of_list ", " string_of_value vals ^ ")"
+ | V_list vals -> "[|" ^ Util.string_of_list ", " string_of_value vals ^ "|]"
+ | V_unit -> "()"
+ | V_string str -> "\"" ^ str ^ "\""
+ | V_ref str -> "ref " ^ str
+ | V_real r -> Sail_lib.string_of_real r
+ | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")"
+ | V_record record ->
+ "{" ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) ^ "}"
+ | V_attempted_read _ -> assert false
+
let rec eq_value v1 v2 =
match v1, v2 with
| V_vector v1s, V_vector v2s when List.length v1s = List.length v2s -> List.for_all2 eq_value v1s v2s
@@ -111,12 +138,7 @@ let rec eq_value v1 v2 =
StringMap.equal eq_value fields1 fields2
| _, _ -> false
-let coerce_bit = function
- | V_bit b -> b
- | _ -> assert false
-
let coerce_ctor = function
- | V_ctor (str, [V_tuple vals]) -> (str, vals)
| V_ctor (str, vals) -> (str, vals)
| _ -> assert false
@@ -371,33 +393,10 @@ let value_replicate_bits = function
| [v1; v2] -> mk_vector (Sail_lib.replicate_bits (coerce_bv v1, coerce_int v2))
| _ -> failwith "value replicate_bits"
-let is_bit = function
- | V_bit _ -> true
- | _ -> false
-
let is_ctor = function
| V_ctor _ -> true
| _ -> false
-let rec string_of_value = function
- | V_vector vs when List.for_all is_bit vs -> Sail_lib.string_of_bits (List.map coerce_bit vs)
- | V_vector vs -> "[" ^ Util.string_of_list ", " string_of_value vs ^ "]"
- | V_bool true -> "true"
- | V_bool false -> "false"
- | V_bit Sail_lib.B0 -> "bitzero"
- | V_bit Sail_lib.B1 -> "bitone"
- | V_int n -> Big_int.to_string n
- | V_tuple vals -> "(" ^ Util.string_of_list ", " string_of_value vals ^ ")"
- | V_list vals -> "[|" ^ Util.string_of_list ", " string_of_value vals ^ "|]"
- | V_unit -> "()"
- | V_string str -> "\"" ^ str ^ "\""
- | V_ref str -> "ref " ^ str
- | V_real r -> Sail_lib.string_of_real r
- | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")"
- | V_record record ->
- "{" ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) ^ "}"
- | V_attempted_read _ -> assert false
-
let value_sign_extend = function
| [v1; v2] -> mk_vector (Sail_lib.sign_extend (coerce_bv v1, coerce_int v2))
| _ -> failwith "value sign_extend"