summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--language/sail.ott4
-rw-r--r--src/ast_util.ml9
-rw-r--r--src/ast_util.mli1
-rw-r--r--src/bitfield.ml25
-rw-r--r--src/initial_check.ml20
-rw-r--r--src/parse_ast.ml4
-rw-r--r--src/parser.mly4
-rw-r--r--src/pretty_print_common.ml2
-rw-r--r--src/pretty_print_coq.ml22
-rw-r--r--src/pretty_print_lem.ml13
-rw-r--r--src/type_check.ml5
-rw-r--r--test/typecheck/pass/bitvector_param.sail42
12 files changed, 124 insertions, 27 deletions
diff --git a/language/sail.ott b/language/sail.ott
index 6d2760ff..7dbd3c9e 100644
--- a/language/sail.ott
+++ b/language/sail.ott
@@ -316,8 +316,8 @@ type_union :: 'Tu_' ::=
index_range :: 'BF_' ::= {{ com index specification, for bitfields in register types}}
{{ aux _ l }}
- | num :: :: 'single' {{ com single index }}
- | num1 '..' num2 :: :: range {{ com index range }}
+ | nexp :: :: 'single' {{ com single index }}
+ | nexp1 '..' nexp2 :: :: range {{ com index range }}
| index_range1 , index_range2 :: :: concat {{ com concatenation of index ranges }}
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 03031767..b3ab2cfd 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -236,6 +236,11 @@ let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with
| Nexp_exp n | Nexp_neg n -> is_nexp_constant n
| Nexp_app (_, nexps) -> List.for_all is_nexp_constant nexps
+let int_of_nexp_opt nexp =
+ match nexp with
+ | Nexp_aux(Nexp_constant i,_) -> Some i
+ | _ -> None
+
let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l)
and nexp_simp_aux = function
(* (n - (n - m)) often appears in foreach loops *)
@@ -911,8 +916,8 @@ and string_of_letbind (LB_aux (lb, l)) =
let rec string_of_index_range (BF_aux (ir, _)) =
match ir with
- | BF_single n -> Big_int.to_string n
- | BF_range (n, m) -> Big_int.to_string n ^ " .. " ^ Big_int.to_string m
+ | BF_single n -> string_of_nexp n
+ | BF_range (n, m) -> string_of_nexp n ^ " .. " ^ string_of_nexp m
| BF_concat (ir1, ir2) -> "(" ^ string_of_index_range ir1 ^ ") : (" ^ string_of_index_range ir2 ^ ")"
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 4cbea3dc..c4eb0b4b 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -337,6 +337,7 @@ end
val nexp_frees : nexp -> KidSet.t
val nexp_identical : nexp -> nexp -> bool
val is_nexp_constant : nexp -> bool
+val int_of_nexp_opt : nexp -> Big_int.num option
val lexp_to_exp : 'a lexp -> 'a exp
diff --git a/src/bitfield.ml b/src/bitfield.ml
index e8250598..1f64adbd 100644
--- a/src/bitfield.ml
+++ b/src/bitfield.ml
@@ -150,19 +150,30 @@ let index_range_update name field order start stop =
let index_range_overload name field order =
ast_of_def_string (Printf.sprintf "overload _mod_%s = {_get_%s_%s, _set_%s_%s}" field name field name field)
-let index_range_accessor name field order (BF_aux (bf_aux, l)) =
+let index_range_accessor (eval, typ_error) name field order (BF_aux (bf_aux, l)) =
let getter n m = index_range_getter name field order (Big_int.to_int n) (Big_int.to_int m) in
let setter n m = index_range_setter name field order (Big_int.to_int n) (Big_int.to_int m) in
let update n m = index_range_update name field order (Big_int.to_int n) (Big_int.to_int m) in
let overload = index_range_overload name field order in
+ let const_fold nexp = match eval nexp with
+ | Some v -> v
+ | None -> typ_error l (Printf.sprintf "Non-constant index for field %s" field) in
match bf_aux with
- | BF_single n -> combine [getter n n; setter n n; update n n; overload]
- | BF_range (n, m) -> combine [getter n m; setter n m; update n m; overload]
+ | BF_single n ->
+ let n = const_fold n in
+ combine [getter n n; setter n n; update n n; overload]
+ | BF_range (n, m) ->
+ let n, m = const_fold n, const_fold m in
+ combine [getter n m; setter n m; update n m; overload]
| BF_concat _ -> failwith "Unimplemented"
-let field_accessor name order (id, ir) = index_range_accessor name (string_of_id id) order ir
+let field_accessor (eval, typ_error) name order (id, ir) =
+ index_range_accessor (eval, typ_error) name (string_of_id id) order ir
-let macro id size order ranges =
+let macro (eval, typ_error) id size order ranges =
let name = string_of_id id in
- let ranges = (mk_id "bits", BF_aux (BF_range (Big_int.of_int (size - 1), Big_int.of_int 0), Parse_ast.Unknown)) :: ranges in
- combine ([newtype name size order; constructor name order (size - 1) 0] @ List.map (field_accessor name order) ranges)
+ let ranges = (mk_id "bits", BF_aux (BF_range (nconstant (Big_int.of_int (size - 1)),
+ nconstant (Big_int.of_int 0)),
+ Parse_ast.Unknown)) :: ranges in
+ combine ([newtype name size order; constructor name order (size - 1) 0]
+ @ List.map (field_accessor (eval, typ_error) name order) ranges)
diff --git a/src/initial_check.ml b/src/initial_check.ml
index 108e04d0..07316c6d 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -201,6 +201,20 @@ and to_ast_nexp ctx (P.ATyp_aux (aux, l)) =
in
Nexp_aux (aux, l)
+and to_ast_bitfield_index_nexp (P.ATyp_aux (aux, l)) =
+ let aux = match aux with
+ | P.ATyp_id id -> Nexp_id (to_ast_id id)
+ | P.ATyp_lit (P.L_aux (P.L_num c, _)) -> Nexp_constant c
+ | P.ATyp_sum (t1, t2) -> Nexp_sum (to_ast_bitfield_index_nexp t1, to_ast_bitfield_index_nexp t2)
+ | P.ATyp_exp t1 -> Nexp_exp (to_ast_bitfield_index_nexp t1)
+ | P.ATyp_neg t1 -> Nexp_neg (to_ast_bitfield_index_nexp t1)
+ | P.ATyp_times (t1, t2) -> Nexp_times (to_ast_bitfield_index_nexp t1, to_ast_bitfield_index_nexp t2)
+ | P.ATyp_minus (t1, t2) -> Nexp_minus (to_ast_bitfield_index_nexp t1, to_ast_bitfield_index_nexp t2)
+ | P.ATyp_app (id, ts) -> Nexp_app (to_ast_id id, List.map (to_ast_bitfield_index_nexp) ts)
+ | _ -> raise (Reporting.err_typ l "Invalid numeric expression in field index")
+ in
+ Nexp_aux (aux, l)
+
and to_ast_order ctx (P.ATyp_aux (aux, l)) =
match aux with
| ATyp_var v -> Ord_aux (Ord_var (to_ast_var v), l)
@@ -503,9 +517,9 @@ let to_ast_spec ctx (val_:P.val_spec) : (unit val_spec) ctx_out =
let rec to_ast_range (P.BF_aux(r,l)) = (* TODO add check that ranges are sensible for some definition of sensible *)
BF_aux(
(match r with
- | P.BF_single(i) -> BF_single(i)
- | P.BF_range(i1,i2) -> BF_range(i1,i2)
- | P.BF_concat(ir1,ir2) -> BF_concat( to_ast_range ir1, to_ast_range ir2)),
+ | P.BF_single(i) -> BF_single(to_ast_bitfield_index_nexp i)
+ | P.BF_range(i1,i2) -> BF_range(to_ast_bitfield_index_nexp i1,to_ast_bitfield_index_nexp i2)
+ | P.BF_concat(ir1,ir2) -> BF_concat(to_ast_range ir1, to_ast_range ir2)),
l)
let to_ast_type_union ctx (P.Tu_aux (P.Tu_ty_id (atyp, id), l)) =
diff --git a/src/parse_ast.ml b/src/parse_ast.ml
index 6401331e..eb5c3dc6 100644
--- a/src/parse_ast.ml
+++ b/src/parse_ast.ml
@@ -371,8 +371,8 @@ type_union =
type
index_range_aux = (* index specification, for bitfields in register types *)
- BF_single of Big_int.num (* single index *)
- | BF_range of Big_int.num * Big_int.num (* index range *)
+ BF_single of atyp (* single index *)
+ | BF_range of atyp * atyp (* index range *)
| BF_concat of index_range * index_range (* concatenation of index ranges *)
and index_range =
diff --git a/src/parser.mly b/src/parser.mly
index cbbc41e3..bd832d28 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -1130,9 +1130,9 @@ funcl_typ:
{ mk_tannot mk_typqn $1 $startpos $endpos }
index_range:
- | Num
+ | typ
{ mk_ir (BF_single $1) $startpos $endpos }
- | Num DotDot Num
+ | typ DotDot typ
{ mk_ir (BF_range ($1, $3)) $startpos $endpos }
r_id_def:
diff --git a/src/pretty_print_common.ml b/src/pretty_print_common.ml
index c01896ac..3a1deed0 100644
--- a/src/pretty_print_common.ml
+++ b/src/pretty_print_common.ml
@@ -89,10 +89,12 @@ let doc_id (Id_aux(i,_)) =
* token in case of x ending with star. *)
parens (separate space [string "deinfix"; string x; empty])
+(*
let rec doc_range (BF_aux(r,_)) = match r with
| BF_single i -> doc_int i
| BF_range(i1,i2) -> doc_op dotdot (doc_int i1) (doc_int i2)
| BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2)
+*)
let print ?(len=100) channel doc = ToChannel.pretty 1. len channel doc
let to_buf ?(len=100) buf doc = ToBuffer.pretty 1. len buf doc
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml
index bb6a3d6a..4596f23f 100644
--- a/src/pretty_print_coq.ml
+++ b/src/pretty_print_coq.ml
@@ -1896,10 +1896,12 @@ let doc_type_union ctxt typ_name (Tu_aux(Tu_ty_id(typ,id),_)) =
separate space [doc_id_ctor id; colon;
doc_typ ctxt typ; arrow; typ_name]
-let rec doc_range (BF_aux(r,_)) = match r with
- | BF_single i -> parens (doc_op comma (doc_int i) (doc_int i))
- | BF_range(i1,i2) -> parens (doc_op comma (doc_int i1) (doc_int i2))
- | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2)
+(*
+let rec doc_range ctxt (BF_aux(r,_)) = match r with
+ | BF_single i -> parens (doc_op comma (doc_nexp ctxt i) (doc_nexp ctxt i))
+ | BF_range(i1,i2) -> parens (doc_op comma (doc_nexp ctxt i1) (doc_nexp ctxt i2))
+ | BF_concat(ir1,ir2) -> (doc_range ctxt ir1) ^^ comma ^^ (doc_range ctxt ir2)
+ *)
let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with
| TD_abbrev(id,typq,A_aux (A_typ typ, _)) ->
@@ -2408,7 +2410,15 @@ let is_field_accessor regtypes fdef =
(access = "get" || access = "set") && is_field_of regtyp field
| _ -> false
+
+let int_of_field_index tname fid nexp =
+ match int_of_nexp_opt nexp with
+ | Some i -> i
+ | None -> raise (Reporting.err_typ Parse_ast.Unknown
+ ("Non-constant bitfield index in field " ^ string_of_id fid ^ " of " ^ tname))
+
let doc_regtype_fields (tname, (n1, n2, fields)) =
+ let const_int fid idx = int_of_field_index tname fid idx in
let i1, i2 = match n1, n2 with
| Nexp_aux(Nexp_constant i1,_),Nexp_aux(Nexp_constant i2,_) -> i1, i2
| _ -> raise (Reporting.err_typ Parse_ast.Unknown
@@ -2417,8 +2427,8 @@ let doc_regtype_fields (tname, (n1, n2, fields)) =
let dir = (if dir_b then "true" else "false") in
let doc_field (fr, fid) =
let i, j = match fr with
- | BF_aux (BF_single i, _) -> (i, i)
- | BF_aux (BF_range (i, j), _) -> (i, j)
+ | BF_aux (BF_single i, _) -> let i = const_int fid i in (i, i)
+ | BF_aux (BF_range (i, j), _) -> (const_int fid i, const_int fid j)
| _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__
("Unsupported type in field " ^ string_of_id fid ^ " of " ^ tname)) in
let fsize = Big_int.succ (Big_int.abs (Big_int.sub i j)) in
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 7d2cc479..dee0a29f 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -1016,10 +1016,12 @@ let doc_type_union_lem env (Tu_aux(Tu_ty_id(typ,id),_)) =
separate space [pipe; doc_id_lem_ctor id; string "of";
parens (doc_typ_lem env typ)]
+(*
let rec doc_range_lem (BF_aux(r,_)) = match r with
| BF_single i -> parens (doc_op comma (doc_int i) (doc_int i))
| BF_range(i1,i2) -> parens (doc_op comma (doc_int i1) (doc_int i2))
| BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2)
+ *)
let doc_typdef_lem env (TD_aux(td, (l, annot))) = match td with
| TD_abbrev(id,typq,A_aux (A_typ typ, _)) ->
@@ -1392,7 +1394,14 @@ let is_field_accessor regtypes fdef =
(access = "get" || access = "set") && is_field_of regtyp field
| _ -> false
+let int_of_field_index tname fid nexp =
+ match int_of_nexp_opt nexp with
+ | Some i -> i
+ | None -> raise (Reporting.err_typ Parse_ast.Unknown
+ ("Non-constant bitfield index in field " ^ string_of_id fid ^ " of " ^ tname))
+
let doc_regtype_fields (tname, (n1, n2, fields)) =
+ let const_int fid idx = int_of_field_index tname fid idx in
let i1, i2 = match n1, n2 with
| Nexp_aux(Nexp_constant i1,_),Nexp_aux(Nexp_constant i2,_) -> i1, i2
| _ -> raise (Reporting.err_typ Parse_ast.Unknown
@@ -1401,8 +1410,8 @@ let doc_regtype_fields (tname, (n1, n2, fields)) =
let dir = (if dir_b then "true" else "false") in
let doc_field (fr, fid) =
let i, j = match fr with
- | BF_aux (BF_single i, _) -> (i, i)
- | BF_aux (BF_range (i, j), _) -> (i, j)
+ | BF_aux (BF_single i, _) -> let i = const_int fid i in (i, i)
+ | BF_aux (BF_range (i, j), _) -> (const_int fid i, const_int fid j)
| _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__
("Unsupported type in field " ^ string_of_id fid ^ " of " ^ tname)) in
let fsize = Big_int.succ (Big_int.abs (Big_int.sub i j)) in
diff --git a/src/type_check.ml b/src/type_check.ml
index f31da5f4..a19f77de 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -4787,7 +4787,10 @@ let rec check_typedef : 'a. Env.t -> 'a type_def -> (tannot def) list * Env.t =
A_aux (A_typ (Typ_aux (Typ_id b, _)), _)]), _)
when string_of_id v = "vector" && string_of_id b = "bit" ->
let size = Big_int.to_int size in
- let (Defs defs), env = check env (Bitfield.macro id size order ranges) in
+ let eval_index_nexp env nexp =
+ int_of_nexp_opt (nexp_simp (Env.expand_nexp_synonyms env nexp)) in
+ let (Defs defs), env =
+ check env (Bitfield.macro (eval_index_nexp env, (typ_error env)) id size order ranges) in
defs, env
| _ ->
typ_error env l "Bad bitfield type"
diff --git a/test/typecheck/pass/bitvector_param.sail b/test/typecheck/pass/bitvector_param.sail
new file mode 100644
index 00000000..ffebeb6e
--- /dev/null
+++ b/test/typecheck/pass/bitvector_param.sail
@@ -0,0 +1,42 @@
+/* from prelude */
+default Order dec
+type bits ('n : Int) = vector('n, dec, bit)
+
+val vector_subrange = {
+ ocaml: "subrange",
+ lem: "subrange_vec_dec",
+ c: "vector_subrange",
+ coq: "subrange_vec_dec"
+} : forall ('n : Int) ('m : Int) ('o : Int), 0 <= 'o <= 'm < 'n.
+ (bits('n), atom('m), atom('o)) -> bits('m - 'o + 1)
+
+val vector_update_subrange_dec = {ocaml: "update_subrange", c: "vector_update_subrange", lem: "update_subrange_vec_dec", coq: "update_subrange_vec_dec"} : forall 'n 'm 'o.
+ (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n)
+
+val vector_update_subrange_inc = {ocaml: "update_subrange", lem: "update_subrange_vec_inc"} : forall 'n 'm 'o.
+ (vector('n, inc, bit), atom('m), atom('o), vector('o - ('m - 1), inc, bit)) -> vector('n, inc, bit)
+
+overload vector_update_subrange = {vector_update_subrange_dec, vector_update_subrange_inc}
+
+val bitvector_concat = {c: "append", ocaml: "append", lem: "concat_vec", coq: "concat_vec"} : forall ('n : Int) ('m : Int).
+ (bits('n), bits('m)) -> bits('n + 'm)
+
+val vector_concat = {ocaml: "append", lem: "append_list"} : forall ('n : Int) ('m : Int) ('a : Type).
+ (vector('n, dec, 'a), vector('m, dec, 'a)) -> vector('n + 'm, dec, 'a)
+
+overload append = {bitvector_concat, vector_concat}
+
+val "reg_deref" : forall ('a : Type). register('a) -> 'a effect {rreg}
+/* sneaky deref with no effect necessary for bitfield writes */
+val _reg_deref = "reg_deref" : forall ('a : Type). register('a) -> 'a
+
+type xlen : Int = 64
+type ylen : Int = 1
+
+type xlenbits = bits(xlen)
+
+bitfield Mstatus : xlenbits = {
+ SD : xlen - ylen,
+ SXL : xlen - ylen - 1 .. xlen - ylen - 3
+}
+register mstatus : Mstatus