diff options
| -rw-r--r-- | src/ast_util.ml | 2 | ||||
| -rw-r--r-- | src/rewrites.ml | 67 | ||||
| -rw-r--r-- | test/c/pattern_concat_nest.expect | 2 | ||||
| -rw-r--r-- | test/c/pattern_concat_nest.sail | 44 |
4 files changed, 69 insertions, 46 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 29d48543..b6f81de6 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -742,7 +742,7 @@ and string_of_pat (P_aux (pat, l)) = | P_list pats -> "[||" ^ string_of_list "," string_of_pat pats ^ "||]" | P_vector_concat pats -> string_of_list " : " string_of_pat pats | P_vector pats -> "[" ^ string_of_list ", " string_of_pat pats ^ "]" - | P_as (pat, id) -> string_of_pat pat ^ " as " ^ string_of_id id + | P_as (pat, id) -> "(" ^ string_of_pat pat ^ " as " ^ string_of_id id ^ ")" | P_string_append pats -> string_of_list " ^^ " string_of_pat pats | _ -> "PAT" diff --git a/src/rewrites.ml b/src/rewrites.ml index 6ae987c0..b7104da3 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -669,7 +669,7 @@ let remove_vector_concat_pat pat = P_aux (p,annot) ) } in - + (* let pat = remove_typed_patterns pat in *) let fresh_id_v = fresh_id "v__" in @@ -719,13 +719,14 @@ let remove_vector_concat_pat pat = let name_vector_concat_elements = let p_vector_concat pats = let rec aux ((P_aux (p,((l,_) as a))) as pat) = match p with - | P_vector _ -> P_aux (P_as (pat,fresh_id_v l),a) + | P_vector _ -> P_aux (P_as (pat, fresh_id_v l),a) | P_lit _ -> P_aux (P_as (pat, fresh_id_v l), a) | P_id id -> P_aux (P_id id,a) | P_as (p,id) -> P_aux (P_as (p,id),a) | P_typ (typ, pat) -> P_aux (P_typ (typ, aux pat),a) | P_wild -> P_aux (P_wild,a) - | P_app (id, pats) when Env.is_mapping id (env_of_annot a) -> P_aux (P_app (id, List.map aux pats), a) + | P_app (id, pats) when Env.is_mapping id (env_of_annot a) -> + P_aux (P_app (id, List.map aux pats), a) | _ -> raise (Reporting_basic.err_unreachable @@ -735,38 +736,25 @@ let remove_vector_concat_pat pat = let pat = fold_pat name_vector_concat_elements pat in - - let rec tag_last = function | x :: xs -> let is_last = xs = [] in (x,is_last) :: tag_last xs | _ -> [] in (* remove names from vectors in vector_concat patterns and collect them as declarations for the function body or expression *) - let unname_vector_concat_elements = (* : - ('a, - 'a pat * ((tannot exp -> tannot exp) list), - 'a pat_aux * ((tannot exp -> tannot exp) list), - 'a fpat * ((tannot exp -> tannot exp) list), - 'a fpat_aux * ((tannot exp -> tannot exp) list)) - pat_alg = *) - + let unname_vector_concat_elements = (* build a let-expression of the form "let child = root[i..j] in body" *) let letbind_vec typ_opt (rootid,rannot) (child,cannot) (i,j) = let (l,_) = cannot in let env = env_of_annot rannot in let rootname = string_of_id rootid in let childname = string_of_id child in - + let root = E_aux (E_id rootid, rannot) in let index_i = simple_num l i in let index_j = simple_num l j in - (* FIXME *) let subv = fix_eff_exp (E_aux (E_vector_subrange (root, index_i, index_j), cannot)) in - (* let (_, ord, _) = vector_typ_args_of (Env.base_typ_of (env_of root) (typ_of root)) in - let subrange_id = if is_order_inc ord then "bitvector_subrange_inc" else "bitvector_subrange_dec" in - let subv = fix_eff_exp (E_aux (E_app (mk_id subrange_id, [root; index_i; index_j]), cannot)) in *) let id_pat = match typ_opt with @@ -859,38 +847,28 @@ let remove_vector_concat_pat pat = let (pat,decls) = fold_pat unname_vector_concat_elements pat in - let decls = - let module S = Set.Make(String) in - - let roots_needed = - List.fold_right - (fun (_,(rootid,childid)) roots_needed -> - if S.mem childid roots_needed then - (* let _ = print_endline rootid in *) - S.add rootid roots_needed - else if String.length childid >= 3 && String.sub childid 0 2 = String.sub "v__" 0 2 then - roots_needed - else - S.add rootid roots_needed - ) decls S.empty in - List.filter - (fun (_,(_,childid)) -> - S.mem childid roots_needed || - String.length childid < 3 || - not (String.sub childid 0 2 = String.sub "v__" 0 2)) - decls in + (* We need to put the decls in the right order so letbinds are generated correctly for nested patterns *) + let module G = Graph.Make(String) in + let root_graph = List.fold_left (fun g (_, (root_id, child_id)) -> G.add_edge root_id child_id g) G.empty decls in + let root_order = G.topsort root_graph in + let find_root root_id = + try List.find (fun (_, (root_id', _)) -> root_id = root_id') decls with + | Not_found -> + (* If it's not a root the it's a leaf node in the graph, so search for child_id *) + try List.find (fun (_, (_, child_id)) -> root_id = child_id) decls with + | Not_found -> assert false (* Should never happen *) + in + let decls = List.map find_root root_order in let (letbinds,decls) = - let (decls,_) = List.split decls in + let decls = List.map fst decls in List.split decls in let decls = List.fold_left (fun f g x -> f (g x)) (fun b -> b) decls in - (* at this point shouldn't have P_as patterns in P_vector_concat patterns any more, all P_as and P_id vectors should have their declarations in decls. Now flatten all vector_concat patterns *) - let flatten = let p_vector_concat ps = let aux p acc = match p with @@ -898,12 +876,11 @@ let remove_vector_concat_pat pat = | pat -> pat :: acc in P_vector_concat (List.fold_right aux ps []) in {id_pat_alg with p_vector_concat = p_vector_concat} in - + let pat = fold_pat flatten pat in (* at this point pat should be a flat pattern: no vector_concat patterns with vector_concats patterns as direct child-nodes anymore *) - let range a b = let rec aux a b = if Big_int.greater a b then [] else a :: aux (Big_int.add a (Big_int.of_int 1)) b in if Big_int.greater a b then List.rev (aux b a) else aux a b in @@ -948,9 +925,9 @@ let remove_vector_concat_pat pat = P_vector_concat ps' in {id_pat_alg with p_vector_concat = p_vector_concat} in - + let pat = fold_pat remove_vector_concats pat in - + (pat,letbinds,decls) (* assumes there are no more E_internal expressions *) diff --git a/test/c/pattern_concat_nest.expect b/test/c/pattern_concat_nest.expect new file mode 100644 index 00000000..b3bdd5e7 --- /dev/null +++ b/test/c/pattern_concat_nest.expect @@ -0,0 +1,2 @@ +works = 0b101 +doesnt = 0b010 diff --git a/test/c/pattern_concat_nest.sail b/test/c/pattern_concat_nest.sail new file mode 100644 index 00000000..92150e66 --- /dev/null +++ b/test/c/pattern_concat_nest.sail @@ -0,0 +1,44 @@ +default Order dec +type bits ('n : Int) = vector('n, dec, bit) + +union option ('a : Type) = {None : unit, Some : 'a} + +val vector_subrange = {ocaml: "subrange", lem: "subrange_vec_dec", c: "vector_subrange"} + : forall ('n : Int) ('m : Int) ('o : Int), 'o <= 'm & 'm <= 'n. + (bits('n), atom('m), atom('o)) -> bits('m - ('o - 1)) + +val bitvector_access = {ocaml: "access", lem: "access_vec_dec", c: "vector_access"} + : forall ('n : Int) ('m : Int), 0 <= 'm & 'm + 1 <= 'n. + (bits('n), atom('m)) -> bit + +overload vector_access = {bitvector_access} + +val eq_bit = {ocaml: "eq_bit", lem: "eq", interpreter: "eq_anything", c: "eq_bit"}: (bit, bit) -> bool + +overload operator == = { + eq_bit +} + +val and_bool = {ocaml: "and_bool", lem: "and_bool", smt: "and_bool", interpreter: "and_bool", c: "and_bool"}: (bool, bool) -> bool + + +//////////////////////////////////////////////////////////// + +val works : bits(8) -> bits(3) +function works bv = match bv { + a : bits(3) @ b : bits(2) @ 0b000 => a +} + +val doesnt : bits(8) -> bits(3) +function doesnt bv = match bv { + (a : bits(3) @ b : bits(2)) @ 0b000 => a +} + +val "print_bits" : forall 'n. (string, bits('n)) -> unit + +val main : unit -> unit + +function main() = { + print_bits("works = ", works(0b1010_0000)); + print_bits("doesnt = ", doesnt(0b0101_0000)); +}
\ No newline at end of file |
