diff options
| author | Christopher Pulte | 2015-11-10 22:59:56 +0000 |
|---|---|---|
| committer | Christopher Pulte | 2015-11-10 22:59:56 +0000 |
| commit | 3945afb351cda3ed4eacb494ff426d108fd38612 (patch) | |
| tree | 085834c127bd733013c341af587c89cab43a5df4 /src | |
| parent | afb10f429248912984a7915bf05c58de85ea5cbb (diff) | |
rewriting fixes, syntactically correct lem syntax, number type errors remaining
Diffstat (limited to 'src')
| -rw-r--r-- | src/gen_lib/sail_values.lem | 26 | ||||
| -rw-r--r-- | src/gen_lib/state.lem | 115 | ||||
| -rw-r--r-- | src/gen_lib/vector.lem | 20 | ||||
| -rw-r--r-- | src/pretty_print.ml | 602 | ||||
| -rw-r--r-- | src/rewriter.ml | 138 |
5 files changed, 425 insertions, 476 deletions
diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index a51a0091..00b3e3ab 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -9,22 +9,22 @@ let to_bool = function (* | BU -> assert false *) end -let get_elements (Vector elements _) = elements let get_start (Vector _ s) = s let length (Vector bs _) = length bs -(* -let write_two_registers r1 r2 vec = - let size = length_register r1 in +let write_two_regs r1 r2 vec = + let size = length_reg r1 in let start = get_start vec in let vsize = length vec in - let r1_v = read_vector_subrange is_inc vec start ((if is_inc then size - start else start - size) - 1) in + let r1_v = + (slice vec) + start + ((if is_inc then size - start else start - size) - 1) in let r2_v = - read_vector_subrange is_inc - vec (if is_inc then size - start else start - size) + (slice vec) + (if is_inc then size - start else start - size) (if is_inc then vsize - start else start - vsize) in - write_register r1 r1_v >> write_register r2 r2_v - *) + write_reg r1 r1_v >> write_reg r2 r2_v let rec replace bs ((n : nat),b') = match (n,bs) with | (_, []) -> [] @@ -269,15 +269,15 @@ let shift_op_vec op (((Vector bs start) as l),r) = match op with | LL (*"<<"*) -> let right_vec = Vector (List.replicate n B0) 0 in - let left_vec = read_vector_subrange is_inc l n (if is_inc then len + start else start - len) in + let left_vec = slice l n (if is_inc then len + start else start - len) in vector_concat left_vec right_vec | RR (*">>"*) -> - let right_vec = read_vector_subrange is_inc l start n in + let right_vec = slice l start n in let left_vec = Vector (List.replicate n B0) 0 in vector_concat left_vec right_vec | LLL (*"<<<"*) -> - let left_vec = read_vector_subrange is_inc l n (if is_inc then len + start else start - len) in - let right_vec = read_vector_subrange is_inc l start n in + let left_vec = slice l n (if is_inc then len + start else start - len) in + let right_vec = slice l start n in vector_concat left_vec right_vec end diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index f9404416..268077da 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -1,4 +1,6 @@ open import Pervasives +open import Vector +open import Arch type M 's 'a = 's -> ('a * 's) @@ -11,18 +13,111 @@ let bind m f s = let (a,s') = m s in f a s' let (>>=) = bind let (>>) m n = m >>= fun _ -> n +val foreach_inc : forall 's 'vars. (nat * nat * nat) -> 'vars -> + (nat -> 'vars -> (unit * 'vars)) -> (unit * 'vars) +let rec foreach_inc (i,stop,by) vars body = + if i <= stop + then + let (_,vars) = body i vars in + foreach_inc (i + by,stop,by) vars body + else ((),vars) + + +val foreach_dec : forall 's 'vars. (nat * nat * nat) -> 'vars -> + (nat -> 'vars -> (unit * 'vars)) -> (unit * 'vars) +let rec foreach_dec (i,stop,by) vars body = + if i >= stop + then + let (_,vars) = body i vars in + foreach_dec (i - by,stop,by) vars body + else ((),vars) + -val foreach_inc : forall 's 'vars. (nat * nat * nat) -> (nat -> 'vars -> M 's 'vars) -> - 'vars -> M 's 'vars -let rec foreach_inc (i,stop,by) body vars = +val foreachM_inc : forall 's 'vars. (nat * nat * nat) -> 'vars -> + (nat -> 'vars -> M 's (unit * 'vars)) -> M 's (unit * 'vars) +let rec foreachM_inc (i,stop,by) vars body = if i <= stop - then (body i vars >>= fun vars -> foreach_inc (i + by,stop,by) body vars) - else return vars + then + body i vars >>= fun (_,vars) -> + foreachM_inc (i + by,stop,by) vars body + else return ((),vars) -val foreach_dec : forall 's 'vars. (nat * nat * nat) -> (nat -> 'vars -> M 's 'vars) -> - 'vars -> M 's 'vars -let rec foreach_dec (i,stop,by) body vars = +val foreachM_dec : forall 's 'vars. (nat * nat * nat) -> 'vars -> + (nat -> 'vars -> M 's (unit * 'vars)) -> M 's (unit * 'vars) +let rec foreachM_dec (i,stop,by) vars body = if i >= stop - then (body i vars >>= fun vars -> foreach_dec (i - by,stop,by) body vars) - else return vars + then + body i vars >>= fun (_,vars) -> + foreachM_dec (i - by,stop,by) vars body + else return ((),vars) + + + +let slice (Vector bs start) n m = + let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in + let (_,suffix) = List.splitAt offset bs in + let (subvector,_) = List.splitAt length suffix in + Vector subvector n + +let update (Vector bs start) n m (Vector bs' _) = + let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in + let (prefix,_) = List.splitAt offset bs in + let (_,suffix) = List.splitAt (offset + length) bs in + Vector (prefix ++ (List.take length bs') ++ suffix) start + +let hd (x :: _) = x + +val access : forall 'a. vector 'a -> nat -> 'a +let access (Vector bs start) n = + if is_inc then nth bs (n - start) else nth bs (start - n) + +val update_pos : forall 'a. vector 'a -> nat -> 'a -> vector 'a +let update_pos v n b = + update v n n (Vector [b] 0) + +val read_reg_range : register -> (nat * nat) -> M state (vector bit) +let read_reg_range reg (i,j) s = + let v = slice (read_regstate s reg) i j in + (v,s) + +val read_reg_bit : register -> nat -> M state bit +let read_reg_bit reg i s = + let v = access (read_regstate s reg) i in + (v,s) + +val write_reg_range : register -> (nat * nat) -> vector bit -> M state unit +let write_reg_range (reg : register) (i,j) (v : vector bit) s = + let v' = update (read_regstate s reg) i j v in + let s' = write_regstate s reg v' in + ((),s') + +val write_reg_bit : register -> nat -> bit -> M state unit +let write_reg_bit reg i bit s = + let v = read_regstate s reg in + let v' = update_pos v i bit in + let s' = write_regstate s reg v' in + ((),s') + + +val read_reg : register -> M state (vector bit) +let read_reg reg s = + let v = read_regstate s reg in + (v,s) + +val write_reg : register -> vector bit -> M state unit +let write_reg reg v s = + let s' = write_regstate s reg v in + ((),s') + +val read_reg_field : register -> register_field -> M state (vector bit) +let read_reg_field reg rfield = read_reg_range reg (field_indices rfield) + +val write_reg_field : register -> register_field -> vector bit -> M state unit +let write_reg_field reg rfield = write_reg_range reg (field_indices rfield) + +val read_reg_field_bit : register -> register_field_bit -> M state bit +let read_reg_field_bit reg rbit = read_reg_bit reg (field_index_bit rbit) + +val write_reg_field_bit : register -> register_field_bit -> bit -> M state unit +let write_reg_field_bit reg rbit = write_reg_bit reg (field_index_bit rbit) diff --git a/src/gen_lib/vector.lem b/src/gen_lib/vector.lem index d5f47492..5f239e37 100644 --- a/src/gen_lib/vector.lem +++ b/src/gen_lib/vector.lem @@ -7,23 +7,3 @@ let rec nth xs (n : nat) = match (n,xs) with | (0,x :: xs) -> x | (n + 1,x :: xs) -> nth xs n end - -let vector_access is_inc (Vector bs start) n = - if is_inc then nth bs (n - start) else nth bs (start - n) - -let read_vector_subrange is_inc (Vector bs start) n m = - let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in - let (_,suffix) = List.splitAt offset bs in - let (subvector,_) = List.splitAt length suffix in - Vector subvector n - -let write_vector_subrange is_inc (Vector bs start) n m (Vector bs' _) = - let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in - let (prefix,_) = List.splitAt offset bs in - let (_,suffix) = List.splitAt (offset + length) bs in - Vector (prefix ++ (List.take length bs') ++ suffix) start - -let write_vector_bit is_inc v n (Vector [b] 0) = - write_vector_subrange is_inc v n n b - -let hd (x :: xs) = x diff --git a/src/pretty_print.ml b/src/pretty_print.ml index 9fb93a48..897330d7 100644 --- a/src/pretty_print.ml +++ b/src/pretty_print.ml @@ -1803,6 +1803,7 @@ let pp_defs_ocaml f d top_line opens = * PPrint-based sail-to-lem pprinter ****************************************************************************) + let langlebar = string "<|" let ranglebar = string "|>" let anglebars = enclose langlebar ranglebar @@ -1843,7 +1844,7 @@ let doc_typ_lem, doc_atomic_typ_lem = separate space [tup_typ arg; arrow; fn_typ ret] | _ -> tup_typ ty and tup_typ ((Typ_aux (t, _)) as ty) = match t with - | Typ_tup typs -> parens (separate_map star app_typ typs) + | Typ_tup typs -> separate_map star app_typ typs | _ -> app_typ ty and app_typ ((Typ_aux (t, _)) as ty) = match t with | Typ_app(Id_aux (Id "vector", _), [ @@ -1883,7 +1884,7 @@ let doc_lit_lem in_pat (L_aux(l,_)) = | L_one -> "B1" | L_false -> "B0" | L_true -> "B1" - | L_num i -> (* if in_pat then string_of_int i else "(big_int_of_int " ^ string_of_int i ^ ")" *) string_of_int i + | L_num i -> if i < 0 then "(0 " ^ string_of_int i ^ ")" else string_of_int i | L_hex n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)*) | L_bin n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)*) | L_undef -> "BU" @@ -1929,7 +1930,10 @@ let doc_pat_lem = underscore])]) else non_bit_print() | _ -> non_bit_print ()) - | P_tup pats -> parens (separate_map comma_sp pat pats) + | P_tup pats -> + (match pats with + | [p] -> pat p + | _ -> parens (separate_map comma_sp pat pats)) | P_list pats -> brackets (separate_map semi pat pats) (*Never seen but easy in lem*) in pat @@ -1938,21 +1942,26 @@ let doc_exp_lem, doc_let_lem = let exp = top_exp in match e with | E_assign((LEXP_aux(le_act,tannot) as le),e) -> - (match annot with - | Base(_,(Emp_local | Emp_set),_,_,_,_) -> - (match le_act with - | LEXP_id _ | LEXP_cast _ -> - (*Setting local variable fully *) - doc_op coloneq (doc_lexp_lem true le) (exp e) - | LEXP_vector _ -> - separate space [string "write_register";parens (doc_lexp_array_lem le);exp e] - | LEXP_vector_range _ -> - doc_lexp_rwrite le e) - | _ -> - (match le_act with - | LEXP_vector _ | LEXP_vector_range _ | LEXP_cast _ | LEXP_field _ | LEXP_id _ -> - (doc_lexp_rwrite le e) - | LEXP_memory _ -> (doc_lexp_fcall le e))) + (* can only be register writes *) + let (_,(Base ((_,_),tag,_,_,_,_))) = tannot in + (separate space) + (match tag with + | External _ -> + (match le_act with + | LEXP_vector (le,e2) -> + [string "write_reg_bit";doc_lexp_deref_lem le;exp e2;exp e] + | LEXP_vector_range (le,e2,e3) -> + [string "write_reg_range";doc_lexp_deref_lem le;exp e2;exp e3;exp e] + | LEXP_field (lexp,id) -> + let (Base ((_,{t = t}),_,_,_,_,_)) = annot in + (match t with + | Tid "bit" -> + [string "write_reg_field_bit";doc_lexp_deref_lem le;doc_id_lem id;exp e] + | _ -> + [string "write_reg_field";doc_lexp_deref_lem le;doc_id_lem id;exp e]) + | (LEXP_id _ | LEXP_cast _) -> [string "write_reg";doc_lexp_deref_lem le;exp e]) + | _ -> [string "write_reg";doc_lexp_deref_lem le;exp e] + ) | E_vector_append(l,r) -> parens ((string "vector_concat ") ^^ (exp l) ^^ space ^^ (exp r)) | E_cons(l,r) -> doc_op (group (colon^^colon)) (exp l) (exp r) @@ -1962,21 +1971,26 @@ let doc_exp_lem, doc_let_lem = (prefix 2 1 (string "then") (exp t)) ^/^ (prefix 2 1 (string "else") (exp e)) ) ^^ hardline - | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> failwith "shouldn't happen" + | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> + failwith "E_for should have been removed till now" | E_let(leb,e) -> (let_exp leb) ^^ space ^^ string "in" ^/^ (exp e) | E_app(f,args) -> (match f with (* temporary hack to make the loop body a function of the temporary variables *) | Id_aux ((Id (("foreach_inc" | "foreach_dec") as loopf),_)) -> let call = doc_id_lem in - let [indices;body;E_aux (E_tuple vars,_) as e5] = args in + let [id;indices;body;E_aux (E_tuple vars,_) as e5] = args in let vars = List.map (fun (E_aux (E_id (Id_aux (Id name,_)),_)) -> string name) vars in - separate space [string loopf;exp indices; - parens((separate space) - [string "fun";parens (separate comma vars);arrow] ^/^ - exp body - ); - exp e5] + let varspp = + match vars with + | [v] -> v + | _ -> parens (separate comma vars) in + (separate space) + [string loopf;exp indices;exp e5] ^/^ + parens((prefix 2 1) + (separate space [string "fun";exp id;varspp;arrow]) + (exp body) + ) | _ -> let call = match annot with | Base(_,External (Some n),_,_,_,_) -> string n @@ -1984,16 +1998,9 @@ let doc_exp_lem, doc_let_lem = | _ -> doc_id_lem f in parens (doc_unop call (parens (separate_map comma exp args))) ) - | E_vector_access(v,e) -> - let call = (match annot with - | Base((_,t),_,_,_,_,_) -> - (match t.t with - | Tid "bit" | Tabbrev(_,{t=Tid "bit"}) -> (string "bit_vector_access") - | _ -> (string "vector_access")) - | _ -> (string "vector_access")) in - parens (call ^^ space ^^ exp v ^^ space ^^ exp e) + | E_vector_access(v,e) -> separate space [string "access";exp v;exp e] | E_vector_subrange(v,e1,e2) -> - parens ((string "read_vector_subrange") ^^ space ^^ (exp v) ^^ space ^^ (exp e1) ^^ space ^^ (exp e2)) + parens ((string "slice") ^^ space ^^ (exp v) ^^ space ^^ (exp e1) ^^ space ^^ (exp e2)) | E_field((E_aux(_,(_,fannot)) as fexp),id) -> (match fannot with | Base((_,{t= Tapp("register",_)}),_,_,_,_,_) | @@ -2001,10 +2008,10 @@ let doc_exp_lem, doc_let_lem = let field_f = match annot with | Base((_,{t = Tid "bit"}),_,_,_,_,_) | Base((_,{t = Tabbrev(_,{t=Tid "bit"})}),_,_,_,_,_) -> - string "read_register_field_bit" - | _ -> string "read_register_field" in - parens (field_f ^^ space ^^ (exp fexp) ^^ space ^^ string_lit (doc_id id)) - | _ -> exp fexp ^^ dot ^^ doc_id id) + string "read_reg_field_bit" + | _ -> string "read_reg_field" in + parens (field_f ^^ space ^^ (exp fexp) ^^ space ^^ string_lit (doc_id_lem id)) + | _ -> exp fexp ^^ dot ^^ doc_id_lem id) | E_block [] -> string "()" | E_block exps | E_nondet exps -> let exps_doc = separate_map (semi ^^ hardline) exp exps in @@ -2015,21 +2022,21 @@ let doc_exp_lem, doc_let_lem = doc_id_lem id | Base((_, ({t = Tapp("register",_)} | {t=Tabbrev(_,{t=Tapp("register",_)})})),tag,_,_,_,_) -> (match tag with - | External _ -> separate space [string "read_register";doc_id_lem id] + | External _ -> separate space [string "read_reg";doc_id_lem id] | _ -> doc_id_lem id) | Base(_,(Constructor i |Enum i),_,_,_,_) -> doc_id_lem_ctor id | Base((_,t),Alias alias_info,_,_,_,_) -> (match alias_info with | Alias_field(reg,field) -> let field_f = match t.t with - | Tid "bit" | Tabbrev(_,{t=Tid "bit"}) -> string "read_register_field_bit" - | _ -> string "read_register_field" in + | Tid "bit" | Tabbrev(_,{t=Tid "bit"}) -> string "read_reg_field_bit" + | _ -> string "read_reg_field" in separate space [field_f; string reg; string_lit (string field)] | Alias_extract(reg,start,stop) -> if start = stop - then parens (separate space [string "vector_access";string reg;doc_int start]) + then parens (separate space [string "access";string reg;doc_int start]) else parens - (separate space [string "vector_subrange"; string reg; doc_int start; doc_int stop]) + (separate space [string "slice"; string reg; doc_int start; doc_int stop]) | Alias_pair(reg1,reg2) -> parens (separate space [string "vector_concat"; string reg1; @@ -2039,25 +2046,27 @@ let doc_exp_lem, doc_let_lem = then parens (separate space - [string "vector_access"; doc_int start; - parens (string "read_register" ^^ space ^^ string reg)]) + [string "access";doc_int start; + parens (string "read_reg" ^^ space ^^ string reg)]) else parens (separate space - [string "vector_subrange"; doc_int start; doc_int stop; - parens (string "read_register" ^^ space ^^ string reg)]) + [string "slice"; doc_int start; doc_int stop; + parens (string "read_reg" ^^ space ^^ string reg)]) | Alias_pair(reg1,reg2) -> parens (separate space [string "vector_concat"; - parens (string "read_register" ^^ space ^^ string reg1); - parens (string "read_register" ^^ space ^^ string reg2)])) + parens (string "read_reg" ^^ space ^^ string reg1); + parens (string "read_reg" ^^ space ^^ string reg2)])) | _ -> doc_id_lem id) | E_lit lit -> doc_lit_lem false lit | E_cast(typ,e) -> (match annot with - | Base(_,External _,_,_,_,_) -> string "read_register" ^^ space ^^ exp e + | Base(_,External _,_,_,_,_) -> string "read_reg" ^^ space ^^ exp e | _ -> exp e) (*(parens (doc_op colon (group (exp e)) (doc_typ_lem typ)))) *) | E_tuple exps -> - parens (separate_map comma exp exps) + (match exps with + | [e] -> exp e + | _ -> parens (separate_map comma exp exps)) | E_record(FES_aux(FES_Fexps(fexps,_),_)) -> anglebars (separate_map semi_sp doc_fexp fexps) | E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) -> @@ -2105,19 +2114,18 @@ let doc_exp_lem, doc_let_lem = string start; string size])) | E_vector_update(v,e1,e2) -> - (*Has never happened to date*) - brackets (doc_op (string "with") (exp v) (doc_op equals (exp e1) (exp e2))) + separate space [string "update_pos";exp v;exp e1;exp e2] | E_vector_update_subrange(v,e1,e2,e3) -> - (*Has never happened to date*) - brackets ( - doc_op (string "with") (exp v) - (doc_op equals (exp e1 ^^ colon ^^ exp e2) (exp e3))) + separate space [string "update";exp v;exp e1;exp e2;exp e3] | E_list exps -> brackets (separate_map semi exp exps) | E_case(e,pexps) -> - let opening = separate space [string "("; string "match"; top_exp e; string "with"] in - let cases = separate_map (break 1) doc_case pexps in - surround 2 1 opening cases rparen + parens + ((prefix 2 1) + (separate space [string "match"; top_exp e; string "with"]) + ((separate_map (break 1) doc_case pexps) ^/^ + (string "end" ^^ hardline)) + ) | E_exit e -> separate space [string "exit"; exp e;] | E_app_infix (e1,id,e2) -> @@ -2127,14 +2135,15 @@ let doc_exp_lem, doc_let_lem = | _ -> doc_id_lem id in parens (separate space [call; parens (separate_map comma exp [e1;e2])]) | E_internal_let(lexp, eq_exp, in_exp) -> - (separate + failwith "E_internal_lets should have been removed till now" +(* (separate space - [string "let"; + [string "let TAKTAK"; doc_lexp_lem true lexp; (*Rewriter/typecheck should ensure this is only cast or id*) coloneq; exp eq_exp; string "in"]) ^/^ - exp in_exp + exp in_exp *) | E_internal_plet (pat,e1,e2) -> (match pat with | P_aux (P_wild,_) -> @@ -2157,86 +2166,18 @@ let doc_exp_lem, doc_let_lem = and doc_case (Pat_aux(Pat_exp(pat,e),_)) = doc_op arrow (separate space [pipe; doc_pat_lem pat]) (group (top_exp e)) - and doc_lexp_lem top_call ((LEXP_aux(lexp,(l,annot))) as le) = - let exp = top_exp in - match lexp with - | LEXP_vector(v,e) -> doc_lexp_array_lem le - | LEXP_vector_range(v,e1,e2) -> - parens (string "vector_subrange" ^^ space ^^ (doc_lexp_lem false v) ^^ space ^^ (exp e1) ^^ space ^^ (exp e2)) - | LEXP_field(v,id) -> (doc_lexp_lem false v) ^^ dot ^^ doc_id_lem id - | LEXP_id id | LEXP_cast(_,id) -> doc_id_lem id - - and doc_lexp_array_lem ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with - | LEXP_vector(v,e) -> - (match annot with - | Base((_,t),_,_,_,_,_) -> - let t_act = match t.t with | Tapp("reg",[TA_typ t]) | Tabbrev(_,{t=Tapp("reg",[TA_typ t])}) -> t | _ -> t in - (match t_act.t with - | Tid "bit" - | Tabbrev(_,{t=Tid "bit"}) -> - separate space [string "nth"; - parens (string "get_elements" ^^ space ^^ doc_lexp_lem false v); - parens (top_exp e)] - | _ -> - separate space [string "nth"; - parens (string "get_elements" ^^ space ^^ doc_lexp_lem false v); - parens (top_exp e)] - | _ -> - parens ((string "get_elements") ^^ space ^^ doc_lexp_lem false v) ^^ dot ^^ parens (top_exp e))) + and doc_lexp_deref_lem ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with + | LEXP_field (le,id) -> + parens (separate empty [doc_lexp_deref_lem le;dot;doc_id_lem id]) + | LEXP_vector(le,e) -> + parens + ((separate space) + [string "access";parens (doc_lexp_deref_lem le);parens (top_exp e)] + ) + | LEXP_id id -> doc_id_lem id | _ -> empty - and doc_lexp_rwrite ((LEXP_aux(lexp,(l,annot))) as le) e_new_v = - let exp = top_exp in - let (is_bit,is_bitv) = match e_new_v with - | E_aux(_,(_,Base((_,t),_,_,_,_,_))) -> - (match t.t with - | Tapp("vector", [_;_;_;(TA_typ ({t=Tid "bit"} | {t=Tabbrev(_,{t=Tid "bit"})}))]) | - Tabbrev(_,{t=Tapp("vector",[_;_;_;TA_typ ({t=Tid "bit"} | {t=Tabbrev(_,{t=Tid "bit"})})])}) | - Tapp("reg", [TA_typ {t= Tapp("vector", [_;_;_;(TA_typ ({t=Tid "bit"} | {t=Tabbrev(_,{t=Tid "bit"})}))])}]) - -> - (false,true) - | Tid "bit" | Tabbrev(_,{t=Tid "bit"}) | Tapp("reg",[TA_typ ({t=Tid "bit"} | {t=Tabbrev(_,{t=Tid "bit"})})]) - -> (true,false) - | _ -> (false,false)) - | _ -> (false,false) in - match lexp with - | LEXP_vector(v,e) -> - doc_op (string "<-") - (group (parens ((string "get_elements") ^^ space ^^ doc_lexp_lem false v)) ^^ - dot ^^ parens (exp e)) - (exp e_new_v) - | LEXP_vector_range(v,e1,e2) -> - parens ((string "write_vector_subrange") ^^ space ^^ - doc_lexp_lem false v ^^ space ^^ exp e1 ^^ space ^^ exp e2 ^^ space ^^ exp e_new_v) - | LEXP_field(v,id) -> - parens ((string (if is_bit then "read_register_field_bit" else "read_register_field")) ^^ space ^^ - doc_lexp_lem false v ^^ space ^^string_lit (doc_id id) ^^ space ^^ exp e_new_v) - | LEXP_id id | LEXP_cast (_,id) -> - (match annot with - | Base(_,Alias alias_info,_,_,_,_) -> - (match alias_info with - | Alias_field(reg,field) -> - parens ((if is_bit then string "write_register_field_bit" else string "write_register_field_v") ^^ space ^^ - string (String.uncapitalize reg) ^^ space ^^string_lit (string field) ^^ space ^^ exp e_new_v) - | Alias_extract(reg,start,stop) -> - if start = stop - then - doc_op (string "<-") - (group (parens ((string "get_elements") ^^ space ^^ string reg)) ^^ - dot ^^ parens (doc_int start)) - (exp e_new_v) - else - parens ((string "write_vector_subrange") ^^ space ^^ - string reg ^^ space ^^ doc_int start ^^ space ^^ doc_int stop ^^ space ^^ exp e_new_v) - | Alias_pair(reg1,reg2) -> - parens ((string "write_two_regs") ^^ space ^^ string reg1 ^^ space ^^ string reg2 ^^ space ^^ exp e_new_v)) - | _ -> - separate space [string "write_register"; doc_id_lem id; exp e_new_v]) - - and doc_lexp_fcall ((LEXP_aux(lexp,(l,annot))) as le) e_new_v = match lexp with - | LEXP_memory(id,args) -> doc_id_lem id ^^ parens (separate_map comma top_exp (args@[e_new_v])) - - (* expose doc_exp and doc_let *) + (* expose doc_exp_lem and doc_let *) in top_exp, let_exp (*TODO Upcase and downcase type and constructors as needed*) @@ -2264,7 +2205,7 @@ let doc_typdef_lem (TD_aux(td,_)) = match td with (concat [string "type"; space; doc_id_lem_type id;]) (doc_typquant_lem typq ar_doc) | TD_enum(id,nm,enums,_) -> - let enums_doc = group (separate_map (break 1 ^^ pipe) doc_id_lem_ctor enums) in + let enums_doc = group (separate_map (break 1 ^^ pipe ^^ space) doc_id_lem_ctor enums) in doc_op equals (concat [string "type"; space; doc_id_lem_type id;]) (enums_doc) @@ -2295,21 +2236,20 @@ let doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),_)) = (doc_exp_lem exp) | _ -> let id = get_id fcls in - let sep = hardline ^^ pipe ^^ space in - let clauses = hardline ^^ pipe ^^ separate_map sep doc_funcl_lem fcls in - prefix 2 1 - (separate space [string "let" ^^ doc_rec_lem r ^^ doc_id_lem id; - equals; - (string "function")] - ) - clauses + (* let sep = hardline ^^ pipe ^^ space in *) + let clauses = separate_map + (break 1) + (fun fcl -> separate space [pipe;doc_funcl_lem fcl] ) fcls in + (prefix 2 1) + ((separate space) [string "let" ^^ doc_rec_lem r ^^ doc_id_lem id;equals;string "function"]) + (clauses ^/^ string "end") let doc_dec_lem (DEC_aux (reg,(l,annot))) = match reg with | DEC_reg(typ,id) -> failwith "DEC_reg shouldn't occur here" | DEC_alias(id,alspec) -> empty (* - doc_op equals (string "register alias" ^^ space ^^ doc_id id) (doc_alias alspec) *) + doc_op equals (string "register alias" ^^ space ^^ doc_id_lem id) (doc_alias alspec) *) | DEC_typ_alias(typ,id,alspec) -> empty (* doc_op equals (string "register alias" ^^ space ^^ doc_atomic_typ typ) (doc_alias alspec) *) @@ -2329,49 +2269,38 @@ let reg_decls (Defs defs) = | {order = Odec} -> false | {order = _} -> failwith "Can't deal with variable order" in - let dirpp = + let dir_pp = let is_inc = if is_inc then "true" else "false" in separate space [string "let";string "is_inc";equals;string is_inc] in - let (regtypes,defs) = + let (regtypes,rsranges,rsbits,defs) = List.fold_left - (fun (regtypes,defs) def -> + (fun (regtypes,rsranges,rsbits,defs) def -> match def with - | DEF_type (TD_aux(TD_register (Id_aux (Id name, _),n1,n2,rs),_)) -> - let (rsbits,rsranges) = + | DEF_type (TD_aux(TD_register (Id_aux (Id tname, _),n1,n2,rs),_)) -> + let (rsbits',rsranges') = List.fold_left (fun (rsbits,rsranges) field -> match field with - | (BF_aux (BF_range (i,j), _), Id_aux (Id name,_)) -> - (rsbits,rsranges @ [(name,i,j)]) - | (BF_aux (BF_single i, _), Id_aux (Id name, _)) -> - (rsbits @ [(name,i)],rsranges) + | (BF_aux (BF_range (i,j), _), Id_aux (Id fname,_)) -> + (rsbits,rsranges @ [(fname,tname,i,j)]) + | (BF_aux (BF_single i, _), Id_aux (Id fname, _)) -> + (rsbits @ [(fname,tname,i)],rsranges) ) ([],[]) rs in - (regtypes @ [(name,(n1,n2,rsranges,rsbits))],defs) - | _ -> (regtypes,defs @ [def]) - ) ([],[]) defs in + (regtypes @ [(tname,(n1,n2,rsranges',rsbits'))],rsranges @ rsranges',rsbits @ rsbits',defs) + | _ -> (regtypes,rsranges,rsbits,defs @ [def]) + ) ([],[],[],[]) defs in - let ((simpleregs : string list),(typedregs : ((string * string) list)),defs) = + let (regs,defs) = List.fold_left - (fun (simpleregs,typedregs,defs) def -> + (fun (regs,defs) def -> match def with | DEF_reg_dec (DEC_aux (DEC_reg(Typ_aux (Typ_id (Id_aux (Id typ,_)),_),Id_aux (Id name,_)),_)) -> - (simpleregs,typedregs @ [(name,typ)],defs) + (regs @ [(name,Some typ)],defs) | DEF_reg_dec (DEC_aux (DEC_reg(_, Id_aux (Id name,_)),_)) -> - (simpleregs @ [name],typedregs,defs) - | def -> (simpleregs,typedregs,defs @ [def]) - ) ([],[],[]) defs in - - - let typedregs_per_type : (string * (string list)) list = - List.map (fun (typ,_) -> - let regs = List.filter (fun (_,regtyp) -> regtyp = typ) typedregs in - (typ,List.map fst regs)) regtypes in - - - let regs_per_type : (string option * string list) list = - (None,simpleregs) :: - (List.map (fun (name,regs) -> (Some name,regs)) typedregs_per_type) in + (regs @ [(name,None)],defs) + | def -> (regs,defs @ [def]) + ) ([],[]) defs in (* maybe we need a function that analyses the spec for this as well *) let default = @@ -2379,229 +2308,112 @@ let reg_decls (Defs defs) = Nexp_aux (Nexp_constant (if is_inc then 63 else 0),Unknown), [],[]) in - let regspp = - separate_map - (break 1) - (fun (typ,names) -> - let typ = match typ with Some typ -> "register_" ^ typ | None -> "register" in - (prefix 2 1) - (separate space [string "type"; string typ; equals]) - (separate_map space (fun reg -> pipe ^^ space ^^ string reg) names) - ) regs_per_type in - - let regfieldspp = - separate_map - (break 1) - (fun (typ,(_,_,rsranges,rsbits)) -> - (if rsranges = [] then empty else - (prefix 2 1) - (separate space [string "type"; string ("register_field_" ^ typ); equals]) - (separate_map space (fun (name,_,_) -> pipe ^^ space ^^ string name) rsranges) - ) ^/^ - (if rsranges = [] then empty else - (prefix 2 1) - (separate space [string "type"; string ("register_field_bit_" ^ typ); equals]) - (separate_map space (fun (name,_) -> pipe ^^ space ^^ string name) rsbits) - ) - ) - regtypes in + let regs_pp = + (prefix 2 1) + (separate space [string "type";string "register";equals]) + (separate_map space (fun (reg,_) -> pipe ^^ space ^^ string reg) regs) in + + let regfields_pp = + (prefix 2 1) + (separate space [string "type";string "register_field";equals]) + (separate_map space (fun (fname,tname,_,_) -> + pipe ^^ space ^^ string (tname ^ "_" ^ fname)) rsranges) in - let statepp = + let regfieldsbit_pp = + (prefix 2 1) + (separate space [string "type";string "register_field_bit";equals]) + (separate_map space (fun (fname,tname,_) -> + pipe ^^ space ^^ string (tname ^ "_" ^ fname)) rsbits) in + + let state_pp = (prefix 2 1) (separate space [string "type";string "state";equals]) (anglebars ((separate_map (semi ^^ break 1)) - (fun reg -> separate space [string (String.lowercase reg);colon;string "vector bit"]) - (simpleregs @ List.map fst typedregs) + (fun (reg,_) -> separate space [string (String.lowercase reg);colon;string "vector bit"]) + regs )) in let length_pp = - (separate_map (break 1)) - (fun (typ,regs) -> - let ((n1,n2,_,_),typname) = - match typ with - | Some typname -> (List.assoc typname regtypes,"register_" ^ typname) - | None -> (default,"register") in - ((separate space) - [string "let";string ("length_" ^ typname);underscore;equals; - string "natFromInteger"; - parens ( - (separate space) - [string "abs";parens (separate space [doc_nexp n2;minus;doc_nexp n1]); - plus;string "1"] - )]) - ) regs_per_type in - - let read_register_pp = - separate_map - (break 1) - (fun (typ,regs) -> - let typ = match typ with Some typ -> "register_" ^ typ | None -> "register" in - (prefix 2 1) - (separate space [string "let";string ("read_" ^ typ);string "reg";string "s";equals]) - ((prefix 2 1) - ((separate space) [string "let";string "v";equals]) - (string "match reg with" ^/^ - ((separate_map (break 1)) - (fun name -> - separate space [pipe;string name;arrow; - string "s." ^^ (string (String.lowercase name))]) - regs) ^/^ - string "end in" ) - ^^ hardline ^^ string "(v,s)" ^^ hardline) - ) regs_per_type - in - - let write_register_pp = - separate_map - (break 1) - (fun (typ,regs) -> - let typ = match typ with Some typ -> "register_" ^ typ | None -> "register" in - (prefix 2 1) - (separate space [string "let";string ("write_" ^ typ);string "reg";string "v";string "s";equals]) - (string "match reg with" ^/^ - ((separate_map (break 1)) - (fun name -> - separate - space - [pipe;string name;arrow; - parens (string "()" ^^ comma ^^ - anglebars ( - (separate space) [string "s";string"with";string (String.lowercase name); - equals;string "v"] - )) - ]) - regs) ^/^ - string "end" ^^ hardline ) - ) regs_per_type - in + (prefix 2 1) + (separate space [string "let";string "length_reg";string "reg";equals]) + ( + (prefix 2 1) + (separate space [string "let";string "v";equals;string "match reg with"]) + ((separate_map (break 1)) + (fun (name,typ) -> + let ((n1,n2,_,_),typname) = + match typ with + | Some typname -> (List.assoc typname regtypes,"register_" ^ typname) + | None -> (default,"register") in + separate space [pipe;string name;arrow;string "abs"; + parens (separate space [doc_nexp n2;minus;doc_nexp n1]); + plus;string "1"]) + regs) ^/^ + string "end in" ^/^ + string "natFromInteger v") in + + let field_indices_pp = + (prefix 2 1) + ((separate space) [string "let";string "field_indices"; + colon;string "register_field";arrow;string "(nat * nat)"; + equals;string "function"]) + ( + ((separate_map (break 1)) + (fun (fname,tname,i,j) -> + separate space[pipe;string ((String.capitalize tname) ^ "_" ^ fname);arrow; + parens (separate comma [string (string_of_int i); + string (string_of_int j)])] + ) rsranges + ) ^/^ string "end" ^^ hardline + ) in + +let field_index_bit_pp = + (prefix 2 1) + ((separate space) [string "let";string "field_index_bit"; + colon;string "register_field_bit";arrow;string "nat"; + equals;string "function"]) + ( + ((separate_map (break 1)) + (fun (fname,tname,i) -> + separate space[pipe;string ((String.capitalize tname) ^ "_" ^ fname); + arrow;string (string_of_int i)] + ) rsbits + ) ^/^ string "end" ^^ hardline + ) in + + + let read_regstate_pp = + (prefix 2 1) + (separate space [string "let";string "read_regstate";string "s";equals;string "function"]) + ( + ((separate_map (break 1)) + (fun (name,_) -> + separate space [pipe;string name;arrow;string "s." ^^ (string (String.lowercase name))]) + regs) ^/^ + string "end" ^^ hardline ) in + + let write_regstate_pp = + (prefix 2 1) + (separate space [string "let";string "write_regstate";string "s";string "reg";string "v"; + equals;string "match reg with"]) + ( + ((separate_map (break 1)) + (fun (name,_) -> + separate + space + [pipe;string name;arrow; + anglebars + ((separate space) + [string "s";string"with";string (String.lowercase name);equals;string "v"] + )] + ) regs) ^/^ + string "end" ^^ hardline ) in - let read_register_field_pp = - separate_map - (break 1) - (fun (typ,(n1,n2,rsranges,rsbits)) -> - (if rsranges = [] then empty else - ((prefix 2 1) - ((separate space) [string "let";string ("read_register_field_" ^ typ); - string "reg";string "rfield";equals]) - (string "match rfield with" ^/^ - ((separate_map (break 1)) - (fun (name,i,j) -> - (separate space) - [pipe;string name;arrow; - string ("read_register_" ^ typ ^ " reg"); - string ">>=";string "fun v";arrow; - string "return"; - parens ( - (separate space) - [string "read_vector_subrange"; - string "is_inc"; - string "v"; - string (string_of_int i); - string (string_of_int j)] - ) - ]; - ) - rsranges) ^/^ - string "end" - ) ^^ hardline - ) ^/^ - hardline - ) ^^ - (if rsbits = [] then empty else - (prefix 2 1) - ((separate space) [string "let";string ("read_register_field_bit_" ^ typ); - string "reg";string "rfield";equals]) - (string "match rfield with" ^/^ - ((separate_map (break 1)) - (fun (name,i) -> - (separate space) - [pipe;string name;arrow; - string ("read_register_" ^ typ ^ " reg"); - string ">>=";string "fun v";arrow; - string "return"; - parens ( - (separate space) - [string "vector_access"; - string "is_inc"; - string "v"; - string (string_of_int i)] - ) - ]; - ) - rsbits) ^/^ - string "end" - ) ^^ hardline - ) - ) regtypes - in - - let write_register_field_pp = - separate_map - (break 1) - (fun (typ,(n1,n2,rsranges,rsbits)) -> - (if rsranges = [] then empty else - (prefix 2 1) - (separate space [string "let";string ("write_register_field_" ^ typ);string "reg"; - string "rfield";string "v";string "s";equals]) - (string "match (reg,rfield) with" ^/^ - (separate_map (break 1)) - (fun regname -> - ((separate_map (break 1)) - (fun (fieldname,i,j) -> - (prefix 2 1) - (separate space [pipe;parens (string regname ^^ comma ^^ string fieldname);arrow]) - (parens - (string "()" ^^ comma ^^ - anglebars ( - (separate space) - [string "s";string"with";string (String.lowercase regname);equals; - string "write_vector_subrange";string "is_inc"; - string "s." ^^ string (String.lowercase regname); - string (string_of_int i); string (string_of_int j);string "v"] - ) - ) - ) - ) rsranges - ) - ) (List.assoc typ typedregs_per_type) ^/^ - string "end" ^^ hardline - ) ^/^ hardline) ^^ - (if rsbits = [] then empty else - (prefix 2 1) - (separate space [string "let";string ("write_register_field_bit_" ^ typ);string "reg"; - string "rfield";string "v";string "s";equals]) - (string "match (reg,rfield) with" ^/^ - (separate_map (break 1)) - (fun regname -> - ((separate_map (break 1)) - (fun (fieldname,i) -> - (prefix 2 1) - (separate space [pipe;parens (string regname ^^ comma ^^ string fieldname);arrow]) - (parens - (string "()" ^^ comma ^^ - anglebars ( - (separate space) - [string "s";string"with";string (String.lowercase regname);equals; - string "write_vector_bit";string "is_inc"; - string "s." ^^ string (String.lowercase regname); - string (string_of_int i);string "v"] - ) - ) - ) - ) rsbits - ) - ) (List.assoc typ typedregs_per_type) ^/^ - string "end" ^^ hardline - ) ^^ hardline) - ) regtypes - in - (separate (hardline ^^ hardline) - [dirpp;regspp;regfieldspp;statepp;length_pp;read_register_pp;write_register_pp; - read_register_field_pp;write_register_field_pp],defs) - - + [dir_pp;regs_pp;regfields_pp;regfieldsbit_pp;field_index_bit_pp;field_indices_pp; + state_pp;length_pp;read_regstate_pp;write_regstate_pp],defs) + let doc_defs_lem (Defs defs) = let (decls,defs) = reg_decls (Defs defs) in (decls,separate_map empty doc_def_lem defs) @@ -2618,4 +2430,4 @@ let pp_defs_lem f_arch f d top_line opens = (string "(*" ^^ (string top_line) ^^ string "*)" ^/^ ((separate_map hardline) (fun lib -> separate space [string "open import";string lib]) - ["Pervasives";"State";"Vector"]) ^/^ decls) + ["Pervasives";"Vector"]) ^/^ decls) diff --git a/src/rewriter.ml b/src/rewriter.ml index 1f28bb34..c42942ac 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -1407,6 +1407,10 @@ let mktup_pat es = P_aux (P_tup pats,(Parse_ast.Unknown,simple_annot {t = Ttup typs})) +type 'a updated_term = + | Added_vars of 'a exp * 'a pat + | Same_vars of 'a exp + let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = let rec add_vars (*overwrite*) ((E_aux (expaux,annot)) as exp) vars = @@ -1439,8 +1443,8 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = else*) E_aux (E_tuple [exp;vars],swaptyp {t = Ttup [gettype exp;gettype vars]} annot) in - let rewrite (E_aux (eaux,annot)) (P_aux (_,pannot) as pat) = - match eaux with + let rewrite (E_aux (expaux,annot)) (P_aux (_,pannot) as pat) = + match expaux with | E_for(id,exp1,exp2,exp3,order,exp4) -> let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars exp4) in let vartuple = mktup vars in @@ -1448,21 +1452,41 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = | {t = Tid "unit"} -> true | _ -> false in*) let exp4 = rewrite_var_updates (add_vars (*overwrite*) exp4 vartuple) in - let funcl = match order with - | Ord_aux (Ord_inc,_) -> Id_aux (Id "foreach_inc",Unknown) - | Ord_aux (Ord_dec,_) -> Id_aux (Id "foreach_dec",Unknown) in - let v = E_aux (E_app (funcl,[mktup [exp1;exp2;exp3];exp4;vartuple]), + let orderb = match order with + | Ord_aux (Ord_inc,_) -> true + | Ord_aux (Ord_dec,_) -> false in + let funcl = match effectful exp4 with + | false -> Id_aux (Id (if orderb then "foreach_inc" else "foreach_dec"),Unknown) + | true -> Id_aux (Id (if orderb then "foreachM_inc" else "foreachM_dec"),Unknown) in + let loopvar = + let (bf,tf) = match gettype exp1 with + | {t = Tapp ("atom",[TA_nexp f])} -> (TA_nexp f,TA_nexp f) + | {t = Tapp ("reg", [TA_typ {t = Tapp ("atom",[TA_nexp f])}])} -> (TA_nexp f,TA_nexp f) + | {t = Tapp ("range",[TA_nexp bf;TA_nexp tf])} -> (TA_nexp bf,TA_nexp tf) + | {t = Tapp ("reg", [TA_typ {t = Tapp ("range",[TA_nexp bf;TA_nexp tf])}])} -> (TA_nexp bf,TA_nexp tf) + | {t = Tapp (name,_)} -> failwith (name ^ " shouldn't be here") in + let (bt,tt) = match gettype exp2 with + | {t = Tapp ("atom",[TA_nexp t])} -> (TA_nexp t,TA_nexp t) + | {t = Tapp ("atom",[TA_typ {t = Tapp ("atom", [TA_nexp t])}])} -> (TA_nexp t,TA_nexp t) + | {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])} -> (TA_nexp bt,TA_nexp tt) + | {t = Tapp ("atom",[TA_typ {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])}])} -> (TA_nexp bt,TA_nexp tt) + | {t = Tapp (name,_)} -> failwith (name ^ " shouldn't be here") in + let t = {t = Tapp ("range",if orderb then [bf;tt] else [tf;bt])} in + E_aux (E_id id,(Unknown,simple_annot t)) in + let v = E_aux (E_app (funcl,[loopvar;mktup [exp1;exp2;exp3];exp4;vartuple]), (Unknown,simple_annot_efr (gettype exp4) (geteffs exp4))) in let pat = (* if overwrite then mktup_pat vars else *) P_aux (P_tup [pat; mktup_pat vars], (Unknown,simple_annot (gettype v))) in - Some (v,pat) + Added_vars (v,pat) | E_if (c,e1,e2) -> let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (dedup eqidtyp (find_updated_vars e1 @ find_updated_vars e2)) in - if vars = [] then None else + if vars = [] then + (Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot))) + else let vartuple = mktup vars in (* let overwrite = match gettype exp with | {t = Tid "unit"} -> true @@ -1478,14 +1502,17 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = mktup_pat vars else*) P_aux (P_tup [pat; mktup_pat vars],(Unknown,simple_annot (gettype v))) in - Some (v,pat) + Added_vars (v,pat) | E_case (e1,ps) -> (* after a-normalisation e1 shouldn't need rewriting *) let vars = let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (dedup eqidtyp (List.fold_left f [] ps)) in - if vars = [] then None else + if vars = [] then + let ps = List.map (fun (Pat_aux (Pat_exp (p,e),a)) -> Pat_aux (Pat_exp (p,rewrite_var_updates e),a)) ps in + Same_vars (E_aux (E_case (e1,ps),annot)) + else let vartuple = mktup vars in (* let overwrite = match gettype exp with | {t = Tid "unit"} -> true @@ -1514,33 +1541,34 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = P_aux (P_tup [mktup_pat vars],(Unknown,simple_annot (gettype v))) else*) P_aux (P_tup [pat; mktup_pat vars],(Unknown,simple_annot (gettype v))) in - Some (v,pat) + Added_vars (v,pat) | E_assign (lexp,vexp) -> let {effect = Eset effs} = geteffs_annot annot in - if not (List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs) - then None else + if not (List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs) then + Same_vars (E_aux (E_assign (lexp,vexp),annot)) + else (match lexp with | LEXP_aux (LEXP_id id,annot) -> let pat = P_aux (P_id id,(Unknown,simple_annot (gettype vexp))) in - Some (vexp,pat) + Added_vars (vexp,pat) | LEXP_aux (LEXP_cast (_,id),annot) -> let pat = P_aux (P_id id,(Unknown,simple_annot (gettype vexp))) in - Some (vexp,pat) + Added_vars (vexp,pat) | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,annot2),i),annot) -> let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in let vexp = E_aux (E_vector_update (eid,i,vexp),(Unknown,simple_annot (gettype_annot annot))) in let pat = P_aux (P_id id,(Unknown,simple_annot (gettype vexp))) in - Some (vexp,pat) + Added_vars (vexp,pat) | LEXP_aux (LEXP_vector_range (LEXP_aux (LEXP_id id,annot2),i,j),annot) -> let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in let vexp = E_aux (E_vector_update_subrange (eid,i,j,vexp), (Unknown,simple_annot (gettype_annot annot))) in let pat = P_aux (P_id id,(Unknown,simple_annot (gettype vexp))) in - Some (vexp,pat)) + Added_vars (vexp,pat)) | _ -> (* assumes everying's a-normlised: an expression is a sequence of let-expressions, * "control-flow" structures and a return value, possibly wrapped in E_return *) - None in + Same_vars (E_aux (expaux,annot)) in match expaux with | E_let (lb,body) -> @@ -1548,41 +1576,35 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = let (eff,lb) = match lb with | LB_aux (LB_val_implicit (pat,v),lbannot) -> (match rewrite v pat with - | Some (v,pat) -> + | Added_vars (v,pat) -> let lbannot = (Parse_ast.Unknown,simple_annot (gettype v)) in (geteffs v,LB_aux (LB_val_implicit (pat,v),lbannot)) - | None -> (geteffs v,lb)) + | Same_vars v -> (geteffs v,LB_aux (LB_val_implicit (pat,v),lbannot))) | LB_aux (LB_val_explicit (typ,pat,v),lbannot) -> (match rewrite v pat with - | Some (v,pat) -> + | Added_vars (v,pat) -> let lbannot = (Parse_ast.Unknown,simple_annot (gettype v)) in (geteffs v,LB_aux (LB_val_implicit (pat,v),lbannot)) - | None -> (geteffs v,lb)) in + | Same_vars v -> (geteffs v,LB_aux (LB_val_explicit (typ,pat,v),lbannot))) in E_aux (E_let (lb,body), (Unknown,simple_annot_efr (gettype body) (union_effects eff (geteffs body)))) | E_internal_plet (pat,v,body) -> let body = rewrite_var_updates body in (match rewrite v pat with - | Some (v,pat) -> + | Added_vars (v,pat) -> E_aux (E_internal_plet (pat,v,body), (Unknown,simple_annot_efr (gettype body) (eff_union v body))) - | None -> E_aux (E_internal_plet (pat,v,body),annot)) + | Same_vars v -> E_aux (E_internal_plet (pat,v,body),annot)) | E_internal_let (lexp,v,body) -> - (* because we need patterns and internal_plets are needed to distinguish monadic - * expressions E_internal_lets are rewritten to E_lets. We only need them for OCaml - * anyways. *) + (* After a-normalisation E_internal_lets can only bind values to names, those don't + * need rewriting. *) let body = rewrite_var_updates body in let id = match lexp with | LEXP_aux (LEXP_id id,_) -> id | LEXP_aux (LEXP_cast (_,id),_) -> id in let pat = P_aux (P_id id, (Parse_ast.Unknown,simple_annot (gettype v))) in - let lb = (match rewrite v pat with - | Some (v,pat) -> - let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in - LB_aux (LB_val_implicit (pat,v),lbannot) - | None -> - let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in - LB_aux (LB_val_implicit (pat,v),lbannot)) in + let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in + let lb = LB_aux (LB_val_implicit (pat,v),lbannot) in E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union v body))) (* In tail-position there shouldn't be anything we need to do as the terms after * a-normalisation are pure and don't update local variables. There can't be any variable @@ -1659,17 +1681,57 @@ let replace_var_update_e_assign = { id_exp_alg with e_aux = e_aux } *) -let replace_memwrite_e_assign = +let replace_memwrite_e_assign exp = let e_aux = fun (expaux,annot) -> match expaux with | E_assign (LEXP_aux (LEXP_memory (id,args),_),v) -> E_aux (E_app (id,args @ [v]),annot) | _ -> E_aux (expaux,annot) in - { id_exp_alg with e_aux = e_aux } + fold_exp { id_exp_alg with e_aux = e_aux } exp + + + +let remove_reference_types exp = + + let rec rewrite_t {t = t_aux} = {t = rewrite_t_aux t_aux} + and rewrite_t_aux t_aux = match t_aux with + | Tapp ("reg",[TA_typ {t = t_aux2}]) -> rewrite_t_aux t_aux2 + | Tapp (name,t_args) -> Tapp (name,List.map rewrite_t_arg t_args) + | Tfn (t1,t2,imp,e) -> Tfn (rewrite_t t1,rewrite_t t2,imp,e) + | Ttup ts -> Ttup (List.map rewrite_t ts) + | Tabbrev (t1,t2) -> Tabbrev (rewrite_t t1,rewrite_t t2) + | Toptions (t1,t2) -> + let t2 = match t2 with Some t2 -> Some (rewrite_t t2) | None -> None in + Toptions (rewrite_t t1,t2) + | Tuvar t_uvar -> Tuvar t_uvar (*(rewrite_t_uvar t_uvar) *) + | _ -> t_aux +(* and rewrite_t_uvar t_uvar = + t_uvar.subst <- (match t_uvar.subst with None -> None | Some t -> Some (rewrite_t t)) *) + and rewrite_t_arg t_arg = match t_arg with + | TA_typ t -> TA_typ (rewrite_t t) + | _ -> t_arg in + + let rec rewrite_annot = function + | NoTyp -> NoTyp + | Base ((tparams,t),tag,nexprs,effs,effsum,bounds) -> + Base ((tparams,rewrite_t t),tag,nexprs,effs,effsum,bounds) + | Overload (tannot1,b,tannots) -> + Overload (rewrite_annot tannot1,b,List.map rewrite_annot tannots) in + + + fold_exp + { id_exp_alg with + e_aux = (fun (e,(l,annot)) -> E_aux (e,(l,rewrite_annot annot))) + ; lEXP_aux = (fun (lexp,(l,annot)) -> LEXP_aux (lexp,(l,rewrite_annot annot))) + ; fE_aux = (fun (fexp,(l,annot)) -> FE_aux (fexp,(l,(rewrite_annot annot)))) + ; fES_aux = (fun (fexp,(l,annot)) -> FES_aux (fexp,(l,rewrite_annot annot))) + ; pat_aux = (fun (pexp,(l,annot)) -> Pat_aux (pexp,(l,rewrite_annot annot))) + ; lB_aux = (fun (lb,(l,annot)) -> LB_aux (lb,(l,rewrite_annot annot))) + } + exp let rewrite_defs_remove_e_assign = let rewrite_exp _ _ e = - (fold_exp replace_memwrite_e_assign) - ((rewrite_var_updates e)) in + replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in rewrite_defs_base {rewrite_exp = rewrite_exp ; rewrite_pat = rewrite_pat |
