summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/type_check.ml6
-rw-r--r--test/typecheck/pass/int_synonym.sail18
2 files changed, 23 insertions, 1 deletions
diff --git a/src/type_check.ml b/src/type_check.ml
index 65e13a19..0c936860 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -548,6 +548,10 @@ end = struct
match aux with
| NC_or (nc1, nc2) -> NC_aux (NC_or (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l)
| NC_and (nc1, nc2) -> NC_aux (NC_and (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l)
+ | NC_equal (n1, n2) -> NC_aux (NC_equal (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l)
+ | NC_not_equal (n1, n2) -> NC_aux (NC_not_equal (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l)
+ | NC_bounded_le (n1, n2) -> NC_aux (NC_bounded_le (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l)
+ | NC_bounded_ge (n1, n2) -> NC_aux (NC_bounded_ge (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l)
| NC_app (id, args) ->
(try
begin match Bindings.find id env.typ_synonyms l env args with
@@ -555,7 +559,7 @@ end = struct
| arg -> typ_error env l ("Expected Bool when expanding synonym " ^ string_of_id id ^ " got " ^ string_of_typ_arg arg)
end
with Not_found -> NC_aux (NC_app (id, List.map (expand_synonyms_arg env) args), l))
- | NC_true | NC_false | NC_equal _ | NC_not_equal _ | NC_bounded_le _ | NC_bounded_ge _ | NC_var _ | NC_set _ -> nc
+ | NC_true | NC_false | NC_var _ | NC_set _ -> nc
and expand_nexp_synonyms env (Nexp_aux (aux, l) as nexp) =
typ_debug ~level:2 (lazy ("Expanding " ^ string_of_nexp nexp));
diff --git a/test/typecheck/pass/int_synonym.sail b/test/typecheck/pass/int_synonym.sail
new file mode 100644
index 00000000..33bdaf0c
--- /dev/null
+++ b/test/typecheck/pass/int_synonym.sail
@@ -0,0 +1,18 @@
+/* from prelude */
+default Order dec
+
+$include <flow.sail>
+
+type bits ('n : Int) = vector('n, dec, bit)
+
+type xlen : Int = 64
+type xlen_bytes : Int = 8
+type xlenbits = bits(xlen)
+
+val "sign_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m)
+val EXTS : forall 'n 'm, 'm >= 'n. (implicit('m), bits('n)) -> bits('m)
+function EXTS(m, v) = sign_extend(v, m)
+
+val extend : forall 'n, 'n <= xlen_bytes. (bool, bits(8 * 'n)) -> xlenbits
+
+function extend(flag, value) = EXTS(value)