summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/rewrites.ml57
-rw-r--r--src/sail_lib.ml2
-rw-r--r--src/type_check.ml51
3 files changed, 95 insertions, 15 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index b3e60423..9067b22c 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -2818,14 +2818,18 @@ let rec rewrite_defs_pat_string_append =
let (pat, _, _, _) = destruct_pexp pexp in
let env = pat_env_of pat in
let assert_false = mk_exp (E_assert (mk_exp (E_lit (mk_lit L_false)), mk_exp (E_lit (mk_lit (L_string "unreachable"))))) in
- let construct_single_match match_on pattern maybe_expr =
- let (true_exp, false_exp) =
- match maybe_expr with
- | Some expr -> expr, assert_false
- | None -> (mk_exp (E_lit (mk_lit L_true))), (mk_exp (E_lit (mk_lit L_false)))
+ let construct_bool_match match_on pexp =
+ let true_exp = (mk_exp (E_lit (mk_lit L_true))) in
+ let false_exp = (mk_exp (E_lit (mk_lit L_false))) in
+ let true_pexp =
+ match pexp with
+ | Pat_aux (Pat_exp (pat, exp), _) ->
+ mk_pexp (Pat_exp (pat, true_exp))
+ | Pat_aux (Pat_when (pat, guards, exp), _) ->
+ mk_pexp (Pat_when (pat, guards, true_exp))
in
- mk_exp (E_case (match_on, [mk_pexp (Pat_exp (pattern, true_exp));
- mk_pexp (Pat_exp (mk_pat P_wild, false_exp))]))
+ let false_pexp = mk_pexp (Pat_exp (mk_pat P_wild, false_exp)) in
+ mk_exp (E_case (match_on, [true_pexp; false_pexp]))
in
(* merge cases of Pat_exp and Pat_when *)
@@ -2842,6 +2846,21 @@ let rec rewrite_defs_pat_string_append =
let (new_pat, new_pat_typ, new_guards, new_expr) =
match (p_aux, p_annot) with
+ (* (pat1 ^^ pat2) ^^ pat3 => expr ---> pat1 ^^ (pat2 ^^ pat3) => expr and recurse*)
+ | P_string_append (P_aux (P_string_append (pat1, pat2), _), pat3), annot ->
+ let new_pat = P_aux (P_string_append (pat1, P_aux (P_string_append (pat2, pat3), p_annot)), annot) in
+ let new_pexp = match guards with
+ | [] -> Pat_aux (Pat_exp (new_pat, expr), p_annot)
+ | [g] -> Pat_aux (Pat_when (new_pat, g, expr), p_annot)
+ | gs -> assert false
+ in
+ Printf.printf "PEXP BEFORE RECURSE IS %s\n%!" (Pretty_print_sail.doc_pexp new_pexp |> Pretty_print_sail.to_string);
+ let rewritten = rewrite_pexp new_pexp in
+ Printf.printf "PEXP AFTER RECURSE IS %s\n%!" (Pretty_print_sail.doc_pexp rewritten |> Pretty_print_sail.to_string);
+ begin match rewritten with
+ | Pat_aux (Pat_exp (pat, exp), _) -> strip_pat pat, pat_typ_of pat, [], strip_exp exp
+ | Pat_aux (Pat_when (pat, guard, exp), _) -> strip_pat pat, pat_typ_of pat, [strip_exp guard], strip_exp exp
+ end
(*
"lit" ^^ pat2 => expr ---> s# if startswith(s#, "lit")
&& match str_drop(s#, strlen("lit")) {
@@ -2861,16 +2880,16 @@ let rec rewrite_defs_pat_string_append =
(* construct drop expression -- string_drop(s#, strlen("lit")) *)
let drop_exp = mk_exp (E_app (mk_id "string_drop", [mk_exp (E_id id); mk_exp (E_app (mk_id "string_length", [mk_exp (E_lit lit)]))])) in
- (* construct the two new guards *)
- let guard1 = mk_exp (E_app (mk_id "string_startswith", [mk_exp (E_id id); mk_exp (E_lit lit)])) in
- let guard2 = construct_single_match drop_exp (strip_pat pat2) None in
-
(* recurse into pat2 *)
let new_pat2_pexp = mk_pexp (Pat_exp (strip_pat pat2, strip_exp expr)) in
let new_pat2_pexp = check_case env (pat_typ_of pat2) new_pat2_pexp (typ_of expr) in
let new_pat2_pexp = rewrite_pexp new_pat2_pexp in
let new_pat2_pexp = strip_pexp new_pat2_pexp in
+ (* construct the two new guards *)
+ let guard1 = mk_exp (E_app (mk_id "string_startswith", [mk_exp (E_id id); mk_exp (E_lit lit)])) in
+ let guard2 = construct_bool_match drop_exp new_pat2_pexp in
+
(* construct new match expr *)
let new_expr = mk_exp (E_case (drop_exp, [new_pat2_pexp])) in
@@ -2916,17 +2935,22 @@ let rec rewrite_defs_pat_string_append =
(* construct None pattern *)
let none_exp = mk_pat (P_app (mk_id "None", [])) in
+ (* recurse into pat2 *)
+ let new_pat2_pexp = mk_pexp (Pat_exp (strip_pat pat2, strip_exp expr)) in
+ Printf.printf "PEXP BEFORE TYPECHECK IS %s\n%!" (Pretty_print_sail.doc_pexp new_pat2_pexp |> Pretty_print_sail.to_string);
+ let new_pat2_pexp = check_case env (pat_typ_of pat2) new_pat2_pexp (typ_of expr) in
+ let new_pat2_pexp = rewrite_pexp new_pat2_pexp in
+ let new_pat2_pexp = strip_pexp new_pat2_pexp in
+
(* construct the new guard *)
- let guard_inner_match = construct_single_match drop_exp (strip_pat pat2) None in
+ let guard_inner_match = construct_bool_match drop_exp new_pat2_pexp in
let new_guard = mk_exp (E_case (func_exp, [
mk_pexp (Pat_exp (some_exp, guard_inner_match));
mk_pexp (Pat_exp (none_exp, mk_exp (E_lit (mk_lit (L_false)))))
])) in
(* construct the new match *)
- let new_match = mk_exp (E_case (drop_exp, [
- mk_pexp (Pat_exp (strip_pat pat2, strip_exp expr))
- ])) in
+ let new_match = mk_exp (E_case (drop_exp, [new_pat2_pexp])) in
(* construct the new let *)
let new_binding = mk_exp (E_case (func_exp, [
@@ -2953,11 +2977,14 @@ let rec rewrite_defs_pat_string_append =
| [] -> mk_pexp (Pat_exp (new_pat, new_expr))
| gs -> mk_pexp (Pat_when (new_pat, fold_guards gs, new_expr))
in
+ Printf.printf "PEXP BEFORE TYPECHECK IS %s\n%!" (Pretty_print_sail.doc_pexp new_pexp |> Pretty_print_sail.to_string);
check_case env new_pat_typ new_pexp (typ_of expr)
in
pexp_rewriters rewrite_pexp
+(* let rewrite_defs_mapping_builtins =
+ * let rewrite_pexp *)
let rewrite_defs_pat_lits =
let rewrite_pexp (Pat_aux (pexp_aux, annot) as pexp) =
diff --git a/src/sail_lib.ml b/src/sail_lib.ml
index 132af6f5..188a0703 100644
--- a/src/sail_lib.ml
+++ b/src/sail_lib.ml
@@ -467,6 +467,8 @@ let string_drop (str, n) = let n = Big_int.to_int n in String.sub str n (String.
let string_length str = Big_int.of_int (String.length str)
+let string_append (s1, s2) = s1 ^ s2
+
let lt_int (x, y) = Big_int.less x y
let set_slice (out_len, slice_len, out, n, slice) =
diff --git a/src/type_check.ml b/src/type_check.ml
index 0eaec4f8..cda624fc 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -2611,6 +2611,57 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
end
| _ -> typ_error l ("Mal-formed constructor " ^ string_of_id f)
end
+
+ | P_app (f, pats) when Env.is_mapping f env ->
+ begin
+ let (typq, mapping_typ) = Env.get_val_spec f env in
+ let quants = quant_items typq in
+ let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with
+ | Typ_tup typs -> typs
+ | _ -> [typ]
+ in
+ match Env.expand_synonyms env mapping_typ with
+ | Typ_aux (Typ_bidir (typ1, typ2), _) ->
+ begin
+ try
+ typ_debug ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ);
+ let unifiers, _, _ (* FIXME! *) = unify l env typ2 typ in
+ typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));
+ let arg_typ' = subst_unifiers unifiers typ1 in
+ let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ if (match quants' with [] -> false | _ -> true)
+ then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat)
+ else ();
+ let ret_typ' = subst_unifiers unifiers typ2 in
+ let tpats, env, guards =
+ try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with
+ | Invalid_argument _ -> typ_error l "Mapping pattern arguments have incorrect length"
+ in
+ annot_pat (P_app (f, List.rev tpats)) typ, env, guards
+ with
+ | Unification_error (l, m) ->
+ try
+ typ_debug "Unifying mapping forwards failed, trying backwards.";
+ typ_debug ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ);
+ let unifiers, _, _ (* FIXME! *) = unify l env typ1 typ in
+ typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));
+ let arg_typ' = subst_unifiers unifiers typ2 in
+ let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ if (match quants' with [] -> false | _ -> true)
+ then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat)
+ else ();
+ let ret_typ' = subst_unifiers unifiers typ1 in
+ let tpats, env, guards =
+ try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with
+ | Invalid_argument _ -> typ_error l "Mapping pattern arguments have incorrect length"
+ in
+ annot_pat (P_app (f, List.rev tpats)) typ, env, guards
+ with
+ | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m)
+ end
+ | _ -> typ_error l ("Mal-formed mapping " ^ string_of_id f)
+ end
+
| P_app (f, _) when (not (Env.is_union_constructor f env) && not (Env.is_mapping f env)) ->
typ_error l (string_of_id f ^ " is not a union constructor or mapping in pattern " ^ string_of_pat pat)
| P_as (pat, id) ->