summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair2019-02-02 00:10:27 +0000
committerAlasdair2019-02-02 00:10:27 +0000
commit38befd6856101fb7a7a4b49dcb0306dc9dd2f64f (patch)
treed8abbace91daa7ab39e77d2c01ee20f1602a87ce
parent2f8dd66dcaec500561f8736c98bebf65938fa608 (diff)
parent4f45f462333c5494a84886677bc78a49c84da081 (diff)
Merge remote-tracking branch 'origin/sail2' into asl_flow2
-rw-r--r--editors/sail-mode.el22
-rw-r--r--lib/hol/Holmakefile8
-rw-r--r--lib/hol/Makefile4
-rw-r--r--src/ast_util.ml141
-rw-r--r--src/monomorphise.ml184
-rw-r--r--src/pretty_print_lem.ml16
-rw-r--r--src/rewrites.ml10
-rw-r--r--src/type_check.mli4
-rw-r--r--test/mono/castrequnion.sail58
-rw-r--r--test/mono/flow_extend.sail16
-rw-r--r--test/mono/pass/castrequnion1
-rw-r--r--test/mono/pass/flow_extend1
12 files changed, 353 insertions, 112 deletions
diff --git a/editors/sail-mode.el b/editors/sail-mode.el
index 6dae0761..b1adccaf 100644
--- a/editors/sail-mode.el
+++ b/editors/sail-mode.el
@@ -1,3 +1,24 @@
+;;; sail-mode.el --- Major mode for editing .sail files -*- lexical-binding: t; -*-
+
+;; Copyright (C) 2013-2018 The Sail Authors
+;;
+;; Author: The Sail Authors
+;; URL: http://github.com/rems-project/sail
+;; Package-Requires: ((emacs "25"))
+;; Version: 0.0.1
+;; Keywords: language
+
+;; This file is not part of GNU Emacs.
+
+;;; License:
+
+;; 2-Clause BSD License (See LICENSE file in Sail repository)
+
+;;; Commentary:
+
+;; This mode is only compatible with new, recent of the new Sail on the "sail2" branch.
+
+;;; Code:
(defvar sail2-mode-hook nil)
@@ -65,3 +86,4 @@
(provide 'sail2-mode)
+;;; sail-mode.el ends here
diff --git a/lib/hol/Holmakefile b/lib/hol/Holmakefile
index 8e8403f3..0da5813f 100644
--- a/lib/hol/Holmakefile
+++ b/lib/hol/Holmakefile
@@ -1,3 +1,5 @@
+# Ensure LEM_DIR is set before running Holmake, e.g., by using the accompanying Makefile
+
LEM_SCRIPTS = sail2_instr_kindsScript.sml sail2_valuesScript.sml sail2_operatorsScript.sml \
sail2_operators_mwordsScript.sml sail2_operators_bitlistsScript.sml \
sail2_state_monadScript.sml sail2_stateScript.sml sail2_promptScript.sml sail2_prompt_monadScript.sml \
@@ -10,9 +12,7 @@ SCRIPTS = $(LEM_SCRIPTS) \
THYS = $(patsubst %Script.sml,%Theory.uo,$(SCRIPTS))
-LEMDIR=../../../lem/hol-lib
-
-INCLUDES = $(LEMDIR)
+INCLUDES = $(LEM_DIR)/hol-lib
all: $(THYS)
.PHONY: all
@@ -23,7 +23,7 @@ ifdef POLY
HOLHEAP = sail-heap
EXTRA_CLEANS = $(LEM_CLEANS) $(HOLHEAP) $(HOLHEAP).o
-BASE_HEAP = $(LEMDIR)/lemheap
+BASE_HEAP = $(LEM_DIR)/hol-lib/lemheap
$(HOLHEAP): $(BASE_HEAP)
$(protect $(HOLDIR)/bin/buildheap) -o $(HOLHEAP) -b $(BASE_HEAP)
diff --git a/lib/hol/Makefile b/lib/hol/Makefile
index 783ef23d..c863a05b 100644
--- a/lib/hol/Makefile
+++ b/lib/hol/Makefile
@@ -1,3 +1,5 @@
+LEM_DIR?=$(shell opam config var lem:share)
+
LEMSRC = \
../../src/lem_interp/sail2_instr_kinds.lem \
../../src/gen_lib/sail2_values.lem \
@@ -25,7 +27,7 @@ $(SCRIPTS): $(LEMSRC)
lem -hol -outdir . -auxiliary_level none -lib ../../src/lem_interp -lib ../../src/gen_lib $(LEMSRC)
$(THYS) sail-heap: $(SCRIPTS)
- Holmake
+ LEM_DIR=$(LEM_DIR) Holmake
# Holmake will also clear out the generated $(SCRIPTS) files
clean:
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 26c6b9df..1fe4798f 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -996,77 +996,88 @@ let lex_ord f g x1 x2 y1 y2 =
| 0 -> g y1 y2
| n -> n
+let rec nc_compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) =
+ match nc1, nc2 with
+ | NC_equal (n1,n2), NC_equal (n3,n4)
+ | NC_bounded_ge (n1,n2), NC_bounded_ge (n3,n4)
+ | NC_bounded_le (n1,n2), NC_bounded_le (n3,n4)
+ | NC_not_equal (n1,n2), NC_not_equal (n3,n4)
+ -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4
+ | NC_set (k1,s1), NC_set (k2,s2) ->
+ lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2
+ | NC_or (nc1,nc2), NC_or (nc3,nc4)
+ | NC_and (nc1,nc2), NC_and (nc3,nc4)
+ -> lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4
+ | NC_app (f1,args1), NC_app (f2,args2)
+ -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2
+ | NC_var v1, NC_var v2
+ -> Kid.compare v1 v2
+ | NC_true, NC_true
+ | NC_false, NC_false
+ -> 0
+ | NC_equal _, _ -> -1 | _, NC_equal _ -> 1
+ | NC_bounded_ge _, _ -> -1 | _, NC_bounded_ge _ -> 1
+ | NC_bounded_le _, _ -> -1 | _, NC_bounded_le _ -> 1
+ | NC_not_equal _, _ -> -1 | _, NC_not_equal _ -> 1
+ | NC_set _, _ -> -1 | _, NC_set _ -> 1
+ | NC_or _, _ -> -1 | _, NC_or _ -> 1
+ | NC_and _, _ -> -1 | _, NC_and _ -> 1
+ | NC_app _, _ -> -1 | _, NC_app _ -> 1
+ | NC_var _, _ -> -1 | _, NC_var _ -> 1
+ | NC_true, _ -> -1 | _, NC_true -> 1
+
+and typ_compare (Typ_aux (t1,_)) (Typ_aux (t2,_)) =
+ match t1,t2 with
+ | Typ_internal_unknown, Typ_internal_unknown -> 0
+ | Typ_id id1, Typ_id id2 -> Id.compare id1 id2
+ | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2
+ | Typ_fn (ts1,t2,e1), Typ_fn (ts3,t4,e2) ->
+ (match Util.compare_list typ_compare ts1 ts3 with
+ | 0 -> (match typ_compare t2 t4 with
+ | 0 -> effect_compare e1 e2
+ | n -> n)
+ | n -> n)
+ | Typ_bidir (t1,t2), Typ_bidir (t3,t4) ->
+ (match typ_compare t1 t3 with
+ | 0 -> typ_compare t2 t3
+ | n -> n)
+ | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list typ_compare ts1 ts2
+ | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) ->
+ (match Util.compare_list KOpt.compare ks1 ks2 with
+ | 0 -> (match nc_compare nc1 nc2 with
+ | 0 -> typ_compare t1 t2
+ | n -> n)
+ | n -> n)
+ | Typ_app (id1,ts1), Typ_app (id2,ts2) ->
+ (match Id.compare id1 id2 with
+ | 0 -> Util.compare_list typ_arg_compare ts1 ts2
+ | n -> n)
+ | Typ_internal_unknown, _ -> -1 | _, Typ_internal_unknown -> 1
+ | Typ_id _, _ -> -1 | _, Typ_id _ -> 1
+ | Typ_var _, _ -> -1 | _, Typ_var _ -> 1
+ | Typ_fn _, _ -> -1 | _, Typ_fn _ -> 1
+ | Typ_bidir _, _ -> -1 | _, Typ_bidir _ -> 1
+ | Typ_tup _, _ -> -1 | _, Typ_tup _ -> 1
+ | Typ_exist _, _ -> -1 | _, Typ_exist _ -> 1
+
+and typ_arg_compare (A_aux (ta1,_)) (A_aux (ta2,_)) =
+ match ta1, ta2 with
+ | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2
+ | A_typ t1, A_typ t2 -> typ_compare t1 t2
+ | A_order o1, A_order o2 -> order_compare o1 o2
+ | A_bool nc1, A_bool nc2 -> nc_compare nc1 nc2
+ | A_nexp _, _ -> -1 | _, A_nexp _ -> 1
+ | A_typ _, _ -> -1 | _, A_typ _ -> 1
+ | A_order _, _ -> -1 | _, A_order _ -> 1
+
module NC = struct
type t = n_constraint
- let rec compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) =
- match nc1, nc2 with
- | NC_equal (n1,n2), NC_equal (n3,n4)
- | NC_bounded_ge (n1,n2), NC_bounded_ge (n3,n4)
- | NC_bounded_le (n1,n2), NC_bounded_le (n3,n4)
- | NC_not_equal (n1,n2), NC_not_equal (n3,n4)
- -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4
- | NC_set (k1,s1), NC_set (k2,s2) ->
- lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2
- | NC_or (nc1,nc2), NC_or (nc3,nc4)
- | NC_and (nc1,nc2), NC_and (nc3,nc4)
- -> lex_ord compare compare nc1 nc3 nc2 nc4
- | NC_true, NC_true
- | NC_false, NC_false
- -> 0
- | NC_equal _, _ -> -1 | _, NC_equal _ -> 1
- | NC_bounded_ge _, _ -> -1 | _, NC_bounded_ge _ -> 1
- | NC_bounded_le _, _ -> -1 | _, NC_bounded_le _ -> 1
- | NC_not_equal _, _ -> -1 | _, NC_not_equal _ -> 1
- | NC_set _, _ -> -1 | _, NC_set _ -> 1
- | NC_or _, _ -> -1 | _, NC_or _ -> 1
- | NC_and _, _ -> -1 | _, NC_and _ -> 1
- | NC_true, _ -> -1 | _, NC_true -> 1
+ let compare = nc_compare
end
module Typ = struct
type t = typ
- let rec compare (Typ_aux (t1,_)) (Typ_aux (t2,_)) =
- match t1,t2 with
- | Typ_internal_unknown, Typ_internal_unknown -> 0
- | Typ_id id1, Typ_id id2 -> Id.compare id1 id2
- | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2
- | Typ_fn (ts1,t2,e1), Typ_fn (ts3,t4,e2) ->
- (match Util.compare_list compare ts1 ts3 with
- | 0 -> (match compare t2 t4 with
- | 0 -> effect_compare e1 e2
- | n -> n)
- | n -> n)
- | Typ_bidir (t1,t2), Typ_bidir (t3,t4) ->
- (match compare t1 t3 with
- | 0 -> compare t2 t3
- | n -> n)
- | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list compare ts1 ts2
- | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) ->
- (match Util.compare_list KOpt.compare ks1 ks2 with
- | 0 -> (match NC.compare nc1 nc2 with
- | 0 -> compare t1 t2
- | n -> n)
- | n -> n)
- | Typ_app (id1,ts1), Typ_app (id2,ts2) ->
- (match Id.compare id1 id2 with
- | 0 -> Util.compare_list arg_compare ts1 ts2
- | n -> n)
- | Typ_internal_unknown, _ -> -1 | _, Typ_internal_unknown -> 1
- | Typ_id _, _ -> -1 | _, Typ_id _ -> 1
- | Typ_var _, _ -> -1 | _, Typ_var _ -> 1
- | Typ_fn _, _ -> -1 | _, Typ_fn _ -> 1
- | Typ_bidir _, _ -> -1 | _, Typ_bidir _ -> 1
- | Typ_tup _, _ -> -1 | _, Typ_tup _ -> 1
- | Typ_exist _, _ -> -1 | _, Typ_exist _ -> 1
- and arg_compare (A_aux (ta1,_)) (A_aux (ta2,_)) =
- match ta1, ta2 with
- | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2
- | A_typ t1, A_typ t2 -> compare t1 t2
- | A_order o1, A_order o2 -> order_compare o1 o2
- | A_bool nc1, A_bool nc2 -> NC.compare nc1 nc2
- | A_nexp _, _ -> -1 | _, A_nexp _ -> 1
- | A_typ _, _ -> -1 | _, A_typ _ -> 1
- | A_order _, _ -> -1 | _, A_order _ -> 1
+ let compare = typ_compare
end
module TypMap = Map.Make(Typ)
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index e16431b8..ab1a2f82 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3935,16 +3935,46 @@ end
module BitvectorSizeCasts =
struct
-let simplify_size_nexp env quant_kids (Nexp_aux (_,l) as nexp) =
- match solve env nexp with
- | Some n -> Some (nconstant n)
- | None ->
- let is_equal kid =
- prove __POS__ env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown))
- in
- match List.find is_equal quant_kids with
- | kid -> Some (Nexp_aux (Nexp_var kid,Generated l))
- | exception Not_found -> None
+let simplify_size_nexp env quant_kids nexp =
+ let rec aux (Nexp_aux (ne,l) as nexp) =
+ match solve env nexp with
+ | Some n -> Some (nconstant n)
+ | None ->
+ let is_equal kid =
+ prove __POS__ env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown))
+ in
+ match List.find is_equal quant_kids with
+ | kid -> Some (Nexp_aux (Nexp_var kid,Generated l))
+ | exception Not_found ->
+ (* Normally rewriting of complex nexps in function signatures will
+ produce a simple constant or variable above, but occasionally it's
+ useful to work when that rewriting hasn't been applied. In
+ particular, that rewriting isn't fully working with RISC-V at the
+ moment. *)
+ let re f = function
+ | Some n1, Some n2 -> Some (Nexp_aux (f n1 n2,l))
+ | _ -> None
+ in
+ match ne with
+ | Nexp_times(n1,n2) ->
+ re (fun n1 n2 -> Nexp_times(n1,n2)) (aux n1, aux n2)
+ | Nexp_sum(n1,n2) ->
+ re (fun n1 n2 -> Nexp_sum(n1,n2)) (aux n1, aux n2)
+ | Nexp_minus(n1,n2) ->
+ re (fun n1 n2 -> Nexp_times(n1,n2)) (aux n1, aux n2)
+ | Nexp_exp n ->
+ Util.option_map (fun n -> Nexp_aux (Nexp_exp n,l)) (aux n)
+ | Nexp_neg n ->
+ Util.option_map (fun n -> Nexp_aux (Nexp_neg n,l)) (aux n)
+ | _ -> None
+ in aux nexp
+
+let specs_required = ref IdSet.empty
+let check_for_spec env name =
+ let id = mk_id name in
+ match Env.get_val_spec id env with
+ | _ -> ()
+ | exception _ -> specs_required := IdSet.add id !specs_required
(* These functions add cast functions across case splits, so that when a
bitvector size becomes known in sail, the generated Lem code contains a
@@ -4000,6 +4030,7 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
let pat, e' = aux src_typ' target_typ' in
match !at_least_one with
| Some one_target_typ -> begin
+ check_for_spec env cast_name;
let src_ann = mk_tannot env src_typ no_effect in
let tar_ann = mk_tannot env target_typ no_effect in
match src_typ' with
@@ -4033,7 +4064,49 @@ let make_bitvector_env_casts env quant_kids (kid,i) exp =
Bindings.fold (fun var (mut,typ) exp ->
if mut = Immutable then mk_cast var typ exp else exp) locals exp
-let make_bitvector_cast_exp cast_name env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns cast_name env quant_kids typ target_typ)) exp
+let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp =
+ let infer_arg_typ env f l typ =
+ let (typq, ctor_typ) = Env.get_union_id f env in
+ let quants = quant_items typq in
+ match Env.expand_synonyms env ctor_typ with
+ | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) ->
+ begin
+ let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in
+ let unifiers = unify l env goals ret_typ typ in
+ let arg_typ' = subst_unifiers unifiers arg_typ in
+ arg_typ'
+ end
+ | _ -> typ_error l ("Malformed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ)
+
+ in
+ (* Push the cast down, including through constructors *)
+ let rec aux exp (typ, target_typ) =
+ let exp_env = env_of exp in
+ match exp with
+ | E_aux (E_let (lb,exp'),ann) ->
+ E_aux (E_let (lb,aux exp' (typ, target_typ)),ann)
+ | E_aux (E_block exps,ann) ->
+ let exps' = match List.rev exps with
+ | [] -> []
+ | final::l -> aux final (typ, target_typ)::l
+ in E_aux (E_block (List.rev exps'),ann)
+ | E_aux (E_tuple exps,(l,ann)) -> begin
+ match Env.expand_synonyms exp_env typ, Env.expand_synonyms exp_env target_typ with
+ | Typ_aux (Typ_tup src_typs,_), Typ_aux (Typ_tup tgt_typs,_) ->
+ E_aux (E_tuple (List.map2 aux exps (List.combine src_typs tgt_typs)),(l,ann))
+ | _ -> raise (Reporting.err_unreachable l __POS__
+ ("Attempted to insert cast on tuple on non-tuple type: " ^
+ string_of_typ typ ^ " to " ^ string_of_typ target_typ))
+ end
+ | E_aux (E_app (f,args),(l,ann)) when Env.is_union_constructor f (env_of exp) ->
+ let arg = match args with [arg] -> arg | _ -> E_aux (E_tuple args, (l,empty_tannot)) in
+ let src_arg_typ = infer_arg_typ (env_of exp) f l typ in
+ let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in
+ E_aux (E_app (f,[aux arg (src_arg_typ, tgt_arg_typ)]),(l,ann))
+ | _ ->
+ (snd (make_bitvector_cast_fns cast_name cast_env quant_kids typ target_typ)) exp
+ in
+ aux exp (typ, target_typ)
let rec extract_value_from_guard var (E_aux (e,_)) =
match e with
@@ -4082,13 +4155,14 @@ let add_bitvector_casts (Defs defs) =
let pat,guard,body,ann = destruct_pexp pexp in
let body = match pat, guard with
| P_aux (P_lit (L_aux (L_num i,_)),_), _ ->
- let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
+ (* We used to just substitute kid, but fill_in_type also catches other kids defined by it *)
+ let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in
make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
(make_bitvector_env_casts env quant_kids (kid,i) body)
| P_aux (P_id var,_), Some guard ->
(match extract_value_from_guard var guard with
| Some i ->
- let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
+ let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in
make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
(make_bitvector_env_casts env quant_kids (kid,i) body)
| None -> body)
@@ -4112,6 +4186,13 @@ let add_bitvector_casts (Defs defs) =
(match destruct_atom_nexp (env_of y) (typ_of y) with
| Some (Nexp_aux (Nexp_constant i,_)) -> [(kid,i)]
| _ -> [])
+ | E_app (op,[x;y])
+ when string_of_id op = "eq_int" ->
+ (match destruct_atom_nexp (env_of x) (typ_of x), destruct_atom_nexp (env_of y) (typ_of y) with
+ | Some (Nexp_aux (Nexp_var kid,_)), Some (Nexp_aux (Nexp_constant i,_))
+ | Some (Nexp_aux (Nexp_constant i,_)), Some (Nexp_aux (Nexp_var kid,_))
+ -> [(kid,i)]
+ | _ -> [])
| E_app (op, [x;y]) when string_of_id op = "and_bool" ->
extract x @ extract y
| _ -> []
@@ -4126,10 +4207,16 @@ let add_bitvector_casts (Defs defs) =
E_aux (E_if (e1,e2',e3), ann)
| E_return e' ->
E_aux (E_return (make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann)
- | E_assign (LEXP_aux (lexp,lexp_annot),e') ->
- E_aux (E_assign (LEXP_aux (lexp,lexp_annot),
- make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e'))
- (typ_of_annot lexp_annot) e'),ann)
+ | E_assign (LEXP_aux (_,lexp_annot) as lexp,e') -> begin
+ (* The type in the lexp_annot might come from e' rather than being the
+ type of the storage, so ask the type checker what it really is. *)
+ match infer_lexp (env_of_annot lexp_annot) (strip_lexp lexp) with
+ | LEXP_aux (_,lexp_annot') ->
+ E_aux (E_assign (lexp,
+ make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e'))
+ (typ_of_annot lexp_annot') e'),ann)
+ | exception _ -> E_aux (e,ann)
+ end
| E_id id -> begin
let env = env_of_annot ann in
match Env.lookup_id id env with
@@ -4173,7 +4260,23 @@ let add_bitvector_casts (Defs defs) =
| DEF_fundef (FD_aux (FD_function (r,t,e,fcls),fd_ann)) ->
DEF_fundef (FD_aux (FD_function (r,t,e,List.map rewrite_funcl fcls),fd_ann))
| d -> d
- in Defs (List.map rewrite_def defs)
+ in
+ specs_required := IdSet.empty;
+ let defs = List.map rewrite_def defs in
+ let l = Generated Unknown in
+ let Defs cast_specs,_ =
+ (* TODO: use default/relevant order *)
+ let kid = mk_kid "n" in
+ let bitsn = vector_typ (nvar kid) dec_ord bit_typ in
+ let ts = mk_typschm (mk_typquant [mk_qi_id K_int kid])
+ (function_typ [bitsn] bitsn no_effect) in
+ let extfn _ = Some "zeroExtend" in
+ let mkfn name =
+ mk_val_spec (VS_val_spec (ts,name,extfn,false))
+ in
+ let defs = List.map mkfn (IdSet.elements !specs_required) in
+ check Env.empty (Defs defs)
+ in Defs (cast_specs @ defs)
end
let replace_nexp_in_typ env typ orig new_nexp =
@@ -4249,14 +4352,13 @@ let rewrite_toplevel_nexps (Defs defs) =
let nexp_map, typ = rewrite_typ_in_spec env nexp_map typ in
(nexp_map, typ::t)) typs (nexp_map,[])
in nexp_map, Typ_aux (Typ_tup typs,ann)
- | _ ->
- let typ' = Env.base_typ_of env typ_full in
+ | _ when is_number typ_full || is_bitvector_typ typ_full -> begin
let nexp_opt =
- match destruct_atom_nexp env typ' with
+ match destruct_atom_nexp env typ_full with
| Some nexp -> Some nexp
| None ->
- if is_bitvector_typ typ' then
- let (size,_,_) = vector_typ_args_of typ' in
+ if is_bitvector_typ typ_full then
+ let (size,_,_) = vector_typ_args_of typ_full in
Some size
else None
in match nexp_opt with
@@ -4272,10 +4374,27 @@ let rewrite_toplevel_nexps (Defs defs) =
(kid, nexp)::nexp_map, kid
in
let new_nexp = nvar kid in
- (* Try to avoid expanding the original type *)
- let changed, typ = replace_nexp_in_typ env typ_full nexp new_nexp in
- if changed then nexp_map, typ
- else nexp_map, snd (replace_nexp_in_typ env typ' nexp new_nexp)
+ nexp_map, snd (replace_nexp_in_typ env typ_full nexp new_nexp)
+ end
+ | _ ->
+ let typ' = Env.base_typ_of env typ_full in
+ if Typ.compare typ_full typ' == 0 then
+ match t with
+ | Typ_app (f,args) ->
+ let in_arg nexp_map (A_aux (arg,l) as arg_full) =
+ match arg with
+ | A_typ typ ->
+ let nexp_map, typ' = rewrite_typ_in_spec env nexp_map typ in
+ nexp_map, A_aux (A_typ typ',l)
+ | A_bool _ | A_nexp _ | A_order _ -> nexp_map, arg_full
+ in
+ let nexp_map, args =
+ List.fold_right (fun arg (nexp_map,args) ->
+ let nexp_map, arg = in_arg nexp_map arg in
+ (nexp_map, arg::args)) args (nexp_map,[])
+ in nexp_map, Typ_aux (Typ_app (f,args),ann)
+ | _ -> nexp_map, typ_full
+ else rewrite_typ_in_spec env nexp_map typ'
in
let rewrite_valspec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs,typ),ts_l),id,ext_opt,is_cast),ann)) =
match tqs with
@@ -4307,10 +4426,9 @@ let rewrite_toplevel_nexps (Defs defs) =
| A_typ typ -> A_aux (A_typ (aux typ),l)
| A_order _ -> ta_full
| A_nexp nexp ->
- (match find_nexp env nexp_map nexp with
+ match find_nexp env nexp_map nexp with
| (kid,_) -> A_aux (A_nexp (nvar kid),l)
- | exception Not_found -> ta_full)
- | _ -> ta_full
+ | exception Not_found -> ta_full
in aux typ
in
let rewrite_one_exp nexp_map (e,ann) =
@@ -4344,7 +4462,11 @@ let rewrite_toplevel_nexps (Defs defs) =
| DEF_spec vs -> (match rewrite_valspec vs with
| None -> spec_map, def
| Some (id, nexp_map, vs) -> Bindings.add id nexp_map spec_map, DEF_spec vs)
- | DEF_fundef (FD_aux (FD_function (recopt,tann,eff,funcls),ann)) ->
+ | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) ->
+ (* Type annotations on function definitions will have been turned into
+ valspecs by type checking, so it should be safe to drop them rather
+ than updating them. *)
+ let tann = Typ_annot_opt_aux (Typ_annot_opt_none,Generated Unknown) in
spec_map,
DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann))
| _ -> spec_map, def
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 8cef3529..169bd824 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -359,14 +359,14 @@ let replace_typ_size ctxt env (Typ_aux (t,a)) =
let mk_typ nexp =
Some (Typ_aux (Typ_app (id, [A_aux (A_nexp nexp,Parse_ast.Unknown);ord;typ']),a))
in
- let is_equal nexp =
- prove __POS__ env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
- in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
- | nexp -> mk_typ nexp
- | exception Not_found ->
- match Type_check.solve env size with
- | Some n -> mk_typ (nconstant n)
- | None -> None
+ match Type_check.solve env size with
+ | Some n -> mk_typ (nconstant n)
+ | None ->
+ let is_equal nexp =
+ prove __POS__ env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
+ in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
+ | nexp -> mk_typ nexp
+ | exception Not_found -> None
end
| _ -> None
diff --git a/src/rewrites.ml b/src/rewrites.ml
index f8146a72..8fa90643 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -5017,6 +5017,10 @@ let if_mono f defs =
| [], false -> defs
| _, _ -> f defs
+(* Also turn mwords stages on when we're just trying out mono *)
+let if_mwords f defs =
+ if !Pretty_print_lem.opt_mwords then f defs else if_mono f defs
+
let rewrite_defs_lem = [
("realise_mappings", rewrite_defs_realise_mappings);
("remove_mapping_valspecs", remove_mapping_valspecs);
@@ -5027,10 +5031,10 @@ let rewrite_defs_lem = [
("recheck_defs", if_mono recheck_defs);
("rewrite_toplevel_nexps", if_mono rewrite_toplevel_nexps);
("monomorphise", if_mono monomorphise);
- ("recheck_defs", if_mono recheck_defs);
- ("add_bitvector_casts", if_mono Monomorphise.add_bitvector_casts);
+ ("recheck_defs", if_mwords recheck_defs);
+ ("add_bitvector_casts", if_mwords Monomorphise.add_bitvector_casts);
("rewrite_atoms_to_singletons", if_mono Monomorphise.rewrite_atoms_to_singletons);
- ("recheck_defs", if_mono recheck_defs);
+ ("recheck_defs", if_mwords recheck_defs);
("rewrite_undefined", rewrite_undefined_if_gen false);
("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
("remove_not_pats", rewrite_defs_not_pats);
diff --git a/src/type_check.mli b/src/type_check.mli
index 2663c1c7..82e9ebc1 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -206,6 +206,8 @@ module Env : sig
val pattern_completeness_ctx : t -> Pattern_completeness.ctx
val builtin_typs : typquant Bindings.t
+
+ val get_union_id : id -> t -> typquant * typ
end
(** Push all the type variables and constraints from a typquant into
@@ -299,6 +301,8 @@ val infer_exp : Env.t -> unit exp -> tannot exp
val infer_pat : Env.t -> unit pat -> tannot pat * Env.t * unit exp list
+val infer_lexp : Env.t -> unit lexp -> tannot lexp
+
val check_case : Env.t -> typ -> unit pexp -> typ -> tannot pexp
val check_fundef : Env.t -> 'a fundef -> tannot def list * Env.t
diff --git a/test/mono/castrequnion.sail b/test/mono/castrequnion.sail
new file mode 100644
index 00000000..4729fb11
--- /dev/null
+++ b/test/mono/castrequnion.sail
@@ -0,0 +1,58 @@
+default Order dec
+$include <prelude.sail>
+
+val bitvector_cast_in = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure
+val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure
+
+val foo : forall 'm 'n, 'm in {8,16} & 'n in {16,32,64}. bits('m) -> option(bits('n)) effect pure
+
+function foo(x) =
+ let y : bits(16) = sail_zero_extend(x,16) in
+ match 'n {
+ 16 => None(),
+ 32 => Some(y@y),
+ 64 => let z = y@y@y@y in let dfsf = 4 in Some(z)
+ }
+
+union Result ('a : Type) = {
+ Value : ('a, int),
+ Complaint : string
+}
+
+/* Getting ahead of myself: the 2*'n isn't supported yet, although shouldn't it end up in the form below?
+*/
+val bar : forall 'n, 'n in {8,16,32}. bits('n) -> Result(bits(2*'n))
+
+function bar(x) =
+ match 'n {
+ 8 => Complaint("No bytes"),
+ 16 => Value(x@x, unsigned(x)),
+ 32 => Value(sail_sign_extend(x,64), unsigned(x))
+ }
+/*
+val bar : forall 'n 'm, 'n in {8,16,32} & 'm == 2*'n. bits('n) -> Result(bits('m))
+
+function bar(x) =
+ match 'n {
+ 8 => Complaint("No bytes"),
+ 16 => Value(x@x, unsigned(x)),
+ 32 => Value(sail_sign_extend(x,64), unsigned(x))
+ }
+*/
+
+val cmp : forall 'n. (option(bits('n)), option(bits('n))) -> bool
+
+function cmp (None(),None()) = true
+and cmp (None(),Some(_)) = false
+and cmp (Some(_),None()) = false
+and cmp (Some(x),Some(y)) = x == y
+
+overload operator == = {cmp}
+
+val run : unit -> unit effect {escape}
+
+function run() = {
+ assert((foo(0x12) : option(bits(16))) == None());
+ assert((foo(0x12) : option(bits(32))) == Some(0x00120012));
+ assert((foo(0x12) : option(bits(64))) == Some(0x0012001200120012));
+}
diff --git a/test/mono/flow_extend.sail b/test/mono/flow_extend.sail
new file mode 100644
index 00000000..7e118993
--- /dev/null
+++ b/test/mono/flow_extend.sail
@@ -0,0 +1,16 @@
+default Order dec
+$include <prelude.sail>
+
+val bitvector_cast_in = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure
+val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure
+
+val byte_extend : forall 'n, 'n >= 8. (bits(8), int('n)) -> bits('n)
+
+function byte_extend (v, n) = if (n == 8) then v else sail_zero_extend(v, n)
+
+val run : unit -> unit effect {escape}
+
+function run() = {
+ assert(byte_extend(0x12,8) == 0x12);
+ assert(byte_extend(0x12,16) == 0x0012);
+}
diff --git a/test/mono/pass/castrequnion b/test/mono/pass/castrequnion
new file mode 100644
index 00000000..9b2a2f38
--- /dev/null
+++ b/test/mono/pass/castrequnion
@@ -0,0 +1 @@
+castrequnion.sail -auto_mono
diff --git a/test/mono/pass/flow_extend b/test/mono/pass/flow_extend
new file mode 100644
index 00000000..fea386e5
--- /dev/null
+++ b/test/mono/pass/flow_extend
@@ -0,0 +1 @@
+flow_extend.sail -auto_mono