summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Campbell2017-09-28 11:09:24 +0100
committerBrian Campbell2017-09-28 11:09:24 +0100
commitb5969ea7ca7de19ea2b96c48b1765e2c51e5d2af (patch)
tree64fdca7e2bedaef9a78e63f5bfbc23a82c4d451b
parent8bbb538494a3fc17c2c6a2fda2106a3c07ac0ed9 (diff)
Refine constructors during monomorphisation
-rw-r--r--src/monomorphise.ml139
-rw-r--r--src/type_check.mli2
-rwxr-xr-xtest/mono/test.sh4
-rw-r--r--test/mono/tests1
-rw-r--r--test/mono/union-exist.sail33
5 files changed, 109 insertions, 70 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index f5847db8..27b88ea7 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -320,8 +320,12 @@ let split_src_type id ty (TypQ_aux (q,ql)) =
let insts = cross'' insts in
let ty_and_inst (inst0,ty) inst =
let kids, ty = apply_kid_insts inst ty in
- let ty = Typ_aux (Typ_exist (kids, nc', ty),l) in
- inst@inst0, ty
+ let ty =
+ (* Typ_exist is not allowed an empty list of kids *)
+ match kids with
+ | [] -> ty
+ | _ -> Typ_aux (Typ_exist (kids, nc', ty),l)
+ in inst@inst0, ty
in
let tys = List.concat (List.map (fun instty -> List.map (ty_and_inst instty) insts) tys) in
let free = List.fold_left (fun vars k -> KidSet.remove k vars) vars kids in
@@ -335,6 +339,13 @@ let split_src_type id ty (TypQ_aux (q,ql)) =
match snd (size_nvars_ty typ) with
| [] -> []
| tys ->
+ (* One level of tuple type is stripped off by the type checker, so
+ add another here *)
+ let tys =
+ List.map (fun (x,ty) ->
+ x, match ty with
+ | Typ_aux (Typ_tup _,_) -> Typ_aux (Typ_tup [ty],Unknown)
+ | _ -> ty) tys in
if contains_exist t then
raise (Reporting_basic.err_general l
"Only prenex types in unions are supported by monomorphisation")
@@ -380,48 +391,57 @@ let reduce_nexp subst ne =
string_of_nexp nexp ^ " into concrete value"))
in eval ne
+
+let typ_of_args args =
+ match args with
+ | [E_aux (_,(l,annot))] ->
+ snd (env_typ_expected l annot)
+ | _ ->
+ let tys = List.map (fun (E_aux (_,(l,annot))) -> snd (env_typ_expected l annot)) args in
+ Typ_aux (Typ_tup tys,Unknown)
+
(* Check to see if we need to monomorphise a use of a constructor. Currently
assumes that bitvector sizes are always given as a variable; don't yet handle
more general cases (e.g., 8 * var) *)
-(* TODO: use type checker's instantiation instead *)
-let refine_constructor refinements id substs (E_aux (_,(l,_)) as arg) t =
- let rec derive_vars (Typ_aux (t,_)) (E_aux (e,(l,tannot))) =
- match t with
- | Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var v,_)),_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
- (match tannot with
- | Some (_,Typ_aux (Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp ne,_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]),_),_) ->
- [(v,reduce_nexp substs ne)]
- | _ -> [])
- | Typ_wild
- | Typ_var _
- | Typ_id _
- | Typ_fn _
- | Typ_app _
- -> []
- | Typ_tup ts ->
- match e with
- | E_tuple es -> List.concat (List.map2 derive_vars ts es)
- | _ -> [] (* TODO? *)
- in
- try
- let (_,irefinements) = List.find (fun (id',_) -> Id.compare id id' = 0) refinements in
- let vars = List.sort_uniq (fun x y -> Kid.compare (fst x) (fst y)) (derive_vars t arg) in
- try
- Some (List.assoc vars irefinements)
- with Not_found ->
- (Reporting_basic.print_err false true l "Monomorphisation"
- ("Failed to find a monomorphic constructor for " ^ string_of_id id ^ " instance " ^
- match vars with [] -> "<empty>"
- | _ -> String.concat "," (List.map (fun (x,y) -> string_of_kid x ^ "=" ^ string_of_int y) vars)); None)
- with Not_found -> None
+let refine_constructor refinements l env id args =
+ match List.find (fun (id',_) -> Id.compare id id' = 0) refinements with
+ | (_,irefinements) -> begin
+ let (_,constr_ty) = Env.get_val_spec id env in
+ match constr_ty with
+ | Typ_aux (Typ_fn (constr_ty,_,_),_) -> begin
+ let arg_ty = typ_of_args args in
+ match Type_check.destruct_exist env constr_ty with
+ | None -> None
+ | Some (kids,nc,constr_ty) ->
+ let (bindings,_,_) = Type_check.unify l env constr_ty arg_ty in
+ let find_kid kid = try Some (KBindings.find kid bindings) with Not_found -> None in
+ let bindings = List.map find_kid kids in
+ let matches_refinement (mapping,_,_) =
+ List.for_all2
+ (fun v (_,w) ->
+ match v,w with
+ | _,None -> true
+ | Some (U_nexp (Nexp_aux (Nexp_constant n, _))),Some m -> n = m
+ | _,_ -> false) bindings mapping
+ in
+ match List.find matches_refinement irefinements with
+ | (_,new_id,_) -> Some (E_app (new_id,args))
+ | exception Not_found ->
+ (Reporting_basic.print_err false true l "Monomorphisation"
+ ("Unable to refine constructor " ^ string_of_id id);
+ None)
+ end
+ | _ -> None
+ end
+ | exception Not_found -> None
(* Substitute found nexps for variables in an expression, and rename constructors to reflect
specialisation *)
(* TODO: kid shadowing *)
-let nexp_subst_fns substs refinements =
+let nexp_subst_fns substs =
let s_t t = subst_src_typ substs t in
(* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in
@@ -464,25 +484,7 @@ let nexp_subst_fns substs refinements =
| E_internal_exp_user ((l1,annot1),(l2,annot2)) ->
re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2)))
| E_cast (t,e') -> re (E_cast (t, s_exp e'))
- | E_app (id,es) ->
- let es' = List.map s_exp es in
- let arg =
- match es' with
- | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(l,None))
- | [e] -> e
- | _ -> E_aux (E_tuple es',(l,None))
- in
- let id' =
- let env,_ = env_typ_expected l annot in
- if Env.is_union_constructor id env then
- let (qs,ty) = Env.get_val_spec id env in
- match ty with (Typ_aux (Typ_fn(inty,outty,_),_)) ->
- (match refine_constructor refinements id substs arg inty with
- | None -> id
- | Some id' -> id')
- | _ -> id
- else id
- in re (E_app (id',es'))
+ | E_app (id,es) -> re (E_app (id, List.map s_exp es))
| E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2))
| E_tuple es -> re (E_tuple (List.map s_exp es))
| E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3))
@@ -542,8 +544,8 @@ let nexp_subst_fns substs refinements =
| LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2))
| LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id))
in ((fun x -> x (*s_pat*)),s_exp)
-let nexp_subst_pat substs refinements = fst (nexp_subst_fns substs refinements)
-let nexp_subst_exp substs refinements = snd (nexp_subst_fns substs refinements)
+let nexp_subst_pat substs = fst (nexp_subst_fns substs)
+let nexp_subst_exp substs = snd (nexp_subst_fns substs)
let bindings_from_pat p =
let rec aux_pat (P_aux (p,(l,annot))) =
@@ -884,7 +886,11 @@ let split_defs splits defs =
(match try_app (l,annot) (id,es') with
| None ->
(match const_prop_try_fn (id,es') with
- | None -> re (E_app (id,es')) assigns
+ | None ->
+ (let env,_ = env_typ_expected l annot in
+ match Env.is_union_constructor id env, refine_constructor refinements l env id es' with
+ | true, Some exp -> re exp assigns
+ | _,_ -> re (E_app (id,es')) assigns)
| Some r -> r,assigns)
| Some r -> r,assigns)
| E_app_infix (e1,id,e2) ->
@@ -968,7 +974,7 @@ let split_defs splits defs =
let assigns' = isubst_minus_set assigns assigned_in in
re (E_case (e', List.map (const_prop_pexp substs assigns) cases)) assigns'
| Some (E_aux (_,(_,annot')) as exp,newbindings,kbindings) ->
- let exp = nexp_subst_exp (ksubst_from_list kbindings) [] (*???*) exp in
+ let exp = nexp_subst_exp (ksubst_from_list kbindings) exp in
let newbindings_env = isubst_from_list newbindings in
let substs' = isubst_union substs newbindings_env in
const_prop_exp substs' assigns exp)
@@ -991,7 +997,7 @@ let split_defs splits defs =
match can_match e' [Pat_aux (Pat_exp (p,e2),(Unknown,None))] with
| None -> plain ()
| Some (e'',bindings,kbindings) ->
- let e'' = nexp_subst_exp (ksubst_from_list kbindings) [] (*???*) e'' in
+ let e'' = nexp_subst_exp (ksubst_from_list kbindings) e'' in
let bindings = isubst_from_list bindings in
let substs'' = isubst_union substs' bindings in
const_prop_exp substs'' assigns e''
@@ -1372,9 +1378,8 @@ let split_defs splits defs =
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp' = nexp_subst_exp nsubst [] e in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp' = nexp_subst_exp nsubst e in
Pat_aux (Pat_exp (pat', map_exp exp'),l)
) patnsubsts)
| Pat_aux (Pat_when (p,e1,e2),l) ->
@@ -1388,10 +1393,9 @@ let split_defs splits defs =
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp1' = nexp_subst_exp nsubst [] e1 in
- let exp2' = nexp_subst_exp nsubst [] e2 in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp1' = nexp_subst_exp nsubst e1 in
+ let exp2' = nexp_subst_exp nsubst e2 in
Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)
) patnsubsts)
and map_letbind (LB_aux (lb,annot)) =
@@ -1421,9 +1425,8 @@ let split_defs splits defs =
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp' = nexp_subst_exp nsubst [] exp in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp' = nexp_subst_exp nsubst exp in
FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot)
) patnsubsts
in
diff --git a/src/type_check.mli b/src/type_check.mli
index ca2fb90c..e5279067 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -225,6 +225,8 @@ type uvar =
val string_of_uvar : uvar -> string
+val unify : l -> Env.t -> typ -> typ -> uvar KBindings.t * kid list * n_constraint option
+
(* Throws Invalid_argument if the argument is not a E_app expression *)
val instantiation_of : tannot exp -> uvar KBindings.t
diff --git a/test/mono/test.sh b/test/mono/test.sh
index b82406c3..2a5aa80b 100755
--- a/test/mono/test.sh
+++ b/test/mono/test.sh
@@ -26,7 +26,7 @@ while read -u 3 TEST ARGS; do
if [ -z "$TESTONLY" -o "$TEST" = "$TESTONLY" ]; then
# echo "$TEST ocaml"
# rm -f -- "$OUTDIR"/*
-# "$SAILDIR/sail" -ocaml "$SAILDIR/lib/prelude.sail" "$DIR/$TEST" -o "$OUTDIR/testout" $ARGS
+# "$SAILDIR/sail" -ocaml "$SAILDIR/lib/prelude.sail" "$SAILDIR/lib/prelude_wrappers.sail" "$DIR/$TEST" -o "$OUTDIR/testout" $ARGS
# cp -- "$SAILDIR"/src/gen_lib/sail_values.ml .
# cp -- "$DIR"/test.ml .
# ocamlc -I "$ZARITH" "$ZARITH/zarith.cma" -dllpath "$ZARITH" -I "$LEMDIR/ocaml-lib" "$LEMDIR/ocaml-lib/extract.cma" -I "$SAILDIR/src/_build/lem_interp" "$SAILDIR/src/_build/lem_interp/extract.cma" sail_values.ml testout.ml test.ml -o test
@@ -34,7 +34,7 @@ while read -u 3 TEST ARGS; do
echo "$TEST lem - ocaml" | tee -a -- "$LOG"
rm -f -- "$OUTDIR"/*
- "$SAILDIR/sail" -lem -lem_sequential -lem_mwords "$SAILDIR/lib/prelude.sail" "$DIR/$TEST".sail -o "$OUTDIR/testout" $ARGS $@ &>> "$LOG" && \
+ "$SAILDIR/sail" -lem -lem_sequential -lem_mwords "$SAILDIR/lib/prelude.sail" "$SAILDIR/lib/prelude_wrappers.sail" "$DIR/$TEST".sail -o "$OUTDIR/testout" $ARGS $@ &>> "$LOG" && \
"$LEMDIR/bin/lem" -ocaml -lib "$SAILDIR/src/lem_interp" "$SAILDIR/src/gen_lib/sail_values.lem" "$SAILDIR/src/gen_lib/sail_operators_mwords.lem" "$SAILDIR/src/gen_lib/state.lem" testout_embed_types_sequential.lem testout_embed_sequential.lem -outdir "$OUTDIR" &>> "$LOG" && \
cp -- "$DIR"/test.ml "$OUTDIR" && \
ocamlc -I "$ZARITH" "$ZARITH/zarith.cma" -dllpath "$ZARITH" -I "$LEMDIR/ocaml-lib" "$LEMDIR/ocaml-lib/extract.cma" -I "$SAILDIR/src/_build/lem_interp" "$SAILDIR/src/_build/lem_interp/extract.cma" sail_values.ml sail_operators_mwords.ml state.ml testout_embed_types_sequential.ml testout_embed_sequential.ml test.ml -o test &>> "$LOG" && \
diff --git a/test/mono/tests b/test/mono/tests
index 425230da..0825c686 100644
--- a/test/mono/tests
+++ b/test/mono/tests
@@ -1,3 +1,4 @@
fnreduce -mono-split fnreduce.sail:43:x
varmatch -mono-split varmatch.sail:7:x
vector -mono-split vector.sail:7:sel
+union-exist -mono-split union-exist.sail:9:v
diff --git a/test/mono/union-exist.sail b/test/mono/union-exist.sail
new file mode 100644
index 00000000..74ab429a
--- /dev/null
+++ b/test/mono/union-exist.sail
@@ -0,0 +1,33 @@
+default Order dec
+
+typedef myunion = const union {
+ (exist 'n, 'n in {8,16}. ([:'n:],bit['n])) MyConstr;
+}
+
+val bit[2] -> myunion effect pure make
+
+function make(v) =
+ (* Can't mention these below without running into exp/nexp parsing conflict! *)
+ let eight = 8 in let sixteen = 16 in
+ switch v {
+ case 0b00 -> MyConstr( ( eight, 0x12) )
+ case 0b01 -> MyConstr( (sixteen,0x1234) )
+ case 0b10 -> MyConstr( ( eight, 0x56) )
+ case 0b11 -> MyConstr( (sixteen,0x5678) )
+ }
+
+val myunion -> bit[32] effect pure use
+
+function use(MyConstr('n)) = {
+ switch n {
+ case (n,v) -> extz(v)
+ }
+}
+val unit -> bool effect pure run
+
+function run () = {
+ use(make(0b00)) == 0x00000012 &
+ use(make(0b01)) == 0x00001234 &
+ use(make(0b10)) == 0x00000056 &
+ use(make(0b11)) == 0x00005678
+}