summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml2
-rw-r--r--src/rewrites.ml67
2 files changed, 23 insertions, 46 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 29d48543..b6f81de6 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -742,7 +742,7 @@ and string_of_pat (P_aux (pat, l)) =
| P_list pats -> "[||" ^ string_of_list "," string_of_pat pats ^ "||]"
| P_vector_concat pats -> string_of_list " : " string_of_pat pats
| P_vector pats -> "[" ^ string_of_list ", " string_of_pat pats ^ "]"
- | P_as (pat, id) -> string_of_pat pat ^ " as " ^ string_of_id id
+ | P_as (pat, id) -> "(" ^ string_of_pat pat ^ " as " ^ string_of_id id ^ ")"
| P_string_append pats -> string_of_list " ^^ " string_of_pat pats
| _ -> "PAT"
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 6ae987c0..b7104da3 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -669,7 +669,7 @@ let remove_vector_concat_pat pat =
P_aux (p,annot)
)
} in
-
+
(* let pat = remove_typed_patterns pat in *)
let fresh_id_v = fresh_id "v__" in
@@ -719,13 +719,14 @@ let remove_vector_concat_pat pat =
let name_vector_concat_elements =
let p_vector_concat pats =
let rec aux ((P_aux (p,((l,_) as a))) as pat) = match p with
- | P_vector _ -> P_aux (P_as (pat,fresh_id_v l),a)
+ | P_vector _ -> P_aux (P_as (pat, fresh_id_v l),a)
| P_lit _ -> P_aux (P_as (pat, fresh_id_v l), a)
| P_id id -> P_aux (P_id id,a)
| P_as (p,id) -> P_aux (P_as (p,id),a)
| P_typ (typ, pat) -> P_aux (P_typ (typ, aux pat),a)
| P_wild -> P_aux (P_wild,a)
- | P_app (id, pats) when Env.is_mapping id (env_of_annot a) -> P_aux (P_app (id, List.map aux pats), a)
+ | P_app (id, pats) when Env.is_mapping id (env_of_annot a) ->
+ P_aux (P_app (id, List.map aux pats), a)
| _ ->
raise
(Reporting_basic.err_unreachable
@@ -735,38 +736,25 @@ let remove_vector_concat_pat pat =
let pat = fold_pat name_vector_concat_elements pat in
-
-
let rec tag_last = function
| x :: xs -> let is_last = xs = [] in (x,is_last) :: tag_last xs
| _ -> [] in
(* remove names from vectors in vector_concat patterns and collect them as declarations for the
function body or expression *)
- let unname_vector_concat_elements = (* :
- ('a,
- 'a pat * ((tannot exp -> tannot exp) list),
- 'a pat_aux * ((tannot exp -> tannot exp) list),
- 'a fpat * ((tannot exp -> tannot exp) list),
- 'a fpat_aux * ((tannot exp -> tannot exp) list))
- pat_alg = *)
-
+ let unname_vector_concat_elements =
(* build a let-expression of the form "let child = root[i..j] in body" *)
let letbind_vec typ_opt (rootid,rannot) (child,cannot) (i,j) =
let (l,_) = cannot in
let env = env_of_annot rannot in
let rootname = string_of_id rootid in
let childname = string_of_id child in
-
+
let root = E_aux (E_id rootid, rannot) in
let index_i = simple_num l i in
let index_j = simple_num l j in
- (* FIXME *)
let subv = fix_eff_exp (E_aux (E_vector_subrange (root, index_i, index_j), cannot)) in
- (* let (_, ord, _) = vector_typ_args_of (Env.base_typ_of (env_of root) (typ_of root)) in
- let subrange_id = if is_order_inc ord then "bitvector_subrange_inc" else "bitvector_subrange_dec" in
- let subv = fix_eff_exp (E_aux (E_app (mk_id subrange_id, [root; index_i; index_j]), cannot)) in *)
let id_pat =
match typ_opt with
@@ -859,38 +847,28 @@ let remove_vector_concat_pat pat =
let (pat,decls) = fold_pat unname_vector_concat_elements pat in
- let decls =
- let module S = Set.Make(String) in
-
- let roots_needed =
- List.fold_right
- (fun (_,(rootid,childid)) roots_needed ->
- if S.mem childid roots_needed then
- (* let _ = print_endline rootid in *)
- S.add rootid roots_needed
- else if String.length childid >= 3 && String.sub childid 0 2 = String.sub "v__" 0 2 then
- roots_needed
- else
- S.add rootid roots_needed
- ) decls S.empty in
- List.filter
- (fun (_,(_,childid)) ->
- S.mem childid roots_needed ||
- String.length childid < 3 ||
- not (String.sub childid 0 2 = String.sub "v__" 0 2))
- decls in
+ (* We need to put the decls in the right order so letbinds are generated correctly for nested patterns *)
+ let module G = Graph.Make(String) in
+ let root_graph = List.fold_left (fun g (_, (root_id, child_id)) -> G.add_edge root_id child_id g) G.empty decls in
+ let root_order = G.topsort root_graph in
+ let find_root root_id =
+ try List.find (fun (_, (root_id', _)) -> root_id = root_id') decls with
+ | Not_found ->
+ (* If it's not a root the it's a leaf node in the graph, so search for child_id *)
+ try List.find (fun (_, (_, child_id)) -> root_id = child_id) decls with
+ | Not_found -> assert false (* Should never happen *)
+ in
+ let decls = List.map find_root root_order in
let (letbinds,decls) =
- let (decls,_) = List.split decls in
+ let decls = List.map fst decls in
List.split decls in
let decls = List.fold_left (fun f g x -> f (g x)) (fun b -> b) decls in
-
(* at this point shouldn't have P_as patterns in P_vector_concat patterns any more,
all P_as and P_id vectors should have their declarations in decls.
Now flatten all vector_concat patterns *)
-
let flatten =
let p_vector_concat ps =
let aux p acc = match p with
@@ -898,12 +876,11 @@ let remove_vector_concat_pat pat =
| pat -> pat :: acc in
P_vector_concat (List.fold_right aux ps []) in
{id_pat_alg with p_vector_concat = p_vector_concat} in
-
+
let pat = fold_pat flatten pat in
(* at this point pat should be a flat pattern: no vector_concat patterns
with vector_concats patterns as direct child-nodes anymore *)
-
let range a b =
let rec aux a b = if Big_int.greater a b then [] else a :: aux (Big_int.add a (Big_int.of_int 1)) b in
if Big_int.greater a b then List.rev (aux b a) else aux a b in
@@ -948,9 +925,9 @@ let remove_vector_concat_pat pat =
P_vector_concat ps' in
{id_pat_alg with p_vector_concat = p_vector_concat} in
-
+
let pat = fold_pat remove_vector_concats pat in
-
+
(pat,letbinds,decls)
(* assumes there are no more E_internal expressions *)