diff options
| author | Alasdair | 2019-02-02 00:10:27 +0000 |
|---|---|---|
| committer | Alasdair | 2019-02-02 00:10:27 +0000 |
| commit | 38befd6856101fb7a7a4b49dcb0306dc9dd2f64f (patch) | |
| tree | d8abbace91daa7ab39e77d2c01ee20f1602a87ce | |
| parent | 2f8dd66dcaec500561f8736c98bebf65938fa608 (diff) | |
| parent | 4f45f462333c5494a84886677bc78a49c84da081 (diff) | |
Merge remote-tracking branch 'origin/sail2' into asl_flow2
| -rw-r--r-- | editors/sail-mode.el | 22 | ||||
| -rw-r--r-- | lib/hol/Holmakefile | 8 | ||||
| -rw-r--r-- | lib/hol/Makefile | 4 | ||||
| -rw-r--r-- | src/ast_util.ml | 141 | ||||
| -rw-r--r-- | src/monomorphise.ml | 184 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 16 | ||||
| -rw-r--r-- | src/rewrites.ml | 10 | ||||
| -rw-r--r-- | src/type_check.mli | 4 | ||||
| -rw-r--r-- | test/mono/castrequnion.sail | 58 | ||||
| -rw-r--r-- | test/mono/flow_extend.sail | 16 | ||||
| -rw-r--r-- | test/mono/pass/castrequnion | 1 | ||||
| -rw-r--r-- | test/mono/pass/flow_extend | 1 |
12 files changed, 353 insertions, 112 deletions
diff --git a/editors/sail-mode.el b/editors/sail-mode.el index 6dae0761..b1adccaf 100644 --- a/editors/sail-mode.el +++ b/editors/sail-mode.el @@ -1,3 +1,24 @@ +;;; sail-mode.el --- Major mode for editing .sail files -*- lexical-binding: t; -*- + +;; Copyright (C) 2013-2018 The Sail Authors +;; +;; Author: The Sail Authors +;; URL: http://github.com/rems-project/sail +;; Package-Requires: ((emacs "25")) +;; Version: 0.0.1 +;; Keywords: language + +;; This file is not part of GNU Emacs. + +;;; License: + +;; 2-Clause BSD License (See LICENSE file in Sail repository) + +;;; Commentary: + +;; This mode is only compatible with new, recent of the new Sail on the "sail2" branch. + +;;; Code: (defvar sail2-mode-hook nil) @@ -65,3 +86,4 @@ (provide 'sail2-mode) +;;; sail-mode.el ends here diff --git a/lib/hol/Holmakefile b/lib/hol/Holmakefile index 8e8403f3..0da5813f 100644 --- a/lib/hol/Holmakefile +++ b/lib/hol/Holmakefile @@ -1,3 +1,5 @@ +# Ensure LEM_DIR is set before running Holmake, e.g., by using the accompanying Makefile + LEM_SCRIPTS = sail2_instr_kindsScript.sml sail2_valuesScript.sml sail2_operatorsScript.sml \ sail2_operators_mwordsScript.sml sail2_operators_bitlistsScript.sml \ sail2_state_monadScript.sml sail2_stateScript.sml sail2_promptScript.sml sail2_prompt_monadScript.sml \ @@ -10,9 +12,7 @@ SCRIPTS = $(LEM_SCRIPTS) \ THYS = $(patsubst %Script.sml,%Theory.uo,$(SCRIPTS)) -LEMDIR=../../../lem/hol-lib - -INCLUDES = $(LEMDIR) +INCLUDES = $(LEM_DIR)/hol-lib all: $(THYS) .PHONY: all @@ -23,7 +23,7 @@ ifdef POLY HOLHEAP = sail-heap EXTRA_CLEANS = $(LEM_CLEANS) $(HOLHEAP) $(HOLHEAP).o -BASE_HEAP = $(LEMDIR)/lemheap +BASE_HEAP = $(LEM_DIR)/hol-lib/lemheap $(HOLHEAP): $(BASE_HEAP) $(protect $(HOLDIR)/bin/buildheap) -o $(HOLHEAP) -b $(BASE_HEAP) diff --git a/lib/hol/Makefile b/lib/hol/Makefile index 783ef23d..c863a05b 100644 --- a/lib/hol/Makefile +++ b/lib/hol/Makefile @@ -1,3 +1,5 @@ +LEM_DIR?=$(shell opam config var lem:share) + LEMSRC = \ ../../src/lem_interp/sail2_instr_kinds.lem \ ../../src/gen_lib/sail2_values.lem \ @@ -25,7 +27,7 @@ $(SCRIPTS): $(LEMSRC) lem -hol -outdir . -auxiliary_level none -lib ../../src/lem_interp -lib ../../src/gen_lib $(LEMSRC) $(THYS) sail-heap: $(SCRIPTS) - Holmake + LEM_DIR=$(LEM_DIR) Holmake # Holmake will also clear out the generated $(SCRIPTS) files clean: diff --git a/src/ast_util.ml b/src/ast_util.ml index 26c6b9df..1fe4798f 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -996,77 +996,88 @@ let lex_ord f g x1 x2 y1 y2 = | 0 -> g y1 y2 | n -> n +let rec nc_compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) = + match nc1, nc2 with + | NC_equal (n1,n2), NC_equal (n3,n4) + | NC_bounded_ge (n1,n2), NC_bounded_ge (n3,n4) + | NC_bounded_le (n1,n2), NC_bounded_le (n3,n4) + | NC_not_equal (n1,n2), NC_not_equal (n3,n4) + -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 + | NC_set (k1,s1), NC_set (k2,s2) -> + lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2 + | NC_or (nc1,nc2), NC_or (nc3,nc4) + | NC_and (nc1,nc2), NC_and (nc3,nc4) + -> lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4 + | NC_app (f1,args1), NC_app (f2,args2) + -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2 + | NC_var v1, NC_var v2 + -> Kid.compare v1 v2 + | NC_true, NC_true + | NC_false, NC_false + -> 0 + | NC_equal _, _ -> -1 | _, NC_equal _ -> 1 + | NC_bounded_ge _, _ -> -1 | _, NC_bounded_ge _ -> 1 + | NC_bounded_le _, _ -> -1 | _, NC_bounded_le _ -> 1 + | NC_not_equal _, _ -> -1 | _, NC_not_equal _ -> 1 + | NC_set _, _ -> -1 | _, NC_set _ -> 1 + | NC_or _, _ -> -1 | _, NC_or _ -> 1 + | NC_and _, _ -> -1 | _, NC_and _ -> 1 + | NC_app _, _ -> -1 | _, NC_app _ -> 1 + | NC_var _, _ -> -1 | _, NC_var _ -> 1 + | NC_true, _ -> -1 | _, NC_true -> 1 + +and typ_compare (Typ_aux (t1,_)) (Typ_aux (t2,_)) = + match t1,t2 with + | Typ_internal_unknown, Typ_internal_unknown -> 0 + | Typ_id id1, Typ_id id2 -> Id.compare id1 id2 + | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 + | Typ_fn (ts1,t2,e1), Typ_fn (ts3,t4,e2) -> + (match Util.compare_list typ_compare ts1 ts3 with + | 0 -> (match typ_compare t2 t4 with + | 0 -> effect_compare e1 e2 + | n -> n) + | n -> n) + | Typ_bidir (t1,t2), Typ_bidir (t3,t4) -> + (match typ_compare t1 t3 with + | 0 -> typ_compare t2 t3 + | n -> n) + | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list typ_compare ts1 ts2 + | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) -> + (match Util.compare_list KOpt.compare ks1 ks2 with + | 0 -> (match nc_compare nc1 nc2 with + | 0 -> typ_compare t1 t2 + | n -> n) + | n -> n) + | Typ_app (id1,ts1), Typ_app (id2,ts2) -> + (match Id.compare id1 id2 with + | 0 -> Util.compare_list typ_arg_compare ts1 ts2 + | n -> n) + | Typ_internal_unknown, _ -> -1 | _, Typ_internal_unknown -> 1 + | Typ_id _, _ -> -1 | _, Typ_id _ -> 1 + | Typ_var _, _ -> -1 | _, Typ_var _ -> 1 + | Typ_fn _, _ -> -1 | _, Typ_fn _ -> 1 + | Typ_bidir _, _ -> -1 | _, Typ_bidir _ -> 1 + | Typ_tup _, _ -> -1 | _, Typ_tup _ -> 1 + | Typ_exist _, _ -> -1 | _, Typ_exist _ -> 1 + +and typ_arg_compare (A_aux (ta1,_)) (A_aux (ta2,_)) = + match ta1, ta2 with + | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 + | A_typ t1, A_typ t2 -> typ_compare t1 t2 + | A_order o1, A_order o2 -> order_compare o1 o2 + | A_bool nc1, A_bool nc2 -> nc_compare nc1 nc2 + | A_nexp _, _ -> -1 | _, A_nexp _ -> 1 + | A_typ _, _ -> -1 | _, A_typ _ -> 1 + | A_order _, _ -> -1 | _, A_order _ -> 1 + module NC = struct type t = n_constraint - let rec compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) = - match nc1, nc2 with - | NC_equal (n1,n2), NC_equal (n3,n4) - | NC_bounded_ge (n1,n2), NC_bounded_ge (n3,n4) - | NC_bounded_le (n1,n2), NC_bounded_le (n3,n4) - | NC_not_equal (n1,n2), NC_not_equal (n3,n4) - -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 - | NC_set (k1,s1), NC_set (k2,s2) -> - lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2 - | NC_or (nc1,nc2), NC_or (nc3,nc4) - | NC_and (nc1,nc2), NC_and (nc3,nc4) - -> lex_ord compare compare nc1 nc3 nc2 nc4 - | NC_true, NC_true - | NC_false, NC_false - -> 0 - | NC_equal _, _ -> -1 | _, NC_equal _ -> 1 - | NC_bounded_ge _, _ -> -1 | _, NC_bounded_ge _ -> 1 - | NC_bounded_le _, _ -> -1 | _, NC_bounded_le _ -> 1 - | NC_not_equal _, _ -> -1 | _, NC_not_equal _ -> 1 - | NC_set _, _ -> -1 | _, NC_set _ -> 1 - | NC_or _, _ -> -1 | _, NC_or _ -> 1 - | NC_and _, _ -> -1 | _, NC_and _ -> 1 - | NC_true, _ -> -1 | _, NC_true -> 1 + let compare = nc_compare end module Typ = struct type t = typ - let rec compare (Typ_aux (t1,_)) (Typ_aux (t2,_)) = - match t1,t2 with - | Typ_internal_unknown, Typ_internal_unknown -> 0 - | Typ_id id1, Typ_id id2 -> Id.compare id1 id2 - | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 - | Typ_fn (ts1,t2,e1), Typ_fn (ts3,t4,e2) -> - (match Util.compare_list compare ts1 ts3 with - | 0 -> (match compare t2 t4 with - | 0 -> effect_compare e1 e2 - | n -> n) - | n -> n) - | Typ_bidir (t1,t2), Typ_bidir (t3,t4) -> - (match compare t1 t3 with - | 0 -> compare t2 t3 - | n -> n) - | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list compare ts1 ts2 - | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) -> - (match Util.compare_list KOpt.compare ks1 ks2 with - | 0 -> (match NC.compare nc1 nc2 with - | 0 -> compare t1 t2 - | n -> n) - | n -> n) - | Typ_app (id1,ts1), Typ_app (id2,ts2) -> - (match Id.compare id1 id2 with - | 0 -> Util.compare_list arg_compare ts1 ts2 - | n -> n) - | Typ_internal_unknown, _ -> -1 | _, Typ_internal_unknown -> 1 - | Typ_id _, _ -> -1 | _, Typ_id _ -> 1 - | Typ_var _, _ -> -1 | _, Typ_var _ -> 1 - | Typ_fn _, _ -> -1 | _, Typ_fn _ -> 1 - | Typ_bidir _, _ -> -1 | _, Typ_bidir _ -> 1 - | Typ_tup _, _ -> -1 | _, Typ_tup _ -> 1 - | Typ_exist _, _ -> -1 | _, Typ_exist _ -> 1 - and arg_compare (A_aux (ta1,_)) (A_aux (ta2,_)) = - match ta1, ta2 with - | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 - | A_typ t1, A_typ t2 -> compare t1 t2 - | A_order o1, A_order o2 -> order_compare o1 o2 - | A_bool nc1, A_bool nc2 -> NC.compare nc1 nc2 - | A_nexp _, _ -> -1 | _, A_nexp _ -> 1 - | A_typ _, _ -> -1 | _, A_typ _ -> 1 - | A_order _, _ -> -1 | _, A_order _ -> 1 + let compare = typ_compare end module TypMap = Map.Make(Typ) diff --git a/src/monomorphise.ml b/src/monomorphise.ml index e16431b8..ab1a2f82 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -3935,16 +3935,46 @@ end module BitvectorSizeCasts = struct -let simplify_size_nexp env quant_kids (Nexp_aux (_,l) as nexp) = - match solve env nexp with - | Some n -> Some (nconstant n) - | None -> - let is_equal kid = - prove __POS__ env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown)) - in - match List.find is_equal quant_kids with - | kid -> Some (Nexp_aux (Nexp_var kid,Generated l)) - | exception Not_found -> None +let simplify_size_nexp env quant_kids nexp = + let rec aux (Nexp_aux (ne,l) as nexp) = + match solve env nexp with + | Some n -> Some (nconstant n) + | None -> + let is_equal kid = + prove __POS__ env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown)) + in + match List.find is_equal quant_kids with + | kid -> Some (Nexp_aux (Nexp_var kid,Generated l)) + | exception Not_found -> + (* Normally rewriting of complex nexps in function signatures will + produce a simple constant or variable above, but occasionally it's + useful to work when that rewriting hasn't been applied. In + particular, that rewriting isn't fully working with RISC-V at the + moment. *) + let re f = function + | Some n1, Some n2 -> Some (Nexp_aux (f n1 n2,l)) + | _ -> None + in + match ne with + | Nexp_times(n1,n2) -> + re (fun n1 n2 -> Nexp_times(n1,n2)) (aux n1, aux n2) + | Nexp_sum(n1,n2) -> + re (fun n1 n2 -> Nexp_sum(n1,n2)) (aux n1, aux n2) + | Nexp_minus(n1,n2) -> + re (fun n1 n2 -> Nexp_times(n1,n2)) (aux n1, aux n2) + | Nexp_exp n -> + Util.option_map (fun n -> Nexp_aux (Nexp_exp n,l)) (aux n) + | Nexp_neg n -> + Util.option_map (fun n -> Nexp_aux (Nexp_neg n,l)) (aux n) + | _ -> None + in aux nexp + +let specs_required = ref IdSet.empty +let check_for_spec env name = + let id = mk_id name in + match Env.get_val_spec id env with + | _ -> () + | exception _ -> specs_required := IdSet.add id !specs_required (* These functions add cast functions across case splits, so that when a bitvector size becomes known in sail, the generated Lem code contains a @@ -4000,6 +4030,7 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ = let pat, e' = aux src_typ' target_typ' in match !at_least_one with | Some one_target_typ -> begin + check_for_spec env cast_name; let src_ann = mk_tannot env src_typ no_effect in let tar_ann = mk_tannot env target_typ no_effect in match src_typ' with @@ -4033,7 +4064,49 @@ let make_bitvector_env_casts env quant_kids (kid,i) exp = Bindings.fold (fun var (mut,typ) exp -> if mut = Immutable then mk_cast var typ exp else exp) locals exp -let make_bitvector_cast_exp cast_name env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns cast_name env quant_kids typ target_typ)) exp +let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp = + let infer_arg_typ env f l typ = + let (typq, ctor_typ) = Env.get_union_id f env in + let quants = quant_items typq in + match Env.expand_synonyms env ctor_typ with + | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) -> + begin + let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in + let unifiers = unify l env goals ret_typ typ in + let arg_typ' = subst_unifiers unifiers arg_typ in + arg_typ' + end + | _ -> typ_error l ("Malformed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) + + in + (* Push the cast down, including through constructors *) + let rec aux exp (typ, target_typ) = + let exp_env = env_of exp in + match exp with + | E_aux (E_let (lb,exp'),ann) -> + E_aux (E_let (lb,aux exp' (typ, target_typ)),ann) + | E_aux (E_block exps,ann) -> + let exps' = match List.rev exps with + | [] -> [] + | final::l -> aux final (typ, target_typ)::l + in E_aux (E_block (List.rev exps'),ann) + | E_aux (E_tuple exps,(l,ann)) -> begin + match Env.expand_synonyms exp_env typ, Env.expand_synonyms exp_env target_typ with + | Typ_aux (Typ_tup src_typs,_), Typ_aux (Typ_tup tgt_typs,_) -> + E_aux (E_tuple (List.map2 aux exps (List.combine src_typs tgt_typs)),(l,ann)) + | _ -> raise (Reporting.err_unreachable l __POS__ + ("Attempted to insert cast on tuple on non-tuple type: " ^ + string_of_typ typ ^ " to " ^ string_of_typ target_typ)) + end + | E_aux (E_app (f,args),(l,ann)) when Env.is_union_constructor f (env_of exp) -> + let arg = match args with [arg] -> arg | _ -> E_aux (E_tuple args, (l,empty_tannot)) in + let src_arg_typ = infer_arg_typ (env_of exp) f l typ in + let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in + E_aux (E_app (f,[aux arg (src_arg_typ, tgt_arg_typ)]),(l,ann)) + | _ -> + (snd (make_bitvector_cast_fns cast_name cast_env quant_kids typ target_typ)) exp + in + aux exp (typ, target_typ) let rec extract_value_from_guard var (E_aux (e,_)) = match e with @@ -4082,13 +4155,14 @@ let add_bitvector_casts (Defs defs) = let pat,guard,body,ann = destruct_pexp pexp in let body = match pat, guard with | P_aux (P_lit (L_aux (L_num i,_)),_), _ -> - let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in + (* We used to just substitute kid, but fill_in_type also catches other kids defined by it *) + let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ (make_bitvector_env_casts env quant_kids (kid,i) body) | P_aux (P_id var,_), Some guard -> (match extract_value_from_guard var guard with | Some i -> - let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in + let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ (make_bitvector_env_casts env quant_kids (kid,i) body) | None -> body) @@ -4112,6 +4186,13 @@ let add_bitvector_casts (Defs defs) = (match destruct_atom_nexp (env_of y) (typ_of y) with | Some (Nexp_aux (Nexp_constant i,_)) -> [(kid,i)] | _ -> []) + | E_app (op,[x;y]) + when string_of_id op = "eq_int" -> + (match destruct_atom_nexp (env_of x) (typ_of x), destruct_atom_nexp (env_of y) (typ_of y) with + | Some (Nexp_aux (Nexp_var kid,_)), Some (Nexp_aux (Nexp_constant i,_)) + | Some (Nexp_aux (Nexp_constant i,_)), Some (Nexp_aux (Nexp_var kid,_)) + -> [(kid,i)] + | _ -> []) | E_app (op, [x;y]) when string_of_id op = "and_bool" -> extract x @ extract y | _ -> [] @@ -4126,10 +4207,16 @@ let add_bitvector_casts (Defs defs) = E_aux (E_if (e1,e2',e3), ann) | E_return e' -> E_aux (E_return (make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann) - | E_assign (LEXP_aux (lexp,lexp_annot),e') -> - E_aux (E_assign (LEXP_aux (lexp,lexp_annot), - make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) - (typ_of_annot lexp_annot) e'),ann) + | E_assign (LEXP_aux (_,lexp_annot) as lexp,e') -> begin + (* The type in the lexp_annot might come from e' rather than being the + type of the storage, so ask the type checker what it really is. *) + match infer_lexp (env_of_annot lexp_annot) (strip_lexp lexp) with + | LEXP_aux (_,lexp_annot') -> + E_aux (E_assign (lexp, + make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) + (typ_of_annot lexp_annot') e'),ann) + | exception _ -> E_aux (e,ann) + end | E_id id -> begin let env = env_of_annot ann in match Env.lookup_id id env with @@ -4173,7 +4260,23 @@ let add_bitvector_casts (Defs defs) = | DEF_fundef (FD_aux (FD_function (r,t,e,fcls),fd_ann)) -> DEF_fundef (FD_aux (FD_function (r,t,e,List.map rewrite_funcl fcls),fd_ann)) | d -> d - in Defs (List.map rewrite_def defs) + in + specs_required := IdSet.empty; + let defs = List.map rewrite_def defs in + let l = Generated Unknown in + let Defs cast_specs,_ = + (* TODO: use default/relevant order *) + let kid = mk_kid "n" in + let bitsn = vector_typ (nvar kid) dec_ord bit_typ in + let ts = mk_typschm (mk_typquant [mk_qi_id K_int kid]) + (function_typ [bitsn] bitsn no_effect) in + let extfn _ = Some "zeroExtend" in + let mkfn name = + mk_val_spec (VS_val_spec (ts,name,extfn,false)) + in + let defs = List.map mkfn (IdSet.elements !specs_required) in + check Env.empty (Defs defs) + in Defs (cast_specs @ defs) end let replace_nexp_in_typ env typ orig new_nexp = @@ -4249,14 +4352,13 @@ let rewrite_toplevel_nexps (Defs defs) = let nexp_map, typ = rewrite_typ_in_spec env nexp_map typ in (nexp_map, typ::t)) typs (nexp_map,[]) in nexp_map, Typ_aux (Typ_tup typs,ann) - | _ -> - let typ' = Env.base_typ_of env typ_full in + | _ when is_number typ_full || is_bitvector_typ typ_full -> begin let nexp_opt = - match destruct_atom_nexp env typ' with + match destruct_atom_nexp env typ_full with | Some nexp -> Some nexp | None -> - if is_bitvector_typ typ' then - let (size,_,_) = vector_typ_args_of typ' in + if is_bitvector_typ typ_full then + let (size,_,_) = vector_typ_args_of typ_full in Some size else None in match nexp_opt with @@ -4272,10 +4374,27 @@ let rewrite_toplevel_nexps (Defs defs) = (kid, nexp)::nexp_map, kid in let new_nexp = nvar kid in - (* Try to avoid expanding the original type *) - let changed, typ = replace_nexp_in_typ env typ_full nexp new_nexp in - if changed then nexp_map, typ - else nexp_map, snd (replace_nexp_in_typ env typ' nexp new_nexp) + nexp_map, snd (replace_nexp_in_typ env typ_full nexp new_nexp) + end + | _ -> + let typ' = Env.base_typ_of env typ_full in + if Typ.compare typ_full typ' == 0 then + match t with + | Typ_app (f,args) -> + let in_arg nexp_map (A_aux (arg,l) as arg_full) = + match arg with + | A_typ typ -> + let nexp_map, typ' = rewrite_typ_in_spec env nexp_map typ in + nexp_map, A_aux (A_typ typ',l) + | A_bool _ | A_nexp _ | A_order _ -> nexp_map, arg_full + in + let nexp_map, args = + List.fold_right (fun arg (nexp_map,args) -> + let nexp_map, arg = in_arg nexp_map arg in + (nexp_map, arg::args)) args (nexp_map,[]) + in nexp_map, Typ_aux (Typ_app (f,args),ann) + | _ -> nexp_map, typ_full + else rewrite_typ_in_spec env nexp_map typ' in let rewrite_valspec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs,typ),ts_l),id,ext_opt,is_cast),ann)) = match tqs with @@ -4307,10 +4426,9 @@ let rewrite_toplevel_nexps (Defs defs) = | A_typ typ -> A_aux (A_typ (aux typ),l) | A_order _ -> ta_full | A_nexp nexp -> - (match find_nexp env nexp_map nexp with + match find_nexp env nexp_map nexp with | (kid,_) -> A_aux (A_nexp (nvar kid),l) - | exception Not_found -> ta_full) - | _ -> ta_full + | exception Not_found -> ta_full in aux typ in let rewrite_one_exp nexp_map (e,ann) = @@ -4344,7 +4462,11 @@ let rewrite_toplevel_nexps (Defs defs) = | DEF_spec vs -> (match rewrite_valspec vs with | None -> spec_map, def | Some (id, nexp_map, vs) -> Bindings.add id nexp_map spec_map, DEF_spec vs) - | DEF_fundef (FD_aux (FD_function (recopt,tann,eff,funcls),ann)) -> + | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) -> + (* Type annotations on function definitions will have been turned into + valspecs by type checking, so it should be safe to drop them rather + than updating them. *) + let tann = Typ_annot_opt_aux (Typ_annot_opt_none,Generated Unknown) in spec_map, DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann)) | _ -> spec_map, def diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 8cef3529..169bd824 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -359,14 +359,14 @@ let replace_typ_size ctxt env (Typ_aux (t,a)) = let mk_typ nexp = Some (Typ_aux (Typ_app (id, [A_aux (A_nexp nexp,Parse_ast.Unknown);ord;typ']),a)) in - let is_equal nexp = - prove __POS__ env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown)) - in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with - | nexp -> mk_typ nexp - | exception Not_found -> - match Type_check.solve env size with - | Some n -> mk_typ (nconstant n) - | None -> None + match Type_check.solve env size with + | Some n -> mk_typ (nconstant n) + | None -> + let is_equal nexp = + prove __POS__ env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown)) + in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with + | nexp -> mk_typ nexp + | exception Not_found -> None end | _ -> None diff --git a/src/rewrites.ml b/src/rewrites.ml index f8146a72..8fa90643 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -5017,6 +5017,10 @@ let if_mono f defs = | [], false -> defs | _, _ -> f defs +(* Also turn mwords stages on when we're just trying out mono *) +let if_mwords f defs = + if !Pretty_print_lem.opt_mwords then f defs else if_mono f defs + let rewrite_defs_lem = [ ("realise_mappings", rewrite_defs_realise_mappings); ("remove_mapping_valspecs", remove_mapping_valspecs); @@ -5027,10 +5031,10 @@ let rewrite_defs_lem = [ ("recheck_defs", if_mono recheck_defs); ("rewrite_toplevel_nexps", if_mono rewrite_toplevel_nexps); ("monomorphise", if_mono monomorphise); - ("recheck_defs", if_mono recheck_defs); - ("add_bitvector_casts", if_mono Monomorphise.add_bitvector_casts); + ("recheck_defs", if_mwords recheck_defs); + ("add_bitvector_casts", if_mwords Monomorphise.add_bitvector_casts); ("rewrite_atoms_to_singletons", if_mono Monomorphise.rewrite_atoms_to_singletons); - ("recheck_defs", if_mono recheck_defs); + ("recheck_defs", if_mwords recheck_defs); ("rewrite_undefined", rewrite_undefined_if_gen false); ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list); ("remove_not_pats", rewrite_defs_not_pats); diff --git a/src/type_check.mli b/src/type_check.mli index 2663c1c7..82e9ebc1 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -206,6 +206,8 @@ module Env : sig val pattern_completeness_ctx : t -> Pattern_completeness.ctx val builtin_typs : typquant Bindings.t + + val get_union_id : id -> t -> typquant * typ end (** Push all the type variables and constraints from a typquant into @@ -299,6 +301,8 @@ val infer_exp : Env.t -> unit exp -> tannot exp val infer_pat : Env.t -> unit pat -> tannot pat * Env.t * unit exp list +val infer_lexp : Env.t -> unit lexp -> tannot lexp + val check_case : Env.t -> typ -> unit pexp -> typ -> tannot pexp val check_fundef : Env.t -> 'a fundef -> tannot def list * Env.t diff --git a/test/mono/castrequnion.sail b/test/mono/castrequnion.sail new file mode 100644 index 00000000..4729fb11 --- /dev/null +++ b/test/mono/castrequnion.sail @@ -0,0 +1,58 @@ +default Order dec +$include <prelude.sail> + +val bitvector_cast_in = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure +val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure + +val foo : forall 'm 'n, 'm in {8,16} & 'n in {16,32,64}. bits('m) -> option(bits('n)) effect pure + +function foo(x) = + let y : bits(16) = sail_zero_extend(x,16) in + match 'n { + 16 => None(), + 32 => Some(y@y), + 64 => let z = y@y@y@y in let dfsf = 4 in Some(z) + } + +union Result ('a : Type) = { + Value : ('a, int), + Complaint : string +} + +/* Getting ahead of myself: the 2*'n isn't supported yet, although shouldn't it end up in the form below? +*/ +val bar : forall 'n, 'n in {8,16,32}. bits('n) -> Result(bits(2*'n)) + +function bar(x) = + match 'n { + 8 => Complaint("No bytes"), + 16 => Value(x@x, unsigned(x)), + 32 => Value(sail_sign_extend(x,64), unsigned(x)) + } +/* +val bar : forall 'n 'm, 'n in {8,16,32} & 'm == 2*'n. bits('n) -> Result(bits('m)) + +function bar(x) = + match 'n { + 8 => Complaint("No bytes"), + 16 => Value(x@x, unsigned(x)), + 32 => Value(sail_sign_extend(x,64), unsigned(x)) + } +*/ + +val cmp : forall 'n. (option(bits('n)), option(bits('n))) -> bool + +function cmp (None(),None()) = true +and cmp (None(),Some(_)) = false +and cmp (Some(_),None()) = false +and cmp (Some(x),Some(y)) = x == y + +overload operator == = {cmp} + +val run : unit -> unit effect {escape} + +function run() = { + assert((foo(0x12) : option(bits(16))) == None()); + assert((foo(0x12) : option(bits(32))) == Some(0x00120012)); + assert((foo(0x12) : option(bits(64))) == Some(0x0012001200120012)); +} diff --git a/test/mono/flow_extend.sail b/test/mono/flow_extend.sail new file mode 100644 index 00000000..7e118993 --- /dev/null +++ b/test/mono/flow_extend.sail @@ -0,0 +1,16 @@ +default Order dec +$include <prelude.sail> + +val bitvector_cast_in = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure +val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure + +val byte_extend : forall 'n, 'n >= 8. (bits(8), int('n)) -> bits('n) + +function byte_extend (v, n) = if (n == 8) then v else sail_zero_extend(v, n) + +val run : unit -> unit effect {escape} + +function run() = { + assert(byte_extend(0x12,8) == 0x12); + assert(byte_extend(0x12,16) == 0x0012); +} diff --git a/test/mono/pass/castrequnion b/test/mono/pass/castrequnion new file mode 100644 index 00000000..9b2a2f38 --- /dev/null +++ b/test/mono/pass/castrequnion @@ -0,0 +1 @@ +castrequnion.sail -auto_mono diff --git a/test/mono/pass/flow_extend b/test/mono/pass/flow_extend new file mode 100644 index 00000000..fea386e5 --- /dev/null +++ b/test/mono/pass/flow_extend @@ -0,0 +1 @@ +flow_extend.sail -auto_mono |
