summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/rewriter.ml191
1 files changed, 122 insertions, 69 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index de865be2..5207d880 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -312,41 +312,41 @@ let rewrite_defs (Defs defs) = rewrite_defs_base
rewrite_defs = rewrite_defs_base} (Defs defs)
type ('pat,'pat_aux,'fpat,'fpat_aux,'annot) pat_alg =
- { p_lit : lit -> 'pat_aux
- ; p_wild : 'pat_aux
- ; p_as : 'pat * id -> 'pat_aux
- ; p_typ : Ast.typ * 'pat -> 'pat_aux
- ; p_id : id -> 'pat_aux
- ; p_app : id * 'pat list -> 'pat_aux
- ; p_record : 'fpat list * bool -> 'pat_aux
- ; p_vector : 'pat list -> 'pat_aux
+ { p_lit : lit -> 'pat_aux
+ ; p_wild : 'pat_aux
+ ; p_as : 'pat * id -> 'pat_aux
+ ; p_typ : Ast.typ * 'pat -> 'pat_aux
+ ; p_id : id -> 'pat_aux
+ ; p_app : id * 'pat list -> 'pat_aux
+ ; p_record : 'fpat list * bool -> 'pat_aux
+ ; p_vector : 'pat list -> 'pat_aux
; p_vector_indexed : (int * 'pat) list -> 'pat_aux
- ; p_vector_concat : 'pat list -> 'pat_aux
- ; p_tup : 'pat list -> 'pat_aux
- ; p_list : 'pat list -> 'pat_aux
- ; p_aux : 'pat_aux * 'annot -> 'pat
- ; fP_aux : 'fpat_aux * 'annot -> 'fpat
- ; fP_Fpat : id * 'pat -> 'fpat_aux
+ ; p_vector_concat : 'pat list -> 'pat_aux
+ ; p_tup : 'pat list -> 'pat_aux
+ ; p_list : 'pat list -> 'pat_aux
+ ; p_aux : 'pat_aux * 'annot -> 'pat
+ ; fP_aux : 'fpat_aux * 'annot -> 'fpat
+ ; fP_Fpat : id * 'pat -> 'fpat_aux
}
(* fold from term alg into alg *)
let rec fold_pat_aux alg = function
- | P_lit lit -> alg.p_lit lit
- | P_wild -> alg.p_wild
- | P_id id -> alg.p_id id
- | P_as (p,id) -> alg.p_as (fold_pat alg p,id)
- | P_typ (typ,p) -> alg.p_typ (typ,fold_pat alg p)
- | P_app (id,ps) -> alg.p_app (id,List.map (fold_pat alg) ps)
- | P_record (ps,b) -> alg.p_record (List.map (fold_fpat alg) ps, b)
- | P_vector ps -> alg.p_vector (List.map (fold_pat alg) ps)
+ | P_lit lit -> alg.p_lit lit
+ | P_wild -> alg.p_wild
+ | P_id id -> alg.p_id id
+ | P_as (p,id) -> alg.p_as (fold_pat alg p,id)
+ | P_typ (typ,p) -> alg.p_typ (typ,fold_pat alg p)
+ | P_app (id,ps) -> alg.p_app (id,List.map (fold_pat alg) ps)
+ | P_record (ps,b) -> alg.p_record (List.map (fold_fpat alg) ps, b)
+ | P_vector ps -> alg.p_vector (List.map (fold_pat alg) ps)
| P_vector_indexed ps -> alg.p_vector_indexed (List.map (fun (i,p) -> (i, fold_pat alg p)) ps)
- | P_vector_concat ps -> alg.p_vector_concat (List.map (fold_pat alg) ps)
- | P_tup ps -> alg.p_tup (List.map (fold_pat alg) ps)
- | P_list ps -> alg.p_list (List.map (fold_pat alg) ps)
+ | P_vector_concat ps -> alg.p_vector_concat (List.map (fold_pat alg) ps)
+ | P_tup ps -> alg.p_tup (List.map (fold_pat alg) ps)
+ | P_list ps -> alg.p_list (List.map (fold_pat alg) ps)
and fold_pat alg = function
- | P_aux (pat,annot) -> alg.p_aux (fold_pat_aux alg pat,annot)
+ | P_aux (pat,annot) -> alg.p_aux (fold_pat_aux alg pat,annot)
and fold_fpat_aux alg = function
- | FP_Fpat (id,pat) -> alg.fP_Fpat (id,fold_pat alg pat)
+ | FP_Fpat (id,pat) -> alg.fP_Fpat (id,fold_pat alg pat)
and fold_fpat alg = function
| FP_aux (fpat,annot) -> alg.fP_aux (fold_fpat_aux alg fpat,annot)
@@ -395,64 +395,88 @@ let remove_vector_concat_pat pat =
let p_vector_concat pats =
let aux ((P_aux (p,a)) as pat) = match p with
| P_vector _ -> P_aux (P_as (pat,fresh_name()),a)
- (* | P_vector_concat. cannot happen after fold function name_vector_concat_roots *)
+ (* | P_vector_concat. cannot happen after folding function name_vector_concat_roots *)
| _ -> pat in (* this can only be P_as and P_id *)
P_vector_concat (List.map aux pats) in
{id_f with p_vector_concat = p_vector_concat} in
let pat = fold_pat name_vector_concat_elements pat in
- let zip l1 l2 = List.fold_right2 (fun x y acc -> (x,y) :: acc) l1 l2 [] in
- let unzip l = List.fold_right (fun (a,b) (accA,accB) -> (a :: accA, b :: accB)) l ([],[]) in
-
(* remove names from vectors in vector_concat patterns and collect them as declarations for the
function body or expression *)
- let unname_vector_concat_elements : ('a pat * (string list), 'a pat_aux * (string list), 'a fpat * (string list),
- 'a fpat_aux * (string list), 'a annot) pat_alg =
- let p_aux ((pattern,decls),annot) = match pattern with
- | P_as (P_aux (P_vector_concat pats,_),name) ->
- let aux (pat_acc,decl_acc,pos) = function
- | (P_aux (P_as (P_aux (p,annot),name2),
- (l,Base(([],{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)))) ->
- (pat_acc @ [P_aux (p,annot)],
- decl_acc @ ["define name2 as vector <name> [pos;pos + length -1]"],
- add_big_int pos length)
- | (P_aux (P_id name2,
- ((l,Base(([],{t = Tapp ("vector", [_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)) as annot))) ->
- (pat_acc @ [P_aux (P_id name2,annot)],
- decl_acc @ ["define name2 as vector <name> [pos;pos + length -1]"],
- add_big_int pos length)
- | (P_aux (_,(l,Base(([],{t = Tapp ("vector", [_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_))) as p)
- -> (pat_acc @ [p],decl_acc,add_big_int pos length)
- | (P_aux (_,(l,_))) -> raise (Reporting_basic.err_unreachable l "Non-vector in vector-concat pattern") in
- let (pats',decls',_) = List.fold_left aux ([],[],zero_big_int) pats in
- (P_aux (P_vector_concat pats',annot),decls @ decls')
- | _ -> (P_aux (pattern,annot),decls) in
+ let unname_vector_concat_elements :
+ ('a pat * ((tannot exp -> tannot exp) list),
+ 'a pat_aux * ((tannot exp -> tannot exp) list),
+ 'a fpat * ((tannot exp -> tannot exp) list),
+ 'a fpat_aux * ((tannot exp -> tannot exp) list),
+ 'a annot)
+ pat_alg =
+
+ (* build a let-expression of the form "let child = root[i..j] in body" *)
+ let letbind_vec (rootid,rannot) (child,cannot) (i,j) body =
+ let index n =
+ let typ = simple_annot {t = Tapp ("atom",[TA_nexp (mk_c (big_int_of_int n))])} in
+ E_aux (E_lit (L_aux (L_num n,Unknown)), (Parse_ast.Unknown,typ)) in
+ let subv = E_aux (E_vector_subrange (E_aux (E_id rootid,rannot),index i,index j),cannot) in
+ let typ = (Parse_ast.Unknown,simple_annot {t = Tid "unit"}) in
+ E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_id child,cannot),subv),cannot),body),typ) in
+
+ let p_aux ((pattern,decls),rannot) = match pattern with
+ | P_as (P_aux (P_vector_concat pats,_),rootid) ->
+ let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot)) = match cannot with
+ | (_,Base((_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)) ->
+ let length = int_of_big_int length in
+ (match p with
+ (* if we see a named vector pattern, remove the name and remember to declare it later *)
+ | P_as (P_aux (p,cannot),cname) ->
+ (pos + length, pat_acc @ [P_aux (p,cannot)],
+ decl_acc @ [letbind_vec (rootid,rannot) (cname,cannot) (pos,pos + length - 1)])
+ (* if we see a P_id variable, remember to declare it later *)
+ | P_id cname ->
+ (pos + length, pat_acc @ [P_aux (P_id cname,cannot)],
+ decl_acc @ [letbind_vec (rootid,rannot) (cname,cannot) (pos,pos + length - 1)])
+ (* normal vector patterns are fine *)
+ | _ -> (pos + length, pat_acc @ [P_aux (p,cannot)],decl_acc) )
+ (* non-vector patterns aren't *)
+ | (l,_) -> raise (Reporting_basic.err_unreachable l "Non-vector in vector-concat pattern") in
+ let (_,pats',decls') = List.fold_left aux (0,[],[]) pats in
+ (P_aux (P_vector_concat pats',rannot),decls @ decls')
+ | _ -> (P_aux (pattern,rannot),decls) in
{ p_lit = (fun lit -> (P_lit lit,[]))
; p_wild = (P_wild,[])
; p_as = (fun ((pat,decls),id) -> (P_as (pat,id),decls))
; p_typ = (fun (typ,(pat,decls)) -> (P_typ (typ,pat),decls))
; p_id = (fun id -> (P_id id,[]))
- ; p_app = (fun (id,ps) -> let (ps,decls) = unzip ps in (P_app (id,ps),List.flatten decls))
- ; p_record = (fun (ps,b) -> let (ps,decls) = unzip ps in (P_record (ps,b),List.flatten decls))
- ; p_vector = (fun ps -> let (ps,decls) = unzip ps in (P_vector ps,List.flatten decls))
- ; p_vector_indexed = (fun ps -> let (is,ps) = unzip ps in let (ps,decls) = unzip ps in let ps = zip is ps in
+ ; p_app = (fun (id,ps) -> let (ps,decls) = List.split ps in
+ (P_app (id,ps),List.flatten decls))
+ ; p_record = (fun (ps,b) -> let (ps,decls) = List.split ps in
+ (P_record (ps,b),List.flatten decls))
+ ; p_vector = (fun ps -> let (ps,decls) = List.split ps in
+ (P_vector ps,List.flatten decls))
+ ; p_vector_indexed = (fun ps -> let (is,ps) = List.split ps in
+ let (ps,decls) = List.split ps in
+ let ps = List.combine is ps in
(P_vector_indexed ps,List.flatten decls))
- ; p_vector_concat = (fun ps -> let (ps,decls) = unzip ps in (P_vector_concat ps,List.flatten decls))
- ; p_tup = (fun ps -> let (ps,decls) = unzip ps in (P_tup ps,List.flatten decls))
- ; p_list = (fun ps -> let (ps,decls) = unzip ps in (P_list ps,List.flatten decls))
+ ; p_vector_concat = (fun ps -> let (ps,decls) = List.split ps in
+ (P_vector_concat ps,List.flatten decls))
+ ; p_tup = (fun ps -> let (ps,decls) = List.split ps in
+ (P_tup ps,List.flatten decls))
+ ; p_list = (fun ps -> let (ps,decls) = List.split ps in
+ (P_list ps,List.flatten decls))
; p_aux = (fun ((pat,decls),annot) -> p_aux ((pat,decls),annot))
; fP_aux = (fun ((fpat,decls),annot) -> (FP_aux (fpat,annot),decls))
; fP_Fpat = (fun (id,(pat,decls)) -> (FP_Fpat (id,pat),decls))
} in
- let (pat,decls) = fold_pat unname_vector_concat_elements pat in
-
- (* at this point shouldn't have P_as patterns in P_vector_concat patterns any more,
- all P_as and P_id vectors should have their declarations in decls.
- Now flatten all vector_concat patterns*)
+ let (pat,decls_list) = fold_pat unname_vector_concat_elements pat in
+ let decls = List.fold_right (fun f g x -> f (g x)) decls_list (fun b -> b) in
+
+ (* at this point shouldn't have P_as patterns in P_vector_concat patterns any more,
+ all P_as and P_id vectors should have their declarations in decls.
+ Now flatten all vector_concat patterns *)
+
let flatten =
let p_vector_concat ps =
let aux p acc = match p with
@@ -472,12 +496,13 @@ let remove_vector_concat_pat pat =
let remove_vector_concats =
let p_vector_concat ps =
- let aux acc = function
- | P_aux (P_vector ps,annot) -> acc @ ps
- | P_aux (P_id name2, (_,Base(([],{t = Tapp ("vector", [_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_))) ->
+ let aux acc (P_aux (p,annot)) = match p,annot with
+ | P_vector ps,_ -> acc @ ps
+ | P_id _,
+ (_,Base((_,{t = Tapp ("vector", [_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)) ->
let wild _ = P_aux (P_wild,(Parse_ast.Unknown,simple_annot {t = Tid "bit"})) in
acc @ (List.map wild (range 0 ((int_of_big_int length) - 1)))
- | (P_aux (_,(l,_))) -> raise (Reporting_basic.err_unreachable l "Non-vector in vector-concat pattern") in
+ | _,(l,_) -> raise (Reporting_basic.err_unreachable l "Non-vector in vector-concat pattern") in
P_vector_concat (List.fold_left aux [] ps) in
{id_f with p_vector_concat = p_vector_concat} in
@@ -485,7 +510,35 @@ let remove_vector_concat_pat pat =
(pat,decls)
-
+(* assumes there are no more E_internal expressions *)
+let rewrite_exp_remove_vector_concat_pat rewriters nmap (E_aux (exp,(l,annot)) as full_exp) =
+ let rewrap e = E_aux (e,(l,annot)) in
+ let rewrite_rec = rewriters.rewrite_exp rewriters nmap in
+ let rewrite_base = rewrite_exp rewriters nmap in
+ match exp with
+ | E_case (e,ps) ->
+ let aux (Pat_aux (Pat_exp (pat,body),annot')) =
+ let (pat,decls) = remove_vector_concat_pat pat in
+ Pat_aux (Pat_exp (pat,decls (rewrite_rec body)),annot') in
+ rewrap (E_case (rewrite_rec e,List.map aux ps))
+ | E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) ->
+ let (pat,decls) = remove_vector_concat_pat pat in
+ rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'),
+ decls (rewrite_rec body)))
+ | E_let (LB_aux (LB_val_implicit (pat,v),annot'),body) ->
+ let (pat,decls) = remove_vector_concat_pat pat in
+ rewrap (E_let (LB_aux (LB_val_implicit (pat,rewrite_rec v),annot'),
+ decls (rewrite_rec body)))
+ | exp -> rewrite_base full_exp
+
+let rewrite_defs_remove_vector_concat defs = rewrite_defs_base
+ {rewrite_exp = rewrite_exp_remove_vector_concat_pat;
+ rewrite_pat = rewrite_pat;
+ rewrite_let = rewrite_let;
+ rewrite_lexp = rewrite_lexp;
+ rewrite_fun = rewrite_fun;
+ rewrite_def = rewrite_def;
+ rewrite_defs = rewrite_defs_base} defs
(*Expects to be called after rewrite_defs; thus the following should not appear:
internal_exp of any form