summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Campbell2017-06-29 18:28:39 +0100
committerBrian Campbell2017-06-29 18:28:39 +0100
commit0b3d26f0c7727631ac47c61ff88a16e0a217641d (patch)
tree2d4ab473e9a369a2bf98c1b108624fc9934a78d0
parentae8e96ef1aef29b19ebe50a12dea552b740dc57a (diff)
Propagate type information from reducing case expressions
-rw-r--r--src/monomorphise.ml152
1 files changed, 149 insertions, 3 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 4683279f..c7fd7c31 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -27,6 +27,104 @@ let pat_id_is_variable t_env id =
| _ -> true
+let nexp_subst substs exp =
+ let s_t t = typ_subst substs true t in
+(* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in
+ hopefully don't need this anyway *)
+ let s_typschm tsh = tsh in
+ let s_tannot = function
+ | Base ((params,t),tag,ranges,effl,effc,bounds) ->
+ (* TODO: do other fields need mapped? *)
+ Base ((params,s_t t),tag,ranges,effl,effc,bounds)
+ | tannot -> tannot
+ in
+ let rec s_pat (P_aux (p,(l,annot))) =
+ let re p = P_aux (p,(l,s_tannot annot)) in
+ match p with
+ | P_lit _ | P_wild | P_id _ -> re p
+ | P_as (p',id) -> re (P_as (s_pat p', id))
+ | P_typ (ty,p') -> re (P_typ (ty,s_pat p'))
+ | P_app (id,ps) -> re (P_app (id, List.map s_pat ps))
+ | P_record (fps,flag) -> re (P_record (List.map s_fpat fps, flag))
+ | P_vector ps -> re (P_vector (List.map s_pat ps))
+ | P_vector_indexed ips -> re (P_vector_indexed (List.map (fun (i,p) -> (i,s_pat p)) ips))
+ | P_vector_concat ps -> re (P_vector_concat (List.map s_pat ps))
+ | P_tup ps -> re (P_tup (List.map s_pat ps))
+ | P_list ps -> re (P_list (List.map s_pat ps))
+ and s_fpat (FP_aux (FP_Fpat (id, p), (l,annot))) =
+ FP_aux (FP_Fpat (id, s_pat p), (l,s_tannot annot))
+ in
+ let rec s_exp (E_aux (e,(l,annot))) =
+ let re e = E_aux (e,(l,s_tannot annot)) in
+ match e with
+ | E_block es -> re (E_block (List.map s_exp es))
+ | E_nondet es -> re (E_nondet (List.map s_exp es))
+ | E_id _
+ | E_lit _
+ | E_comment _ -> re e
+ | E_sizeof ne -> re (E_sizeof ne) (* TODO: do this need done? does it appear in type checked code? *)
+ | E_internal_exp (l,annot) -> re (E_internal_exp (l, s_tannot annot))
+ | E_sizeof_internal (l,annot) -> re (E_sizeof_internal (l, s_tannot annot))
+ | E_internal_exp_user ((l1,annot1),(l2,annot2)) ->
+ re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2)))
+ | E_cast (t,e') -> re (E_cast (t, s_exp e'))
+ | E_app (id,es) -> re (E_app (id,List.map s_exp es))
+ | E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2))
+ | E_tuple es -> re (E_tuple (List.map s_exp es))
+ | E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3))
+ | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,s_exp e1,s_exp e2,s_exp e3,ord,s_exp e4))
+ | E_vector es -> re (E_vector (List.map s_exp es))
+ | E_vector_indexed (ies,ed) -> re (E_vector_indexed (List.map (fun (i,e) -> (i,s_exp e)) ies,
+ s_opt_default ed))
+ | E_vector_access (e1,e2) -> re (E_vector_access (s_exp e1,s_exp e2))
+ | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (s_exp e1,s_exp e2,s_exp e3))
+ | E_vector_update (e1,e2,e3) -> re (E_vector_update (s_exp e1,s_exp e2,s_exp e3))
+ | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (s_exp e1,s_exp e2,s_exp e3,s_exp e4))
+ | E_vector_append (e1,e2) -> re (E_vector_append (s_exp e1,s_exp e2))
+ | E_list es -> re (E_list (List.map s_exp es))
+ | E_cons (e1,e2) -> re (E_cons (s_exp e1,s_exp e2))
+ | E_record fes -> re (E_record (s_fexps fes))
+ | E_record_update (e,fes) -> re (E_record_update (s_exp e, s_fexps fes))
+ | E_field (e,id) -> re (E_field (s_exp e,id))
+ | E_case (e,cases) -> re (E_case (s_exp e, List.map s_pexp cases))
+ | E_let (lb,e) -> re (E_let (s_letbind lb, s_exp e))
+ | E_assign (le,e) -> re (E_assign (s_lexp le, s_exp e))
+ | E_exit e -> re (E_exit (s_exp e))
+ | E_return e -> re (E_return (s_exp e))
+ | E_assert (e1,e2) -> re (E_assert (s_exp e1,s_exp e2))
+ | E_internal_cast ((l,ann),e) -> re (E_internal_cast ((l,s_tannot ann),s_exp e))
+ | E_comment_struc e -> re (E_comment_struc e)
+ | E_internal_let (le,e1,e2) -> re (E_internal_let (s_lexp le, s_exp e1, s_exp e2))
+ | E_internal_plet (p,e1,e2) -> re (E_internal_plet (s_pat p, s_exp e1, s_exp e2))
+ | E_internal_return e -> re (E_internal_return (s_exp e))
+ and s_opt_default (Def_val_aux (ed,(l,annot))) =
+ match ed with
+ | Def_val_empty -> Def_val_aux (Def_val_empty,(l,s_tannot annot))
+ | Def_val_dec e -> Def_val_aux (Def_val_dec (s_exp e),(l,s_tannot annot))
+ and s_fexps (FES_aux (FES_Fexps (fes,flag), (l,annot))) =
+ FES_aux (FES_Fexps (List.map s_fexp fes, flag), (l,s_tannot annot))
+ and s_fexp (FE_aux (FE_Fexp (id,e), (l,annot))) =
+ FE_aux (FE_Fexp (id,s_exp e),(l,s_tannot annot))
+ and s_pexp (Pat_aux (Pat_exp (p,e),(l,annot))) =
+ Pat_aux (Pat_exp (s_pat p, s_exp e),(l,s_tannot annot))
+ and s_letbind (LB_aux (lb,(l,annot))) =
+ match lb with
+ | LB_val_explicit (tysch,p,e) ->
+ LB_aux (LB_val_explicit (s_typschm tysch,s_pat p,s_exp e), (l,s_tannot annot))
+ | LB_val_implicit (p,e) -> LB_aux (LB_val_implicit (s_pat p,s_exp e), (l,s_tannot annot))
+ and s_lexp (LEXP_aux (e,(l,annot))) =
+ let re e = LEXP_aux (e,(l,s_tannot annot)) in
+ match e with
+ | LEXP_id _
+ | LEXP_cast _
+ -> re e
+ | LEXP_memory (id,es) -> re (LEXP_memory (id,List.map s_exp es))
+ | LEXP_tup les -> re (LEXP_tup (List.map s_lexp les))
+ | LEXP_vector (le,e) -> re (LEXP_vector (s_lexp le, s_exp e))
+ | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2))
+ | LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id))
+ in s_exp exp
+
let bindings_from_pat t_env p =
let rec aux_pat (P_aux (p,annot)) =
match p with
@@ -77,6 +175,48 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
| _ -> None)
| _ -> None
in
+
+ let build_nexp_subst l t1 t2 =
+ let rec from_types t1 t2 =
+ let t1 = match t1.t with Tabbrev(_,t) -> t | _ -> t1 in
+ let t2 = match t2.t with Tabbrev(_,t) -> t | _ -> t2 in
+ if t1 = t2 then [] else
+ match t1.t,t2.t with
+ | Tapp (s1,args1), Tapp (s2,args2) ->
+ if s1 = s2 then
+ List.concat (List.map2 from_args args1 args2)
+ else (Reporting_basic.print_err false true l "Monomorphisation"
+ "Unexpected type mismatch"; [])
+ | Ttup ts1, Ttup ts2 ->
+ if List.length ts1 = List.length ts2 then
+ List.concat (List.map2 from_types ts1 ts2)
+ else (Reporting_basic.print_err false true l "Monomorphisation"
+ "Unexpected type mismatch"; [])
+ | _ -> []
+ and from_args arg1 arg2 =
+ match arg1,arg2 with
+ | TA_typ t1, TA_typ t2 -> from_types t1 t2
+ | TA_nexp n1, TA_nexp n2 -> from_nexps n1 n2
+ | _ -> []
+ and from_nexps n1 n2 =
+ match n1.nexp, n2.nexp with
+ | Nvar s, Nvar s' when s = s' -> []
+ | Nvar s, _ -> [(s,n2)]
+ | Nadd (n3,n4), Nadd (n5,n6)
+ | Nsub (n3,n4), Nsub (n5,n6)
+ | Nmult (n3,n4), Nmult (n5,n6)
+ -> from_nexps n3 n5 @ from_nexps n4 n6
+ | N2n (n3,p1), N2n (n4,p2) when p1 = p2 -> from_nexps n3 n4
+ | Npow (n3,p1), Npow (n4,p2) when p1 = p2 -> from_nexps n3 n4
+ | Nneg n3, Nneg n4 -> from_nexps n3 n4
+ | _ -> []
+ in match t1,t2 with
+ | Base ((_,t1),_,_,_,_,_),Base ((_,t2),_,_,_,_,_) -> from_types t1 t2
+ | _ -> []
+ in
+
+ let nexp_substs = ref [] in
+
(* Constant propogation *)
let rec const_prop_exp substs ((E_aux (e,(l,annot))) as exp) =
let re e = E_aux (e,(l,annot)) in
@@ -117,10 +257,11 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
| E_field (e,id) -> re (E_field (const_prop_exp substs e,id))
| E_case (e,cases) ->
let e' = const_prop_exp substs e in
- (* TODO: ought to propagate type substitution to other terms *)
(match can_match e' cases with
| None -> re (E_case (e', List.map (const_prop_pexp substs) cases))
- | Some e'' -> const_prop_exp substs e'')
+ | Some (E_aux (_,(_,annot')) as exp) ->
+ nexp_substs := build_nexp_subst l annot annot' @ !nexp_substs;
+ const_prop_exp substs exp)
| E_let (lb,e) ->
let (lb',substs') = const_prop_letbind substs lb in
re (E_let (lb', const_prop_exp substs' e))
@@ -176,7 +317,12 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
E_aux (E_let (LB_aux (LB_val_implicit (p,sube),(lg,annot)), exp),(lg,annot))
else
let substs = Envmap.from_list [subst] in
- const_prop_exp substs exp
+ let () = nexp_substs := [] in
+ let exp' = const_prop_exp substs exp in
+ (* Substitute what we've learned about nvars into the term *)
+ let nsubsts = Envmap.from_list (List.map (fun (id,ne) -> (id,TA_nexp ne)) !nexp_substs) in
+ let () = nexp_substs := [] in
+ nexp_subst nsubsts exp'
in