summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorChristopher Pulte2015-11-25 10:58:46 +0000
committerChristopher Pulte2015-11-25 10:58:46 +0000
commitda258def4f0253c218cdcfef7d144bc256bf4ba5 (patch)
tree369ace633e533a300eb23cd68e9b70ce0da3f455 /src
parentdab6dc6a99f1b68ee701d21050dd6f86818aa525 (diff)
fixes, pp
Diffstat (limited to 'src')
-rw-r--r--src/gen_lib/sail_values.lem188
-rw-r--r--src/pretty_print.ml137
-rw-r--r--src/rewriter.ml20
3 files changed, 225 insertions, 120 deletions
diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem
index 33b8444e..aa0578b2 100644
--- a/src/gen_lib/sail_values.lem
+++ b/src/gen_lib/sail_values.lem
@@ -138,72 +138,107 @@ let to_vec_undef is_inc (len : integer) =
let to_vec_inc_undef = to_vec_undef true
let to_vec_dec_undef = to_vec_undef false
-let add = uncurry integerAdd
-let add_signed = uncurry integerAdd
-let minus = uncurry integerMinus
-let multiply = uncurry integerMult
-let modulo = uncurry integerMod
-let quot = uncurry integerDiv
-let power = uncurry integerPow
-
-let arith_op_vec op sign (size : integer) ((V _ _ is_inc as l),r) =
+
+let add = integerAdd
+let add_signed = integerAdd
+let minus = integerMinus
+let multiply = integerMult
+let modulo = integerMod
+let quot = integerDiv
+let power = integerPow
+
+let arith_op_vec op sign (size : integer) (V _ _ is_inc as l) r =
let (l',r') = (to_num sign l, to_num sign r) in
let n = op l' r' in
to_vec is_inc (size * (length l),n)
-let add_vec = arith_op_vec integerAdd false 1
-let add_vec_signed = arith_op_vec integerAdd true 1
-let minus_vec = arith_op_vec integerMinus false 1
-let multiply_vec = arith_op_vec integerMult false 2
-let multiply_vec_signed = arith_op_vec integerMult true 2
-
-let arith_op_vec_range op sign size ((V _ _ is_inc as l),r) =
- arith_op_vec op sign size (l, to_vec is_inc (length l,r))
-
-let add_vec_range = arith_op_vec_range integerAdd false 1
-let add_vec_range_signed = arith_op_vec_range integerAdd true 1
-let minus_vec_range = arith_op_vec_range integerMinus false 1
-let mult_vec_range = arith_op_vec_range integerMult false 2
-let mult_vec_range_signed = arith_op_vec_range integerMult true 2
-
-let arith_op_range_vec op sign size (l,(V _ _ is_inc as r)) =
- arith_op_vec op sign size (to_vec is_inc (length r, l), r)
-
-let add_range_vec = arith_op_range_vec integerAdd false 1
-let add_range_vec_signed = arith_op_range_vec integerAdd true 1
-let minus_range_vec = arith_op_range_vec integerMinus false 1
-let mult_range_vec = arith_op_range_vec integerMult false 2
-let mult_range_vec_signed = arith_op_range_vec integerMult true 2
-let arith_op_range_vec_range op sign (l,r) = uncurry op (l, to_num sign r)
-
-let add_range_vec_range = arith_op_range_vec_range integerAdd false
-let add_range_vec_range_signed = arith_op_range_vec_range integerAdd true
-let minus_range_vec_range = arith_op_range_vec_range integerMinus false
-
-let arith_op_vec_range_range op sign (l,r) = uncurry op (to_num sign l,r)
-
-let add_vec_range_range = arith_op_vec_range_range integerAdd false
-let add_vec_range_range_signed = arith_op_vec_range_range integerAdd true
-let minus_vec_range_range = arith_op_vec_range_range integerMinus false
-
-let arith_op_vec_vec_range op sign ((V _ _ is_inc as l),r) =
+(* add_vec
+ * add_vec_signed
+ * minus_vec
+ * multiply_vec
+ * multiply_vec_signed
+ *)
+let add_VVV = arith_op_vec integerAdd false 1
+let addS_VVV = arith_op_vec integerAdd true 1
+let minus_VVV = arith_op_vec integerMinus false 1
+let mult_VVV = arith_op_vec integerMult false 2
+let multS_VVV = arith_op_vec integerMult true 2
+
+let arith_op_vec_range op sign size (V _ _ is_inc as l) r =
+ arith_op_vec op sign size l (to_vec is_inc (length l,r))
+
+(* add_vec_range
+ * add_vec_range_signed
+ * minus_vec_range
+ * mult_vec_range
+ * mult_vec_range_signed
+ *)
+let add_VIV = arith_op_vec_range integerAdd false 1
+let addS_VIV = arith_op_vec_range integerAdd true 1
+let minus_VIV = arith_op_vec_range integerMinus false 1
+let mult_VIV = arith_op_vec_range integerMult false 2
+let multS_VIV = arith_op_vec_range integerMult true 2
+
+let arith_op_range_vec op sign size l (V _ _ is_inc as r) =
+ arith_op_vec op sign size (to_vec is_inc (length r, l)) r
+
+(* add_range_vec
+ * add_range_vec_signed
+ * minus_range_vec
+ * mult_range_vec
+ * mult_range_vec_signed
+ *)
+let add_IVV = arith_op_range_vec integerAdd false 1
+let addS_IVV = arith_op_range_vec integerAdd true 1
+let minus_IVV = arith_op_range_vec integerMinus false 1
+let mult_IVV = arith_op_range_vec integerMult false 2
+let multS_IVV = arith_op_range_vec integerMult true 2
+
+let arith_op_range_vec_range op sign l r = op l (to_num sign r)
+
+(* add_range_vec_range
+ * add_range_vec_range_signed
+ * minus_range_vec_range
+ *)
+let add_IVI = arith_op_range_vec_range integerAdd false
+let addS_IVI = arith_op_range_vec_range integerAdd true
+let minus_IVI = arith_op_range_vec_range integerMinus false
+
+let arith_op_vec_range_range op sign l r = op (to_num sign l) r
+
+(* add_vec_range_range
+ * add_vec_range_range_signed
+ * minus_vec_range_range
+ *)
+let add_VII = arith_op_vec_range_range integerAdd false
+let addS_VII = arith_op_vec_range_range integerAdd true
+let minus_VII = arith_op_vec_range_range integerMinus false
+
+let arith_op_vec_vec_range op sign (V _ _ is_inc as l) r =
let (l',r') = (to_num sign l,to_num sign r) in
op l' r'
-let add_vec_vec_range = arith_op_vec_vec_range integerAdd false
-let add_vec_vec_range_signed = arith_op_vec_vec_range integerAdd true
+(* add_vec_vec_range
+ * add_vec_vec_range_signed
+ *)
+let add_VVI = arith_op_vec_vec_range integerAdd false
+let addS_VVI = arith_op_vec_vec_range integerAdd true
-let arith_op_vec_bit op sign (size : integer) ((V _ _ is_inc as l),r) =
+let arith_op_vec_bit op sign (size : integer) (V _ _ is_inc as l)r =
let l' = to_num sign l in
- let n = op l' match r with | I -> (1 : integer) | _ -> 0 end in
+ let n = op l' (match r with | I -> (1 : integer) | _ -> 0 end) in
to_vec is_inc (length l * size,n)
-let add_vec_bit = arith_op_vec_bit integerAdd false 1
-let add_vec_bit_signed = arith_op_vec_bit integerAdd true 1
-let minus_vec_bit = arith_op_vec_bit integerMinus true 1
-
-let rec arith_op_overflow_vec (op : integer -> integer -> integer) sign size ((V _ _ is_inc as l),r) =
+(* add_vec_bit
+ * add_vec_bit_signed
+ * minus_vec_bit_signed
+ *)
+let add_VBV = arith_op_vec_bit integerAdd false 1
+let addS_VBV = arith_op_vec_bit integerAdd true 1
+let minus_VBV = arith_op_vec_bit integerMinus true 1
+
+let rec arith_op_overflow_vec (op : integer -> integer -> integer) sign size (V _ _ is_inc as l) r =
let len = length l in
let act_size = len * size in
let (l_sign,r_sign) = (to_num sign l,to_num sign r) in
@@ -219,15 +254,22 @@ let rec arith_op_overflow_vec (op : integer -> integer -> integer) sign size ((V
let c_out = most_significant one_more_size_u in
(correct_size_num,overflow,c_out)
-let add_overflow_vec = arith_op_overflow_vec integerAdd false 1
-let add_overflow_vec_signed = arith_op_overflow_vec integerAdd true 1
-let minus_overflow_vec = arith_op_overflow_vec integerMinus false 1
-let minus_overflow_vec_signed = arith_op_overflow_vec integerMinus true 1
-let mult_overflow_vec = arith_op_overflow_vec integerMult false 2
-let mult_overflow_vec_signed = arith_op_overflow_vec integerMult true 2
+(* add_overflow_vec
+ * add_overflow_vec_signed
+ * minus_overflow_vec
+ * minus_overflow_vec_signed
+ * mult_overflow_vec
+ * mult_overflow_vec_signed
+ *)
+let addO_VVV = arith_op_overflow_vec integerAdd false 1
+let addSO_VVV = arith_op_overflow_vec integerAdd true 1
+let minusO_VVV = arith_op_overflow_vec integerMinus false 1
+let minusSO_VVV = arith_op_overflow_vec integerMinus true 1
+let multO_VVV = arith_op_overflow_vec integerMult false 2
+let multSO_VVV = arith_op_overflow_vec integerMult true 2
let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (size : integer)
- ((V _ _ is_inc as l),r_bit) =
+ (V _ _ is_inc as l) r_bit =
let act_size = length l * size in
let l' = to_num sign l in
let l_u = to_num false l in
@@ -246,9 +288,13 @@ let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (siz
else I in
(correct_size_num,overflow,most_significant one_larger)
-let add_overflow_vec_bit_signed = arith_op_overflow_vec_bit integerAdd true 1
-let minus_overflow_vec_bit = arith_op_overflow_vec_bit integerMinus false 1
-let minus_overflow_vec_bit_signed = arith_op_overflow_vec_bit integerMinus true 1
+(* add_overflow_vec_bit_signed
+ * minus_overflow_vec_bit
+ * minus_overflow_vec_bit_signed
+ *)
+let addSO_VBV = arith_op_overflow_vec_bit integerAdd true 1
+let minusO_VBV = arith_op_overflow_vec_bit integerMinus false 1
+let minusSO_VBV = arith_op_overflow_vec_bit integerMinus true 1
type shift = LL | RR | LLL
@@ -293,9 +339,9 @@ let rec arith_op_vec_no0 (op : integer -> integer -> integer) sign size (((V _ s
then to_vec is_inc (act_size,n')
else V (List.replicate (natFromInteger act_size) Undef) start is_inc
-let mod_vec = arith_op_vec_no0 integerMod false 1
-let quot_vec = arith_op_vec_no0 integerDiv false 1
-let quot_vec_signed = arith_op_vec_no0 integerDiv true 1
+let mod_VVV = arith_op_vec_no0 integerMod false 1
+let quot_VVV = arith_op_vec_no0 integerDiv false 1
+let quotS_VVV = arith_op_vec_no0 integerDiv true 1
let arith_op_overflow_no0_vec op sign size (((V _ start is_inc) as l),r) =
let rep_size = length r * size in
@@ -320,13 +366,13 @@ let arith_op_overflow_no0_vec op sign size (((V _ start is_inc) as l),r) =
let overflow = if representable then O else I in
(correct_size_num,overflow,most_significant one_more)
-let quot_overflow_vec = arith_op_overflow_no0_vec integerDiv false 1
-let quot_overflow_vec_signed = arith_op_overflow_no0_vec integerDiv true 1
+let quotO_VVV = arith_op_overflow_no0_vec integerDiv false 1
+let quotSO_VVV = arith_op_overflow_no0_vec integerDiv true 1
let arith_op_vec_range_no0 op sign size ((V _ _ is_inc as l),r) =
arith_op_vec_no0 op sign size (l,to_vec is_inc (length l,r))
-let mod_vec_range = arith_op_vec_range_no0 integerMod false 1
+let mod_VIV = arith_op_vec_range_no0 integerMod false 1
let duplicate (bit,length) =
V (List.replicate (natFromInteger length) bit) 0 true
@@ -352,7 +398,7 @@ let lt_vec_signed = compare_op_vec (<) true
let gt_vec_signed = compare_op_vec (>) true
let lteq_vec_signed = compare_op_vec (<=) true
let gteq_vec_signed = compare_op_vec (>=) true
-let lt_vec_unsignedp = compare_op_vec (<) false
+let lt_vec_unsigned = compare_op_vec (<) false
let gt_vec_unsigned = compare_op_vec (>) false
let lteq_vec_unsigned = compare_op_vec (<=) false
let gteq_vec_unsigned = compare_op_vec (>=) false
diff --git a/src/pretty_print.ml b/src/pretty_print.ml
index 2108c62f..0006f290 100644
--- a/src/pretty_print.ml
+++ b/src/pretty_print.ml
@@ -1992,20 +1992,19 @@ let doc_exp_lem, doc_let_lem =
prefix 2 1 f (separate (break 1) args)
| E_vector_append(l,r) ->
let epp =
- separate space [exp l;string "^^"] ^//^ exp r in
+ align (separate space [exp l;string "^^"] ^//^ exp r) in
if aexp_needed then parens epp else epp
| E_cons(l,r) -> doc_op (group (colon^^colon)) (exp l) (exp r)
| E_if(c,t,e) ->
let (E_aux (_,(_,cannot))) = c in
let epp =
- group (
- (match cannot with
- | Base ((_,({t = Tid "bit"})),_,_,_,_,_) ->
- separate space [string "if";string "to_bool";exp c]
- | _ -> separate space [string "if";exp c])
- ^^ break 1 ^^
- (prefix 2 1 (string "then") (top_exp false t)) ^^ (break 1) ^^
- (prefix 2 1 (string "else") (top_exp false e))) in
+ (match cannot with
+ | Base ((_,({t = Tid "bit"})),_,_,_,_,_) ->
+ separate space [string "if";string "to_bool";exp c]
+ | _ -> separate space [string "if";exp c])
+ ^^ break 1 ^^
+ (prefix 2 1 (string "then") (top_exp false t)) ^^ (break 1) ^^
+ (prefix 2 1 (string "else") (top_exp false e)) in
if aexp_needed then parens epp else epp
| E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) ->
failwith "E_for should have been removed till now"
@@ -2049,7 +2048,7 @@ let doc_exp_lem, doc_let_lem =
if aexp_needed then parens epp else epp
| Base (_,External (Some "bitwise_not_bit"),_,_,_,_) ->
let [a] = args in
- let epp = string "~" ^^ exp a in
+ let epp = align (string "~" ^^ exp a) in
if aexp_needed then parens epp else epp
| _ ->
let call = match annot with
@@ -2059,10 +2058,14 @@ let doc_exp_lem, doc_let_lem =
| Base(_,Constructor _,_,_,_,_) -> doc_id_lem_ctor false f
| _ -> doc_id_lem f in
let epp =
- (doc_unop call)
- (match args with
- | [a] -> exp a
- | args -> parens (separate_map comma (top_exp false) args)) in
+ align
+ (call ^//^
+ (match args with
+ | [a] -> exp a
+ | args -> (parens (separate_map (comma ^^ break 1) exp args))
+ )
+ )
+ in
if aexp_needed then parens epp else epp
)
)
@@ -2136,11 +2139,11 @@ let doc_exp_lem, doc_let_lem =
| E_cast(typ,e) ->
(match annot with
| Base(_,External _,_,_,_,_) -> string "read_reg" ^^ space ^^ exp e
- | _ -> exp e) (*(parens (doc_op colon (group (exp e)) (doc_typ_lem typ)))) *)
+ | _ -> top_exp aexp_needed e) (*(parens (doc_op colon (group (exp e)) (doc_typ_lem typ)))) *)
| E_tuple exps ->
(match exps with
- | [e] -> exp e
- | _ -> parens (separate_map comma exp exps))
+ | [e] -> top_exp aexp_needed e
+ | _ -> parens (separate_map comma (top_exp false) exps))
| E_record(FES_aux(FES_Fexps(fexps,_),_)) ->
anglebars (separate_map semi_sp doc_fexp fexps)
| E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) ->
@@ -2158,9 +2161,18 @@ let doc_exp_lem, doc_let_lem =
| Nconst i -> string_of_big_int i
| N2n(_,Some i) -> string_of_big_int i
| _ -> if dir then "0" else string_of_int (List.length exps) in
-
- let epp = group (separate space [string "V"; brackets (separate_map (semi) exp exps);
- string start;string dir_out]) in
+ let expspp =
+ match exps with
+ | [] -> empty
+ | e :: es ->
+ let (expspp,_) =
+ List.fold_left
+ (fun (pp,count) e ->
+ (pp ^^ semi ^^ (if count = 20 then break 0 else empty) ^^ top_exp false e),
+ if count = 20 then 0 else count + 1)
+ (top_exp false e,0) exps in
+ align (group expspp) in
+ let epp = group (separate space [string "V"; brackets expspp;string start;string dir_out]) in
if aexp_needed then parens epp else epp
)
| E_vector_indexed (iexps, (Def_val_aux (default,(dl,dannot)))) ->
@@ -2200,7 +2212,7 @@ let doc_exp_lem, doc_let_lem =
| _ ->
raise (Reporting_basic.err_unreachable dl "nono") in
parens (string "Just " ^^ parens (string ("UndefinedReg " ^ string_of_big_int n)))) in
- let iexp (i,e) = parens (separate_map comma (fun x -> x) [(doc_int i); (exp e)]) in
+ let iexp (i,e) = parens (separate_map comma (fun x -> x) [(doc_int i); top_exp false e]) in
let epp =
(separate space)
[call;(brackets (separate_map semi iexp iexps));
@@ -2215,10 +2227,10 @@ let doc_exp_lem, doc_let_lem =
let epp = separate space [string "update";exp v;exp e1;exp e2;exp e3] in
if aexp_needed then parens epp else epp
| E_list exps ->
- brackets (separate_map semi exp exps)
+ brackets (separate_map semi (top_exp false) exps)
| E_case(e,pexps) ->
let epp =
- (prefix 2 1)
+ (prefix 0 1)
(separate space [string "match"; exp e; string "with"])
(separate_map (break 1) doc_case pexps) ^^ (break 1) ^^
(string "end" ^^ (break 1)) in
@@ -2228,36 +2240,71 @@ let doc_exp_lem, doc_let_lem =
| E_app_infix (e1,id,e2) ->
(match annot with
| Base((_,t),External(Some name),_,_,_,_) ->
- let epp = match name with
- | "bitwise_and_bit" -> separate space [exp e1;string "&."] ^//^ exp e2
- | "bitwise_or_bit" -> separate space [exp e1;string "|."] ^//^ exp e2
- | "bitwise_xor_bit" -> separate space [exp e1;string "+."] ^//^ exp e2
- | "add" -> separate space [exp e1;string "+";exp e2]
- | "minus" -> separate space [exp e1;string "-";exp e2]
- | "multiply" -> separate space [exp e1;string "*";exp e2]
- (* | "lt" -> separate space [exp e1;string "<";exp e2]
- | "gt" -> separate space [exp e1;string ">";exp e2]
- | "lteq" -> separate space [exp e1;string "<=";exp e2]
- | "gteq" -> separate space [exp e1;string ">=";exp e2]
- | "lt_vec" -> separate space [exp e1;string "<";exp e2]
- | "gt_vec" -> separate space [exp e1;string ">";exp e2]
- | "lteq_vec" -> separate space [exp e1;string "<=";exp e2]
- | "gteq_vec" -> separate space [exp e1;string ">=";exp e2] *)
- | _ -> separate space [string name; parens (separate_map comma (top_exp false) [e1;e2])] in
- if aexp_needed then parens epp else epp
+ let epp =
+ let aux name = exp e1 ^//^ string name ^//^ exp e2 in
+ let aux2 name = string name ^//^ exp e1 ^//^ exp e2 in
+ align
+ (match name with
+ | "bitwise_and_bit" -> aux "&."
+ | "bitwise_or_bit" -> aux "|."
+ | "bitwise_xor_bit" -> aux "+."
+ | "add" -> aux "+"
+ | "minus" -> aux "-"
+ | "multiply" -> aux "*"
+ | "quot" -> aux "/"
+ | "modulo" -> aux "(mod)"
+
+ | "add_vec" -> aux2 "add_VVV"
+ | "add_vec_signed" -> aux2 "addS_VVV"
+ | "minus_vec" -> aux2 "minus_VVV"
+ | "multiply_vec" -> aux2 "mult_VVV"
+ | "multiply_vec_signed" -> aux2 "multS_VVV"
+ | "add_vec_range" -> aux2 "add_VIV"
+ | "add_vec_range_signed" -> aux2 "addS_VIV"
+ | "minus_vec_range" -> aux2 "minus_VIV"
+ | "mult_vec_range" -> aux2 "mult_VIV"
+ | "mult_vec_range_signed" -> aux2 "multS_VIV"
+ | "add_range_vec" -> aux2 "add_IVV"
+ | "add_range_vec_signed" -> aux2 "addS_IVV"
+ | "minus_range_vec" -> aux2 "minus_IVV"
+ | "mult_range_vec" -> aux2 "mult_IVV"
+ | "mult_range_vec_signed" -> aux2 "multS_IVV"
+ | "add_range_vec_range" -> aux2 "add_IVI"
+ | "add_range_vec_range_signed" -> aux2 "addS_IVI"
+ | "minus_range_vec_range" -> aux2 "minus_IVI"
+ | "add_vec_range_range" -> aux2 "add_VII"
+ | "add_vec_range_range_signed" -> aux2 "addS_VII"
+ | "minus_vec_range_range" -> aux2 "minus_VII"
+ | "add_vec_bit" -> aux2 "add_VBV"
+ | "add_vec_bit_signed" -> aux2 "addS_VBV"
+ | "minus_vec_bit_signed" -> aux2 "minus_VBV"
+ | "add_overflow_vec" -> aux2 "addO_VVV"
+ | "add_overflow_vec_signed" -> aux2 "addSO_VVV"
+ | "minus_overflow_vec" -> aux2 "minusO_VVV"
+ | "minus_overflow_vec_signed" -> aux2 "minusSO_VVV"
+ | "mult_overflow_vec" -> aux2 "multO_VVV"
+ | "mult_overflow_vec_signed" -> aux2 "multSO_VVV"
+ | "add_overflow_vec_bit_signed" -> aux2 "addSO_VBV"
+ | "minus_overflow_vec_bit" -> aux2 "minusO_VBV"
+ | "minus_overflow_vec_bit_signed" -> aux2 "minusSO_VBV"
+
+ | _ ->
+ string name ^//^ parens (top_exp false e1 ^^ comma ^//^ top_exp false e2)) in
+ if aexp_needed then parens epp else epp
| _ ->
- let epp = separate space [doc_id_lem id; parens (separate_map comma (top_exp false) [e1;e2])] in
+ let epp =
+ align (doc_id_lem id ^//^ parens (top_exp false e1 ^^ comma ^//^ top_exp false e2)) in
if aexp_needed then parens epp else epp)
| E_internal_let(lexp, eq_exp, in_exp) ->
- (* failwith "E_internal_lets should have been removed till now" *)
- (separate
+ failwith "E_internal_lets should have been removed till now"
+(* (separate
space
[string "let internal";
(match lexp with (LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_)) -> doc_id_lem id);
coloneq;
exp eq_exp;
string "in"]) ^/^
- exp in_exp
+ exp in_exp *)
| E_internal_plet (pat,e1,e2) ->
let epp =
let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
@@ -2281,7 +2328,7 @@ let doc_exp_lem, doc_let_lem =
and doc_fexp (FE_aux(FE_Fexp(id,e),_)) = doc_op equals (doc_id_lem id) (top_exp false e)
and doc_case (Pat_aux(Pat_exp(pat,e),_)) =
- doc_op arrow (separate space [pipe; doc_pat_lem false pat]) (group (top_exp false e))
+ group (prefix 3 1 (separate space [pipe; doc_pat_lem false pat;arrow]) (group (top_exp false e)))
and doc_lexp_deref_lem ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with
| LEXP_field (le,id) ->
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 288410e0..e1a1f5a2 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -32,6 +32,11 @@ let get_effsum_annot (_,t) = match t with
| NoTyp -> failwith "no effect information"
| _ -> failwith "a_normalise doesn't support Overload"
+let get_localeff_annot (_,t) = match t with
+ | Base (_,_,_,eff,_,_) -> eff
+ | NoTyp -> failwith "no effect information"
+ | _ -> failwith "a_normalise doesn't support Overload"
+
let get_type_annot (_,t) = match t with
| Base((_,t),_,_,_,_,_) -> t
| NoTyp -> failwith "no type information"
@@ -1428,16 +1433,23 @@ and n_exp (E_aux (exp_aux,annot) as exp : 'a exp) (k : 'a exp -> 'a exp) : 'a ex
let rewrite_defs_a_normalise =
- let rewrite_exp _ _ e =
- let e = remove_blocks e in
- n_exp_term (effectful e) e in
+
+ let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),fdannot)) =
+ let newreturn =
+ List.fold_left
+ (fun b (FCL_aux (FCL_Funcl(id,pat,exp),annot)) ->
+ b || effectful_effs (get_localeff_annot annot)) false funcls in
+ let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),annot)) =
+ let _ = reset_fresh_name_counter () in
+ FCL_aux (FCL_Funcl (id,pat,n_exp_term newreturn (remove_blocks exp)),annot)
+ in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),fdannot) in
rewrite_defs_base
{rewrite_exp = rewrite_exp
; rewrite_pat = rewrite_pat
; rewrite_let = rewrite_let
; rewrite_lexp = rewrite_lexp
; rewrite_fun = rewrite_fun
- ; rewrite_def = (fun rws def -> let _ = reset_fresh_name_counter () in rewrite_def rws def)
+ ; rewrite_def = rewrite_def
; rewrite_defs = rewrite_defs_base
}