diff options
| author | Thomas Bauereiss | 2017-09-27 15:23:43 +0100 |
|---|---|---|
| committer | Thomas Bauereiss | 2017-09-27 15:23:43 +0100 |
| commit | 381a3967ebd9269082b452669f507787decf28b0 (patch) | |
| tree | ad1f38b6689e1bccb267520124cb0d89365b4a82 | |
| parent | ced56765ec9324a0e690cbb4e790280d17413f99 (diff) | |
Add while-loops to Lem backend
| -rw-r--r-- | src/gen_lib/prompt.lem | 29 | ||||
| -rw-r--r-- | src/gen_lib/state.lem | 29 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 20 | ||||
| -rw-r--r-- | src/rewriter.ml | 30 | ||||
| -rw-r--r-- | test/typecheck/pass/while_MM.sail | 23 | ||||
| -rw-r--r-- | test/typecheck/pass/while_MP.sail | 17 | ||||
| -rw-r--r-- | test/typecheck/pass/while_PM.sail | 23 | ||||
| -rw-r--r-- | test/typecheck/pass/while_PP.sail | 24 |
8 files changed, 195 insertions, 0 deletions
diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem index 5c539354..8e04bd30 100644 --- a/src/gen_lib/prompt.lem +++ b/src/gen_lib/prompt.lem @@ -170,6 +170,35 @@ let rec foreachM_dec (i,stop,by) vars body = foreachM_dec (i - by,stop,by) vars body else return vars +val while_PP : forall 'vars. bool -> 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars +let rec while_PP is_while vars cond body = + if (cond vars = is_while) + then while_PP is_while (body vars) cond body + else vars + +val while_PM : forall 'vars 'r. bool -> 'vars -> ('vars -> bool) -> + ('vars -> MR 'vars 'r) -> MR 'vars 'r +let rec while_PM is_while vars cond body = + if (cond vars = is_while) + then body vars >>= fun vars -> while_PM is_while vars cond body + else return vars + +val while_MP : forall 'vars 'r. bool -> 'vars -> ('vars -> MR bool 'r) -> + ('vars -> 'vars) -> MR 'vars 'r +let rec while_MP is_while vars cond body = + cond vars >>= fun continue -> + if (continue = is_while) + then while_MP is_while (body vars) cond body + else return vars + +val while_MM : forall 'vars 'r. bool -> 'vars -> ('vars -> MR bool 'r) -> + ('vars -> MR 'vars 'r) -> MR 'vars 'r +let rec while_MM is_while vars cond body = + cond vars >>= fun continue -> + if (continue = is_while) + then body vars >>= fun vars -> while_MM is_while vars cond body + else return vars + let write_two_regs r1 r2 vec = let is_inc = let is_inc_r1 = is_inc_of_reg r1 in diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index 914955e0..4e649144 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -203,6 +203,35 @@ let rec foreachM_dec (i,stop,by) vars body = foreachM_dec (i - by,stop,by) vars body else return vars +val while_PP : forall 'vars. bool -> 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars +let rec while_PP is_while vars cond body = + if (cond vars = is_while) + then while_PP is_while (body vars) cond body + else vars + +val while_PM : forall 'regs 'vars 'e. bool -> 'vars -> ('vars -> bool) -> + ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e +let rec while_PM is_while vars cond body = + if (cond vars = is_while) + then body vars >>= fun vars -> while_PM is_while vars cond body + else return vars + +val while_MP : forall 'regs 'vars 'e. bool -> 'vars -> ('vars -> ME 'regs bool 'e) -> + ('vars -> 'vars) -> ME 'regs 'vars 'e +let rec while_MP is_while vars cond body = + cond vars >>= fun continue -> + if (continue = is_while) + then while_MP is_while (body vars) cond body + else return vars + +val while_MM : forall 'regs 'vars 'e. bool -> 'vars -> ('vars -> ME 'regs bool 'e) -> + ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e +let rec while_MM is_while vars cond body = + cond vars >>= fun continue -> + if (continue = is_while) + then body vars >>= fun vars -> while_MM is_while vars cond body + else return vars + (*let write_two_regs r1 r2 bvec state = let vec = bvec_to_vec bvec in let is_inc = diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 1bfb19aa..f981297d 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -514,6 +514,26 @@ let doc_exp_lem, doc_let_lem = (prefix 1 1 (separate space [string "fun";expY id;varspp;arrow]) (expN body)) ) ) + | Id_aux ((Id (("while_PP" | "while_PM" | + "while_MP" | "while_MM" ) as loopf),_)) -> + let [is_while;cond;body;e5] = args in + let varspp = match e5 with + | E_aux (E_tuple vars,_) -> + let vars = List.map (fun (E_aux (E_id (Id_aux (Id name,_)),_)) -> string name) vars in + begin match vars with + | [v] -> v + | _ -> parens (separate comma vars) end + | E_aux (E_id (Id_aux (Id name,_)),_) -> + string name + | E_aux (E_lit (L_aux (L_unit,_)),_) -> + string "_" in + parens ( + (prefix 2 1) + ((separate space) [string loopf;expY is_while;expY e5]) + ((prefix 0 1) + (parens (prefix 1 1 (separate space [string "fun";varspp;arrow]) (expN cond))) + (parens (prefix 1 1 (separate space [string "fun";varspp;arrow]) (expN body)))) + ) (* | Id_aux (Id "append",_) -> let [e1;e2] = args in let epp = align (expY e1 ^^ space ^^ string "++" ^//^ expY e2) in diff --git a/src/rewriter.ml b/src/rewriter.ml index e257e19c..5cf1a6b9 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -130,6 +130,7 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with union_effects (fun_app_effects f env) (union_eff_exps [e1;e2]) | E_if (e1,e2,e3) -> union_eff_exps [e1;e2;e3] | E_for (_,e1,e2,e3,_,e4) -> union_eff_exps [e1;e2;e3;e4] + | E_loop (_,e1,e2) -> union_eff_exps [e1;e2] | E_vector es -> union_eff_exps es | E_vector_indexed (ies,opt_default) -> let (_,es) = List.split ies in @@ -388,6 +389,8 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot))) = | E_if (c,t,e) -> rewrap (E_if (rewrite c,rewrite t, rewrite e)) | E_for (id, e1, e2, e3, o, body) -> rewrap (E_for (id, rewrite e1, rewrite e2, rewrite e3, o, rewrite body)) + | E_loop (loop, e1, e2) -> + rewrap (E_loop (loop, rewrite e1, rewrite e2)) | E_vector exps -> rewrap (E_vector (List.map rewrite exps)) | E_vector_indexed (exps,(Def_val_aux(default,dannot))) -> let def = match default with @@ -2847,6 +2850,10 @@ let rewrite_defs_letbind_effects = n_exp_name by (fun by -> let body = n_exp_term (effectful body) body in k (rewrap (E_for (id,start,stop,by,dir,body)))))) + | E_loop (loop, cond, body) -> + let cond = n_exp_term (effectful cond) cond in + let body = n_exp_term (effectful body) body in + k (rewrap (E_loop (loop,cond,body))) | E_vector exps -> n_exp_nameL exps (fun exps -> k (rewrap (E_vector exps))) @@ -3171,6 +3178,29 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = else P_aux (P_tup [pat; mktup_pat pl vars], simple_annot pl (typ_of v)) in Added_vars (v,pat) + | E_loop(loop,cond,body) -> + let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars body) in + let vartuple = mktup el vars in + (* let cond = rewrite_var_updates (add_vars false cond vartuple) in *) + let body = rewrite_var_updates (add_vars overwrite body vartuple) in + let (E_aux (_,(_,bannot))) = body in + let fname = match effectful cond, effectful body with + | false, false -> "while_PP" + | false, true -> "while_PM" + | true, false -> "while_MP" + | true, true -> "while_MM" in + let funcl = Id_aux (Id fname,Parse_ast.Generated el) in + let is_while = + match loop with + | While -> E_aux (E_lit (mk_lit L_true), simple_annot el bool_typ) + | Until -> E_aux (E_lit (mk_lit L_false), simple_annot el bool_typ) in + let v = E_aux (E_app (funcl,[is_while;cond;body;vartuple]), + (Parse_ast.Generated el, bannot)) in + let pat = + if overwrite then mktup_pat el vars + else P_aux (P_tup [pat; mktup_pat pl vars], + simple_annot pl (typ_of v)) in + Added_vars (v,pat) | E_if (c,e1,e2) -> let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (dedup eqidtyp (find_updated_vars e1 @ find_updated_vars e2)) in diff --git a/test/typecheck/pass/while_MM.sail b/test/typecheck/pass/while_MM.sail new file mode 100644 index 00000000..e6916edd --- /dev/null +++ b/test/typecheck/pass/while_MM.sail @@ -0,0 +1,23 @@ +default Order dec + +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. + ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add_range = "add" +val extern (int, int) -> int effect pure add_int = "add" +val forall Num 'n, Num 'o, Order 'ord. + (vector<'o, 'n, 'ord, bit>, int) -> vector<'o, 'n, 'ord, bit> effect pure add_vec_int +overload (deinfix +) [add_vec_int; add_range; add_int] + +val extern bool -> bool effect pure bool_not = "not" + +val cast forall Num 'n, Num 'l. [:0:] -> vector<'n,'l,dec,bit> effect pure cast_0_vec_dec + +register (bit[64]) COUNT +register (bool) INT + +function (unit) test () = { + COUNT := 0; + while (bool_not(INT)) do { + COUNT := COUNT + 1; + } +} + diff --git a/test/typecheck/pass/while_MP.sail b/test/typecheck/pass/while_MP.sail new file mode 100644 index 00000000..05d396e2 --- /dev/null +++ b/test/typecheck/pass/while_MP.sail @@ -0,0 +1,17 @@ +default Order dec + +val extern (int, int) -> int effect pure add_int = "add" +overload (deinfix +) [add_vec_int; add_range; add_int] + +val extern bool -> bool effect pure bool_not = "not" + +register (bool) INT + +function (int) test () = { + (int) count := 0; + while (bool_not(INT)) do { + count := count + 1; + }; + return count; +} + diff --git a/test/typecheck/pass/while_PM.sail b/test/typecheck/pass/while_PM.sail new file mode 100644 index 00000000..b03a87dc --- /dev/null +++ b/test/typecheck/pass/while_PM.sail @@ -0,0 +1,23 @@ +default Order dec + +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern (int, int) -> bool effect pure lt_int = "lt" +overload (deinfix <) [lt_range_atom; lt_int] + +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. + ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add_range = "add" +val extern (int, int) -> int effect pure add_int = "add" +overload (deinfix +) [add_range; add_int] + +val extern forall Num 'n, Num 'l, 'l >= 0. (vector<'n,'l,dec,bit>, int) -> bit effect pure vector_access = "bitvector_access_dec" + +register (bit[64]) GPR00 + +function (unit) test ((bit) b) = { + (int) i := 0; + while (i < 64) do { + GPR00[i] := b; + i := i + 1; + } +} + diff --git a/test/typecheck/pass/while_PP.sail b/test/typecheck/pass/while_PP.sail new file mode 100644 index 00000000..454cc9ac --- /dev/null +++ b/test/typecheck/pass/while_PP.sail @@ -0,0 +1,24 @@ +default Order dec + +val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt" +val extern (int, int) -> bool effect pure lt_int = "lt" +overload (deinfix <) [lt_range_atom; lt_int] + +val (int, int) -> int effect pure mult_int +overload (deinfix * ) [mult_int] + +val extern forall Num 'n, Num 'm, Num 'o, Num 'p. + ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add_range = "add" +val extern (int, int) -> int effect pure add_int = "add" +overload (deinfix +) [add_range; add_int] + +function (int) test ((int) n) = { + (int) i := 1; + (int) j := 1; + while (i < n) do { + j := i * j; + i := i + 1; + }; + j +} + |
