diff options
| -rw-r--r-- | language/sail.ott | 4 | ||||
| -rw-r--r-- | src/ast_util.ml | 9 | ||||
| -rw-r--r-- | src/ast_util.mli | 1 | ||||
| -rw-r--r-- | src/bitfield.ml | 25 | ||||
| -rw-r--r-- | src/initial_check.ml | 20 | ||||
| -rw-r--r-- | src/parse_ast.ml | 4 | ||||
| -rw-r--r-- | src/parser.mly | 4 | ||||
| -rw-r--r-- | src/pretty_print_common.ml | 2 | ||||
| -rw-r--r-- | src/pretty_print_coq.ml | 22 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 13 | ||||
| -rw-r--r-- | src/type_check.ml | 5 | ||||
| -rw-r--r-- | test/typecheck/pass/bitvector_param.sail | 42 |
12 files changed, 124 insertions, 27 deletions
diff --git a/language/sail.ott b/language/sail.ott index 6d2760ff..7dbd3c9e 100644 --- a/language/sail.ott +++ b/language/sail.ott @@ -316,8 +316,8 @@ type_union :: 'Tu_' ::= index_range :: 'BF_' ::= {{ com index specification, for bitfields in register types}} {{ aux _ l }} - | num :: :: 'single' {{ com single index }} - | num1 '..' num2 :: :: range {{ com index range }} + | nexp :: :: 'single' {{ com single index }} + | nexp1 '..' nexp2 :: :: range {{ com index range }} | index_range1 , index_range2 :: :: concat {{ com concatenation of index ranges }} %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% diff --git a/src/ast_util.ml b/src/ast_util.ml index 03031767..b3ab2cfd 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -236,6 +236,11 @@ let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with | Nexp_exp n | Nexp_neg n -> is_nexp_constant n | Nexp_app (_, nexps) -> List.for_all is_nexp_constant nexps +let int_of_nexp_opt nexp = + match nexp with + | Nexp_aux(Nexp_constant i,_) -> Some i + | _ -> None + let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l) and nexp_simp_aux = function (* (n - (n - m)) often appears in foreach loops *) @@ -911,8 +916,8 @@ and string_of_letbind (LB_aux (lb, l)) = let rec string_of_index_range (BF_aux (ir, _)) = match ir with - | BF_single n -> Big_int.to_string n - | BF_range (n, m) -> Big_int.to_string n ^ " .. " ^ Big_int.to_string m + | BF_single n -> string_of_nexp n + | BF_range (n, m) -> string_of_nexp n ^ " .. " ^ string_of_nexp m | BF_concat (ir1, ir2) -> "(" ^ string_of_index_range ir1 ^ ") : (" ^ string_of_index_range ir2 ^ ")" diff --git a/src/ast_util.mli b/src/ast_util.mli index 4cbea3dc..c4eb0b4b 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -337,6 +337,7 @@ end val nexp_frees : nexp -> KidSet.t val nexp_identical : nexp -> nexp -> bool val is_nexp_constant : nexp -> bool +val int_of_nexp_opt : nexp -> Big_int.num option val lexp_to_exp : 'a lexp -> 'a exp diff --git a/src/bitfield.ml b/src/bitfield.ml index e8250598..1f64adbd 100644 --- a/src/bitfield.ml +++ b/src/bitfield.ml @@ -150,19 +150,30 @@ let index_range_update name field order start stop = let index_range_overload name field order = ast_of_def_string (Printf.sprintf "overload _mod_%s = {_get_%s_%s, _set_%s_%s}" field name field name field) -let index_range_accessor name field order (BF_aux (bf_aux, l)) = +let index_range_accessor (eval, typ_error) name field order (BF_aux (bf_aux, l)) = let getter n m = index_range_getter name field order (Big_int.to_int n) (Big_int.to_int m) in let setter n m = index_range_setter name field order (Big_int.to_int n) (Big_int.to_int m) in let update n m = index_range_update name field order (Big_int.to_int n) (Big_int.to_int m) in let overload = index_range_overload name field order in + let const_fold nexp = match eval nexp with + | Some v -> v + | None -> typ_error l (Printf.sprintf "Non-constant index for field %s" field) in match bf_aux with - | BF_single n -> combine [getter n n; setter n n; update n n; overload] - | BF_range (n, m) -> combine [getter n m; setter n m; update n m; overload] + | BF_single n -> + let n = const_fold n in + combine [getter n n; setter n n; update n n; overload] + | BF_range (n, m) -> + let n, m = const_fold n, const_fold m in + combine [getter n m; setter n m; update n m; overload] | BF_concat _ -> failwith "Unimplemented" -let field_accessor name order (id, ir) = index_range_accessor name (string_of_id id) order ir +let field_accessor (eval, typ_error) name order (id, ir) = + index_range_accessor (eval, typ_error) name (string_of_id id) order ir -let macro id size order ranges = +let macro (eval, typ_error) id size order ranges = let name = string_of_id id in - let ranges = (mk_id "bits", BF_aux (BF_range (Big_int.of_int (size - 1), Big_int.of_int 0), Parse_ast.Unknown)) :: ranges in - combine ([newtype name size order; constructor name order (size - 1) 0] @ List.map (field_accessor name order) ranges) + let ranges = (mk_id "bits", BF_aux (BF_range (nconstant (Big_int.of_int (size - 1)), + nconstant (Big_int.of_int 0)), + Parse_ast.Unknown)) :: ranges in + combine ([newtype name size order; constructor name order (size - 1) 0] + @ List.map (field_accessor (eval, typ_error) name order) ranges) diff --git a/src/initial_check.ml b/src/initial_check.ml index 108e04d0..07316c6d 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -201,6 +201,20 @@ and to_ast_nexp ctx (P.ATyp_aux (aux, l)) = in Nexp_aux (aux, l) +and to_ast_bitfield_index_nexp (P.ATyp_aux (aux, l)) = + let aux = match aux with + | P.ATyp_id id -> Nexp_id (to_ast_id id) + | P.ATyp_lit (P.L_aux (P.L_num c, _)) -> Nexp_constant c + | P.ATyp_sum (t1, t2) -> Nexp_sum (to_ast_bitfield_index_nexp t1, to_ast_bitfield_index_nexp t2) + | P.ATyp_exp t1 -> Nexp_exp (to_ast_bitfield_index_nexp t1) + | P.ATyp_neg t1 -> Nexp_neg (to_ast_bitfield_index_nexp t1) + | P.ATyp_times (t1, t2) -> Nexp_times (to_ast_bitfield_index_nexp t1, to_ast_bitfield_index_nexp t2) + | P.ATyp_minus (t1, t2) -> Nexp_minus (to_ast_bitfield_index_nexp t1, to_ast_bitfield_index_nexp t2) + | P.ATyp_app (id, ts) -> Nexp_app (to_ast_id id, List.map (to_ast_bitfield_index_nexp) ts) + | _ -> raise (Reporting.err_typ l "Invalid numeric expression in field index") + in + Nexp_aux (aux, l) + and to_ast_order ctx (P.ATyp_aux (aux, l)) = match aux with | ATyp_var v -> Ord_aux (Ord_var (to_ast_var v), l) @@ -503,9 +517,9 @@ let to_ast_spec ctx (val_:P.val_spec) : (unit val_spec) ctx_out = let rec to_ast_range (P.BF_aux(r,l)) = (* TODO add check that ranges are sensible for some definition of sensible *) BF_aux( (match r with - | P.BF_single(i) -> BF_single(i) - | P.BF_range(i1,i2) -> BF_range(i1,i2) - | P.BF_concat(ir1,ir2) -> BF_concat( to_ast_range ir1, to_ast_range ir2)), + | P.BF_single(i) -> BF_single(to_ast_bitfield_index_nexp i) + | P.BF_range(i1,i2) -> BF_range(to_ast_bitfield_index_nexp i1,to_ast_bitfield_index_nexp i2) + | P.BF_concat(ir1,ir2) -> BF_concat(to_ast_range ir1, to_ast_range ir2)), l) let to_ast_type_union ctx (P.Tu_aux (P.Tu_ty_id (atyp, id), l)) = diff --git a/src/parse_ast.ml b/src/parse_ast.ml index 6401331e..eb5c3dc6 100644 --- a/src/parse_ast.ml +++ b/src/parse_ast.ml @@ -371,8 +371,8 @@ type_union = type index_range_aux = (* index specification, for bitfields in register types *) - BF_single of Big_int.num (* single index *) - | BF_range of Big_int.num * Big_int.num (* index range *) + BF_single of atyp (* single index *) + | BF_range of atyp * atyp (* index range *) | BF_concat of index_range * index_range (* concatenation of index ranges *) and index_range = diff --git a/src/parser.mly b/src/parser.mly index cbbc41e3..bd832d28 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -1130,9 +1130,9 @@ funcl_typ: { mk_tannot mk_typqn $1 $startpos $endpos } index_range: - | Num + | typ { mk_ir (BF_single $1) $startpos $endpos } - | Num DotDot Num + | typ DotDot typ { mk_ir (BF_range ($1, $3)) $startpos $endpos } r_id_def: diff --git a/src/pretty_print_common.ml b/src/pretty_print_common.ml index c01896ac..3a1deed0 100644 --- a/src/pretty_print_common.ml +++ b/src/pretty_print_common.ml @@ -89,10 +89,12 @@ let doc_id (Id_aux(i,_)) = * token in case of x ending with star. *) parens (separate space [string "deinfix"; string x; empty]) +(* let rec doc_range (BF_aux(r,_)) = match r with | BF_single i -> doc_int i | BF_range(i1,i2) -> doc_op dotdot (doc_int i1) (doc_int i2) | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2) +*) let print ?(len=100) channel doc = ToChannel.pretty 1. len channel doc let to_buf ?(len=100) buf doc = ToBuffer.pretty 1. len buf doc diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index bb6a3d6a..4596f23f 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -1896,10 +1896,12 @@ let doc_type_union ctxt typ_name (Tu_aux(Tu_ty_id(typ,id),_)) = separate space [doc_id_ctor id; colon; doc_typ ctxt typ; arrow; typ_name] -let rec doc_range (BF_aux(r,_)) = match r with - | BF_single i -> parens (doc_op comma (doc_int i) (doc_int i)) - | BF_range(i1,i2) -> parens (doc_op comma (doc_int i1) (doc_int i2)) - | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2) +(* +let rec doc_range ctxt (BF_aux(r,_)) = match r with + | BF_single i -> parens (doc_op comma (doc_nexp ctxt i) (doc_nexp ctxt i)) + | BF_range(i1,i2) -> parens (doc_op comma (doc_nexp ctxt i1) (doc_nexp ctxt i2)) + | BF_concat(ir1,ir2) -> (doc_range ctxt ir1) ^^ comma ^^ (doc_range ctxt ir2) + *) let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with | TD_abbrev(id,typq,A_aux (A_typ typ, _)) -> @@ -2408,7 +2410,15 @@ let is_field_accessor regtypes fdef = (access = "get" || access = "set") && is_field_of regtyp field | _ -> false + +let int_of_field_index tname fid nexp = + match int_of_nexp_opt nexp with + | Some i -> i + | None -> raise (Reporting.err_typ Parse_ast.Unknown + ("Non-constant bitfield index in field " ^ string_of_id fid ^ " of " ^ tname)) + let doc_regtype_fields (tname, (n1, n2, fields)) = + let const_int fid idx = int_of_field_index tname fid idx in let i1, i2 = match n1, n2 with | Nexp_aux(Nexp_constant i1,_),Nexp_aux(Nexp_constant i2,_) -> i1, i2 | _ -> raise (Reporting.err_typ Parse_ast.Unknown @@ -2417,8 +2427,8 @@ let doc_regtype_fields (tname, (n1, n2, fields)) = let dir = (if dir_b then "true" else "false") in let doc_field (fr, fid) = let i, j = match fr with - | BF_aux (BF_single i, _) -> (i, i) - | BF_aux (BF_range (i, j), _) -> (i, j) + | BF_aux (BF_single i, _) -> let i = const_int fid i in (i, i) + | BF_aux (BF_range (i, j), _) -> (const_int fid i, const_int fid j) | _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ ("Unsupported type in field " ^ string_of_id fid ^ " of " ^ tname)) in let fsize = Big_int.succ (Big_int.abs (Big_int.sub i j)) in diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 7d2cc479..dee0a29f 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -1016,10 +1016,12 @@ let doc_type_union_lem env (Tu_aux(Tu_ty_id(typ,id),_)) = separate space [pipe; doc_id_lem_ctor id; string "of"; parens (doc_typ_lem env typ)] +(* let rec doc_range_lem (BF_aux(r,_)) = match r with | BF_single i -> parens (doc_op comma (doc_int i) (doc_int i)) | BF_range(i1,i2) -> parens (doc_op comma (doc_int i1) (doc_int i2)) | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2) + *) let doc_typdef_lem env (TD_aux(td, (l, annot))) = match td with | TD_abbrev(id,typq,A_aux (A_typ typ, _)) -> @@ -1392,7 +1394,14 @@ let is_field_accessor regtypes fdef = (access = "get" || access = "set") && is_field_of regtyp field | _ -> false +let int_of_field_index tname fid nexp = + match int_of_nexp_opt nexp with + | Some i -> i + | None -> raise (Reporting.err_typ Parse_ast.Unknown + ("Non-constant bitfield index in field " ^ string_of_id fid ^ " of " ^ tname)) + let doc_regtype_fields (tname, (n1, n2, fields)) = + let const_int fid idx = int_of_field_index tname fid idx in let i1, i2 = match n1, n2 with | Nexp_aux(Nexp_constant i1,_),Nexp_aux(Nexp_constant i2,_) -> i1, i2 | _ -> raise (Reporting.err_typ Parse_ast.Unknown @@ -1401,8 +1410,8 @@ let doc_regtype_fields (tname, (n1, n2, fields)) = let dir = (if dir_b then "true" else "false") in let doc_field (fr, fid) = let i, j = match fr with - | BF_aux (BF_single i, _) -> (i, i) - | BF_aux (BF_range (i, j), _) -> (i, j) + | BF_aux (BF_single i, _) -> let i = const_int fid i in (i, i) + | BF_aux (BF_range (i, j), _) -> (const_int fid i, const_int fid j) | _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ ("Unsupported type in field " ^ string_of_id fid ^ " of " ^ tname)) in let fsize = Big_int.succ (Big_int.abs (Big_int.sub i j)) in diff --git a/src/type_check.ml b/src/type_check.ml index f31da5f4..a19f77de 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -4787,7 +4787,10 @@ let rec check_typedef : 'a. Env.t -> 'a type_def -> (tannot def) list * Env.t = A_aux (A_typ (Typ_aux (Typ_id b, _)), _)]), _) when string_of_id v = "vector" && string_of_id b = "bit" -> let size = Big_int.to_int size in - let (Defs defs), env = check env (Bitfield.macro id size order ranges) in + let eval_index_nexp env nexp = + int_of_nexp_opt (nexp_simp (Env.expand_nexp_synonyms env nexp)) in + let (Defs defs), env = + check env (Bitfield.macro (eval_index_nexp env, (typ_error env)) id size order ranges) in defs, env | _ -> typ_error env l "Bad bitfield type" diff --git a/test/typecheck/pass/bitvector_param.sail b/test/typecheck/pass/bitvector_param.sail new file mode 100644 index 00000000..ffebeb6e --- /dev/null +++ b/test/typecheck/pass/bitvector_param.sail @@ -0,0 +1,42 @@ +/* from prelude */ +default Order dec +type bits ('n : Int) = vector('n, dec, bit) + +val vector_subrange = { + ocaml: "subrange", + lem: "subrange_vec_dec", + c: "vector_subrange", + coq: "subrange_vec_dec" +} : forall ('n : Int) ('m : Int) ('o : Int), 0 <= 'o <= 'm < 'n. + (bits('n), atom('m), atom('o)) -> bits('m - 'o + 1) + +val vector_update_subrange_dec = {ocaml: "update_subrange", c: "vector_update_subrange", lem: "update_subrange_vec_dec", coq: "update_subrange_vec_dec"} : forall 'n 'm 'o. + (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n) + +val vector_update_subrange_inc = {ocaml: "update_subrange", lem: "update_subrange_vec_inc"} : forall 'n 'm 'o. + (vector('n, inc, bit), atom('m), atom('o), vector('o - ('m - 1), inc, bit)) -> vector('n, inc, bit) + +overload vector_update_subrange = {vector_update_subrange_dec, vector_update_subrange_inc} + +val bitvector_concat = {c: "append", ocaml: "append", lem: "concat_vec", coq: "concat_vec"} : forall ('n : Int) ('m : Int). + (bits('n), bits('m)) -> bits('n + 'm) + +val vector_concat = {ocaml: "append", lem: "append_list"} : forall ('n : Int) ('m : Int) ('a : Type). + (vector('n, dec, 'a), vector('m, dec, 'a)) -> vector('n + 'm, dec, 'a) + +overload append = {bitvector_concat, vector_concat} + +val "reg_deref" : forall ('a : Type). register('a) -> 'a effect {rreg} +/* sneaky deref with no effect necessary for bitfield writes */ +val _reg_deref = "reg_deref" : forall ('a : Type). register('a) -> 'a + +type xlen : Int = 64 +type ylen : Int = 1 + +type xlenbits = bits(xlen) + +bitfield Mstatus : xlenbits = { + SD : xlen - ylen, + SXL : xlen - ylen - 1 .. xlen - ylen - 3 +} +register mstatus : Mstatus |
