summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Bauereiss2018-03-21 19:54:28 +0000
committerThomas Bauereiss2018-03-22 13:48:29 +0000
commit5c1754d3a8170167c58c876be36d451c7607fb2c (patch)
tree4883fc0d8fda864cf9ddd3bf50699b55ec270e5f /src
parent2dcd2d7b77c2c0759791d92114a844b9990d0820 (diff)
Tune Lem pretty-printing
In particular, improve indentation of if-expressions, and provide infix syntax for monadic binds in Isabelle, allowing Lem to preserve source whitespace.
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml8
-rw-r--r--src/ast_util.mli3
-rw-r--r--src/gen_lib/prompt.lem8
-rw-r--r--src/gen_lib/prompt_monad.lem10
-rw-r--r--src/gen_lib/sail_operators.lem8
-rw-r--r--src/gen_lib/sail_operators_bitlists.lem23
-rw-r--r--src/gen_lib/sail_operators_mwords.lem42
-rw-r--r--src/gen_lib/sail_values.lem14
-rw-r--r--src/pretty_print_lem.ml61
9 files changed, 113 insertions, 64 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 5261a6d2..08532eb4 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -71,9 +71,17 @@ let mk_nexp nexp_aux = Nexp_aux (nexp_aux, Parse_ast.Unknown)
let mk_exp exp_aux = E_aux (exp_aux, no_annot)
let unaux_exp (E_aux (exp_aux, _)) = exp_aux
+let uncast_exp = function
+ | E_aux (E_internal_return (E_aux (E_cast (typ, exp), _)), a) ->
+ E_aux (E_internal_return exp, a), Some typ
+ | E_aux (E_cast (typ, exp), _) -> exp, Some typ
+ | exp -> exp, None
let mk_pat pat_aux = P_aux (pat_aux, no_annot)
let unaux_pat (P_aux (pat_aux, _)) = pat_aux
+let untyp_pat = function
+ | P_aux (P_typ (typ, pat), _) -> pat, Some typ
+ | pat -> pat, None
let mk_pexp pexp_aux = Pat_aux (pexp_aux, no_annot)
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 89004b0d..951f5bed 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -93,6 +93,9 @@ val unaux_nexp : nexp -> nexp_aux
val unaux_order : order -> order_aux
val unaux_typ : typ -> typ_aux
+val untyp_pat : 'a pat -> 'a pat * typ option
+val uncast_exp : 'a exp -> 'a exp * typ option
+
val inc_ord : order
val dec_ord : order
diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem
index 8cef266e..8214bf49 100644
--- a/src/gen_lib/prompt.lem
+++ b/src/gen_lib/prompt.lem
@@ -4,6 +4,14 @@ open import Sail_values
open import Prompt_monad
open import {isabelle} `Prompt_monad_lemmas`
+val (>>=) : forall 'rv 'a 'b 'e. monad 'rv 'a 'e -> ('a -> monad 'rv 'b 'e) -> monad 'rv 'b 'e
+declare isabelle target_rep function (>>=) = infix `\<bind>`
+let inline ~{isabelle} (>>=) = bind
+
+val (>>) : forall 'rv 'b 'e. monad 'rv unit 'e -> monad 'rv 'b 'e -> monad 'rv 'b 'e
+declare isabelle target_rep function (>>) = infix `\<then>`
+let inline ~{isabelle} (>>) m n = m >>= fun (_ : unit) -> n
+
val iter_aux : forall 'rv 'a 'e. integer -> (integer -> 'a -> monad 'rv unit 'e) -> list 'a -> monad 'rv unit 'e
let rec iter_aux i f xs = match xs with
| x :: xs -> f i x >> iter_aux (i + 1) f xs
diff --git a/src/gen_lib/prompt_monad.lem b/src/gen_lib/prompt_monad.lem
index 92b9ac5e..b1dd59c4 100644
--- a/src/gen_lib/prompt_monad.lem
+++ b/src/gen_lib/prompt_monad.lem
@@ -62,10 +62,6 @@ let rec bind m f = match m with
| Exception e -> Exception e
end
-let inline (>>=) = bind
-val (>>) : forall 'rv 'b 'e. monad 'rv unit 'e -> monad 'rv 'b 'e -> monad 'rv 'b 'e
-let inline (>>) m n = m >>= fun (_ : unit) -> n
-
val exit : forall 'rv 'a 'e. unit -> monad 'rv 'a 'e
let exit () = Fail "exit"
@@ -139,8 +135,10 @@ let read_mem_bytes rk addr sz =
val read_mem : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv 'b 'e
let read_mem rk addr sz =
- read_mem_bytes rk addr sz >>= (fun bytes ->
- maybe_fail "bits_of_mem_bytes" (of_bits (bits_of_mem_bytes bytes)))
+ bind
+ (read_mem_bytes rk addr sz)
+ (fun bytes ->
+ maybe_fail "bits_of_mem_bytes" (of_bits (bits_of_mem_bytes bytes)))
val read_tag : forall 'rv 'a 'e. Bitvector 'a => 'a -> monad 'rv bitU 'e
let read_tag addr = Read_tag (bits_of addr) return
diff --git a/src/gen_lib/sail_operators.lem b/src/gen_lib/sail_operators.lem
index 9b0857d9..d4275c87 100644
--- a/src/gen_lib/sail_operators.lem
+++ b/src/gen_lib/sail_operators.lem
@@ -187,12 +187,12 @@ let duplicate_bit_bv bit len = replicate_bits_bv [bit] len
val eq_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool
let eq_bv l r = (bits_of l = bits_of r)
-let eq_mword l r = (l = r)
+let inline eq_mword l r = (l = r)
val neq_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool
let neq_bv l r = not (eq_bv l r)
-let neq_mword l r = (l <> r)
+let inline neq_mword l r = (l <> r)
val ult_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool
let ult_bv l r = lexicographicLess (List.reverse (bits_of l)) (List.reverse (bits_of r))
@@ -219,7 +219,7 @@ let sgt_bv l r = not (slteq_bv l r)
let sgteq_bv l r = (eq_bv l r) || (sgt_bv l r)
val ucmp_mword : forall 'a. Size 'a => (integer -> integer -> bool) -> mword 'a -> mword 'a -> bool
-let ucmp_mword cmp l r = cmp (unsignedIntegerFromWord l) (unsignedIntegerFromWord r)
+let inline ucmp_mword cmp l r = cmp (unsignedIntegerFromWord l) (unsignedIntegerFromWord r)
val scmp_mword : forall 'a. Size 'a => (integer -> integer -> bool) -> mword 'a -> mword 'a -> bool
-let scmp_mword cmp l r = cmp (signedIntegerFromWord l) (signedIntegerFromWord r)
+let inline scmp_mword cmp l r = cmp (signedIntegerFromWord l) (signedIntegerFromWord r)
diff --git a/src/gen_lib/sail_operators_bitlists.lem b/src/gen_lib/sail_operators_bitlists.lem
index edba83ba..3c1afe79 100644
--- a/src/gen_lib/sail_operators_bitlists.lem
+++ b/src/gen_lib/sail_operators_bitlists.lem
@@ -23,6 +23,23 @@ let sint_oracle v =
return (int_of_bools true bs))
let sint v = maybe_failwith (sint_maybe v)
+val extz_vec : integer -> list bitU -> list bitU
+let extz_vec = extz_bv
+
+val exts_vec : integer -> list bitU -> list bitU
+let exts_vec = exts_bv
+
+val vec_of_bits_maybe : list bitU -> maybe (list bitU)
+val vec_of_bits_fail : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e
+val vec_of_bits_oracle : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e
+val vec_of_bits_failwith : list bitU -> list bitU
+val vec_of_bits : list bitU -> list bitU
+let inline vec_of_bits bits = bits
+let inline vec_of_bits_maybe bits = Just bits
+let inline vec_of_bits_fail bits = return bits
+let inline vec_of_bits_oracle bits = return bits
+let inline vec_of_bits_failwith bits = bits
+
val access_vec_inc : list bitU -> integer -> bitU
let access_vec_inc = access_bv_inc
@@ -53,12 +70,6 @@ let update_subrange_vec_inc = update_subrange_bv_inc
val update_subrange_vec_dec : list bitU -> integer -> integer -> list bitU -> list bitU
let update_subrange_vec_dec = update_subrange_bv_dec
-val extz_vec : integer -> list bitU -> list bitU
-let extz_vec = extz_bv
-
-val exts_vec : integer -> list bitU -> list bitU
-let exts_vec = exts_bv
-
val concat_vec : list bitU -> list bitU -> list bitU
let concat_vec = concat_bv
diff --git a/src/gen_lib/sail_operators_mwords.lem b/src/gen_lib/sail_operators_mwords.lem
index 16b9a912..79b7674e 100644
--- a/src/gen_lib/sail_operators_mwords.lem
+++ b/src/gen_lib/sail_operators_mwords.lem
@@ -7,24 +7,26 @@ open import Prompt
(* Specialisation of operators to machine words *)
-let uint v = unsignedIntegerFromWord v
+let inline uint v = unsignedIntegerFromWord v
let uint_maybe v = Just (uint v)
let uint_fail v = return (uint v)
let uint_oracle v = return (uint v)
-let sint v = signedIntegerFromWord v
+let inline sint v = signedIntegerFromWord v
let sint_maybe v = Just (sint v)
let sint_fail v = return (sint v)
let sint_oracle v = return (sint v)
-val vec_of_bits_failwith : forall 'a. Size 'a => integer -> list bitU -> mword 'a
-let vec_of_bits_failwith _ bits = of_bits_failwith bits
-
-val vec_of_bits_fail : forall 'rv 'a 'e. Size 'a => integer -> list bitU -> monad 'rv (mword 'a) 'e
-let vec_of_bits_fail _ bits = of_bits_fail bits
-
-val vec_of_bits_oracle : forall 'rv 'a 'e. Size 'a => integer -> list bitU -> monad 'rv (mword 'a) 'e
-let vec_of_bits_oracle _ bits = of_bits_oracle bits
+val vec_of_bits_maybe : forall 'a. Size 'a => list bitU -> maybe (mword 'a)
+val vec_of_bits_fail : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e
+val vec_of_bits_oracle : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e
+val vec_of_bits_failwith : forall 'a. Size 'a => list bitU -> mword 'a
+val vec_of_bits : forall 'a. Size 'a => list bitU -> mword 'a
+let vec_of_bits_maybe bits = of_bits bits
+let vec_of_bits_fail bits = of_bits_fail bits
+let vec_of_bits_oracle bits = of_bits_oracle bits
+let vec_of_bits_failwith bits = of_bits_failwith bits
+let vec_of_bits bits = of_bits_failwith bits
val access_vec_inc : forall 'a. Size 'a => mword 'a -> integer -> bitU
let access_vec_inc = access_bv_inc
@@ -243,13 +245,13 @@ val ulteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool
val slteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool
val ugteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool
val sgteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool
-let eq_vec = eq_mword
-let neq_vec = neq_mword
-let ult_vec = ucmp_mword (<)
-let slt_vec = scmp_mword (<)
-let ugt_vec = ucmp_mword (>)
-let sgt_vec = scmp_mword (>)
-let ulteq_vec = ucmp_mword (<=)
-let slteq_vec = scmp_mword (<=)
-let ugteq_vec = ucmp_mword (>=)
-let sgteq_vec = scmp_mword (>=)
+let inline eq_vec = eq_mword
+let inline neq_vec = neq_mword
+let inline ult_vec = ucmp_mword (<)
+let inline slt_vec = scmp_mword (<)
+let inline ugt_vec = ucmp_mword (>)
+let inline sgt_vec = scmp_mword (>)
+let inline ulteq_vec = ucmp_mword (<=)
+let inline slteq_vec = scmp_mword (<=)
+let inline ugteq_vec = ucmp_mword (>=)
+let inline sgteq_vec = scmp_mword (>=)
diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem
index 238ebe58..a89456b9 100644
--- a/src/gen_lib/sail_values.lem
+++ b/src/gen_lib/sail_values.lem
@@ -45,12 +45,12 @@ let negate_real r = realNegate r
let abs_real r = realAbs r
let power_real b e = realPowInteger b e*)
-let or_bool l r = (l || r)
-let and_bool l r = (l && r)
-let xor_bool l r = xor l r
+let inline or_bool l r = (l || r)
+let inline and_bool l r = (l && r)
+let inline xor_bool l r = xor l r
-let append_list l r = l ++ r
-let length_list xs = integerFromNat (List.length xs)
+let inline append_list l r = l ++ r
+let inline length_list xs = integerFromNat (List.length xs)
let take_list n xs = List.take (nat_of_int n) xs
let drop_list n xs = List.drop (nat_of_int n) xs
@@ -467,7 +467,7 @@ end
(*** Machine words *)
val length_mword : forall 'a. mword 'a -> integer
-let length_mword w = integerFromNat (word_length w)
+let inline length_mword w = integerFromNat (word_length w)
val slice_mword_dec : forall 'a 'b. mword 'a -> integer -> integer -> mword 'b
let slice_mword_dec w i j = word_extract (nat_of_int i) (nat_of_int j) w
@@ -526,7 +526,7 @@ let size_itself_int x = integerFromNat (size_itself x)
the actual integer is ignored. *)
val make_the_value : forall 'n. integer -> itself 'n
-let inline make_the_value = (fun _ -> the_value)
+let make_the_value _ = the_value
(*** Bitvectors *)
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index eb9feee3..b5e2a14d 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -142,6 +142,10 @@ let doc_id_lem_ctor (Id_aux(i,_)) =
let doc_var_lem kid = string (fix_id true (string_of_kid kid))
+let doc_docstring_lem (l, _) = match l with
+ | Parse_ast.Documented (str, _) -> string ("(*" ^ str ^ "*)") ^^ hardline
+ | _ -> empty
+
let simple_annot l typ = (Parse_ast.Generated l, Some (Env.empty, typ, no_effect))
let simple_num l n = E_aux (
E_lit (L_aux (L_num n, Parse_ast.Generated l)),
@@ -595,14 +599,7 @@ let doc_exp_lem, doc_let_lem =
"E_vector_append should have been rewritten before pretty-printing")
| E_cons(le,re) -> doc_op (group (colon^^colon)) (expY le) (expY re)
| E_if(c,t,e) ->
- let indent = match e with
- | E_aux (E_if _, _) -> 0
- | _ -> 2 in
- let epp =
- separate space [string "if";group (expY c)] ^^
- break 1 ^^
- (prefix 2 1 (string "then") (expN t)) ^^ (break 1) ^^
- (prefix indent 1 (string "else") (expN e)) in
+ let epp = if_exp ctxt false c t e in
if aexp_needed then parens (align epp) else epp
| E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) ->
raise (report l "E_for should have been rewritten before pretty-printing")
@@ -696,14 +693,13 @@ let doc_exp_lem, doc_let_lem =
| Some (env, _, _) when Env.is_extern f env "lem" ->
string (Env.get_extern f env "lem"), true
| _ -> doc_id_lem f, false in
- let argspp = align (separate_map (break 1) (expV true) args) in
- let epp = align (call ^//^ argspp) in
+ let epp = hang 2 (flow (break 1) (call :: List.map expY args)) in
let (taepp,aexp_needed) =
let env = env_of full_exp in
let t = Env.expand_synonyms env (typ_of full_exp) in
let eff = effect_of full_exp in
if typ_needs_printed t
- then (align epp ^^ (doc_tannot_lem ctxt env (effectful eff) t), true)
+ then (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env (effectful eff) t))), true)
else (epp, aexp_needed) in
liftR (if aexp_needed then parens (align taepp) else taepp)
end
@@ -738,7 +734,7 @@ let doc_exp_lem, doc_let_lem =
if has_effect eff BE_rreg then
let epp = separate space [string "read_reg";doc_id_lem (append_id id "_ref")] in
if is_bitvector_typ base_typ
- then liftR (parens (epp ^^ doc_tannot_lem ctxt env true base_typ))
+ then liftR (parens (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env true base_typ)))))
else liftR epp
else if Env.is_register id env then doc_id_lem (append_id id "_ref")
else if is_ctor env id then doc_id_lem_ctor id
@@ -747,7 +743,7 @@ let doc_exp_lem, doc_let_lem =
| E_cast(typ,e) ->
expV aexp_needed e
| E_tuple exps ->
- parens (separate_map comma expN exps)
+ parens (align (group (separate_map (comma ^^ break 1) expN exps)))
| E_record(FES_aux(FES_Fexps(fexps,_),_)) ->
let recordtyp = match annot with
| Some (env, Typ_aux (Typ_id tid,_), _)
@@ -792,8 +788,8 @@ let doc_exp_lem, doc_let_lem =
let epp = brackets expspp in
let (epp,aexp_needed) =
if is_bit_typ etyp && !opt_mwords then
- let bepp = string "of_bits_failwith" ^^ space ^^ parens (align epp) in
- (bepp ^^ doc_tannot_lem ctxt (env_of full_exp) false t, true)
+ let bepp = string "vec_of_bits" ^^ space ^^ align epp in
+ (align (group (prefix 0 1 bepp (doc_tannot_lem ctxt (env_of full_exp) false t))), true)
else (epp,aexp_needed) in
if aexp_needed then parens (align epp) else epp
| E_vector_update(v,e1,e2) ->
@@ -836,12 +832,21 @@ let doc_exp_lem, doc_let_lem =
| E_internal_plet (pat,e1,e2) ->
let epp =
let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
- match pat with
- | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) ->
- (separate space [expV b e1; string ">>"]) ^^ hardline ^^ expN e2
- | _ ->
- (separate space [expV b e1; string ">>= fun";
- doc_pat_lem ctxt true pat;arrow]) ^^ hardline ^^ expN e2 in
+ let middle =
+ match fst (untyp_pat pat) with
+ | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) ->
+ string ">>"
+ | P_aux (P_tup _, _) ->
+ (* TODO Make sure to avoid name-clashes with temp variable *)
+ separate space
+ [string ">>= fun varstup -> let";
+ doc_pat_lem ctxt true pat;
+ string "= varstup in"]
+ | _ ->
+ separate space [string ">>= fun"; doc_pat_lem ctxt true pat; arrow]
+ in
+ infix 0 1 middle (expV b e1) (expN e2)
+ in
if aexp_needed then parens (align epp) else epp
| E_internal_return (e1) ->
separate space [string "return"; expY e1]
@@ -867,6 +872,19 @@ let doc_exp_lem, doc_let_lem =
| E_internal_exp_user _ | E_internal_value _ ->
raise (Reporting_basic.err_unreachable l
"unsupported internal expression encountered while pretty-printing")
+ and if_exp ctxt (elseif : bool) c t e =
+ let if_pp = string (if elseif then "else if" else "if") in
+ let else_pp = match e with
+ | E_aux (E_if (c', t', e'), _)
+ | E_aux (E_cast (_, E_aux (E_if (c', t', e'), _)), _) ->
+ if_exp ctxt true c' t' e'
+ | _ -> prefix 2 1 (string "else") (top_exp ctxt false e)
+ in
+ (prefix 2 1
+ (soft_surround 2 1 if_pp (top_exp ctxt true c) (string "then"))
+ (top_exp ctxt false t)) ^^
+ break 1 ^^
+ else_pp
and let_exp ctxt (LB_aux(lb,_)) = match lb with
| LB_val(pat,e) ->
prefix 2 1
@@ -1252,6 +1270,7 @@ let doc_spec_lem (VS_aux (valspec,annot)) =
| VS_val_spec (typschm,id,ext,_) when ext "lem" = None ->
(* let (TypSchm_aux (TypSchm_ts (tq, typ), _)) = typschm in
if contains_t_pp_var typ then empty else *)
+ doc_docstring_lem annot ^^
separate space [string "val"; doc_id_lem id; string ":";doc_typschm_lem true typschm] ^/^ hardline
(* | VS_val_spec (_,_,Some _,_) -> empty *)
| _ -> empty