diff options
| author | Brian Campbell | 2017-09-28 11:09:24 +0100 |
|---|---|---|
| committer | Brian Campbell | 2017-09-28 11:09:24 +0100 |
| commit | b5969ea7ca7de19ea2b96c48b1765e2c51e5d2af (patch) | |
| tree | 64fdca7e2bedaef9a78e63f5bfbc23a82c4d451b | |
| parent | 8bbb538494a3fc17c2c6a2fda2106a3c07ac0ed9 (diff) | |
Refine constructors during monomorphisation
| -rw-r--r-- | src/monomorphise.ml | 139 | ||||
| -rw-r--r-- | src/type_check.mli | 2 | ||||
| -rwxr-xr-x | test/mono/test.sh | 4 | ||||
| -rw-r--r-- | test/mono/tests | 1 | ||||
| -rw-r--r-- | test/mono/union-exist.sail | 33 |
5 files changed, 109 insertions, 70 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index f5847db8..27b88ea7 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -320,8 +320,12 @@ let split_src_type id ty (TypQ_aux (q,ql)) = let insts = cross'' insts in let ty_and_inst (inst0,ty) inst = let kids, ty = apply_kid_insts inst ty in - let ty = Typ_aux (Typ_exist (kids, nc', ty),l) in - inst@inst0, ty + let ty = + (* Typ_exist is not allowed an empty list of kids *) + match kids with + | [] -> ty + | _ -> Typ_aux (Typ_exist (kids, nc', ty),l) + in inst@inst0, ty in let tys = List.concat (List.map (fun instty -> List.map (ty_and_inst instty) insts) tys) in let free = List.fold_left (fun vars k -> KidSet.remove k vars) vars kids in @@ -335,6 +339,13 @@ let split_src_type id ty (TypQ_aux (q,ql)) = match snd (size_nvars_ty typ) with | [] -> [] | tys -> + (* One level of tuple type is stripped off by the type checker, so + add another here *) + let tys = + List.map (fun (x,ty) -> + x, match ty with + | Typ_aux (Typ_tup _,_) -> Typ_aux (Typ_tup [ty],Unknown) + | _ -> ty) tys in if contains_exist t then raise (Reporting_basic.err_general l "Only prenex types in unions are supported by monomorphisation") @@ -380,48 +391,57 @@ let reduce_nexp subst ne = string_of_nexp nexp ^ " into concrete value")) in eval ne + +let typ_of_args args = + match args with + | [E_aux (_,(l,annot))] -> + snd (env_typ_expected l annot) + | _ -> + let tys = List.map (fun (E_aux (_,(l,annot))) -> snd (env_typ_expected l annot)) args in + Typ_aux (Typ_tup tys,Unknown) + (* Check to see if we need to monomorphise a use of a constructor. Currently assumes that bitvector sizes are always given as a variable; don't yet handle more general cases (e.g., 8 * var) *) -(* TODO: use type checker's instantiation instead *) -let refine_constructor refinements id substs (E_aux (_,(l,_)) as arg) t = - let rec derive_vars (Typ_aux (t,_)) (E_aux (e,(l,tannot))) = - match t with - | Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var v,_)),_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -> - (match tannot with - | Some (_,Typ_aux (Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp ne,_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]),_),_) -> - [(v,reduce_nexp substs ne)] - | _ -> []) - | Typ_wild - | Typ_var _ - | Typ_id _ - | Typ_fn _ - | Typ_app _ - -> [] - | Typ_tup ts -> - match e with - | E_tuple es -> List.concat (List.map2 derive_vars ts es) - | _ -> [] (* TODO? *) - in - try - let (_,irefinements) = List.find (fun (id',_) -> Id.compare id id' = 0) refinements in - let vars = List.sort_uniq (fun x y -> Kid.compare (fst x) (fst y)) (derive_vars t arg) in - try - Some (List.assoc vars irefinements) - with Not_found -> - (Reporting_basic.print_err false true l "Monomorphisation" - ("Failed to find a monomorphic constructor for " ^ string_of_id id ^ " instance " ^ - match vars with [] -> "<empty>" - | _ -> String.concat "," (List.map (fun (x,y) -> string_of_kid x ^ "=" ^ string_of_int y) vars)); None) - with Not_found -> None +let refine_constructor refinements l env id args = + match List.find (fun (id',_) -> Id.compare id id' = 0) refinements with + | (_,irefinements) -> begin + let (_,constr_ty) = Env.get_val_spec id env in + match constr_ty with + | Typ_aux (Typ_fn (constr_ty,_,_),_) -> begin + let arg_ty = typ_of_args args in + match Type_check.destruct_exist env constr_ty with + | None -> None + | Some (kids,nc,constr_ty) -> + let (bindings,_,_) = Type_check.unify l env constr_ty arg_ty in + let find_kid kid = try Some (KBindings.find kid bindings) with Not_found -> None in + let bindings = List.map find_kid kids in + let matches_refinement (mapping,_,_) = + List.for_all2 + (fun v (_,w) -> + match v,w with + | _,None -> true + | Some (U_nexp (Nexp_aux (Nexp_constant n, _))),Some m -> n = m + | _,_ -> false) bindings mapping + in + match List.find matches_refinement irefinements with + | (_,new_id,_) -> Some (E_app (new_id,args)) + | exception Not_found -> + (Reporting_basic.print_err false true l "Monomorphisation" + ("Unable to refine constructor " ^ string_of_id id); + None) + end + | _ -> None + end + | exception Not_found -> None (* Substitute found nexps for variables in an expression, and rename constructors to reflect specialisation *) (* TODO: kid shadowing *) -let nexp_subst_fns substs refinements = +let nexp_subst_fns substs = let s_t t = subst_src_typ substs t in (* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in @@ -464,25 +484,7 @@ let nexp_subst_fns substs refinements = | E_internal_exp_user ((l1,annot1),(l2,annot2)) -> re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2))) | E_cast (t,e') -> re (E_cast (t, s_exp e')) - | E_app (id,es) -> - let es' = List.map s_exp es in - let arg = - match es' with - | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(l,None)) - | [e] -> e - | _ -> E_aux (E_tuple es',(l,None)) - in - let id' = - let env,_ = env_typ_expected l annot in - if Env.is_union_constructor id env then - let (qs,ty) = Env.get_val_spec id env in - match ty with (Typ_aux (Typ_fn(inty,outty,_),_)) -> - (match refine_constructor refinements id substs arg inty with - | None -> id - | Some id' -> id') - | _ -> id - else id - in re (E_app (id',es')) + | E_app (id,es) -> re (E_app (id, List.map s_exp es)) | E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2)) | E_tuple es -> re (E_tuple (List.map s_exp es)) | E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3)) @@ -542,8 +544,8 @@ let nexp_subst_fns substs refinements = | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2)) | LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id)) in ((fun x -> x (*s_pat*)),s_exp) -let nexp_subst_pat substs refinements = fst (nexp_subst_fns substs refinements) -let nexp_subst_exp substs refinements = snd (nexp_subst_fns substs refinements) +let nexp_subst_pat substs = fst (nexp_subst_fns substs) +let nexp_subst_exp substs = snd (nexp_subst_fns substs) let bindings_from_pat p = let rec aux_pat (P_aux (p,(l,annot))) = @@ -884,7 +886,11 @@ let split_defs splits defs = (match try_app (l,annot) (id,es') with | None -> (match const_prop_try_fn (id,es') with - | None -> re (E_app (id,es')) assigns + | None -> + (let env,_ = env_typ_expected l annot in + match Env.is_union_constructor id env, refine_constructor refinements l env id es' with + | true, Some exp -> re exp assigns + | _,_ -> re (E_app (id,es')) assigns) | Some r -> r,assigns) | Some r -> r,assigns) | E_app_infix (e1,id,e2) -> @@ -968,7 +974,7 @@ let split_defs splits defs = let assigns' = isubst_minus_set assigns assigned_in in re (E_case (e', List.map (const_prop_pexp substs assigns) cases)) assigns' | Some (E_aux (_,(_,annot')) as exp,newbindings,kbindings) -> - let exp = nexp_subst_exp (ksubst_from_list kbindings) [] (*???*) exp in + let exp = nexp_subst_exp (ksubst_from_list kbindings) exp in let newbindings_env = isubst_from_list newbindings in let substs' = isubst_union substs newbindings_env in const_prop_exp substs' assigns exp) @@ -991,7 +997,7 @@ let split_defs splits defs = match can_match e' [Pat_aux (Pat_exp (p,e2),(Unknown,None))] with | None -> plain () | Some (e'',bindings,kbindings) -> - let e'' = nexp_subst_exp (ksubst_from_list kbindings) [] (*???*) e'' in + let e'' = nexp_subst_exp (ksubst_from_list kbindings) e'' in let bindings = isubst_from_list bindings in let substs'' = isubst_union substs' bindings in const_prop_exp substs'' assigns e'' @@ -1372,9 +1378,8 @@ let split_defs splits defs = patsubsts | ConstrSplit patnsubsts -> List.map (fun (pat',nsubst) -> - (* Leave refinements to later *) - let pat' = nexp_subst_pat nsubst [] pat' in - let exp' = nexp_subst_exp nsubst [] e in + let pat' = nexp_subst_pat nsubst pat' in + let exp' = nexp_subst_exp nsubst e in Pat_aux (Pat_exp (pat', map_exp exp'),l) ) patnsubsts) | Pat_aux (Pat_when (p,e1,e2),l) -> @@ -1388,10 +1393,9 @@ let split_defs splits defs = patsubsts | ConstrSplit patnsubsts -> List.map (fun (pat',nsubst) -> - (* Leave refinements to later *) - let pat' = nexp_subst_pat nsubst [] pat' in - let exp1' = nexp_subst_exp nsubst [] e1 in - let exp2' = nexp_subst_exp nsubst [] e2 in + let pat' = nexp_subst_pat nsubst pat' in + let exp1' = nexp_subst_exp nsubst e1 in + let exp2' = nexp_subst_exp nsubst e2 in Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l) ) patnsubsts) and map_letbind (LB_aux (lb,annot)) = @@ -1421,9 +1425,8 @@ let split_defs splits defs = patsubsts | ConstrSplit patnsubsts -> List.map (fun (pat',nsubst) -> - (* Leave refinements to later *) - let pat' = nexp_subst_pat nsubst [] pat' in - let exp' = nexp_subst_exp nsubst [] exp in + let pat' = nexp_subst_pat nsubst pat' in + let exp' = nexp_subst_exp nsubst exp in FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot) ) patnsubsts in diff --git a/src/type_check.mli b/src/type_check.mli index ca2fb90c..e5279067 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -225,6 +225,8 @@ type uvar = val string_of_uvar : uvar -> string +val unify : l -> Env.t -> typ -> typ -> uvar KBindings.t * kid list * n_constraint option + (* Throws Invalid_argument if the argument is not a E_app expression *) val instantiation_of : tannot exp -> uvar KBindings.t diff --git a/test/mono/test.sh b/test/mono/test.sh index b82406c3..2a5aa80b 100755 --- a/test/mono/test.sh +++ b/test/mono/test.sh @@ -26,7 +26,7 @@ while read -u 3 TEST ARGS; do if [ -z "$TESTONLY" -o "$TEST" = "$TESTONLY" ]; then # echo "$TEST ocaml" # rm -f -- "$OUTDIR"/* -# "$SAILDIR/sail" -ocaml "$SAILDIR/lib/prelude.sail" "$DIR/$TEST" -o "$OUTDIR/testout" $ARGS +# "$SAILDIR/sail" -ocaml "$SAILDIR/lib/prelude.sail" "$SAILDIR/lib/prelude_wrappers.sail" "$DIR/$TEST" -o "$OUTDIR/testout" $ARGS # cp -- "$SAILDIR"/src/gen_lib/sail_values.ml . # cp -- "$DIR"/test.ml . # ocamlc -I "$ZARITH" "$ZARITH/zarith.cma" -dllpath "$ZARITH" -I "$LEMDIR/ocaml-lib" "$LEMDIR/ocaml-lib/extract.cma" -I "$SAILDIR/src/_build/lem_interp" "$SAILDIR/src/_build/lem_interp/extract.cma" sail_values.ml testout.ml test.ml -o test @@ -34,7 +34,7 @@ while read -u 3 TEST ARGS; do echo "$TEST lem - ocaml" | tee -a -- "$LOG" rm -f -- "$OUTDIR"/* - "$SAILDIR/sail" -lem -lem_sequential -lem_mwords "$SAILDIR/lib/prelude.sail" "$DIR/$TEST".sail -o "$OUTDIR/testout" $ARGS $@ &>> "$LOG" && \ + "$SAILDIR/sail" -lem -lem_sequential -lem_mwords "$SAILDIR/lib/prelude.sail" "$SAILDIR/lib/prelude_wrappers.sail" "$DIR/$TEST".sail -o "$OUTDIR/testout" $ARGS $@ &>> "$LOG" && \ "$LEMDIR/bin/lem" -ocaml -lib "$SAILDIR/src/lem_interp" "$SAILDIR/src/gen_lib/sail_values.lem" "$SAILDIR/src/gen_lib/sail_operators_mwords.lem" "$SAILDIR/src/gen_lib/state.lem" testout_embed_types_sequential.lem testout_embed_sequential.lem -outdir "$OUTDIR" &>> "$LOG" && \ cp -- "$DIR"/test.ml "$OUTDIR" && \ ocamlc -I "$ZARITH" "$ZARITH/zarith.cma" -dllpath "$ZARITH" -I "$LEMDIR/ocaml-lib" "$LEMDIR/ocaml-lib/extract.cma" -I "$SAILDIR/src/_build/lem_interp" "$SAILDIR/src/_build/lem_interp/extract.cma" sail_values.ml sail_operators_mwords.ml state.ml testout_embed_types_sequential.ml testout_embed_sequential.ml test.ml -o test &>> "$LOG" && \ diff --git a/test/mono/tests b/test/mono/tests index 425230da..0825c686 100644 --- a/test/mono/tests +++ b/test/mono/tests @@ -1,3 +1,4 @@ fnreduce -mono-split fnreduce.sail:43:x varmatch -mono-split varmatch.sail:7:x vector -mono-split vector.sail:7:sel +union-exist -mono-split union-exist.sail:9:v diff --git a/test/mono/union-exist.sail b/test/mono/union-exist.sail new file mode 100644 index 00000000..74ab429a --- /dev/null +++ b/test/mono/union-exist.sail @@ -0,0 +1,33 @@ +default Order dec + +typedef myunion = const union { + (exist 'n, 'n in {8,16}. ([:'n:],bit['n])) MyConstr; +} + +val bit[2] -> myunion effect pure make + +function make(v) = + (* Can't mention these below without running into exp/nexp parsing conflict! *) + let eight = 8 in let sixteen = 16 in + switch v { + case 0b00 -> MyConstr( ( eight, 0x12) ) + case 0b01 -> MyConstr( (sixteen,0x1234) ) + case 0b10 -> MyConstr( ( eight, 0x56) ) + case 0b11 -> MyConstr( (sixteen,0x5678) ) + } + +val myunion -> bit[32] effect pure use + +function use(MyConstr('n)) = { + switch n { + case (n,v) -> extz(v) + } +} +val unit -> bool effect pure run + +function run () = { + use(make(0b00)) == 0x00000012 & + use(make(0b01)) == 0x00001234 & + use(make(0b10)) == 0x00000056 & + use(make(0b11)) == 0x00005678 +} |
