summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/pretty_print.ml2
-rw-r--r--src/rewriter.ml104
2 files changed, 78 insertions, 28 deletions
diff --git a/src/pretty_print.ml b/src/pretty_print.ml
index 0b776585..068f16a2 100644
--- a/src/pretty_print.ml
+++ b/src/pretty_print.ml
@@ -1614,7 +1614,7 @@ let doc_fundef_ocaml (FD_aux(FD_function(r, typa, efa, fcls),_)) =
match fcls with
| [] -> failwith "FD_function with empty function list"
| [FCL_aux (FCL_Funcl(id,pat,exp),_)] ->
- separate space [(string "let"); (doc_rec_ocaml r); (doc_id_ocaml id); (doc_pat_ocaml pat); equals; (doc_exp_ocaml exp)]
+ (separate space [(string "let"); (doc_rec_ocaml r); (doc_id_ocaml id); (doc_pat_ocaml pat); equals]) ^^ hardline ^^ (doc_exp_ocaml exp)
| _ ->
let id = get_id fcls in
let sep = hardline ^^ pipe ^^ space in
diff --git a/src/rewriter.ml b/src/rewriter.ml
index c598d6ef..ceae3462 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -498,7 +498,34 @@ and fold_letbind alg (LB_aux (letbind_aux,annot)) = alg.lB_aux (fold_letbind_aux
let remove_vector_concat_pat pat =
-
+ (* ivc: bool that indicates whether the exp is in a vector_concat pattern *)
+ let remove_tannot_in_vector_concats =
+ { p_lit = (fun lit ivc -> P_lit lit)
+ ; p_wild = (fun ivc -> P_wild)
+ ; p_as = (fun (pat,id) ivc -> P_as (pat ivc,id))
+ ; p_typ =
+ (fun (typ,pat) ivc ->
+ let P_aux (p,annot) = pat ivc in
+ if ivc then p else P_typ (typ,P_aux (p,annot))
+ )
+ ; p_id = (fun id ivc -> P_id id)
+ ; p_app = (fun (id,ps) ivc -> P_app (id, List.map (fun p -> p ivc) ps))
+ ; p_record =
+ (fun (fpats,b) ivc -> P_record (List.map (fun f -> f false) fpats, b))
+ ; p_vector = (fun ps ivc -> P_vector (List.map (fun p -> p ivc) ps))
+ ; p_vector_indexed =
+ (fun ps ivc -> P_vector_indexed (List.map (fun (i,p) -> (i,p ivc)) ps))
+ ; p_vector_concat =
+ (fun ps ivc -> P_vector_concat (List.map (fun p -> p true) ps))
+ ; p_tup = (fun ps ivc -> P_tup (List.map (fun p -> p ivc) ps))
+ ; p_list = (fun ps ivc -> P_list (List.map (fun p -> p ivc) ps))
+ ; p_aux = (fun (p,annot) ivc -> P_aux (p ivc,annot))
+ ; fP_aux = (fun (fpat,annot) ivc -> FP_aux (fpat false,annot))
+ ; fP_Fpat = (fun (id,p) ivc -> FP_Fpat (id,p false))
+ } in
+
+ let pat = (fold_pat remove_tannot_in_vector_concats pat) false in
+
let counter = ref 0 in
let fresh_name () =
let current = !counter in
@@ -511,12 +538,33 @@ let remove_vector_concat_pat pat =
(* introduce names for all patterns of form P_vector_concat *)
let name_vector_concat_roots =
- let p_aux (pat,annot) = match pat with
- | P_vector_concat pats -> P_aux (P_as (P_aux (pat,annot),fresh_name()),annot)
- | _ -> P_aux (pat,annot) in
- {id_pat_alg with p_aux = p_aux} in
+ { p_lit = (fun lit -> P_lit lit)
+ ; p_wild = P_wild
+ ; p_as = (fun (pat,id) -> P_as (pat true,id))
+ ; p_typ = (fun (typ,pat) -> P_typ (typ,pat false))
+ ; p_id = (fun id -> P_id id)
+ ; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps))
+ ; p_record = (fun (fpats,b) -> P_record (fpats, b))
+ ; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps))
+ ; p_vector_indexed = (fun ps -> P_vector_indexed (List.map (fun (i,p) -> (i,p false)) ps))
+ ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps))
+ ; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps))
+ ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps))
+ ; p_aux =
+ (fun (pat,annot) contained_in_p_as ->
+ match pat with
+ | P_vector_concat pats ->
+ (if contained_in_p_as
+ then P_aux (pat,annot)
+ else P_aux (P_as (P_aux (pat,annot),fresh_name()),annot))
+ | _ -> P_aux (pat,annot)
+ )
+ ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot))
+ ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false))
+ } in
+
- let pat = fold_pat name_vector_concat_roots pat in
+ let pat = (fold_pat name_vector_concat_roots pat) false in
(* introduce names for all unnamed child nodes of P_vector_concat *)
let name_vector_concat_elements =
@@ -550,29 +598,31 @@ let remove_vector_concat_pat pat =
let letbind = LB_val_implicit (P_aux (P_id child,cannot),subv) in
(LB_aux (letbind,typ), (fun body -> E_aux (E_let (LB_aux (letbind,cannot),body),typ))) in
- let p_aux ((pattern,decls),rannot) = match pattern with
- | P_as (P_aux (P_vector_concat pats,_),rootid) ->
- let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot)) = match cannot with
- | (_,Base((_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)) ->
- let length = int_of_big_int length in
- (match p with
- (* if we see a named vector pattern, remove the name and remember to
+ let p_aux = function
+ | ((P_as (P_aux (P_vector_concat pats,rannot'),rootid),decls),rannot) ->
+ let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot)) = match cannot with
+ | (_,Base((_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)) ->
+ let length = int_of_big_int length in
+ (match p with
+ (* if we see a named vector pattern, remove the name and remember to
declare it later *)
- | P_as (P_aux (p,cannot),cname) ->
+ | P_as (P_aux (p,cannot),cname) ->
let (lb,decl) = letbind_vec (rootid,rannot) (cname,cannot) (pos,pos + length - 1) in
(pos + length, pat_acc @ [P_aux (p,cannot)], decl_acc @ [(lb,decl)])
- (* if we see a P_id variable, remember to declare it later *)
- | P_id cname ->
- let (lb,decl) = letbind_vec (rootid,rannot) (cname,cannot) (pos,pos + length - 1) in
- (pos + length, pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [(lb,decl)])
- (* normal vector patterns are fine *)
- | _ -> (pos + length, pat_acc @ [P_aux (p,cannot)],decl_acc) )
- (* non-vector patterns aren't *)
- | (l,_) -> raise (Reporting_basic.err_unreachable
- l "Non-vector in vector-concat pattern") in
- let (_,pats',decls') = List.fold_left aux (0,[],[]) pats in
- (P_aux (P_vector_concat pats',rannot),decls @ decls')
- | _ -> (P_aux (pattern,rannot),decls) in
+ (* if we see a P_id variable, remember to declare it later *)
+ | P_id cname ->
+ let (lb,decl) = letbind_vec (rootid,rannot) (cname,cannot) (pos,pos + length - 1) in
+ (pos + length, pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [(lb,decl)])
+ (* normal vector patterns are fine *)
+ | _ -> (pos + length, pat_acc @ [P_aux (p,cannot)],decl_acc) )
+ (* non-vector patterns aren't *)
+ | (l,_) ->
+ raise
+ (Reporting_basic.err_unreachable
+ l "unname_vector_concat_elements: Non-vector in vector-concat pattern") in
+ let (_,pats',decls') = List.fold_left aux (0,[],[]) pats in
+ (P_aux (P_as (P_aux (P_vector_concat pats',rannot'),rootid),rannot), decls @ decls')
+ | ((p,decls),annot) -> (P_aux (p,annot),decls) in
{ p_lit = (fun lit -> (P_lit lit,[]))
; p_wild = (P_wild,[])
@@ -635,7 +685,7 @@ let remove_vector_concat_pat pat =
(_,Base((_,{t = Tapp ("vector", [_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)) ->
let wild _ = P_aux (P_wild,(Parse_ast.Unknown,simple_annot {t = Tid "bit"})) in
acc @ (List.map wild (range 0 ((int_of_big_int length) - 1)))
- | _,(l,_) -> raise (Reporting_basic.err_unreachable l "Non-vector in vector-concat pattern") in
+ | _,(l,_) -> raise (Reporting_basic.err_unreachable l "remove_vector_concats: Non-vector in vector-concat pattern") in
P_vector (List.fold_left aux [] ps) in
{id_pat_alg with p_vector_concat = p_vector_concat} in