summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Bauereiss2017-09-27 15:23:43 +0100
committerThomas Bauereiss2017-09-27 15:23:43 +0100
commit381a3967ebd9269082b452669f507787decf28b0 (patch)
treead1f38b6689e1bccb267520124cb0d89365b4a82
parentced56765ec9324a0e690cbb4e790280d17413f99 (diff)
Add while-loops to Lem backend
-rw-r--r--src/gen_lib/prompt.lem29
-rw-r--r--src/gen_lib/state.lem29
-rw-r--r--src/pretty_print_lem.ml20
-rw-r--r--src/rewriter.ml30
-rw-r--r--test/typecheck/pass/while_MM.sail23
-rw-r--r--test/typecheck/pass/while_MP.sail17
-rw-r--r--test/typecheck/pass/while_PM.sail23
-rw-r--r--test/typecheck/pass/while_PP.sail24
8 files changed, 195 insertions, 0 deletions
diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem
index 5c539354..8e04bd30 100644
--- a/src/gen_lib/prompt.lem
+++ b/src/gen_lib/prompt.lem
@@ -170,6 +170,35 @@ let rec foreachM_dec (i,stop,by) vars body =
foreachM_dec (i - by,stop,by) vars body
else return vars
+val while_PP : forall 'vars. bool -> 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars
+let rec while_PP is_while vars cond body =
+ if (cond vars = is_while)
+ then while_PP is_while (body vars) cond body
+ else vars
+
+val while_PM : forall 'vars 'r. bool -> 'vars -> ('vars -> bool) ->
+ ('vars -> MR 'vars 'r) -> MR 'vars 'r
+let rec while_PM is_while vars cond body =
+ if (cond vars = is_while)
+ then body vars >>= fun vars -> while_PM is_while vars cond body
+ else return vars
+
+val while_MP : forall 'vars 'r. bool -> 'vars -> ('vars -> MR bool 'r) ->
+ ('vars -> 'vars) -> MR 'vars 'r
+let rec while_MP is_while vars cond body =
+ cond vars >>= fun continue ->
+ if (continue = is_while)
+ then while_MP is_while (body vars) cond body
+ else return vars
+
+val while_MM : forall 'vars 'r. bool -> 'vars -> ('vars -> MR bool 'r) ->
+ ('vars -> MR 'vars 'r) -> MR 'vars 'r
+let rec while_MM is_while vars cond body =
+ cond vars >>= fun continue ->
+ if (continue = is_while)
+ then body vars >>= fun vars -> while_MM is_while vars cond body
+ else return vars
+
let write_two_regs r1 r2 vec =
let is_inc =
let is_inc_r1 = is_inc_of_reg r1 in
diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem
index 914955e0..4e649144 100644
--- a/src/gen_lib/state.lem
+++ b/src/gen_lib/state.lem
@@ -203,6 +203,35 @@ let rec foreachM_dec (i,stop,by) vars body =
foreachM_dec (i - by,stop,by) vars body
else return vars
+val while_PP : forall 'vars. bool -> 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars
+let rec while_PP is_while vars cond body =
+ if (cond vars = is_while)
+ then while_PP is_while (body vars) cond body
+ else vars
+
+val while_PM : forall 'regs 'vars 'e. bool -> 'vars -> ('vars -> bool) ->
+ ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+let rec while_PM is_while vars cond body =
+ if (cond vars = is_while)
+ then body vars >>= fun vars -> while_PM is_while vars cond body
+ else return vars
+
+val while_MP : forall 'regs 'vars 'e. bool -> 'vars -> ('vars -> ME 'regs bool 'e) ->
+ ('vars -> 'vars) -> ME 'regs 'vars 'e
+let rec while_MP is_while vars cond body =
+ cond vars >>= fun continue ->
+ if (continue = is_while)
+ then while_MP is_while (body vars) cond body
+ else return vars
+
+val while_MM : forall 'regs 'vars 'e. bool -> 'vars -> ('vars -> ME 'regs bool 'e) ->
+ ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+let rec while_MM is_while vars cond body =
+ cond vars >>= fun continue ->
+ if (continue = is_while)
+ then body vars >>= fun vars -> while_MM is_while vars cond body
+ else return vars
+
(*let write_two_regs r1 r2 bvec state =
let vec = bvec_to_vec bvec in
let is_inc =
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 1bfb19aa..f981297d 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -514,6 +514,26 @@ let doc_exp_lem, doc_let_lem =
(prefix 1 1 (separate space [string "fun";expY id;varspp;arrow]) (expN body))
)
)
+ | Id_aux ((Id (("while_PP" | "while_PM" |
+ "while_MP" | "while_MM" ) as loopf),_)) ->
+ let [is_while;cond;body;e5] = args in
+ let varspp = match e5 with
+ | E_aux (E_tuple vars,_) ->
+ let vars = List.map (fun (E_aux (E_id (Id_aux (Id name,_)),_)) -> string name) vars in
+ begin match vars with
+ | [v] -> v
+ | _ -> parens (separate comma vars) end
+ | E_aux (E_id (Id_aux (Id name,_)),_) ->
+ string name
+ | E_aux (E_lit (L_aux (L_unit,_)),_) ->
+ string "_" in
+ parens (
+ (prefix 2 1)
+ ((separate space) [string loopf;expY is_while;expY e5])
+ ((prefix 0 1)
+ (parens (prefix 1 1 (separate space [string "fun";varspp;arrow]) (expN cond)))
+ (parens (prefix 1 1 (separate space [string "fun";varspp;arrow]) (expN body))))
+ )
(* | Id_aux (Id "append",_) ->
let [e1;e2] = args in
let epp = align (expY e1 ^^ space ^^ string "++" ^//^ expY e2) in
diff --git a/src/rewriter.ml b/src/rewriter.ml
index e257e19c..5cf1a6b9 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -130,6 +130,7 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with
union_effects (fun_app_effects f env) (union_eff_exps [e1;e2])
| E_if (e1,e2,e3) -> union_eff_exps [e1;e2;e3]
| E_for (_,e1,e2,e3,_,e4) -> union_eff_exps [e1;e2;e3;e4]
+ | E_loop (_,e1,e2) -> union_eff_exps [e1;e2]
| E_vector es -> union_eff_exps es
| E_vector_indexed (ies,opt_default) ->
let (_,es) = List.split ies in
@@ -388,6 +389,8 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot))) =
| E_if (c,t,e) -> rewrap (E_if (rewrite c,rewrite t, rewrite e))
| E_for (id, e1, e2, e3, o, body) ->
rewrap (E_for (id, rewrite e1, rewrite e2, rewrite e3, o, rewrite body))
+ | E_loop (loop, e1, e2) ->
+ rewrap (E_loop (loop, rewrite e1, rewrite e2))
| E_vector exps -> rewrap (E_vector (List.map rewrite exps))
| E_vector_indexed (exps,(Def_val_aux(default,dannot))) ->
let def = match default with
@@ -2847,6 +2850,10 @@ let rewrite_defs_letbind_effects =
n_exp_name by (fun by ->
let body = n_exp_term (effectful body) body in
k (rewrap (E_for (id,start,stop,by,dir,body))))))
+ | E_loop (loop, cond, body) ->
+ let cond = n_exp_term (effectful cond) cond in
+ let body = n_exp_term (effectful body) body in
+ k (rewrap (E_loop (loop,cond,body)))
| E_vector exps ->
n_exp_nameL exps (fun exps ->
k (rewrap (E_vector exps)))
@@ -3171,6 +3178,29 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
else P_aux (P_tup [pat; mktup_pat pl vars],
simple_annot pl (typ_of v)) in
Added_vars (v,pat)
+ | E_loop(loop,cond,body) ->
+ let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars body) in
+ let vartuple = mktup el vars in
+ (* let cond = rewrite_var_updates (add_vars false cond vartuple) in *)
+ let body = rewrite_var_updates (add_vars overwrite body vartuple) in
+ let (E_aux (_,(_,bannot))) = body in
+ let fname = match effectful cond, effectful body with
+ | false, false -> "while_PP"
+ | false, true -> "while_PM"
+ | true, false -> "while_MP"
+ | true, true -> "while_MM" in
+ let funcl = Id_aux (Id fname,Parse_ast.Generated el) in
+ let is_while =
+ match loop with
+ | While -> E_aux (E_lit (mk_lit L_true), simple_annot el bool_typ)
+ | Until -> E_aux (E_lit (mk_lit L_false), simple_annot el bool_typ) in
+ let v = E_aux (E_app (funcl,[is_while;cond;body;vartuple]),
+ (Parse_ast.Generated el, bannot)) in
+ let pat =
+ if overwrite then mktup_pat el vars
+ else P_aux (P_tup [pat; mktup_pat pl vars],
+ simple_annot pl (typ_of v)) in
+ Added_vars (v,pat)
| E_if (c,e1,e2) ->
let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t)))
(dedup eqidtyp (find_updated_vars e1 @ find_updated_vars e2)) in
diff --git a/test/typecheck/pass/while_MM.sail b/test/typecheck/pass/while_MM.sail
new file mode 100644
index 00000000..e6916edd
--- /dev/null
+++ b/test/typecheck/pass/while_MM.sail
@@ -0,0 +1,23 @@
+default Order dec
+
+val extern forall Num 'n, Num 'm, Num 'o, Num 'p.
+ ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add_range = "add"
+val extern (int, int) -> int effect pure add_int = "add"
+val forall Num 'n, Num 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, int) -> vector<'o, 'n, 'ord, bit> effect pure add_vec_int
+overload (deinfix +) [add_vec_int; add_range; add_int]
+
+val extern bool -> bool effect pure bool_not = "not"
+
+val cast forall Num 'n, Num 'l. [:0:] -> vector<'n,'l,dec,bit> effect pure cast_0_vec_dec
+
+register (bit[64]) COUNT
+register (bool) INT
+
+function (unit) test () = {
+ COUNT := 0;
+ while (bool_not(INT)) do {
+ COUNT := COUNT + 1;
+ }
+}
+
diff --git a/test/typecheck/pass/while_MP.sail b/test/typecheck/pass/while_MP.sail
new file mode 100644
index 00000000..05d396e2
--- /dev/null
+++ b/test/typecheck/pass/while_MP.sail
@@ -0,0 +1,17 @@
+default Order dec
+
+val extern (int, int) -> int effect pure add_int = "add"
+overload (deinfix +) [add_vec_int; add_range; add_int]
+
+val extern bool -> bool effect pure bool_not = "not"
+
+register (bool) INT
+
+function (int) test () = {
+ (int) count := 0;
+ while (bool_not(INT)) do {
+ count := count + 1;
+ };
+ return count;
+}
+
diff --git a/test/typecheck/pass/while_PM.sail b/test/typecheck/pass/while_PM.sail
new file mode 100644
index 00000000..b03a87dc
--- /dev/null
+++ b/test/typecheck/pass/while_PM.sail
@@ -0,0 +1,23 @@
+default Order dec
+
+val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt"
+val extern (int, int) -> bool effect pure lt_int = "lt"
+overload (deinfix <) [lt_range_atom; lt_int]
+
+val extern forall Num 'n, Num 'm, Num 'o, Num 'p.
+ ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add_range = "add"
+val extern (int, int) -> int effect pure add_int = "add"
+overload (deinfix +) [add_range; add_int]
+
+val extern forall Num 'n, Num 'l, 'l >= 0. (vector<'n,'l,dec,bit>, int) -> bit effect pure vector_access = "bitvector_access_dec"
+
+register (bit[64]) GPR00
+
+function (unit) test ((bit) b) = {
+ (int) i := 0;
+ while (i < 64) do {
+ GPR00[i] := b;
+ i := i + 1;
+ }
+}
+
diff --git a/test/typecheck/pass/while_PP.sail b/test/typecheck/pass/while_PP.sail
new file mode 100644
index 00000000..454cc9ac
--- /dev/null
+++ b/test/typecheck/pass/while_PP.sail
@@ -0,0 +1,24 @@
+default Order dec
+
+val extern forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom = "lt"
+val extern (int, int) -> bool effect pure lt_int = "lt"
+overload (deinfix <) [lt_range_atom; lt_int]
+
+val (int, int) -> int effect pure mult_int
+overload (deinfix * ) [mult_int]
+
+val extern forall Num 'n, Num 'm, Num 'o, Num 'p.
+ ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add_range = "add"
+val extern (int, int) -> int effect pure add_int = "add"
+overload (deinfix +) [add_range; add_int]
+
+function (int) test ((int) n) = {
+ (int) i := 1;
+ (int) j := 1;
+ while (i < n) do {
+ j := i * j;
+ i := i + 1;
+ };
+ j
+}
+