diff options
| author | Thomas Bauereiss | 2018-03-21 19:54:28 +0000 |
|---|---|---|
| committer | Thomas Bauereiss | 2018-03-22 13:48:29 +0000 |
| commit | 5c1754d3a8170167c58c876be36d451c7607fb2c (patch) | |
| tree | 4883fc0d8fda864cf9ddd3bf50699b55ec270e5f /src | |
| parent | 2dcd2d7b77c2c0759791d92114a844b9990d0820 (diff) | |
Tune Lem pretty-printing
In particular, improve indentation of if-expressions, and provide infix syntax
for monadic binds in Isabelle, allowing Lem to preserve source whitespace.
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 8 | ||||
| -rw-r--r-- | src/ast_util.mli | 3 | ||||
| -rw-r--r-- | src/gen_lib/prompt.lem | 8 | ||||
| -rw-r--r-- | src/gen_lib/prompt_monad.lem | 10 | ||||
| -rw-r--r-- | src/gen_lib/sail_operators.lem | 8 | ||||
| -rw-r--r-- | src/gen_lib/sail_operators_bitlists.lem | 23 | ||||
| -rw-r--r-- | src/gen_lib/sail_operators_mwords.lem | 42 | ||||
| -rw-r--r-- | src/gen_lib/sail_values.lem | 14 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 61 |
9 files changed, 113 insertions, 64 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 5261a6d2..08532eb4 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -71,9 +71,17 @@ let mk_nexp nexp_aux = Nexp_aux (nexp_aux, Parse_ast.Unknown) let mk_exp exp_aux = E_aux (exp_aux, no_annot) let unaux_exp (E_aux (exp_aux, _)) = exp_aux +let uncast_exp = function + | E_aux (E_internal_return (E_aux (E_cast (typ, exp), _)), a) -> + E_aux (E_internal_return exp, a), Some typ + | E_aux (E_cast (typ, exp), _) -> exp, Some typ + | exp -> exp, None let mk_pat pat_aux = P_aux (pat_aux, no_annot) let unaux_pat (P_aux (pat_aux, _)) = pat_aux +let untyp_pat = function + | P_aux (P_typ (typ, pat), _) -> pat, Some typ + | pat -> pat, None let mk_pexp pexp_aux = Pat_aux (pexp_aux, no_annot) diff --git a/src/ast_util.mli b/src/ast_util.mli index 89004b0d..951f5bed 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -93,6 +93,9 @@ val unaux_nexp : nexp -> nexp_aux val unaux_order : order -> order_aux val unaux_typ : typ -> typ_aux +val untyp_pat : 'a pat -> 'a pat * typ option +val uncast_exp : 'a exp -> 'a exp * typ option + val inc_ord : order val dec_ord : order diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem index 8cef266e..8214bf49 100644 --- a/src/gen_lib/prompt.lem +++ b/src/gen_lib/prompt.lem @@ -4,6 +4,14 @@ open import Sail_values open import Prompt_monad open import {isabelle} `Prompt_monad_lemmas` +val (>>=) : forall 'rv 'a 'b 'e. monad 'rv 'a 'e -> ('a -> monad 'rv 'b 'e) -> monad 'rv 'b 'e +declare isabelle target_rep function (>>=) = infix `\<bind>` +let inline ~{isabelle} (>>=) = bind + +val (>>) : forall 'rv 'b 'e. monad 'rv unit 'e -> monad 'rv 'b 'e -> monad 'rv 'b 'e +declare isabelle target_rep function (>>) = infix `\<then>` +let inline ~{isabelle} (>>) m n = m >>= fun (_ : unit) -> n + val iter_aux : forall 'rv 'a 'e. integer -> (integer -> 'a -> monad 'rv unit 'e) -> list 'a -> monad 'rv unit 'e let rec iter_aux i f xs = match xs with | x :: xs -> f i x >> iter_aux (i + 1) f xs diff --git a/src/gen_lib/prompt_monad.lem b/src/gen_lib/prompt_monad.lem index 92b9ac5e..b1dd59c4 100644 --- a/src/gen_lib/prompt_monad.lem +++ b/src/gen_lib/prompt_monad.lem @@ -62,10 +62,6 @@ let rec bind m f = match m with | Exception e -> Exception e end -let inline (>>=) = bind -val (>>) : forall 'rv 'b 'e. monad 'rv unit 'e -> monad 'rv 'b 'e -> monad 'rv 'b 'e -let inline (>>) m n = m >>= fun (_ : unit) -> n - val exit : forall 'rv 'a 'e. unit -> monad 'rv 'a 'e let exit () = Fail "exit" @@ -139,8 +135,10 @@ let read_mem_bytes rk addr sz = val read_mem : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv 'b 'e let read_mem rk addr sz = - read_mem_bytes rk addr sz >>= (fun bytes -> - maybe_fail "bits_of_mem_bytes" (of_bits (bits_of_mem_bytes bytes))) + bind + (read_mem_bytes rk addr sz) + (fun bytes -> + maybe_fail "bits_of_mem_bytes" (of_bits (bits_of_mem_bytes bytes))) val read_tag : forall 'rv 'a 'e. Bitvector 'a => 'a -> monad 'rv bitU 'e let read_tag addr = Read_tag (bits_of addr) return diff --git a/src/gen_lib/sail_operators.lem b/src/gen_lib/sail_operators.lem index 9b0857d9..d4275c87 100644 --- a/src/gen_lib/sail_operators.lem +++ b/src/gen_lib/sail_operators.lem @@ -187,12 +187,12 @@ let duplicate_bit_bv bit len = replicate_bits_bv [bit] len val eq_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool let eq_bv l r = (bits_of l = bits_of r) -let eq_mword l r = (l = r) +let inline eq_mword l r = (l = r) val neq_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool let neq_bv l r = not (eq_bv l r) -let neq_mword l r = (l <> r) +let inline neq_mword l r = (l <> r) val ult_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool let ult_bv l r = lexicographicLess (List.reverse (bits_of l)) (List.reverse (bits_of r)) @@ -219,7 +219,7 @@ let sgt_bv l r = not (slteq_bv l r) let sgteq_bv l r = (eq_bv l r) || (sgt_bv l r) val ucmp_mword : forall 'a. Size 'a => (integer -> integer -> bool) -> mword 'a -> mword 'a -> bool -let ucmp_mword cmp l r = cmp (unsignedIntegerFromWord l) (unsignedIntegerFromWord r) +let inline ucmp_mword cmp l r = cmp (unsignedIntegerFromWord l) (unsignedIntegerFromWord r) val scmp_mword : forall 'a. Size 'a => (integer -> integer -> bool) -> mword 'a -> mword 'a -> bool -let scmp_mword cmp l r = cmp (signedIntegerFromWord l) (signedIntegerFromWord r) +let inline scmp_mword cmp l r = cmp (signedIntegerFromWord l) (signedIntegerFromWord r) diff --git a/src/gen_lib/sail_operators_bitlists.lem b/src/gen_lib/sail_operators_bitlists.lem index edba83ba..3c1afe79 100644 --- a/src/gen_lib/sail_operators_bitlists.lem +++ b/src/gen_lib/sail_operators_bitlists.lem @@ -23,6 +23,23 @@ let sint_oracle v = return (int_of_bools true bs)) let sint v = maybe_failwith (sint_maybe v) +val extz_vec : integer -> list bitU -> list bitU +let extz_vec = extz_bv + +val exts_vec : integer -> list bitU -> list bitU +let exts_vec = exts_bv + +val vec_of_bits_maybe : list bitU -> maybe (list bitU) +val vec_of_bits_fail : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e +val vec_of_bits_oracle : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e +val vec_of_bits_failwith : list bitU -> list bitU +val vec_of_bits : list bitU -> list bitU +let inline vec_of_bits bits = bits +let inline vec_of_bits_maybe bits = Just bits +let inline vec_of_bits_fail bits = return bits +let inline vec_of_bits_oracle bits = return bits +let inline vec_of_bits_failwith bits = bits + val access_vec_inc : list bitU -> integer -> bitU let access_vec_inc = access_bv_inc @@ -53,12 +70,6 @@ let update_subrange_vec_inc = update_subrange_bv_inc val update_subrange_vec_dec : list bitU -> integer -> integer -> list bitU -> list bitU let update_subrange_vec_dec = update_subrange_bv_dec -val extz_vec : integer -> list bitU -> list bitU -let extz_vec = extz_bv - -val exts_vec : integer -> list bitU -> list bitU -let exts_vec = exts_bv - val concat_vec : list bitU -> list bitU -> list bitU let concat_vec = concat_bv diff --git a/src/gen_lib/sail_operators_mwords.lem b/src/gen_lib/sail_operators_mwords.lem index 16b9a912..79b7674e 100644 --- a/src/gen_lib/sail_operators_mwords.lem +++ b/src/gen_lib/sail_operators_mwords.lem @@ -7,24 +7,26 @@ open import Prompt (* Specialisation of operators to machine words *) -let uint v = unsignedIntegerFromWord v +let inline uint v = unsignedIntegerFromWord v let uint_maybe v = Just (uint v) let uint_fail v = return (uint v) let uint_oracle v = return (uint v) -let sint v = signedIntegerFromWord v +let inline sint v = signedIntegerFromWord v let sint_maybe v = Just (sint v) let sint_fail v = return (sint v) let sint_oracle v = return (sint v) -val vec_of_bits_failwith : forall 'a. Size 'a => integer -> list bitU -> mword 'a -let vec_of_bits_failwith _ bits = of_bits_failwith bits - -val vec_of_bits_fail : forall 'rv 'a 'e. Size 'a => integer -> list bitU -> monad 'rv (mword 'a) 'e -let vec_of_bits_fail _ bits = of_bits_fail bits - -val vec_of_bits_oracle : forall 'rv 'a 'e. Size 'a => integer -> list bitU -> monad 'rv (mword 'a) 'e -let vec_of_bits_oracle _ bits = of_bits_oracle bits +val vec_of_bits_maybe : forall 'a. Size 'a => list bitU -> maybe (mword 'a) +val vec_of_bits_fail : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e +val vec_of_bits_oracle : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e +val vec_of_bits_failwith : forall 'a. Size 'a => list bitU -> mword 'a +val vec_of_bits : forall 'a. Size 'a => list bitU -> mword 'a +let vec_of_bits_maybe bits = of_bits bits +let vec_of_bits_fail bits = of_bits_fail bits +let vec_of_bits_oracle bits = of_bits_oracle bits +let vec_of_bits_failwith bits = of_bits_failwith bits +let vec_of_bits bits = of_bits_failwith bits val access_vec_inc : forall 'a. Size 'a => mword 'a -> integer -> bitU let access_vec_inc = access_bv_inc @@ -243,13 +245,13 @@ val ulteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool val slteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool val ugteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool val sgteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool -let eq_vec = eq_mword -let neq_vec = neq_mword -let ult_vec = ucmp_mword (<) -let slt_vec = scmp_mword (<) -let ugt_vec = ucmp_mword (>) -let sgt_vec = scmp_mword (>) -let ulteq_vec = ucmp_mword (<=) -let slteq_vec = scmp_mword (<=) -let ugteq_vec = ucmp_mword (>=) -let sgteq_vec = scmp_mword (>=) +let inline eq_vec = eq_mword +let inline neq_vec = neq_mword +let inline ult_vec = ucmp_mword (<) +let inline slt_vec = scmp_mword (<) +let inline ugt_vec = ucmp_mword (>) +let inline sgt_vec = scmp_mword (>) +let inline ulteq_vec = ucmp_mword (<=) +let inline slteq_vec = scmp_mword (<=) +let inline ugteq_vec = ucmp_mword (>=) +let inline sgteq_vec = scmp_mword (>=) diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index 238ebe58..a89456b9 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -45,12 +45,12 @@ let negate_real r = realNegate r let abs_real r = realAbs r let power_real b e = realPowInteger b e*) -let or_bool l r = (l || r) -let and_bool l r = (l && r) -let xor_bool l r = xor l r +let inline or_bool l r = (l || r) +let inline and_bool l r = (l && r) +let inline xor_bool l r = xor l r -let append_list l r = l ++ r -let length_list xs = integerFromNat (List.length xs) +let inline append_list l r = l ++ r +let inline length_list xs = integerFromNat (List.length xs) let take_list n xs = List.take (nat_of_int n) xs let drop_list n xs = List.drop (nat_of_int n) xs @@ -467,7 +467,7 @@ end (*** Machine words *) val length_mword : forall 'a. mword 'a -> integer -let length_mword w = integerFromNat (word_length w) +let inline length_mword w = integerFromNat (word_length w) val slice_mword_dec : forall 'a 'b. mword 'a -> integer -> integer -> mword 'b let slice_mword_dec w i j = word_extract (nat_of_int i) (nat_of_int j) w @@ -526,7 +526,7 @@ let size_itself_int x = integerFromNat (size_itself x) the actual integer is ignored. *) val make_the_value : forall 'n. integer -> itself 'n -let inline make_the_value = (fun _ -> the_value) +let make_the_value _ = the_value (*** Bitvectors *) diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index eb9feee3..b5e2a14d 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -142,6 +142,10 @@ let doc_id_lem_ctor (Id_aux(i,_)) = let doc_var_lem kid = string (fix_id true (string_of_kid kid)) +let doc_docstring_lem (l, _) = match l with + | Parse_ast.Documented (str, _) -> string ("(*" ^ str ^ "*)") ^^ hardline + | _ -> empty + let simple_annot l typ = (Parse_ast.Generated l, Some (Env.empty, typ, no_effect)) let simple_num l n = E_aux ( E_lit (L_aux (L_num n, Parse_ast.Generated l)), @@ -595,14 +599,7 @@ let doc_exp_lem, doc_let_lem = "E_vector_append should have been rewritten before pretty-printing") | E_cons(le,re) -> doc_op (group (colon^^colon)) (expY le) (expY re) | E_if(c,t,e) -> - let indent = match e with - | E_aux (E_if _, _) -> 0 - | _ -> 2 in - let epp = - separate space [string "if";group (expY c)] ^^ - break 1 ^^ - (prefix 2 1 (string "then") (expN t)) ^^ (break 1) ^^ - (prefix indent 1 (string "else") (expN e)) in + let epp = if_exp ctxt false c t e in if aexp_needed then parens (align epp) else epp | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> raise (report l "E_for should have been rewritten before pretty-printing") @@ -696,14 +693,13 @@ let doc_exp_lem, doc_let_lem = | Some (env, _, _) when Env.is_extern f env "lem" -> string (Env.get_extern f env "lem"), true | _ -> doc_id_lem f, false in - let argspp = align (separate_map (break 1) (expV true) args) in - let epp = align (call ^//^ argspp) in + let epp = hang 2 (flow (break 1) (call :: List.map expY args)) in let (taepp,aexp_needed) = let env = env_of full_exp in let t = Env.expand_synonyms env (typ_of full_exp) in let eff = effect_of full_exp in if typ_needs_printed t - then (align epp ^^ (doc_tannot_lem ctxt env (effectful eff) t), true) + then (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env (effectful eff) t))), true) else (epp, aexp_needed) in liftR (if aexp_needed then parens (align taepp) else taepp) end @@ -738,7 +734,7 @@ let doc_exp_lem, doc_let_lem = if has_effect eff BE_rreg then let epp = separate space [string "read_reg";doc_id_lem (append_id id "_ref")] in if is_bitvector_typ base_typ - then liftR (parens (epp ^^ doc_tannot_lem ctxt env true base_typ)) + then liftR (parens (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env true base_typ))))) else liftR epp else if Env.is_register id env then doc_id_lem (append_id id "_ref") else if is_ctor env id then doc_id_lem_ctor id @@ -747,7 +743,7 @@ let doc_exp_lem, doc_let_lem = | E_cast(typ,e) -> expV aexp_needed e | E_tuple exps -> - parens (separate_map comma expN exps) + parens (align (group (separate_map (comma ^^ break 1) expN exps))) | E_record(FES_aux(FES_Fexps(fexps,_),_)) -> let recordtyp = match annot with | Some (env, Typ_aux (Typ_id tid,_), _) @@ -792,8 +788,8 @@ let doc_exp_lem, doc_let_lem = let epp = brackets expspp in let (epp,aexp_needed) = if is_bit_typ etyp && !opt_mwords then - let bepp = string "of_bits_failwith" ^^ space ^^ parens (align epp) in - (bepp ^^ doc_tannot_lem ctxt (env_of full_exp) false t, true) + let bepp = string "vec_of_bits" ^^ space ^^ align epp in + (align (group (prefix 0 1 bepp (doc_tannot_lem ctxt (env_of full_exp) false t))), true) else (epp,aexp_needed) in if aexp_needed then parens (align epp) else epp | E_vector_update(v,e1,e2) -> @@ -836,12 +832,21 @@ let doc_exp_lem, doc_let_lem = | E_internal_plet (pat,e1,e2) -> let epp = let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in - match pat with - | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> - (separate space [expV b e1; string ">>"]) ^^ hardline ^^ expN e2 - | _ -> - (separate space [expV b e1; string ">>= fun"; - doc_pat_lem ctxt true pat;arrow]) ^^ hardline ^^ expN e2 in + let middle = + match fst (untyp_pat pat) with + | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> + string ">>" + | P_aux (P_tup _, _) -> + (* TODO Make sure to avoid name-clashes with temp variable *) + separate space + [string ">>= fun varstup -> let"; + doc_pat_lem ctxt true pat; + string "= varstup in"] + | _ -> + separate space [string ">>= fun"; doc_pat_lem ctxt true pat; arrow] + in + infix 0 1 middle (expV b e1) (expN e2) + in if aexp_needed then parens (align epp) else epp | E_internal_return (e1) -> separate space [string "return"; expY e1] @@ -867,6 +872,19 @@ let doc_exp_lem, doc_let_lem = | E_internal_exp_user _ | E_internal_value _ -> raise (Reporting_basic.err_unreachable l "unsupported internal expression encountered while pretty-printing") + and if_exp ctxt (elseif : bool) c t e = + let if_pp = string (if elseif then "else if" else "if") in + let else_pp = match e with + | E_aux (E_if (c', t', e'), _) + | E_aux (E_cast (_, E_aux (E_if (c', t', e'), _)), _) -> + if_exp ctxt true c' t' e' + | _ -> prefix 2 1 (string "else") (top_exp ctxt false e) + in + (prefix 2 1 + (soft_surround 2 1 if_pp (top_exp ctxt true c) (string "then")) + (top_exp ctxt false t)) ^^ + break 1 ^^ + else_pp and let_exp ctxt (LB_aux(lb,_)) = match lb with | LB_val(pat,e) -> prefix 2 1 @@ -1252,6 +1270,7 @@ let doc_spec_lem (VS_aux (valspec,annot)) = | VS_val_spec (typschm,id,ext,_) when ext "lem" = None -> (* let (TypSchm_aux (TypSchm_ts (tq, typ), _)) = typschm in if contains_t_pp_var typ then empty else *) + doc_docstring_lem annot ^^ separate space [string "val"; doc_id_lem id; string ":";doc_typschm_lem true typschm] ^/^ hardline (* | VS_val_spec (_,_,Some _,_) -> empty *) | _ -> empty |
