summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/jib/anf.ml64
-rw-r--r--src/jib/anf.mli2
-rw-r--r--src/jib/jib_compile.ml7
-rw-r--r--src/type_check.ml2
4 files changed, 46 insertions, 29 deletions
diff --git a/src/jib/anf.ml b/src/jib/anf.ml
index 7b91e4a5..29d2c016 100644
--- a/src/jib/anf.ml
+++ b/src/jib/anf.ml
@@ -208,7 +208,7 @@ let rec alexp_rename from_id to_id = function
| AL_addr (id, typ) when Id.compare from_id id = 0 -> AL_addr (to_id, typ)
| AL_addr (id, typ) -> AL_id (id, typ)
| AL_field (alexp, field_id) -> AL_field (alexp_rename from_id to_id alexp, field_id)
-
+
let rec aexp_rename from_id to_id (AE_aux (aexp, env, l)) =
let recur = aexp_rename from_id to_id in
let aexp = match aexp with
@@ -239,6 +239,38 @@ and apexp_rename from_id to_id (apat, aexp1, aexp2) =
else
(apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2)
+let rec fold_aexp f (AE_aux (aexp, env, l)) =
+ let aexp = match aexp with
+ | AE_app (id, vs, typ) -> AE_app (id, vs, typ)
+ | AE_cast (aexp, typ) -> AE_cast (fold_aexp f aexp, typ)
+ | AE_assign (alexp, aexp) -> AE_assign (alexp, fold_aexp f aexp)
+ | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, fold_aexp f aexp)
+ | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> AE_let (mut, id, typ1, fold_aexp f aexp1, fold_aexp f aexp2, typ2)
+ | AE_block (aexps, aexp, typ) -> AE_block (List.map (fold_aexp f) aexps, fold_aexp f aexp, typ)
+ | AE_if (aval, aexp1, aexp2, typ) ->
+ AE_if (aval, fold_aexp f aexp1, fold_aexp f aexp2, typ)
+ | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, fold_aexp f aexp1, fold_aexp f aexp2)
+ | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) ->
+ AE_for (id, fold_aexp f aexp1, fold_aexp f aexp2, fold_aexp f aexp3, order, fold_aexp f aexp4)
+ | AE_case (aval, cases, typ) ->
+ AE_case (aval, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ)
+ | AE_try (aexp, cases, typ) ->
+ AE_try (fold_aexp f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ)
+ | AE_field _ | AE_record_update _ | AE_val _ | AE_return _ | AE_throw _ as v -> v
+ in
+ f (AE_aux (aexp, env, l))
+
+let aexp_bindings aexp =
+ let ids = ref IdSet.empty in
+ let collect_lets = function
+ | AE_aux (AE_let (_, id, _, _, _, _), _, _) as aexp ->
+ ids := IdSet.add id !ids;
+ aexp
+ | aexp -> aexp
+ in
+ ignore (fold_aexp collect_lets aexp);
+ !ids
+
let shadow_counter = ref 0
let new_shadow id =
@@ -288,7 +320,8 @@ and no_shadow_apexp ids (apat, aexp1, aexp2) =
let rename aexp = List.fold_left (fun aexp (from_id, to_id) -> aexp_rename from_id to_id aexp) aexp shadows in
let rename_apat apat = List.fold_left (fun apat (from_id, to_id) -> apat_rename from_id to_id apat) apat shadows in
let ids = IdSet.union (apat_bindings apat) (IdSet.union ids (IdSet.of_list (List.map snd shadows))) in
- (rename_apat apat, no_shadow ids (rename aexp1), no_shadow ids (rename aexp2))
+ let new_guard = no_shadow ids (rename aexp1) in
+ (rename_apat apat, new_guard, no_shadow (IdSet.union ids (aexp_bindings new_guard)) (rename aexp2))
(* Map over all the avals in an aexp. *)
let rec map_aval f (AE_aux (aexp, env, l)) =
@@ -341,27 +374,6 @@ let rec map_functions f (AE_aux (aexp, env, l)) =
in
AE_aux (aexp, env, l)
-let rec fold_aexp f (AE_aux (aexp, env, l)) =
- let aexp = match aexp with
- | AE_app (id, vs, typ) -> AE_app (id, vs, typ)
- | AE_cast (aexp, typ) -> AE_cast (fold_aexp f aexp, typ)
- | AE_assign (alexp, aexp) -> AE_assign (alexp, fold_aexp f aexp)
- | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, fold_aexp f aexp)
- | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> AE_let (mut, id, typ1, fold_aexp f aexp1, fold_aexp f aexp2, typ2)
- | AE_block (aexps, aexp, typ) -> AE_block (List.map (fold_aexp f) aexps, fold_aexp f aexp, typ)
- | AE_if (aval, aexp1, aexp2, typ) ->
- AE_if (aval, fold_aexp f aexp1, fold_aexp f aexp2, typ)
- | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, fold_aexp f aexp1, fold_aexp f aexp2)
- | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) ->
- AE_for (id, fold_aexp f aexp1, fold_aexp f aexp2, fold_aexp f aexp3, order, fold_aexp f aexp4)
- | AE_case (aval, cases, typ) ->
- AE_case (aval, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ)
- | AE_try (aexp, cases, typ) ->
- AE_try (fold_aexp f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ)
- | AE_field _ | AE_record_update _ | AE_val _ | AE_return _ | AE_throw _ as v -> v
- in
- f (AE_aux (aexp, env, l))
-
(* For debugging we provide a pretty printer for ANF expressions. *)
let pp_lvar lvar doc =
@@ -393,8 +405,8 @@ let rec pp_alexp = function
| AL_addr (id, typ) ->
string "*" ^^ parens (pp_annot typ (pp_id id))
| AL_field (alexp, field) ->
- pp_alexp alexp ^^ dot ^^ pp_id field
-
+ pp_alexp alexp ^^ dot ^^ pp_id field
+
let rec pp_aexp (AE_aux (aexp, _, _)) =
match aexp with
| AE_val v -> pp_aval v
@@ -550,7 +562,7 @@ let rec anf (E_aux (e_aux, ((l, _) as exp_annot)) as exp) =
Reporting.unreachable l __POS__
("Encountered complex l-expression " ^ string_of_lexp lexp ^ " when converting to ANF")
in
-
+
let to_aval (AE_aux (aexp_aux, env, _) as aexp) =
let mk_aexp (AE_aux (_, _, l)) aexp = AE_aux (aexp, env, l) in
match aexp_aux with
diff --git a/src/jib/anf.mli b/src/jib/anf.mli
index a9ee10a2..4007911b 100644
--- a/src/jib/anf.mli
+++ b/src/jib/anf.mli
@@ -152,6 +152,8 @@ val map_functions : (Env.t -> Ast.l -> id -> ('a aval) list -> 'a -> 'a aexp_aux
val fold_aexp : ('a aexp -> 'a aexp) -> 'a aexp -> 'a aexp
+val aexp_bindings : 'a aexp -> IdSet.t
+
(** Remove all variable shadowing in an ANF expression *)
val no_shadow : IdSet.t -> 'a aexp -> 'a aexp
diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml
index 5bf53009..42228cd6 100644
--- a/src/jib/jib_compile.ml
+++ b/src/jib/jib_compile.ml
@@ -1297,9 +1297,12 @@ let compile_funcl ctx id pat guard exp =
List.fold_left2 (fun ctx (id, _) ctyp -> { ctx with locals = Bindings.add id (Immutable, ctyp) ctx.locals }) ctx compiled_args arg_ctyps
in
+ let guard_bindings = ref IdSet.empty in
let guard_instrs = match guard with
| Some guard ->
- let guard_aexp = C.optimize_anf ctx (no_shadow (pat_ids pat) (anf guard)) in
+ let guard = anf guard in
+ guard_bindings := aexp_bindings guard;
+ let guard_aexp = C.optimize_anf ctx (no_shadow (pat_ids pat) guard) in
let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard_aexp in
let guard_label = label "guard_" in
let gs = ngensym () in
@@ -1316,7 +1319,7 @@ let compile_funcl ctx id pat guard exp =
in
(* Optimize and compile the expression to ANF. *)
- let aexp = C.optimize_anf ctx (no_shadow (pat_ids pat) (anf exp)) in
+ let aexp = C.optimize_anf ctx (no_shadow (IdSet.union (pat_ids pat) !guard_bindings) (anf exp)) in
let setup, call, cleanup = compile_aexp ctx aexp in
let destructure, destructure_cleanup =
diff --git a/src/type_check.ml b/src/type_check.ml
index 1d6566ef..3aae9b09 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -1156,7 +1156,7 @@ end = struct
let add_toplevel_lets ids env =
{ env with top_letbinds = IdSet.union ids env.top_letbinds }
-
+
let get_toplevel_lets env = env.top_letbinds
let add_variant id variant env =