diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/type_check.ml | 60 | ||||
| -rw-r--r-- | src/value.ml | 55 |
2 files changed, 54 insertions, 61 deletions
diff --git a/src/type_check.ml b/src/type_check.ml index 7af0ecff..69e1e2c9 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -3145,40 +3145,34 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) typ_error env l (Printf.sprintf "Cannot bind tuple pattern %s against non tuple type %s" (string_of_pat pat) (string_of_typ typ)) end - | P_app (f, pats) when Env.is_union_constructor f env -> - begin - (* Treat Ctor((p, x)) the same as Ctor(p, x) *) - let pats = match pats with [P_aux (P_tup pats, _)] -> pats | _ -> pats in - let (typq, ctor_typ) = Env.get_union_id f env in - let quants = quant_items typq in - let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with - | Typ_tup typs -> typs - | _ -> [typ] - in - match Env.expand_synonyms env ctor_typ with - | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) -> - begin - try - let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in - typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ)); - let unifiers = unify l env goals ret_typ typ in - let arg_typ' = subst_unifiers unifiers arg_typ in - let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in - if not (List.for_all (solve_quant env) quants') then - typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env)) - else (); - let ret_typ' = subst_unifiers unifiers ret_typ in - let arg_typ', env = bind_existential l None arg_typ' env in - let tpats, env, guards = - try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with - | Invalid_argument _ -> typ_error env l "Union constructor pattern arguments have incorrect length" - in - annot_pat (P_app (f, List.rev tpats)) typ, env, guards - with - | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against union constructor: " ^ m) - end - | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) + | P_app (f, [pat]) when Env.is_union_constructor f env -> + let (typq, ctor_typ) = Env.get_union_id f env in + let quants = quant_items typq in + begin match Env.expand_synonyms env ctor_typ with + | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) -> + begin + try + let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in + typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ)); + let unifiers = unify l env goals ret_typ typ in + let arg_typ' = subst_unifiers unifiers arg_typ in + let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in + if not (List.for_all (solve_quant env) quants') then + typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env)) + else (); + let ret_typ' = subst_unifiers unifiers ret_typ in + let arg_typ', env = bind_existential l None arg_typ' env in + let tpat, env, guards = bind_pat env pat arg_typ' in + annot_pat (P_app (f, [tpat])) typ, env, guards + with + | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against union constructor: " ^ m) + end + | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) end + + | P_app (f, pats) when Env.is_union_constructor f env -> + (* Treat Ctor(x, y) as Ctor((x, y)) *) + bind_pat env (mk_pat (P_app (f, [mk_pat (P_tup pats)]))) typ | P_app (f, pats) when Env.is_mapping f env -> begin diff --git a/src/value.ml b/src/value.ml index 8f8e651a..d1b945a7 100644 --- a/src/value.ml +++ b/src/value.ml @@ -93,6 +93,33 @@ type value = with a direct register read. *) | V_attempted_read of string +let coerce_bit = function + | V_bit b -> b + | _ -> assert false + +let is_bit = function + | V_bit _ -> true + | _ -> false + +let rec string_of_value = function + | V_vector vs when List.for_all is_bit vs -> Sail_lib.string_of_bits (List.map coerce_bit vs) + | V_vector vs -> "[" ^ Util.string_of_list ", " string_of_value vs ^ "]" + | V_bool true -> "true" + | V_bool false -> "false" + | V_bit Sail_lib.B0 -> "bitzero" + | V_bit Sail_lib.B1 -> "bitone" + | V_int n -> Big_int.to_string n + | V_tuple vals -> "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" + | V_list vals -> "[|" ^ Util.string_of_list ", " string_of_value vals ^ "|]" + | V_unit -> "()" + | V_string str -> "\"" ^ str ^ "\"" + | V_ref str -> "ref " ^ str + | V_real r -> Sail_lib.string_of_real r + | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" + | V_record record -> + "{" ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) ^ "}" + | V_attempted_read _ -> assert false + let rec eq_value v1 v2 = match v1, v2 with | V_vector v1s, V_vector v2s when List.length v1s = List.length v2s -> List.for_all2 eq_value v1s v2s @@ -111,12 +138,7 @@ let rec eq_value v1 v2 = StringMap.equal eq_value fields1 fields2 | _, _ -> false -let coerce_bit = function - | V_bit b -> b - | _ -> assert false - let coerce_ctor = function - | V_ctor (str, [V_tuple vals]) -> (str, vals) | V_ctor (str, vals) -> (str, vals) | _ -> assert false @@ -371,33 +393,10 @@ let value_replicate_bits = function | [v1; v2] -> mk_vector (Sail_lib.replicate_bits (coerce_bv v1, coerce_int v2)) | _ -> failwith "value replicate_bits" -let is_bit = function - | V_bit _ -> true - | _ -> false - let is_ctor = function | V_ctor _ -> true | _ -> false -let rec string_of_value = function - | V_vector vs when List.for_all is_bit vs -> Sail_lib.string_of_bits (List.map coerce_bit vs) - | V_vector vs -> "[" ^ Util.string_of_list ", " string_of_value vs ^ "]" - | V_bool true -> "true" - | V_bool false -> "false" - | V_bit Sail_lib.B0 -> "bitzero" - | V_bit Sail_lib.B1 -> "bitone" - | V_int n -> Big_int.to_string n - | V_tuple vals -> "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" - | V_list vals -> "[|" ^ Util.string_of_list ", " string_of_value vals ^ "|]" - | V_unit -> "()" - | V_string str -> "\"" ^ str ^ "\"" - | V_ref str -> "ref " ^ str - | V_real r -> Sail_lib.string_of_real r - | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" - | V_record record -> - "{" ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) ^ "}" - | V_attempted_read _ -> assert false - let value_sign_extend = function | [v1; v2] -> mk_vector (Sail_lib.sign_extend (coerce_bv v1, coerce_int v2)) | _ -> failwith "value sign_extend" |
