diff options
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 63 |
1 files changed, 42 insertions, 21 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index 74eb4bb1..8a0c0599 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -6,6 +6,15 @@ type 'a exp = 'a Ast.exp type 'a emap = 'a Envmap.t type envs = Type_check.envs +type 'a rewriters = { rewrite_exp : 'a rewriters -> nexp_map option -> 'a exp -> 'a exp; + rewrite_lexp : 'a rewriters -> nexp_map option -> 'a lexp -> 'a lexp; + rewrite_pat : 'a rewriters -> nexp_map option -> 'a pat -> 'a pat; + rewrite_let : 'a rewriters -> nexp_map option -> 'a letbind -> 'a letbind; + rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef; + rewrite_def : 'a rewriters -> 'a def -> 'a def; + rewrite_defs : 'a rewriters -> 'a defs -> 'a defs; + } + let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with | [] -> None | (v1,v2)::ls -> if (eq v1 v) then Some v2 else partial_assoc eq v ls @@ -63,9 +72,9 @@ let rec match_to_program_vars ns bounds = (*let _ = Printf.eprintf "adding n %s to program var %s\n" (n_to_string n) ev in*) (n,(augment,ev))::(match_to_program_vars ns bounds) -let rec rewrite_exp nmap (E_aux (exp,(l,annot))) = +let rewrite_exp rewriters nmap (E_aux (exp,(l,annot))) = let rewrap e = E_aux (e,(l,annot)) in - let rewrite = rewrite_exp nmap in + let rewrite = rewriters.rewrite_exp rewriters nmap in match exp with | E_block exps -> rewrap (E_block (List.map rewrite exps)) | E_nondet exps -> rewrap (E_nondet (List.map rewrite exps)) @@ -109,8 +118,8 @@ let rec rewrite_exp nmap (E_aux (exp,(l,annot))) = (List.map (fun (Pat_aux (Pat_exp(p,e),pannot)) -> Pat_aux (Pat_exp(p,rewrite e),pannot)) pexps))) - | E_let (letbind,body) -> rewrap (E_let(rewrite_let nmap letbind,rewrite body)) - | E_assign (lexp,exp) -> rewrap (E_assign(rewrite_lexp nmap lexp,rewrite exp)) + | E_let (letbind,body) -> rewrap (E_let(rewriters.rewrite_let rewriters nmap letbind,rewrite body)) + | E_assign (lexp,exp) -> rewrap (E_assign(rewriters.rewrite_lexp rewriters nmap lexp,rewrite exp)) | E_exit e -> rewrap (E_exit (rewrite e)) | E_internal_cast ((_,casted_annot),exp) -> let new_exp = rewrite exp in @@ -191,45 +200,48 @@ let rec rewrite_exp nmap (E_aux (exp,(l,annot))) = raise (Reporting_basic.err_unreachable l ("Internal_exp_user given unexpected types " ^ (t_to_string tu) ^ ", " ^ (t_to_string ti)))) | _ -> raise (Reporting_basic.err_unreachable l ("Internal_exp_user given none Base annot"))) + | E_internal_let _ -> raise (Reporting_basic.err_unreachable l "Internal let found before it should have been introduced") -and rewrite_let map (LB_aux(letbind,(l,annot))) = +let rewrite_let rewriters map (LB_aux(letbind,(l,annot))) = let map = merge_option_maps map (get_map_tannot annot) in match letbind with | LB_val_explicit (typschm, pat,exp) -> - LB_aux(LB_val_explicit (typschm,pat, rewrite_exp map exp),(l,annot)) + LB_aux(LB_val_explicit (typschm,pat, rewriters.rewrite_exp rewriters map exp),(l,annot)) | LB_val_implicit ( pat, exp) -> - LB_aux(LB_val_implicit (pat,rewrite_exp map exp),(l,annot)) + LB_aux(LB_val_implicit (pat,rewriters.rewrite_exp rewriters map exp),(l,annot)) -and rewrite_lexp map (LEXP_aux(lexp,(l,annot))) = +let rewrite_lexp rewriters map (LEXP_aux(lexp,(l,annot))) = let rewrap le = LEXP_aux(le,(l,annot)) in match lexp with | LEXP_id _ | LEXP_cast _ -> rewrap lexp - | LEXP_memory (id,exps) -> rewrap (LEXP_memory(id,List.map (rewrite_exp map) exps)) - | LEXP_vector (lexp,exp) -> rewrap (LEXP_vector (rewrite_lexp map lexp,rewrite_exp map exp)) + | LEXP_memory (id,exps) -> rewrap (LEXP_memory(id,List.map (rewriters.rewrite_exp rewriters map) exps)) + | LEXP_vector (lexp,exp) -> + rewrap (LEXP_vector (rewriters.rewrite_lexp rewriters map lexp,rewriters.rewrite_exp rewriters map exp)) | LEXP_vector_range (lexp,exp1,exp2) -> - rewrap (LEXP_vector_range (rewrite_lexp map lexp,rewrite_exp map exp1,rewrite_exp map exp2)) - | LEXP_field (lexp,id) -> rewrap (LEXP_field (rewrite_lexp map lexp,id)) + rewrap (LEXP_vector_range (rewriters.rewrite_lexp rewriters map lexp, + rewriters.rewrite_exp rewriters map exp1, + rewriters.rewrite_exp rewriters map exp2)) + | LEXP_field (lexp,id) -> rewrap (LEXP_field (rewriters.rewrite_lexp rewriters map lexp,id)) -let rewrite_fun (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = +let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),(l,annot))) = (*let _ = Printf.eprintf "Rewriting function %s, pattern %s\n" (match id with (Id_aux (Id i,_)) -> i) (Pretty_print.pat_to_string pat) in*) - (FCL_aux (FCL_Funcl (id,pat,rewrite_exp (get_map_tannot fdannot) exp),(l,annot))) + (FCL_aux (FCL_Funcl (id,pat,rewriters.rewrite_exp rewriters (get_map_tannot fdannot) exp),(l,annot))) in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot)) -let rewrite_def d = match d with +let rewrite_def rewriters d = match d with | DEF_type _ | DEF_spec _ | DEF_default _ | DEF_reg_dec _ -> d - | DEF_fundef fdef -> DEF_fundef (rewrite_fun fdef) - | DEF_val letbind -> DEF_val (rewrite_let None letbind) + | DEF_fundef fdef -> DEF_fundef (rewriters.rewrite_fun rewriters fdef) + | DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters None letbind) | DEF_scattered _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "DEF_scattered survived to rewritter") -let rewrite_defs (Defs defs) = +let rewrite_defs_base rewriters (Defs defs) = let rec rewrite ds = match ds with | [] -> [] - | d::ds -> (rewrite_def d)::(rewrite ds) in + | d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) in Defs (rewrite defs) - let explode s = let rec exp i l = if i < 0 then l else exp (i - 1) (s.[i] :: l) in exp (String.length s - 1) [] @@ -257,7 +269,7 @@ let vector_string_to_bit_list lit = let s_bin = match lit with | L_hex s_hex -> List.flatten (List.map hexchar_to_binlist (explode s_hex)) - | L_bin s_bin -> explode s_bin in + | L_bin s_bin -> explode (String.uppercase s_bin) in List.map (function '0' -> L_aux (L_zero, Unknown) | '1' -> L_aux (L_one,Unknown)) s_bin @@ -274,3 +286,12 @@ let remove_vector_string_expressions exp = match exp with (vector_string_to_bit_list lit) in E_vector es | _ -> exp + +let rewrite_defs (Defs defs) = rewrite_defs_base + {rewrite_exp = rewrite_exp; + rewrite_pat = (fun _ _ p -> p); + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp; + rewrite_fun = rewrite_fun; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base} (Defs defs) |
