summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml63
1 files changed, 42 insertions, 21 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 74eb4bb1..8a0c0599 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -6,6 +6,15 @@ type 'a exp = 'a Ast.exp
type 'a emap = 'a Envmap.t
type envs = Type_check.envs
+type 'a rewriters = { rewrite_exp : 'a rewriters -> nexp_map option -> 'a exp -> 'a exp;
+ rewrite_lexp : 'a rewriters -> nexp_map option -> 'a lexp -> 'a lexp;
+ rewrite_pat : 'a rewriters -> nexp_map option -> 'a pat -> 'a pat;
+ rewrite_let : 'a rewriters -> nexp_map option -> 'a letbind -> 'a letbind;
+ rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef;
+ rewrite_def : 'a rewriters -> 'a def -> 'a def;
+ rewrite_defs : 'a rewriters -> 'a defs -> 'a defs;
+ }
+
let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with
| [] -> None
| (v1,v2)::ls -> if (eq v1 v) then Some v2 else partial_assoc eq v ls
@@ -63,9 +72,9 @@ let rec match_to_program_vars ns bounds =
(*let _ = Printf.eprintf "adding n %s to program var %s\n" (n_to_string n) ev in*)
(n,(augment,ev))::(match_to_program_vars ns bounds)
-let rec rewrite_exp nmap (E_aux (exp,(l,annot))) =
+let rewrite_exp rewriters nmap (E_aux (exp,(l,annot))) =
let rewrap e = E_aux (e,(l,annot)) in
- let rewrite = rewrite_exp nmap in
+ let rewrite = rewriters.rewrite_exp rewriters nmap in
match exp with
| E_block exps -> rewrap (E_block (List.map rewrite exps))
| E_nondet exps -> rewrap (E_nondet (List.map rewrite exps))
@@ -109,8 +118,8 @@ let rec rewrite_exp nmap (E_aux (exp,(l,annot))) =
(List.map
(fun (Pat_aux (Pat_exp(p,e),pannot)) ->
Pat_aux (Pat_exp(p,rewrite e),pannot)) pexps)))
- | E_let (letbind,body) -> rewrap (E_let(rewrite_let nmap letbind,rewrite body))
- | E_assign (lexp,exp) -> rewrap (E_assign(rewrite_lexp nmap lexp,rewrite exp))
+ | E_let (letbind,body) -> rewrap (E_let(rewriters.rewrite_let rewriters nmap letbind,rewrite body))
+ | E_assign (lexp,exp) -> rewrap (E_assign(rewriters.rewrite_lexp rewriters nmap lexp,rewrite exp))
| E_exit e -> rewrap (E_exit (rewrite e))
| E_internal_cast ((_,casted_annot),exp) ->
let new_exp = rewrite exp in
@@ -191,45 +200,48 @@ let rec rewrite_exp nmap (E_aux (exp,(l,annot))) =
raise (Reporting_basic.err_unreachable l
("Internal_exp_user given unexpected types " ^ (t_to_string tu) ^ ", " ^ (t_to_string ti))))
| _ -> raise (Reporting_basic.err_unreachable l ("Internal_exp_user given none Base annot")))
+ | E_internal_let _ -> raise (Reporting_basic.err_unreachable l "Internal let found before it should have been introduced")
-and rewrite_let map (LB_aux(letbind,(l,annot))) =
+let rewrite_let rewriters map (LB_aux(letbind,(l,annot))) =
let map = merge_option_maps map (get_map_tannot annot) in
match letbind with
| LB_val_explicit (typschm, pat,exp) ->
- LB_aux(LB_val_explicit (typschm,pat, rewrite_exp map exp),(l,annot))
+ LB_aux(LB_val_explicit (typschm,pat, rewriters.rewrite_exp rewriters map exp),(l,annot))
| LB_val_implicit ( pat, exp) ->
- LB_aux(LB_val_implicit (pat,rewrite_exp map exp),(l,annot))
+ LB_aux(LB_val_implicit (pat,rewriters.rewrite_exp rewriters map exp),(l,annot))
-and rewrite_lexp map (LEXP_aux(lexp,(l,annot))) =
+let rewrite_lexp rewriters map (LEXP_aux(lexp,(l,annot))) =
let rewrap le = LEXP_aux(le,(l,annot)) in
match lexp with
| LEXP_id _ | LEXP_cast _ -> rewrap lexp
- | LEXP_memory (id,exps) -> rewrap (LEXP_memory(id,List.map (rewrite_exp map) exps))
- | LEXP_vector (lexp,exp) -> rewrap (LEXP_vector (rewrite_lexp map lexp,rewrite_exp map exp))
+ | LEXP_memory (id,exps) -> rewrap (LEXP_memory(id,List.map (rewriters.rewrite_exp rewriters map) exps))
+ | LEXP_vector (lexp,exp) ->
+ rewrap (LEXP_vector (rewriters.rewrite_lexp rewriters map lexp,rewriters.rewrite_exp rewriters map exp))
| LEXP_vector_range (lexp,exp1,exp2) ->
- rewrap (LEXP_vector_range (rewrite_lexp map lexp,rewrite_exp map exp1,rewrite_exp map exp2))
- | LEXP_field (lexp,id) -> rewrap (LEXP_field (rewrite_lexp map lexp,id))
+ rewrap (LEXP_vector_range (rewriters.rewrite_lexp rewriters map lexp,
+ rewriters.rewrite_exp rewriters map exp1,
+ rewriters.rewrite_exp rewriters map exp2))
+ | LEXP_field (lexp,id) -> rewrap (LEXP_field (rewriters.rewrite_lexp rewriters map lexp,id))
-let rewrite_fun (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) =
+let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) =
let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),(l,annot))) =
(*let _ = Printf.eprintf "Rewriting function %s, pattern %s\n"
(match id with (Id_aux (Id i,_)) -> i) (Pretty_print.pat_to_string pat) in*)
- (FCL_aux (FCL_Funcl (id,pat,rewrite_exp (get_map_tannot fdannot) exp),(l,annot)))
+ (FCL_aux (FCL_Funcl (id,pat,rewriters.rewrite_exp rewriters (get_map_tannot fdannot) exp),(l,annot)))
in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot))
-let rewrite_def d = match d with
+let rewrite_def rewriters d = match d with
| DEF_type _ | DEF_spec _ | DEF_default _ | DEF_reg_dec _ -> d
- | DEF_fundef fdef -> DEF_fundef (rewrite_fun fdef)
- | DEF_val letbind -> DEF_val (rewrite_let None letbind)
+ | DEF_fundef fdef -> DEF_fundef (rewriters.rewrite_fun rewriters fdef)
+ | DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters None letbind)
| DEF_scattered _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "DEF_scattered survived to rewritter")
-let rewrite_defs (Defs defs) =
+let rewrite_defs_base rewriters (Defs defs) =
let rec rewrite ds = match ds with
| [] -> []
- | d::ds -> (rewrite_def d)::(rewrite ds) in
+ | d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) in
Defs (rewrite defs)
-
let explode s =
let rec exp i l = if i < 0 then l else exp (i - 1) (s.[i] :: l) in
exp (String.length s - 1) []
@@ -257,7 +269,7 @@ let vector_string_to_bit_list lit =
let s_bin = match lit with
| L_hex s_hex -> List.flatten (List.map hexchar_to_binlist (explode s_hex))
- | L_bin s_bin -> explode s_bin in
+ | L_bin s_bin -> explode (String.uppercase s_bin) in
List.map (function '0' -> L_aux (L_zero, Unknown) | '1' -> L_aux (L_one,Unknown)) s_bin
@@ -274,3 +286,12 @@ let remove_vector_string_expressions exp = match exp with
(vector_string_to_bit_list lit) in
E_vector es
| _ -> exp
+
+let rewrite_defs (Defs defs) = rewrite_defs_base
+ {rewrite_exp = rewrite_exp;
+ rewrite_pat = (fun _ _ p -> p);
+ rewrite_let = rewrite_let;
+ rewrite_lexp = rewrite_lexp;
+ rewrite_fun = rewrite_fun;
+ rewrite_def = rewrite_def;
+ rewrite_defs = rewrite_defs_base} (Defs defs)