summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast.ml2
-rw-r--r--src/ast_util.ml2
-rw-r--r--src/initial_check.ml18
-rw-r--r--src/initial_check.mli4
-rw-r--r--src/monomorphise.ml55
-rw-r--r--src/pretty_print_lem.ml34
-rw-r--r--src/pretty_print_lem_ast.ml18
-rw-r--r--src/pretty_print_ocaml.ml21
-rw-r--r--src/process_file.ml34
-rw-r--r--src/process_file.mli5
-rw-r--r--src/rewriter.ml578
-rw-r--r--src/rewriter.mli4
-rw-r--r--src/sail.ml2
-rw-r--r--src/spec_analysis.ml5
-rw-r--r--src/type_check.ml94
-rw-r--r--src/type_check.mli4
16 files changed, 542 insertions, 338 deletions
diff --git a/src/ast.ml b/src/ast.ml
index 1b3cbfd3..6a74d5b2 100644
--- a/src/ast.ml
+++ b/src/ast.ml
@@ -175,6 +175,8 @@ n_constraint_aux = (* constraint over kind $_$ *)
| NC_nat_set_bounded of kid * (int) list
| NC_or of n_constraint * n_constraint
| NC_and of n_constraint * n_constraint
+ | NC_true
+ | NC_false
and
n_constraint =
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 67eedf72..2109175f 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -232,6 +232,8 @@ and string_of_n_constraint = function
"(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_nat_set_bounded (kid, ns), _) ->
string_of_kid kid ^ " IN {" ^ string_of_list ", " string_of_int ns ^ "}"
+ | NC_aux (NC_true, _) -> "true"
+ | NC_aux (NC_false, _) -> "false"
let string_of_quant_item_aux = function
| QI_id (KOpt_aux (KOpt_none kid, _)) -> string_of_kid kid
diff --git a/src/initial_check.ml b/src/initial_check.ml
index 0e68ad81..b831e288 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -112,7 +112,7 @@ let typ_error l msg opt_id opt_var opt_kind =
| None,Some(v),None -> ": " ^ (var_to_string v)
| None,None,Some(kind) -> " " ^ (kind_to_string kind)
| _ -> "")))
-
+
let to_ast_id (Parse_ast.Id_aux(id,l)) =
Id_aux( (match id with
| Parse_ast.Id(x) -> Id(x)
@@ -144,16 +144,8 @@ let rec to_ast_typ (k_env : kind Envmap.t) (def_ord : order) (t: Parse_ast.atyp)
match t with
| Parse_ast.ATyp_aux(t,l) ->
Typ_aux( (match t with
- | Parse_ast.ATyp_id(id) ->
- let id = to_ast_id id in
- let mk = Envmap.apply k_env (id_to_string id) in
- (match mk with
- | Some(k) -> (match k.k with
- | K_Typ -> Typ_id id
- | K_infer -> k.k <- K_Typ; Typ_id id
- | _ -> typ_error l "Required an identifier with kind Type, encountered " (Some id) None (Some k))
- | None -> typ_error l "Encountered an unbound type identifier" (Some id) None None)
- | Parse_ast.ATyp_var(v) ->
+ | Parse_ast.ATyp_id(id) -> Typ_id (to_ast_id id)
+ | Parse_ast.ATyp_var(v) ->
let v = to_ast_var v in
let mk = Envmap.apply k_env (var_to_string v) in
(match mk with
@@ -1010,6 +1002,6 @@ let initial_kind_env =
("implicit", {k = K_Lam( [{k = K_Nat}], {k=K_Typ})} );
]
-let process_ast defs =
- let (ast, _, _) = to_ast Nameset.empty initial_kind_env (Ast.Ord_aux(Ast.Ord_inc,Parse_ast.Unknown)) defs in
+let process_ast order defs =
+ let (ast, _, _) = to_ast Nameset.empty initial_kind_env order defs in
ast
diff --git a/src/initial_check.mli b/src/initial_check.mli
index 063a0131..ed4eb0bf 100644
--- a/src/initial_check.mli
+++ b/src/initial_check.mli
@@ -42,7 +42,5 @@
open Ast
-val process_ast : Parse_ast.defs -> unit defs
+val process_ast : order -> Parse_ast.defs -> unit defs
-
-
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 7bfc3a3d..63be60b2 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -77,9 +77,10 @@ let make_vectors sz =
-(* Based on current type checker's behaviour *)
let pat_id_is_variable env id =
match Env.lookup_id id env with
+ (* Unbound is returned for both variables and constructors which take
+ arguments, but the latter only don't appear in a P_id *)
| Unbound
(* Shadowing of immutable locals is allowed; mutable locals and registers
are rejected by the type checker, so don't matter *)
@@ -90,21 +91,24 @@ let pat_id_is_variable env id =
| Union _
-> false
-
let rec is_value (E_aux (e,(l,annot))) =
+ let is_constructor id =
+ match annot with
+ | None ->
+ (Reporting_basic.print_err false true l "Monomorphisation"
+ ("Missing type information for identifier " ^ string_of_id id);
+ false) (* Be conservative if we have no info *)
+ | Some (env,_,_) ->
+ Env.is_union_constructor id env ||
+ (match Env.lookup_id id env with
+ | Enum _ | Union _ -> true
+ | Unbound | Local _ | Register _ -> false)
+ in
match e with
- | E_id id ->
- (match annot with
- | None ->
- (Reporting_basic.print_err false true l "Monomorphisation"
- ("Missing type information for identifier " ^ string_of_id id);
- false) (* Be conservative if we have no info *)
- | Some (env,_,_) ->
- match Env.lookup_id id env with
- | Enum _ | Union _ -> true
- | Unbound | Local _ | Register _ -> false)
+ | E_id id -> is_constructor id
| E_lit _ -> true
| E_tuple es -> List.for_all is_value es
+ | E_app (id,es) -> is_constructor id && List.for_all is_value es
(* TODO: more? *)
| _ -> false
@@ -294,6 +298,7 @@ let nexp_subst_fns substs refinements =
| E_lit _
| E_comment _ -> re e
| E_sizeof ne -> re (E_sizeof ne) (* TODO: does this need done? does it appear in type checked code? *)
+ | E_constraint _ -> re e (* TODO: actual substitution if necessary *)
| E_internal_exp (l,annot) -> re (E_internal_exp (l, (*s_tannot*) annot))
| E_sizeof_internal (l,annot) -> re (E_sizeof_internal (l, (*s_tannot*) annot))
| E_internal_exp_user ((l1,annot1),(l2,annot2)) ->
@@ -308,12 +313,15 @@ let nexp_subst_fns substs refinements =
| _ -> E_aux (E_tuple es',(l,None))
in
let id' =
- match Env.lookup_id id (fst (env_typ_expected l annot)) with
- | Union (qs,Typ_aux (Typ_fn(inty,outty,_),_)) ->
- (match refine_constructor refinements id substs arg inty with
- | None -> id
- | Some id' -> id')
- | _ -> id
+ let env,_ = env_typ_expected l annot in
+ if Env.is_union_constructor id env then
+ let (qs,ty) = Env.get_val_spec id env in
+ match ty with (Typ_aux (Typ_fn(inty,outty,_),_)) ->
+ (match refine_constructor refinements id substs arg inty with
+ | None -> id
+ | Some id' -> id')
+ | _ -> id
+ else id
in re (E_app (id',es'))
| E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2))
| E_tuple es -> re (E_tuple (List.map s_exp es))
@@ -395,6 +403,7 @@ let bindings_from_pat p =
-> List.concat (List.map aux_pat ps)
| P_record (fps,_) -> List.concat (List.map aux_fpat fps)
| P_vector_indexed ips -> List.concat (List.map (fun (_,p) -> aux_pat p) ips)
+ | P_cons (p1,p2) -> aux_pat p1 @ aux_pat p2
and aux_fpat (FP_aux (FP_Fpat (_,p), _)) = aux_pat p
in aux_pat p
@@ -577,6 +586,7 @@ let split_defs splits defs =
| E_sizeof_internal _
| E_internal_exp_user _
| E_comment _
+ | E_constraint _
-> exp
| E_cast (t,e') -> re (E_cast (t, const_prop_exp substs e'))
| E_app (id,es) ->
@@ -841,6 +851,10 @@ let split_defs splits defs =
relist spl (fun ps -> P_tup ps) ps
| P_list ps ->
relist spl (fun ps -> P_list ps) ps
+ | P_cons (p1,p2) ->
+ match re (fun p' -> P_cons (p',p2)) p1 with
+ | Some r -> Some r
+ | None -> re (fun p' -> P_cons (p1,p')) p2
in spl p
in
@@ -861,8 +875,8 @@ let split_defs splits defs =
let (_,variants) = List.find (fun (id',_) -> Id.compare id id' = 0) refinements in
let env,_ = env_typ_expected l tannot in
let constr_out_typ =
- match Env.lookup_id id env with
- | Union (qs,Typ_aux (Typ_fn(_,outt,_),_)) -> outt
+ match Env.get_val_spec id env with
+ | (qs,Typ_aux (Typ_fn(_,outt,_),_)) -> outt
| _ -> raise (Reporting_basic.err_general l
("Constructor " ^ string_of_id id ^ " is not a construtor!"))
in
@@ -909,6 +923,7 @@ let split_defs splits defs =
| E_sizeof_internal _
| E_internal_exp_user _
| E_comment _
+ | E_constraint _
-> ea
| E_cast (t,e') -> re (E_cast (t, map_exp e'))
| E_app (id,es) -> re (E_app (id,List.map map_exp es))
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 95ddc580..7adccfdf 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -239,7 +239,8 @@ let doc_lit_lem in_pat (L_aux(lit,l)) a =
| Typ_id (Id_aux (Id "string", _)) -> "\"\""
| _ -> "(failwith \"undefined value of unsupported type\")")
| _ -> "(failwith \"undefined value of unsupported type\")")
- | L_string s -> "\"" ^ s ^ "\"")
+ | L_string s -> "\"" ^ s ^ "\""
+ | L_real s -> s (* TODO What's the Lem syntax for reals? *))
(* typ_doc is the doc for the type being quantified *)
@@ -257,16 +258,10 @@ let is_ctor env id = match Env.lookup_id id env with
*)
let rec doc_pat_lem regtypes apat_needed (P_aux (p,(l,annot)) as pa) = match p with
| P_app(id, ((_ :: _) as pats)) ->
- (match annot with
- | Some (env, _, _) when (is_ctor env id) ->
- let ppp = doc_unop (doc_id_lem_ctor id)
- (parens (separate_map comma (doc_pat_lem regtypes true) pats)) in
- if apat_needed then parens ppp else ppp
- | _ -> empty)
- | P_app(id,[]) ->
- (match annot with
- | Some (env, _, _) when (is_ctor env id) -> doc_id_lem_ctor id
- | _ -> empty)
+ let ppp = doc_unop (doc_id_lem_ctor id)
+ (parens (separate_map comma (doc_pat_lem regtypes true) pats)) in
+ if apat_needed then parens ppp else ppp
+ | P_app(id,[]) -> doc_id_lem_ctor id
| P_lit lit -> doc_lit_lem true lit annot
| P_wild -> underscore
| P_id id ->
@@ -281,15 +276,14 @@ let rec doc_pat_lem regtypes apat_needed (P_aux (p,(l,annot)) as pa) = match p w
[string "Vector";brackets (separate_map semi (doc_pat_lem regtypes true) pats);underscore;underscore] in
if apat_needed then parens ppp else ppp
| P_vector_concat pats ->
- let ppp =
- (separate space)
- [string "Vector";parens (separate_map (string "::") (doc_pat_lem regtypes true) pats);underscore;underscore] in
- if apat_needed then parens ppp else ppp
+ raise (Reporting_basic.err_unreachable l
+ "vector concatenation patterns should have been removed before pretty-printing")
| P_tup pats ->
(match pats with
| [p] -> doc_pat_lem regtypes apat_needed p
| _ -> parens (separate_map comma_sp (doc_pat_lem regtypes false) pats))
- | P_list pats -> brackets (separate_map semi (doc_pat_lem regtypes false) pats) (*Never seen but easy in lem*)
+ | P_list pats -> brackets (separate_map semi (doc_pat_lem regtypes false) pats) (*Never seen but easy in lem*)
+ | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem regtypes true p) (doc_pat_lem regtypes true p')
| P_record (_,_) | P_vector_indexed _ -> empty (* TODO *)
let rec contains_bitvector_typ (Typ_aux (t,_) as typ) = match t with
@@ -926,7 +920,7 @@ let doc_exp_lem, doc_let_lem =
| E_return _ ->
raise (Reporting_basic.err_todo l
"pretty-printing early return statements to Lem not yet supported")
- | E_comment _ | E_comment_struc _ -> empty
+ | E_constraint _ | E_comment _ | E_comment_struc _ -> empty
| E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ ->
raise (Reporting_basic.err_unreachable l
"unsupported internal expression encountered while pretty-printing")
@@ -944,9 +938,13 @@ let doc_exp_lem, doc_let_lem =
else doc_id_lem id in
group (doc_op equals fname (top_exp regtypes true e))
- and doc_case regtypes (Pat_aux(Pat_exp(pat,e),_)) =
+ and doc_case regtypes = function
+ | Pat_aux(Pat_exp(pat,e),_) ->
group (prefix 3 1 (separate space [pipe; doc_pat_lem regtypes false pat;arrow])
(group (top_exp regtypes false e)))
+ | Pat_aux(Pat_when(_,_,_),(l,_)) ->
+ raise (Reporting_basic.err_unreachable l
+ "guarded pattern expression should have been rewritten before pretty-printing")
and doc_lexp_deref_lem regtypes ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with
| LEXP_field (le,id) ->
diff --git a/src/pretty_print_lem_ast.ml b/src/pretty_print_lem_ast.ml
index 6809826a..0875aee7 100644
--- a/src/pretty_print_lem_ast.ml
+++ b/src/pretty_print_lem_ast.ml
@@ -219,12 +219,15 @@ let pp_lem_ord ppf o = base ppf (pp_format_ord_lem o)
let pp_lem_effects ppf e = base ppf (pp_format_effects_lem e)
let pp_lem_beffect ppf be = base ppf (pp_format_base_effect_lem be)
-let pp_format_nexp_constraint_lem (NC_aux(nc,l)) =
+let rec pp_format_nexp_constraint_lem (NC_aux(nc,l)) =
"(NC_aux " ^
(match nc with
| NC_fixed(n1,n2) -> "(NC_fixed " ^ pp_format_nexp_lem n1 ^ " " ^ pp_format_nexp_lem n2 ^ ")"
| NC_bounded_ge(n1,n2) -> "(NC_bounded_ge " ^ pp_format_nexp_lem n1 ^ " " ^ pp_format_nexp_lem n2 ^ ")"
| NC_bounded_le(n1,n2) -> "(NC_bounded_le " ^ pp_format_nexp_lem n1 ^ " " ^ pp_format_nexp_lem n2 ^ ")"
+ | NC_not_equal(n1,n2) -> "(NC_not_equal " ^ pp_format_nexp_lem n1 ^ " " ^ pp_format_nexp_lem n2 ^ ")"
+ | NC_or(nc1,nc2) -> "(NC_or " ^ pp_format_nexp_constraint_lem nc1 ^ " " ^ pp_format_nexp_constraint_lem nc2 ^ ")"
+ | NC_and(nc1,nc2) -> "(NC_and " ^ pp_format_nexp_constraint_lem nc1 ^ " " ^ pp_format_nexp_constraint_lem nc2 ^ ")"
| NC_nat_set_bounded(id,bounds) -> "(NC_nat_set_bounded " ^
pp_format_var_lem id ^
" [" ^
@@ -278,7 +281,8 @@ let pp_format_lit_lem (L_aux(lit,l)) =
| L_hex(n) -> "(L_hex \"" ^ n ^ "\")"
| L_bin(n) -> "(L_bin \"" ^ n ^ "\")"
| L_undef -> "L_undef"
- | L_string(s) -> "(L_string \"" ^ s ^ "\")") ^ " " ^
+ | L_string(s) -> "(L_string \"" ^ s ^ "\")"
+ | L_real(s) -> "(L_real \"" ^ s ^ "\")") ^ " " ^
(pp_format_l_lem l) ^ ")"
let pp_lem_lit ppf l = base ppf (pp_format_lit_lem l)
@@ -336,7 +340,8 @@ let rec pp_format_pat_lem (P_aux(p,(l,annot))) =
"(P_vector_indexed [" ^ list_format "; " (fun (i,p) -> Printf.sprintf "(%d, %s)" i (pp_format_pat_lem p)) ipats ^ "])"
| P_vector_concat(pats) -> "(P_vector_concat [" ^ list_format "; " pp_format_pat_lem pats ^ "])"
| P_tup(pats) -> "(P_tup [" ^ (list_format "; " pp_format_pat_lem pats) ^ "])"
- | P_list(pats) -> "(P_list [" ^ (list_format "; " pp_format_pat_lem pats) ^ "])") ^
+ | P_list(pats) -> "(P_list [" ^ (list_format "; " pp_format_pat_lem pats) ^ "])"
+ | P_cons(pat,pat') -> "(P_cons " ^ pp_format_pat_lem pat ^ " " ^ pp_format_pat_lem pat' ^ ")") ^
" (" ^ pp_format_l_lem l ^ ", " ^ pp_format_annot annot ^ "))"
let pp_lem_pat ppf p = base ppf (pp_format_pat_lem p)
@@ -426,6 +431,8 @@ and pp_lem_exp ppf (E_aux(e,(l,annot))) =
pp_lem_lexp lexp pp_lem_exp exp pp_lem_l l pp_annot annot
| E_sizeof nexp ->
fprintf ppf "@[<0>(E_aux (E_sizeof %a) (%a, %a))@]" pp_lem_nexp nexp pp_lem_l l pp_annot annot
+ | E_constraint nc ->
+ fprintf ppf "@[<0>(E_aux (E_constraint %a) (%a, %a))@]" pp_lem_nexp_constraint nc pp_lem_l l pp_annot annot
| E_exit exp ->
fprintf ppf "@[<0>(E_aux (E_exit %a) (%a, %a))@]" pp_lem_exp exp pp_lem_l l pp_annot annot
| E_return exp ->
@@ -476,8 +483,11 @@ and pp_lem_fexp ppf (FE_aux(FE_Fexp(id,exp),(l,annot))) =
fprintf ppf "@[<1>(FE_aux (FE_Fexp %a %a) (%a, %a))@]" pp_lem_id id pp_lem_exp exp pp_lem_l l pp_annot annot
and pp_semi_lem_fexp ppf fexp = fprintf ppf "@[<1>%a %a@]" pp_lem_fexp fexp kwd ";"
-and pp_lem_case ppf (Pat_aux(Pat_exp(pat,exp),(l,annot))) =
+and pp_lem_case ppf = function
+| Pat_aux(Pat_exp(pat,exp),(l,annot)) ->
fprintf ppf "@[<1>(Pat_aux (Pat_exp %a@ %a) (%a, %a))@]" pp_lem_pat pat pp_lem_exp exp pp_lem_l l pp_annot annot
+| Pat_aux(Pat_when(pat,guard,exp),(l,annot)) ->
+ fprintf ppf "@[<1>(Pat_aux (Pat_exp %a@ %a %a) (%a, %a))@]" pp_lem_pat pat pp_lem_exp guard pp_lem_exp exp pp_lem_l l pp_annot annot
and pp_semi_lem_case ppf case = fprintf ppf "@[<1>%a %a@]" pp_lem_case case kwd ";"
and pp_lem_lexp ppf (LEXP_aux(lexp,(l,annot))) =
diff --git a/src/pretty_print_ocaml.ml b/src/pretty_print_ocaml.ml
index 652b0ce9..4f2c3ab0 100644
--- a/src/pretty_print_ocaml.ml
+++ b/src/pretty_print_ocaml.ml
@@ -140,7 +140,8 @@ let doc_lit_ocaml in_pat (L_aux(l,_)) =
| L_hex n -> "(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)
| L_bin n -> "(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)
| L_undef -> "(failwith \"undef literal not supported\")" (* XXX Undef vectors get handled with to_vec_undef. We could support undef bit but would need to check type. For the moment treat as runtime error. *)
- | L_string s -> "\"" ^ s ^ "\"")
+ | L_string s -> "\"" ^ s ^ "\""
+ | L_real s -> s)
(* typ_doc is the doc for the type being quantified *)
let doc_typquant_ocaml (TypQ_aux(tq,_)) typ_doc = typ_doc
@@ -170,7 +171,7 @@ let doc_pat_ocaml =
| P_wild -> underscore
| P_id id -> doc_id_ocaml id
| P_as(p,id) -> parens (separate space [pat p; string "as"; doc_id_ocaml id])
- | P_typ(typ,p) -> doc_op colon (pat p) (doc_typ_ocaml typ)
+ | P_typ(typ,p) -> parens (doc_op colon (pat p) (doc_typ_ocaml typ))
| P_app(id,[]) ->
(match annot with
| Some (env, typ, eff) ->
@@ -196,6 +197,7 @@ let doc_pat_ocaml =
| None -> non_bit_print())
| P_tup pats -> parens (separate_map comma_sp pat pats)
| P_list pats -> brackets (separate_map semi pat pats) (*Never seen but easy in ocaml*)
+ | P_cons (p,p') -> doc_op (string "::") (pat p) (pat p')
| P_record _ -> raise (Reporting_basic.err_unreachable l "unhandled record pattern")
| P_vector_indexed _ -> raise (Reporting_basic.err_unreachable l "unhandled vector_indexed pattern")
| P_vector_concat _ -> raise (Reporting_basic.err_unreachable l "unhandled vector_concat pattern")
@@ -467,6 +469,13 @@ let doc_exp_ocaml, doc_let_ocaml =
separate space [string "return"; exp e1;]
| E_assert (e1, e2) ->
(string "assert") ^^ parens ((string "to_bool") ^^ space ^^ exp e1) (* XXX drops e2 *)
+ | E_sizeof _ -> raise (Reporting_basic.err_unreachable l
+ "E_sizeof should have been rewritten before pretty-printing")
+ | E_constraint _ -> empty
+ | E_sizeof_internal _ | E_internal_exp_user (_, _) | E_internal_cast (_, _)
+ | E_internal_exp _ -> raise (Reporting_basic.err_unreachable l
+ "internal expression should have been rewritten before pretty-printing")
+ | E_comment _ | E_comment_struc _ -> empty (* TODO Should we output comments? *)
and let_exp (LB_aux(lb,_)) = match lb with
| LB_val_explicit(ts,pat,e) ->
prefix 2 1
@@ -479,8 +488,14 @@ let doc_exp_ocaml, doc_let_ocaml =
and doc_fexp (FE_aux(FE_Fexp(id,e),_)) = doc_op equals (doc_id_ocaml id) (top_exp false e)
- and doc_case (Pat_aux(Pat_exp(pat,e),_)) =
+ and doc_case = function
+ | (Pat_aux(Pat_exp(pat,e),_)) ->
doc_op arrow (separate space [pipe; doc_pat_ocaml pat]) (group (top_exp false e))
+ | (Pat_aux(Pat_when(pat,guard,e),_)) ->
+ doc_op arrow
+ (separate space [pipe;
+ doc_op (string "when") (doc_pat_ocaml pat) (top_exp false guard)])
+ (group (top_exp false e))
and doc_lexp_ocaml top_call ((LEXP_aux(lexp,(l,annot))) as le) =
let exp = top_exp false in
diff --git a/src/process_file.ml b/src/process_file.ml
index 0601bfab..c9a4f178 100644
--- a/src/process_file.ml
+++ b/src/process_file.ml
@@ -45,16 +45,17 @@ type out_type =
| Lem_out of string option
| Ocaml_out of string option
-let get_lexbuf fn =
- let lexbuf = Lexing.from_channel (open_in fn) in
- lexbuf.Lexing.lex_curr_p <- { Lexing.pos_fname = fn;
- Lexing.pos_lnum = 1;
- Lexing.pos_bol = 0;
- Lexing.pos_cnum = 0; };
- lexbuf
+let get_lexbuf f =
+ let in_chan = open_in f in
+ let lexbuf = Lexing.from_channel in_chan in
+ lexbuf.Lexing.lex_curr_p <- { Lexing.pos_fname = f;
+ Lexing.pos_lnum = 1;
+ Lexing.pos_bol = 0;
+ Lexing.pos_cnum = 0; };
+ lexbuf, in_chan
let parse_file (f : string) : Parse_ast.defs =
- let scanbuf = get_lexbuf f in
+ let scanbuf, in_chan = get_lexbuf f in
let type_names =
try
Pre_parser.file Pre_lexer.token scanbuf
@@ -67,25 +68,26 @@ let parse_file (f : string) : Parse_ast.defs =
| Lexer.LexError(s,p) ->
raise (Reporting_basic.Fatal_error (Reporting_basic.Err_lex (p, s))) in
let () = Lexer.custom_type_names := !Lexer.custom_type_names @ type_names in
- let lexbuf = get_lexbuf f in
+ close_in in_chan;
+ let lexbuf, in_chan = get_lexbuf f in
try
- Parser.file Lexer.token lexbuf
+ let ast = Parser.file Lexer.token lexbuf in
+ close_in in_chan; ast
with
| Parsing.Parse_error ->
let pos = Lexing.lexeme_start_p lexbuf in
- raise (Reporting_basic.Fatal_error (Reporting_basic.Err_syntax (pos, "main")))
+ raise (Reporting_basic.Fatal_error (Reporting_basic.Err_syntax (pos, "main")))
| Parse_ast.Parse_error_locn(l,m) ->
raise (Reporting_basic.Fatal_error (Reporting_basic.Err_syntax_locn (l, m)))
| Lexer.LexError(s,p) ->
raise (Reporting_basic.Fatal_error (Reporting_basic.Err_lex (p, s)))
+let convert_ast (order : Ast.order) (defs : Parse_ast.defs) : unit Ast.defs = Initial_check.process_ast order defs
-(*Should add a flag to say whether we want to consider Oinc or Odec the default order *)
-let convert_ast (defs : Parse_ast.defs) : unit Ast.defs = Initial_check.process_ast defs
+let load_file_no_check order f = convert_ast order (parse_file f)
-let load_file env f =
- let ast = parse_file f in
- let ast = convert_ast ast in
+let load_file order env f =
+ let ast = convert_ast order (parse_file f) in
Type_check.check env ast
let opt_new_typecheck = ref false
diff --git a/src/process_file.mli b/src/process_file.mli
index b15523bb..9907b743 100644
--- a/src/process_file.mli
+++ b/src/process_file.mli
@@ -41,14 +41,15 @@
(**************************************************************************)
val parse_file : string -> Parse_ast.defs
-val convert_ast : Parse_ast.defs -> unit Ast.defs
+val convert_ast : Ast.order -> Parse_ast.defs -> unit Ast.defs
val check_ast: unit Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t
val monomorphise_ast : ((string * int) * string) list -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t
val rewrite_ast: Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
val rewrite_ast_lem : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
val rewrite_ast_ocaml : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
-val load_file : Type_check.Env.t -> string -> Type_check.tannot Ast.defs * Type_check.Env.t
+val load_file_no_check : Ast.order -> string -> unit Ast.defs
+val load_file : Ast.order -> Type_check.Env.t -> string -> Type_check.tannot Ast.defs * Type_check.Env.t
val opt_new_typecheck : bool ref
val opt_just_check : bool ref
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 166c31f0..0cf25103 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -135,7 +135,7 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with
| E_assign (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e)
| E_exit e -> effect_of e
| E_return e -> effect_of e
- | E_sizeof _ | E_sizeof_internal _ -> no_effect
+ | E_sizeof _ | E_sizeof_internal _ | E_constraint _ -> no_effect
| E_assert (c,m) -> no_effect
| E_comment _ | E_comment_struc _ -> no_effect
| E_internal_cast (_,e) -> effect_of e
@@ -157,6 +157,8 @@ let fix_eff_lexp (LEXP_aux (lexp,((l,_) as annot))) = match snd annot with
| LEXP_id _ -> no_effect
| LEXP_cast _ -> no_effect
| LEXP_memory (_,es) -> union_eff_exps es
+ | LEXP_tup les ->
+ List.fold_left (fun eff le -> union_effects eff (effect_of_lexp le)) no_effect les
| LEXP_vector (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e)
| LEXP_vector_range (lexp,e1,e2) ->
union_effects (effect_of_lexp lexp)
@@ -188,7 +190,8 @@ let fix_eff_opt_default (Def_val_aux (opt_default,((l,_) as annot))) = match snd
let fix_eff_pexp (Pat_aux (pexp,((l,_) as annot))) = match snd annot with
| Some (env, typ, eff) ->
let effsum = union_effects eff (match pexp with
- | Pat_exp (_,e) -> effect_of e) in
+ | Pat_exp (_,e) -> effect_of e
+ | Pat_when (_,e,e') -> union_effects (effect_of e) (effect_of e')) in
Pat_aux (pexp, (l, Some (env, typ, effsum)))
| None ->
Pat_aux (pexp, (l, None))
@@ -396,11 +399,13 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot))) =
(List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) ->
FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot))))
| E_field(exp,id) -> rewrap (E_field(rewrite exp,id))
- | E_case (exp ,pexps) ->
- rewrap (E_case (rewrite exp,
- (List.map
- (fun (Pat_aux (Pat_exp(p,e),pannot)) ->
- Pat_aux (Pat_exp(rewriters.rewrite_pat rewriters p,rewrite e),pannot)) pexps)))
+ | E_case (exp,pexps) ->
+ let rewrite_pexp = function
+ | (Pat_aux (Pat_exp(p, e), pannot)) ->
+ Pat_aux (Pat_exp(rewriters.rewrite_pat rewriters p, rewrite e), pannot)
+ | (Pat_aux (Pat_when(p, e, e'), pannot)) ->
+ Pat_aux (Pat_when(rewriters.rewrite_pat rewriters p, rewrite e, rewrite e'), pannot) in
+ rewrap (E_case (rewrite exp, List.map rewrite_pexp pexps))
| E_let (letbind,body) -> rewrap (E_let(rewriters.rewrite_let rewriters letbind,rewrite body))
| E_assign (lexp,exp) -> rewrap (E_assign(rewriters.rewrite_lexp rewriters lexp,rewrite exp))
| E_sizeof n -> rewrap (E_sizeof n)
@@ -615,6 +620,7 @@ type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg =
; p_vector_concat : 'pat list -> 'pat_aux
; p_tup : 'pat list -> 'pat_aux
; p_list : 'pat list -> 'pat_aux
+ ; p_cons : 'pat * 'pat -> 'pat_aux
; p_aux : 'pat_aux * 'a annot -> 'pat
; fP_aux : 'fpat_aux * 'a annot -> 'fpat
; fP_Fpat : id * 'pat -> 'fpat_aux
@@ -634,6 +640,7 @@ let rec fold_pat_aux (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a pat
| P_vector_concat ps -> alg.p_vector_concat (List.map (fold_pat alg) ps)
| P_tup ps -> alg.p_tup (List.map (fold_pat alg) ps)
| P_list ps -> alg.p_list (List.map (fold_pat alg) ps)
+ | P_cons (ph,pt) -> alg.p_cons (fold_pat alg ph, fold_pat alg pt)
and fold_pat (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a pat -> 'pat =
@@ -660,6 +667,7 @@ let id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg =
; p_vector_concat = (fun ps -> P_vector_concat ps)
; p_tup = (fun ps -> P_tup ps)
; p_list = (fun ps -> P_list ps)
+ ; p_cons = (fun (ph,pt) -> P_cons (ph,pt))
; p_aux = (fun (pat,annot) -> P_aux (pat,annot))
; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot))
; fP_Fpat = (fun (id,pat) -> FP_Fpat (id,pat))
@@ -700,6 +708,8 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
; e_internal_cast : 'a annot * 'exp -> 'exp_aux
; e_internal_exp : 'a annot -> 'exp_aux
; e_internal_exp_user : 'a annot * 'a annot -> 'exp_aux
+ ; e_comment : string -> 'exp_aux
+ ; e_comment_struc : 'exp -> 'exp_aux
; e_internal_let : 'lexp * 'exp * 'exp -> 'exp_aux
; e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux
; e_internal_return : 'exp -> 'exp_aux
@@ -720,6 +730,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
; def_val_dec : 'exp -> 'opt_default_aux
; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default
; pat_exp : 'pat * 'exp -> 'pexp_aux
+ ; pat_when : 'pat * 'exp * 'exp -> 'pexp_aux
; pat_aux : 'pexp_aux * 'a annot -> 'pexp
; lB_val_explicit : typschm * 'pat * 'exp -> 'letbind_aux
; lB_val_implicit : 'pat * 'exp -> 'letbind_aux
@@ -759,12 +770,18 @@ let rec fold_exp_aux alg = function
| E_let (letbind,e) -> alg.e_let (fold_letbind alg letbind, fold_exp alg e)
| E_assign (lexp,e) -> alg.e_assign (fold_lexp alg lexp, fold_exp alg e)
| E_sizeof nexp -> alg.e_sizeof nexp
+ | E_constraint nc -> raise (Reporting_basic.err_unreachable (Parse_ast.Unknown)
+ "E_constraint encountered during rewriting")
| E_exit e -> alg.e_exit (fold_exp alg e)
| E_return e -> alg.e_return (fold_exp alg e)
| E_assert(e1,e2) -> alg.e_assert (fold_exp alg e1, fold_exp alg e2)
| E_internal_cast (annot,e) -> alg.e_internal_cast (annot, fold_exp alg e)
| E_internal_exp annot -> alg.e_internal_exp annot
+ | E_sizeof_internal a -> raise (Reporting_basic.err_unreachable (Parse_ast.Unknown)
+ "E_sizeof_internal encountered during rewriting")
| E_internal_exp_user (annot1,annot2) -> alg.e_internal_exp_user (annot1,annot2)
+ | E_comment c -> alg.e_comment c
+ | E_comment_struc e -> alg.e_comment_struc (fold_exp alg e)
| E_internal_let (lexp,e1,e2) ->
alg.e_internal_let (fold_lexp alg lexp, fold_exp alg e1, fold_exp alg e2)
| E_internal_plet (pat,e1,e2) ->
@@ -774,6 +791,7 @@ and fold_exp alg (E_aux (exp_aux,annot)) = alg.e_aux (fold_exp_aux alg exp_aux,
and fold_lexp_aux alg = function
| LEXP_id id -> alg.lEXP_id id
| LEXP_memory (id,es) -> alg.lEXP_memory (id, List.map (fold_exp alg) es)
+ | LEXP_tup les -> alg.lEXP_tup (List.map (fold_lexp alg) les)
| LEXP_cast (typ,id) -> alg.lEXP_cast (typ,id)
| LEXP_vector (lexp,e) -> alg.lEXP_vector (fold_lexp alg lexp, fold_exp alg e)
| LEXP_vector_range (lexp,e1,e2) ->
@@ -790,7 +808,9 @@ and fold_opt_default_aux alg = function
| Def_val_dec e -> alg.def_val_dec (fold_exp alg e)
and fold_opt_default alg (Def_val_aux (opt_default_aux,annot)) =
alg.def_val_aux (fold_opt_default_aux alg opt_default_aux, annot)
-and fold_pexp_aux alg (Pat_exp (pat,e)) = alg.pat_exp (fold_pat alg.pat_alg pat, fold_exp alg e)
+and fold_pexp_aux alg = function
+ | Pat_exp (pat,e) -> alg.pat_exp (fold_pat alg.pat_alg pat, fold_exp alg e)
+ | Pat_when (pat,e,e') -> alg.pat_when (fold_pat alg.pat_alg pat, fold_exp alg e, fold_exp alg e')
and fold_pexp alg (Pat_aux (pexp_aux,annot)) = alg.pat_aux (fold_pexp_aux alg pexp_aux, annot)
and fold_letbind_aux alg = function
| LB_val_explicit (t,pat,e) -> alg.lB_val_explicit (t,fold_pat alg.pat_alg pat, fold_exp alg e)
@@ -830,6 +850,8 @@ let id_exp_alg =
; e_internal_cast = (fun (a,e1) -> E_internal_cast (a,e1))
; e_internal_exp = (fun a -> E_internal_exp a)
; e_internal_exp_user = (fun (a1,a2) -> E_internal_exp_user (a1,a2))
+ ; e_comment = (fun c -> E_comment c)
+ ; e_comment_struc = (fun e -> E_comment_struc e)
; e_internal_let = (fun (lexp, e2, e3) -> E_internal_let (lexp,e2,e3))
; e_internal_plet = (fun (pat, e1, e2) -> E_internal_plet (pat,e1,e2))
; e_internal_return = (fun e -> E_internal_return e)
@@ -850,6 +872,7 @@ let id_exp_alg =
; def_val_dec = (fun e -> Def_val_dec e)
; def_val_aux = (fun (defval,aux) -> Def_val_aux (defval,aux))
; pat_exp = (fun (pat,e) -> (Pat_exp (pat,e)))
+ ; pat_when = (fun (pat,e,e') -> (Pat_when (pat,e,e')))
; pat_aux = (fun (pexp,a) -> (Pat_aux (pexp,a)))
; lB_val_explicit = (fun (typ,pat,e) -> LB_val_explicit (typ,pat,e))
; lB_val_implicit = (fun (pat,e) -> LB_val_implicit (pat,e))
@@ -880,6 +903,7 @@ let compute_pat_alg bot join =
; p_vector_concat = split_join (fun ps -> P_vector_concat ps)
; p_tup = split_join (fun ps -> P_tup ps)
; p_list = split_join (fun ps -> P_list ps)
+ ; p_cons = (fun ((vh,ph),(vt,pt)) -> (join vh vt, P_cons (ph,pt)))
; p_aux = (fun ((v,pat),annot) -> (v, P_aux (pat,annot)))
; fP_aux = (fun ((v,fpat),annot) -> (v, FP_aux (fpat,annot)))
; fP_Fpat = (fun (id,(v,pat)) -> (v, FP_Fpat (id,pat)))
@@ -926,6 +950,8 @@ let compute_exp_alg bot join =
; e_internal_cast = (fun (a,(v1,e1)) -> (v1, E_internal_cast (a,e1)))
; e_internal_exp = (fun a -> (bot, E_internal_exp a))
; e_internal_exp_user = (fun (a1,a2) -> (bot, E_internal_exp_user (a1,a2)))
+ ; e_comment = (fun c -> (bot, E_comment c))
+ ; e_comment_struc = (fun (v,e) -> (bot, E_comment_struc e)) (* ignore value by default, since it is comes from a comment *)
; e_internal_let = (fun ((vl, lexp), (v2,e2), (v3,e3)) ->
(join_list [vl;v2;v3], E_internal_let (lexp,e2,e3)))
; e_internal_plet = (fun ((vp,pat), (v1,e1), (v2,e2)) ->
@@ -935,7 +961,9 @@ let compute_exp_alg bot join =
; lEXP_id = (fun id -> (bot, LEXP_id id))
; lEXP_memory = (fun (id,es) -> split_join (fun es -> LEXP_memory (id,es)) es)
; lEXP_cast = (fun (typ,id) -> (bot, LEXP_cast (typ,id)))
- ; lEXP_tup = split_join (fun tups -> LEXP_tup tups)
+ ; lEXP_tup = (fun ls ->
+ let (vs,ls) = List.split ls in
+ (join_list vs, LEXP_tup ls))
; lEXP_vector = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, LEXP_vector (lexp,e2)))
; lEXP_vector_range = (fun ((vl,lexp),(v2,e2),(v3,e3)) ->
(join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3)))
@@ -951,6 +979,7 @@ let compute_exp_alg bot join =
; def_val_dec = (fun (v,e) -> (v, Def_val_dec e))
; def_val_aux = (fun ((v,defval),aux) -> (v, Def_val_aux (defval,aux)))
; pat_exp = (fun ((vp,pat),(v,e)) -> (join vp v, Pat_exp (pat,e)))
+ ; pat_when = (fun ((vp,pat),(v,e),(v',e')) -> (join_list [vp;v;v'], Pat_when (pat,e,e')))
; pat_aux = (fun ((v,pexp),a) -> (v, Pat_aux (pexp,a)))
; lB_val_explicit = (fun (typ,(vp,pat),(v,e)) -> (join vp v, LB_val_explicit (typ,pat,e)))
; lB_val_implicit = (fun ((vp,pat),(v,e)) -> (join vp v, LB_val_implicit (pat,e)))
@@ -986,11 +1015,15 @@ let rewrite_sizeof (Defs defs) =
when string_of_id atom = "atom" ->
[nexp, E_id id]
| Typ_app (vector, _) when string_of_id vector = "vector" ->
- let (_,len,_,_) = vector_typ_args_of typ_aux in
- let exp = E_app
- (Id_aux (Id "length", Parse_ast.Generated l),
- [E_aux (E_id id, annot)]) in
- [len, exp]
+ let id_length = Id_aux (Id "length", Parse_ast.Generated l) in
+ (try
+ (match Env.get_val_spec id_length (env_of_annot annot) with
+ | _ ->
+ let (_,len,_,_) = vector_typ_args_of typ_aux in
+ let exp = E_app (id_length, [E_aux (E_id id, annot)]) in
+ [len, exp])
+ with
+ | _ -> [])
| _ -> [])
| _ -> [] in
(v @ v', P_aux (pat,annot)))} pat) in
@@ -1166,6 +1199,7 @@ let remove_vector_concat_pat pat =
; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps))
; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps))
; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps))
+ ; p_cons = (fun (p,ps) -> P_cons (p false, ps false))
; p_aux =
(fun (pat,((l,_) as annot)) contained_in_p_as ->
match pat with
@@ -1218,8 +1252,8 @@ let remove_vector_concat_pat pat =
(* build a let-expression of the form "let child = root[i..j] in body" *)
let letbind_vec typ_opt (rootid,rannot) (child,cannot) (i,j) =
let (l,_) = cannot in
- let (Id_aux (Id rootname,_)) = rootid in
- let (Id_aux (Id childname,_)) = child in
+ let rootname = string_of_id rootid in
+ let childname = string_of_id child in
let root = E_aux (E_id rootid, rannot) in
let index_i = simple_num l i in
@@ -1248,38 +1282,29 @@ let remove_vector_concat_pat pat =
let rec aux typ_opt (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) =
let ctyp = Env.base_typ_of (env_of_annot cannot) (typ_of_annot cannot) in
let (_,length,ord,_) = vector_typ_args_of ctyp in
- (*)| (_,length,ord,_) ->*)
- let (pos',index_j) = match length with
- | Nexp_aux (Nexp_constant i,_) ->
- if is_order_inc ord then (pos+i, pos+i-1)
- else (pos-i, pos-i+1)
- | Nexp_aux (_,l) ->
- if is_last then (pos,last_idx)
- else
- raise
- (Reporting_basic.err_unreachable
- l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) in
- (match p with
- (* if we see a named vector pattern, remove the name and remember to
- declare it later *)
- | P_as (P_aux (p,cannot),cname) ->
- let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in
- (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)])
- (* if we see a P_id variable, remember to declare it later *)
- | P_id cname ->
- let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in
- (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)])
- | P_typ (typ, pat) -> aux (Some typ) (pos,pat_acc,decl_acc) (pat, is_last)
- (* normal vector patterns are fine *)
- | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc) )
- (* non-vector patterns aren't *)
- (*)| _ ->
- raise
- (Reporting_basic.err_unreachable
- (fst cannot)
- ("unname_vector_concat_elements: Non-vector in vector-concat pattern:" ^
- string_of_typ (typ_of_annot cannot))
- )*) in
+ let (pos',index_j) = match length with
+ | Nexp_aux (Nexp_constant i,_) ->
+ if is_order_inc ord then (pos+i, pos+i-1)
+ else (pos-i, pos-i+1)
+ | Nexp_aux (_,l) ->
+ if is_last then (pos,last_idx)
+ else
+ raise
+ (Reporting_basic.err_unreachable
+ l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) in
+ (match p with
+ (* if we see a named vector pattern, remove the name and remember to
+ declare it later *)
+ | P_as (P_aux (p,cannot),cname) ->
+ let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in
+ (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)])
+ (* if we see a P_id variable, remember to declare it later *)
+ | P_id cname ->
+ let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in
+ (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)])
+ | P_typ (typ, pat) -> aux (Some typ) (pos,pat_acc,decl_acc) (pat, is_last)
+ (* normal vector patterns are fine *)
+ | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc)) in
let pats_tagged = tag_last pats in
let (_,pats',decls') = List.fold_left (aux None) (start,[],[]) pats_tagged in
@@ -1309,6 +1334,7 @@ let remove_vector_concat_pat pat =
(P_tup ps,List.flatten decls))
; p_list = (fun ps -> let (ps,decls) = List.split ps in
(P_list ps,List.flatten decls))
+ ; p_cons = (fun ((p,decls),(p',decls')) -> (P_cons (p,p'), decls @ decls'))
; p_aux = (fun ((pat,decls),annot) -> p_aux ((pat,decls),annot))
; fP_aux = (fun ((fpat,decls),annot) -> (FP_aux (fpat,annot),decls))
; fP_Fpat = (fun (id,(pat,decls)) -> (FP_Fpat (id,pat),decls))
@@ -1417,9 +1443,13 @@ let rewrite_exp_remove_vector_concat_pat rewriters (E_aux (exp,(l,annot)) as ful
let rewrite_base = rewrite_exp rewriters in
match exp with
| E_case (e,ps) ->
- let aux (Pat_aux (Pat_exp (pat,body),annot')) =
+ let aux = function
+ | (Pat_aux (Pat_exp (pat,body),annot')) ->
let (pat,_,decls) = remove_vector_concat_pat pat in
- Pat_aux (Pat_exp (pat, decls (rewrite_rec body)),annot') in
+ Pat_aux (Pat_exp (pat, decls (rewrite_rec body)),annot')
+ | (Pat_aux (Pat_when (pat,guard,body),annot')) ->
+ let (pat,_,decls) = remove_vector_concat_pat pat in
+ Pat_aux (Pat_when (pat, decls (rewrite_rec guard), decls (rewrite_rec body)),annot') in
rewrap (E_case (rewrite_rec e, List.map aux ps))
| E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) ->
let (pat,_,decls) = remove_vector_concat_pat pat in
@@ -1462,6 +1492,177 @@ let rewrite_defs_remove_vector_concat (Defs defs) =
| d -> [d] in
Defs (List.flatten (List.map rewrite_def defs))
+(* A few helper functions for rewriting guarded pattern clauses.
+ Used both by the rewriting of P_when and separately by the rewriting of
+ bitvectors in parameter patterns of function clauses *)
+
+let remove_wildcards pre (P_aux (_,(l,_)) as pat) =
+ fold_pat
+ {id_pat_alg with
+ p_aux = function
+ | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot))
+ | (p,annot) -> P_aux (p,annot) }
+ pat
+
+(* Check if one pattern subsumes the other, and if so, calculate a
+ substitution of variables that are used in the same position.
+ TODO: Check somewhere that there are no variable clashes (the same variable
+ name used in different positions of the patterns)
+ *)
+let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) =
+ let rewrap p = P_aux (p,annot1) in
+ let subsumes_list s pats1 pats2 =
+ if List.length pats1 = List.length pats2
+ then
+ let subs = List.map2 s pats1 pats2 in
+ List.fold_right
+ (fun p acc -> match p, acc with
+ | Some subst, Some substs -> Some (subst @ substs)
+ | _ -> None)
+ subs (Some [])
+ else None in
+ match p1, p2 with
+ | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) ->
+ if lit1 = lit2 then Some [] else None
+ | P_as (pat1,_), _ -> subsumes_pat pat1 pat2
+ | _, P_as (pat2,_) -> subsumes_pat pat1 pat2
+ | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2
+ | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2
+ | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) ->
+ if id1 = id2 then Some []
+ else if Env.lookup_id aid1 (env_of_annot annot1) = Unbound &&
+ Env.lookup_id aid2 (env_of_annot annot2) = Unbound
+ then Some [(id2,id1)] else None
+ | P_id id1, _ ->
+ if Env.lookup_id id1 (env_of_annot annot1) = Unbound then Some [] else None
+ | P_wild, _ -> Some []
+ | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) ->
+ if id1 = id2 then subsumes_list subsumes_pat args1 args2 else None
+ | P_record (fps1,b1), P_record (fps2,b2) ->
+ if b1 = b2 then subsumes_list subsumes_fpat fps1 fps2 else None
+ | P_vector pats1, P_vector pats2
+ | P_vector_concat pats1, P_vector_concat pats2
+ | P_tup pats1, P_tup pats2
+ | P_list pats1, P_list pats2 ->
+ subsumes_list subsumes_pat pats1 pats2
+ | P_list (pat1 :: pats1), P_cons _ ->
+ subsumes_pat (rewrap (P_cons (pat1, rewrap (P_list pats1)))) pat2
+ | P_cons _, P_list (pat2 :: pats2)->
+ subsumes_pat pat1 (rewrap (P_cons (pat2, rewrap (P_list pats2))))
+ | P_cons (pat1, pats1), P_cons (pat2, pats2) ->
+ (match subsumes_pat pat1 pat2, subsumes_pat pats1 pats2 with
+ | Some substs1, Some substs2 -> Some (substs1 @ substs2)
+ | _ -> None)
+ | P_vector_indexed ips1, P_vector_indexed ips2 ->
+ let (is1,ps1) = List.split ips1 in
+ let (is2,ps2) = List.split ips2 in
+ if is1 = is2 then subsumes_list subsumes_pat ps1 ps2 else None
+ | _ -> None
+and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_)) =
+ if id1 = id2 then subsumes_pat pat1 pat2 else None
+
+let equiv_pats pat1 pat2 =
+ match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with
+ | Some _, Some _ -> true
+ | _, _ -> false
+
+let subst_id_pat pat (id1,id2) =
+ let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in
+ fold_pat {id_pat_alg with p_id = p_id} pat
+
+let subst_id_exp exp (id1,id2) =
+ (* TODO Don't substitute bound occurrences inside let expressions etc *)
+ let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in
+ fold_exp {id_exp_alg with e_id = e_id} exp
+
+let rec pat_to_exp (P_aux (pat,(l,annot))) =
+ let rewrap e = E_aux (e,(l,annot)) in
+ match pat with
+ | P_lit lit -> rewrap (E_lit lit)
+ | P_wild -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp given wildcard pattern")
+ | P_as (pat,id) -> rewrap (E_id id)
+ | P_typ (_,pat) -> pat_to_exp pat
+ | P_id id -> rewrap (E_id id)
+ | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats))
+ | P_record (fpats,b) ->
+ rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot))))
+ | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats))
+ | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp not implemented for P_vector_concat")
+ (* We assume that vector concatenation patterns have been transformed
+ away already *)
+ | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats))
+ | P_list pats -> rewrap (E_list (List.map pat_to_exp pats))
+ | P_cons (p,ps) -> rewrap (E_cons (pat_to_exp p, pat_to_exp ps))
+ | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp not implemented for P_vector_indexed") (* TODO *)
+and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) =
+ FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot))
+
+let case_exp e t cs =
+ let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in
+ let ps = List.map pexp cs in
+ (* let efr = union_effs (List.map effect_of_pexp ps) in *)
+ fix_eff_exp (E_aux (E_case (e,ps), (get_loc_exp e, Some (env_of e, t, no_effect))))
+
+let rewrite_guarded_clauses l cs =
+ let rec group clauses =
+ let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in
+ let rec group_aux current acc = (function
+ | ((pat,guard,body,annot) as c) :: cs ->
+ let (current_pat,_,_) = current in
+ (match subsumes_pat current_pat pat with
+ | Some substs ->
+ let pat' = List.fold_left subst_id_pat pat substs in
+ let guard' = (match guard with
+ | Some exp -> Some (List.fold_left subst_id_exp exp substs)
+ | None -> None) in
+ let body' = List.fold_left subst_id_exp body substs in
+ let c' = (pat',guard',body',annot) in
+ group_aux (add_clause current c') acc cs
+ | None ->
+ let pat = remove_wildcards "g__" pat in
+ group_aux (pat,[c],annot) (acc @ [current]) cs)
+ | [] -> acc @ [current]) in
+ let groups = match clauses with
+ | ((pat,guard,body,annot) as c) :: cs ->
+ group_aux (remove_wildcards "g__" pat, [c], annot) [] cs
+ | _ ->
+ raise (Reporting_basic.err_unreachable l
+ "group given empty list in rewrite_guarded_clauses") in
+ List.map (fun cs -> if_pexp cs) groups
+ and if_pexp (pat,cs,annot) = (match cs with
+ | c :: _ ->
+ (* fix_eff_pexp (pexp *)
+ let body = if_exp pat cs in
+ let pexp = fix_eff_pexp (Pat_aux (Pat_exp (pat,body),annot)) in
+ let (Pat_aux (_,annot)) = pexp in
+ (pat, body, annot)
+ | [] ->
+ raise (Reporting_basic.err_unreachable l
+ "if_pexp given empty list in rewrite_guarded_clauses"))
+ and if_exp current_pat = (function
+ | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs ->
+ (match guard with
+ | Some exp ->
+ let else_exp =
+ if equiv_pats current_pat pat'
+ then if_exp current_pat (c' :: cs)
+ else case_exp (pat_to_exp current_pat) (typ_of body') (group (c' :: cs)) in
+ fix_eff_exp (E_aux (E_if (exp,body,else_exp), simple_annot (fst annot) (typ_of body)))
+ | None -> body)
+ | [(pat,guard,body,annot)] -> body
+ | [] ->
+ raise (Reporting_basic.err_unreachable l
+ "if_exp given empty list in rewrite_guarded_clauses")) in
+ group cs
+
+let bitwise_and_exp exp1 exp2 =
+ let (E_aux (_,(l,_))) = exp1 in
+ let andid = Id_aux (Id "bool_and", Parse_ast.Generated l) in
+ E_aux (E_app(andid,[exp1;exp2]), simple_annot l bool_typ)
+
let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with
| P_lit _ | P_wild | P_id _ -> false
| P_as (pat,_) | P_typ (_,pat) -> contains_bitvector_pat pat
@@ -1470,9 +1671,16 @@ let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with
is_bitvector_typ typ
| P_app (_,pats) | P_tup pats | P_list pats ->
List.exists contains_bitvector_pat pats
+| P_cons (p,ps) -> contains_bitvector_pat p || contains_bitvector_pat ps
| P_record (fpats,_) ->
List.exists (fun (FP_aux (FP_Fpat (_,pat),_)) -> contains_bitvector_pat pat) fpats
+let contains_bitvector_pexp = function
+| Pat_aux (Pat_exp (pat,_),_) | Pat_aux (Pat_when (pat,_,_),_) ->
+ contains_bitvector_pat pat
+
+(* Rewrite bitvector patterns to guarded patterns *)
+
let remove_bitvector_pat pat =
(* first introduce names for bitvector patterns *)
@@ -1489,6 +1697,7 @@ let remove_bitvector_pat pat =
; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps))
; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps))
; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps))
+ ; p_cons = (fun (p,ps) -> P_cons (p false, ps false))
; p_aux =
(fun (pat,annot) contained_in_p_as ->
let env = env_of_annot annot in
@@ -1557,14 +1766,8 @@ let remove_bitvector_pat pat =
E_aux (E_let (letbind,body), (Parse_ast.Generated l, bannot))) in
(letexp, letbind) in
- (* Helper functions for composing guards *)
- let bitwise_and exp1 exp2 =
- let (E_aux (_,(l,_))) = exp1 in
- let andid = Id_aux (Id "bool_and", Parse_ast.Generated l) in
- E_aux (E_app(andid,[exp1;exp2]), simple_annot l bool_typ) in
-
let compose_guards guards =
- List.fold_right (Util.option_binop bitwise_and) guards None in
+ List.fold_right (Util.option_binop bitwise_and_exp) guards None in
let flatten_guards_decls gd =
let (guards,decls,letbinds) = Util.split3 gd in
@@ -1651,6 +1854,8 @@ let remove_bitvector_pat pat =
(P_tup ps, flatten_guards_decls gdls))
; p_list = (fun ps -> let (ps,gdls) = List.split ps in
(P_list ps, flatten_guards_decls gdls))
+ ; p_cons = (fun ((p,gdls),(p',gdls')) ->
+ (P_cons (p,p'), flatten_guards_decls [gdls;gdls']))
; p_aux = (fun ((pat,gdls),annot) ->
let env = env_of_annot annot in
let t = Env.base_typ_of env (typ_of_annot annot) in
@@ -1665,183 +1870,27 @@ let remove_bitvector_pat pat =
} in
fold_pat guard_bitvector_pat pat
-let remove_wildcards pre (P_aux (_,(l,_)) as pat) =
- fold_pat
- {id_pat_alg with
- p_aux = function
- | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot))
- | (p,annot) -> P_aux (p,annot) }
- pat
-
-(* Check if one pattern subsumes the other, and if so, calculate a
- substitution of variables that are used in the same position.
- TODO: Check somewhere that there are no variable clashes (the same variable
- name used in different positions of the patterns)
- *)
-let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) =
- let rewrap p = P_aux (p,annot1) in
- let subsumes_list s pats1 pats2 =
- if List.length pats1 = List.length pats2
- then
- let subs = List.map2 s pats1 pats2 in
- List.fold_right
- (fun p acc -> match p, acc with
- | Some subst, Some substs -> Some (subst @ substs)
- | _ -> None)
- subs (Some [])
- else None in
- match p1, p2 with
- | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) ->
- if lit1 = lit2 then Some [] else None
- | P_as (pat1,_), _ -> subsumes_pat pat1 pat2
- | _, P_as (pat2,_) -> subsumes_pat pat1 pat2
- | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2
- | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2
- | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) ->
- if id1 = id2 then Some []
- else if Env.lookup_id aid1 (env_of_annot annot1) = Unbound &&
- Env.lookup_id aid2 (env_of_annot annot2) = Unbound
- then Some [(id2,id1)] else None
- | P_id id1, _ ->
- if Env.lookup_id id1 (env_of_annot annot1) = Unbound then Some [] else None
- | P_wild, _ -> Some []
- | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) ->
- if id1 = id2 then subsumes_list subsumes_pat args1 args2 else None
- | P_record (fps1,b1), P_record (fps2,b2) ->
- if b1 = b2 then subsumes_list subsumes_fpat fps1 fps2 else None
- | P_vector pats1, P_vector pats2
- | P_vector_concat pats1, P_vector_concat pats2
- | P_tup pats1, P_tup pats2
- | P_list pats1, P_list pats2 ->
- subsumes_list subsumes_pat pats1 pats2
- | P_vector_indexed ips1, P_vector_indexed ips2 ->
- let (is1,ps1) = List.split ips1 in
- let (is2,ps2) = List.split ips2 in
- if is1 = is2 then subsumes_list subsumes_pat ps1 ps2 else None
- | _ -> None
-and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_)) =
- if id1 = id2 then subsumes_pat pat1 pat2 else None
-
-let equiv_pats pat1 pat2 =
- match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with
- | Some _, Some _ -> true
- | _, _ -> false
-
-let subst_id_pat pat (id1,id2) =
- let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in
- fold_pat {id_pat_alg with p_id = p_id} pat
-
-let subst_id_exp exp (id1,id2) =
- (* TODO Don't substitute bound occurrences inside let expressions etc *)
- let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in
- fold_exp {id_exp_alg with e_id = e_id} exp
-
-let rec pat_to_exp (P_aux (pat,(l,annot))) =
- let rewrap e = E_aux (e,(l,annot)) in
- match pat with
- | P_lit lit -> rewrap (E_lit lit)
- | P_wild -> raise (Reporting_basic.err_unreachable l
- "pat_to_exp given wildcard pattern")
- | P_as (pat,id) -> rewrap (E_id id)
- | P_typ (_,pat) -> pat_to_exp pat
- | P_id id -> rewrap (E_id id)
- | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats))
- | P_record (fpats,b) ->
- rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot))))
- | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats))
- | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l
- "pat_to_exp not implemented for P_vector_concat")
- (* We assume that vector concatenation patterns have been transformed
- away already *)
- | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats))
- | P_list pats -> rewrap (E_list (List.map pat_to_exp pats))
- | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l
- "pat_to_exp not implemented for P_vector_indexed") (* TODO *)
-and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) =
- FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot))
-
-let case_exp e t cs =
- let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in
- let ps = List.map pexp cs in
- (* let efr = union_effs (List.map effect_of_pexp ps) in *)
- fix_eff_exp (E_aux (E_case (e,ps), (get_loc_exp e, Some (env_of e, t, no_effect))))
-
-let rewrite_guarded_clauses l cs =
- let rec group clauses =
- let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in
- let rec group_aux current acc = (function
- | ((pat,guard,body,annot) as c) :: cs ->
- let (current_pat,_,_) = current in
- (match subsumes_pat current_pat pat with
- | Some substs ->
- let pat' = List.fold_left subst_id_pat pat substs in
- let guard' = (match guard with
- | Some exp -> Some (List.fold_left subst_id_exp exp substs)
- | None -> None) in
- let body' = List.fold_left subst_id_exp body substs in
- let c' = (pat',guard',body',annot) in
- group_aux (add_clause current c') acc cs
- | None ->
- let pat = remove_wildcards "g__" pat in
- group_aux (pat,[c],annot) (acc @ [current]) cs)
- | [] -> acc @ [current]) in
- let groups = match clauses with
- | ((pat,guard,body,annot) as c) :: cs ->
- group_aux (remove_wildcards "g__" pat, [c], annot) [] cs
- | _ ->
- raise (Reporting_basic.err_unreachable l
- "group given empty list in rewrite_guarded_clauses") in
- List.map (fun cs -> if_pexp cs) groups
- and if_pexp (pat,cs,annot) = (match cs with
- | c :: _ ->
- (* fix_eff_pexp (pexp *)
- let body = if_exp pat cs in
- let pexp = fix_eff_pexp (Pat_aux (Pat_exp (pat,body),annot)) in
- let (Pat_aux (Pat_exp (_,_),annot)) = pexp in
- (pat, body, annot)
- | [] ->
- raise (Reporting_basic.err_unreachable l
- "if_pexp given empty list in rewrite_guarded_clauses"))
- and if_exp current_pat = (function
- | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs ->
- (match guard with
- | Some exp ->
- let else_exp =
- if equiv_pats current_pat pat'
- then if_exp current_pat (c' :: cs)
- else case_exp (pat_to_exp current_pat) (typ_of body') (group (c' :: cs)) in
- fix_eff_exp (E_aux (E_if (exp,body,else_exp), simple_annot (fst annot) (typ_of body)))
- | None -> body)
- | [(pat,guard,body,annot)] -> body
- | [] ->
- raise (Reporting_basic.err_unreachable l
- "if_exp given empty list in rewrite_guarded_clauses")) in
- group cs
-
let rewrite_exp_remove_bitvector_pat rewriters (E_aux (exp,(l,annot)) as full_exp) =
let rewrap e = E_aux (e,(l,annot)) in
let rewrite_rec = rewriters.rewrite_exp rewriters in
let rewrite_base = rewrite_exp rewriters in
match exp with
| E_case (e,ps)
- when List.exists (fun (Pat_aux (Pat_exp (pat,_),_)) -> contains_bitvector_pat pat) ps ->
- let clause (Pat_aux (Pat_exp (pat,body),annot')) =
- let (pat',(guard,decls,_)) = remove_bitvector_pat pat in
+ when List.exists contains_bitvector_pexp ps ->
+ let rewrite_pexp = function
+ | Pat_aux (Pat_exp (pat,body),annot') ->
+ let (pat',(guard',decls,_)) = remove_bitvector_pat pat in
let body' = decls (rewrite_rec body) in
- (pat',guard,body',annot') in
- let clauses = rewrite_guarded_clauses l (List.map clause ps) in
- if (effectful e) then
- let e = rewrite_rec e in
- let (E_aux (_,(el,eannot))) = e in
- let pat_e' = fresh_id_pat "p__" (el,eannot) in
- let exp_e' = pat_to_exp pat_e' in
- (* let fresh = fresh_id "p__" el in
- let exp_e' = E_aux (E_id fresh, gen_annot l (get_type e) pure_e) in
- let pat_e' = P_aux (P_id fresh, gen_annot l (get_type e) pure_e) in *)
- let letbind_e = LB_aux (LB_val_implicit (pat_e',e), (el,eannot)) in
- let exp' = case_exp exp_e' (typ_of full_exp) clauses in
- rewrap (E_let (letbind_e, exp'))
- else case_exp e (typ_of full_exp) clauses
+ (match guard' with
+ | Some guard' -> Pat_aux (Pat_when (pat', guard', body'), annot')
+ | None -> Pat_aux (Pat_exp (pat', body'), annot'))
+ | Pat_aux (Pat_when (pat,guard,body),annot') ->
+ let (pat',(guard',decls,_)) = remove_bitvector_pat pat in
+ let body' = decls (rewrite_rec body) in
+ (match guard' with
+ | Some guard' -> Pat_aux (Pat_when (pat', bitwise_and_exp guard guard', body'), annot')
+ | None -> Pat_aux (Pat_when (pat', guard, body'), annot')) in
+ rewrap (E_case (e, List.map rewrite_pexp ps))
| E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) ->
let (pat,(_,decls,_)) = remove_bitvector_pat pat in
rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'),
@@ -1891,6 +1940,38 @@ let rewrite_defs_remove_bitvector_pats (Defs defs) =
Defs (List.flatten (List.map rewrite_def defs))
+(* Remove pattern guards by rewriting them to if-expressions within the
+ pattern expression. Shares code with the rewriting of bitvector patterns. *)
+let rewrite_exp_guarded_pats rewriters (E_aux (exp,(l,annot)) as full_exp) =
+ let rewrap e = E_aux (e,(l,annot)) in
+ let rewrite_rec = rewriters.rewrite_exp rewriters in
+ let rewrite_base = rewrite_exp rewriters in
+ let is_guarded_pexp = function
+ | Pat_aux (Pat_when (_,_,_),_) -> true
+ | _ -> false in
+ match exp with
+ | E_case (e,ps)
+ when List.exists is_guarded_pexp ps ->
+ let clause = function
+ | Pat_aux (Pat_exp (pat, body), annot) ->
+ (pat, None, rewrite_rec body, annot)
+ | Pat_aux (Pat_when (pat, guard, body), annot) ->
+ (pat, Some guard, rewrite_rec body, annot) in
+ let clauses = rewrite_guarded_clauses l (List.map clause ps) in
+ if (effectful e) then
+ let e = rewrite_rec e in
+ let (E_aux (_,(el,eannot))) = e in
+ let pat_e' = fresh_id_pat "p__" (el,eannot) in
+ let exp_e' = pat_to_exp pat_e' in
+ let letbind_e = LB_aux (LB_val_implicit (pat_e',e), (el,eannot)) in
+ let exp' = case_exp exp_e' (typ_of full_exp) clauses in
+ rewrap (E_let (letbind_e, exp'))
+ else case_exp e (typ_of full_exp) clauses
+ | _ -> rewrite_base full_exp
+
+let rewrite_defs_guarded_pats =
+ rewrite_defs_base { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats }
+
(*Expects to be called after rewrite_defs; thus the following should not appear:
internal_exp of any form
lit vectors in patterns or expressions
@@ -2145,8 +2226,11 @@ let rewrite_defs_letbind_effects =
mapCont n_fexp fexps k
and n_pexp (newreturn : bool) (pexp : 'a pexp) (k : 'a pexp -> 'a exp) : 'a exp =
- let (Pat_aux (Pat_exp (pat,exp),annot)) = pexp in
- k (fix_eff_pexp (Pat_aux (Pat_exp (pat,n_exp_term newreturn exp), annot)))
+ match pexp with
+ | Pat_aux (Pat_exp (pat,exp),annot) ->
+ k (fix_eff_pexp (Pat_aux (Pat_exp (pat,n_exp_term newreturn exp), annot)))
+ | Pat_aux (Pat_when (pat,guard,exp),annot) ->
+ k (fix_eff_pexp (Pat_aux (Pat_when (pat,n_exp_term newreturn guard,n_exp_term newreturn exp), annot)))
and n_pexpL (newreturn : bool) (pexps : 'a pexp list) (k : 'a pexp list -> 'a exp) : 'a exp =
mapCont (n_pexp newreturn) pexps k
@@ -2181,6 +2265,9 @@ let rewrite_defs_letbind_effects =
| LEXP_memory (id,es) ->
n_exp_nameL es (fun es ->
k (fix_eff_lexp (LEXP_aux (LEXP_memory (id,es),annot))))
+ | LEXP_tup es ->
+ n_lexpL es (fun es ->
+ k (fix_eff_lexp (LEXP_aux (LEXP_tup es,annot))))
| LEXP_cast (typ,id) ->
k (fix_eff_lexp (LEXP_aux (LEXP_cast (typ,id),annot)))
| LEXP_vector (lexp,e) ->
@@ -2196,6 +2283,9 @@ let rewrite_defs_letbind_effects =
n_lexp lexp (fun lexp ->
k (fix_eff_lexp (LEXP_aux (LEXP_field (lexp,id),annot))))
+ and n_lexpL (lexps : 'a lexp list) (k : 'a lexp list -> 'a exp) : 'a exp =
+ mapCont n_lexp lexps k
+
and n_exp_term (newreturn : bool) (exp : 'a exp) : 'a exp =
let (E_aux (_,(l,tannot))) = exp in
let exp =
@@ -2308,6 +2398,7 @@ let rewrite_defs_letbind_effects =
rewrap (E_let (lb,n_exp body k)))
| E_sizeof nexp ->
k (rewrap (E_sizeof nexp))
+ | E_constraint nc -> failwith "E_constraint should have been removed till now"
| E_sizeof_internal annot ->
k (rewrap (E_sizeof_internal annot))
| E_assign (lexp,exp1) ->
@@ -2403,7 +2494,7 @@ let eqidtyp (id1,_) (id2,_) =
let name2 = match id2 with Id_aux ((Id name | DeIid name),_) -> name in
name1 = name2
-let find_updated_vars exp =
+let find_updated_vars (E_aux (_,(l,_)) as exp) =
let ( @@ ) (a,b) (a',b') = (a @ a',b @ b') in
let lapp2 (l : (('a list * 'b list) list)) : ('a list * 'b list) =
List.fold_left
@@ -2445,8 +2536,14 @@ let find_updated_vars exp =
; e_internal_cast = (fun (_,e1) -> e1)
; e_internal_exp = (fun _ -> ([],[]))
; e_internal_exp_user = (fun _ -> ([],[]))
+ ; e_comment = (fun _ -> ([],[]))
+ ; e_comment_struc = (fun _ -> ([],[]))
; e_internal_let =
- (fun (([id],acc),e2,e3) ->
+ (fun ((ids,acc),e2,e3) ->
+ let id = match ids with
+ | [] -> raise (Reporting_basic.err_unreachable l "E_internal_let found not introducing a variable")
+ | [id] -> id
+ | _ -> raise (Reporting_basic.err_unreachable l "E_internal_let found introducing more than one variable") in
let (xs,ys) = ([id],[]) @@ acc @@ e2 @@ e3 in
let ys = List.filter (fun id2 -> not (eqidtyp id id2)) ys in
(xs,ys))
@@ -2475,6 +2572,7 @@ let find_updated_vars exp =
; def_val_dec = (fun e -> e)
; def_val_aux = (fun (defval,_) -> defval)
; pat_exp = (fun (_,e) -> e)
+ ; pat_when = (fun (_,_,e) -> e)
; pat_aux = (fun (pexp,_) -> pexp)
; lB_val_explicit = (fun (_,_,e) -> e)
; lB_val_implicit = (fun (_,e) -> e)
@@ -2568,7 +2666,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| false, Ord_aux (Ord_inc,_) -> "foreach_inc"
| false, Ord_aux (Ord_dec,_) -> "foreach_dec"
| true, Ord_aux (Ord_inc,_) -> "foreachM_inc"
- | true, Ord_aux (Ord_dec,_) -> "foreachM_dec" in
+ | true, Ord_aux (Ord_dec,_) -> "foreachM_dec"
+ | _ -> raise (Reporting_basic.err_unreachable el
+ "Could not determine foreach combinator") in
let funcl = Id_aux (Id fname,Parse_ast.Generated el) in
let loopvar =
(* Don't bother with creating a range type annotation, since the
@@ -2618,16 +2718,21 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| E_case (e1,ps) ->
(* after rewrite_defs_letbind_effects e1 needs no rewriting *)
let vars =
- let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in
+ let f acc (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) =
+ acc @ find_updated_vars e in
List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t)))
(dedup eqidtyp (List.fold_left f [] ps)) in
if vars = [] then
- let ps = List.map (fun (Pat_aux (Pat_exp (p,e),a)) -> Pat_aux (Pat_exp (p,rewrite_var_updates e),a)) ps in
+ let ps = List.map (function
+ | Pat_aux (Pat_exp (p,e),a) ->
+ Pat_aux (Pat_exp (p,rewrite_var_updates e),a)
+ | Pat_aux (Pat_when (p,g,e),a) ->
+ Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in
Same_vars (E_aux (E_case (e1,ps),annot))
else
let vartuple = mktup el vars in
let typ =
- let (Pat_aux (Pat_exp (_,first),_)) = List.hd ps in
+ let (Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_)) = List.hd ps in
typ_of first in
let (ps,typ,effs) =
let f (acc,typ,effs) (Pat_aux (Pat_exp (p,e),pannot)) =
@@ -2856,9 +2961,10 @@ let rewrite_defs_remove_e_assign =
let rewrite_defs_lem =
top_sort_defs >>
+ rewrite_sizeof >>
rewrite_defs_remove_vector_concat >>
rewrite_defs_remove_bitvector_pats >>
- rewrite_sizeof >>
+ rewrite_defs_guarded_pats >>
rewrite_defs_exp_lift_assign >>
rewrite_defs_remove_blocks >>
rewrite_defs_letbind_effects >>
diff --git a/src/rewriter.mli b/src/rewriter.mli
index b2b0bf5e..473456f6 100644
--- a/src/rewriter.mli
+++ b/src/rewriter.mli
@@ -73,6 +73,7 @@ type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg =
; p_vector_concat : 'pat list -> 'pat_aux
; p_tup : 'pat list -> 'pat_aux
; p_list : 'pat list -> 'pat_aux
+ ; p_cons : 'pat * 'pat -> 'pat_aux
; p_aux : 'pat_aux * 'a annot -> 'pat
; fP_aux : 'fpat_aux * 'a annot -> 'fpat
; fP_Fpat : id * 'pat -> 'fpat_aux
@@ -117,6 +118,8 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
; e_internal_cast : 'a annot * 'exp -> 'exp_aux
; e_internal_exp : 'a annot -> 'exp_aux
; e_internal_exp_user : 'a annot * 'a annot -> 'exp_aux
+ ; e_comment : string -> 'exp_aux
+ ; e_comment_struc : 'exp -> 'exp_aux
; e_internal_let : 'lexp * 'exp * 'exp -> 'exp_aux
; e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux
; e_internal_return : 'exp -> 'exp_aux
@@ -137,6 +140,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
; def_val_dec : 'exp -> 'opt_default_aux
; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default
; pat_exp : 'pat * 'exp -> 'pexp_aux
+ ; pat_when : 'pat * 'exp * 'exp -> 'pexp_aux
; pat_aux : 'pexp_aux * 'a annot -> 'pexp
; lB_val_explicit : typschm * 'pat * 'exp -> 'letbind_aux
; lB_val_implicit : 'pat * 'exp -> 'letbind_aux
diff --git a/src/sail.ml b/src/sail.ml
index 3500b213..c7c14a67 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -132,7 +132,7 @@ let main() =
let ast =
List.fold_right (fun (_,(Parse_ast.Defs ast_nodes)) (Parse_ast.Defs later_nodes)
-> Parse_ast.Defs (ast_nodes@later_nodes)) parsed (Parse_ast.Defs []) in
- let ast = convert_ast ast in
+ let ast = convert_ast Type_check.inc_ord ast in
let (ast, type_envs) = check_ast ast in
let (ast, type_envs) =
diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml
index 1447ff02..fdd56ecc 100644
--- a/src/spec_analysis.ml
+++ b/src/spec_analysis.ml
@@ -357,6 +357,11 @@ and fv_of_pes consider_var bound used set pes =
let bound_p,us_p = pat_bindings consider_var bound used p in
let bound_e,us_e,set_e = fv_of_exp consider_var bound_p us_p set e in
fv_of_pes consider_var bound us_e set_e pes
+ | Pat_aux(Pat_when (p,g,e),_)::pes ->
+ let bound_p,us_p = pat_bindings consider_var bound used p in
+ let bound_g,us_g,set_g = fv_of_exp consider_var bound_p us_p set g in
+ let bound_e,us_e,set_e = fv_of_exp consider_var bound_g us_g set_g e in
+ fv_of_pes consider_var bound us_e set_e pes
and fv_of_let consider_var bound used set (LB_aux(lebind,_)) = match lebind with
| LB_val_explicit(typsch,pat,exp) ->
diff --git a/src/type_check.ml b/src/type_check.ml
index 3c133405..ca9c3618 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -73,6 +73,11 @@ let deinfix = function
| Id_aux (Id v, l) -> Id_aux (DeIid v, l)
| Id_aux (DeIid v, l) -> Id_aux (DeIid v, l)
+let field_name rec_id id =
+ match rec_id, id with
+ | Id_aux (Id r, _), Id_aux (Id v, l) -> Id_aux (Id (r ^ "." ^ v), l)
+ | _, _ -> assert false
+
let string_of_bind (typquant, typ) = string_of_typquant typquant ^ ". " ^ string_of_typ typ
let unaux_nexp (Nexp_aux (nexp, _)) = nexp
@@ -133,6 +138,9 @@ let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown)
let nc_lt n1 n2 = nc_lteq n1 (nsum n2 (nconstant 1))
let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nconstant 1))
let nc_and nc1 nc2 = mk_nc (NC_and (nc1, nc2))
+let nc_or nc1 nc2 = mk_nc (NC_or (nc1, nc2))
+let nc_true = mk_nc NC_true
+let nc_false = mk_nc NC_false
let mk_lit l = E_aux (E_lit (L_aux (l, Parse_ast.Unknown)), (Parse_ast.Unknown, ()))
@@ -145,6 +153,8 @@ let rec nc_negate (NC_aux (nc, _)) =
| NC_not_equal (n1, n2) -> nc_eq n1 n2
| NC_and (n1, n2) -> mk_nc (NC_or (nc_negate n1, nc_negate n2))
| NC_or (n1, n2) -> mk_nc (NC_and (nc_negate n1, nc_negate n2))
+ | NC_false -> mk_nc NC_true
+ | NC_true -> mk_nc NC_false
| NC_nat_set_bounded (kid, []) -> typ_error Parse_ast.Unknown "Cannot negate empty nexp set"
| NC_nat_set_bounded (kid, [int]) -> nc_neq (nvar kid) (nconstant int)
| NC_nat_set_bounded (kid, int :: ints) ->
@@ -208,7 +218,6 @@ let is_typ_kopt = function
| KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), _), _) -> true
| _ -> false
-
(**************************************************************************)
(* 1. Substitutions *)
(**************************************************************************)
@@ -240,6 +249,8 @@ and nc_subst_nexp_aux l sv subst = function
else set_nc
| NC_or (nc1, nc2) -> NC_or (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2)
| NC_and (nc1, nc2) -> NC_and (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2)
+ | NC_false -> NC_false
+ | NC_true -> NC_true
let rec typ_subst_nexp sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_nexp_aux sv subst typ, l)
and typ_subst_nexp_aux sv subst = function
@@ -374,7 +385,7 @@ module Env : sig
val is_union_constructor : id -> t -> bool
val add_record : id -> typquant -> (typ * id) list -> t -> t
val is_record : id -> t -> bool
- val get_accessor : id -> t -> typquant * typ
+ val get_accessor : id -> id -> t -> typquant * typ
val add_local : id -> mut * typ -> t -> t
val add_variant : id -> typquant * type_union list -> t -> t
val add_union_id : id -> typquant * typ -> t -> t
@@ -613,18 +624,18 @@ end = struct
in
let fold_accessors accs (typ, fid) =
let acc_typ = mk_typ (Typ_fn (rectyp, typ, Effect_aux (Effect_set [], Parse_ast.Unknown))) in
- typ_print (indent 1 ^ "Adding accessor " ^ string_of_id fid ^ " :: " ^ string_of_bind (typq, acc_typ));
- Bindings.add fid (typq, acc_typ) accs
+ typ_print (indent 1 ^ "Adding accessor " ^ string_of_id id ^ "." ^ string_of_id fid ^ " :: " ^ string_of_bind (typq, acc_typ));
+ Bindings.add (field_name id fid) (typq, acc_typ) accs
in
{ env with records = Bindings.add id (typq, fields) env.records;
accessors = List.fold_left fold_accessors env.accessors fields }
end
- let get_accessor id env =
+ let get_accessor rec_id id env =
let freshen_bind bind = List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) in
- try freshen_bind (Bindings.find id env.accessors)
+ try freshen_bind (Bindings.find (field_name rec_id id) env.accessors)
with
- | Not_found -> typ_error (id_loc id) ("No accessor found for " ^ string_of_id id)
+ | Not_found -> typ_error (id_loc id) ("No accessor found for " ^ string_of_id (field_name rec_id id))
let is_mutable id env =
try
@@ -776,6 +787,7 @@ end = struct
| NC_nat_set_bounded (kid, ints) -> () (* MAYBE: We could demand that ints are all unique here *)
| NC_or (nc1, nc2) -> wf_constraint env nc1; wf_constraint env nc2
| NC_and (nc1, nc2) -> wf_constraint env nc1; wf_constraint env nc2
+ | NC_true | NC_false -> ()
let get_constraints env = env.constraints
@@ -1045,6 +1057,8 @@ let rec nc_constraint var_of (NC_aux (nc, l)) =
(List.map (fun i -> Constraint.eq (nexp_constraint var_of (nvar kid)) (Constraint.constant (big_int_of_int i))) ints)
| NC_or (nc1, nc2) -> Constraint.disj (nc_constraint var_of nc1) (nc_constraint var_of nc2)
| NC_and (nc1, nc2) -> Constraint.conj (nc_constraint var_of nc1) (nc_constraint var_of nc2)
+ | NC_false -> Constraint.literal false
+ | NC_true -> Constraint.literal true
let rec nc_constraints var_of ncs =
match ncs with
@@ -1085,6 +1099,8 @@ let prove env (NC_aux (nc_aux, _) as nc) =
| NC_fixed (nexp1, nexp2) when compare_const (fun c1 c2 -> c1 <> c2) (nexp_simp nexp1) (nexp_simp nexp2) -> false
| NC_bounded_le (nexp1, nexp2) when compare_const (fun c1 c2 -> c1 > c2) (nexp_simp nexp1) (nexp_simp nexp2) -> false
| NC_bounded_ge (nexp1, nexp2) when compare_const (fun c1 c2 -> c1 < c2) (nexp_simp nexp1) (nexp_simp nexp2) -> false
+ | NC_true -> true
+ | NC_false -> false
| _ -> prove_z3 env nc
let rec subtyp_tnf env tnf1 tnf2 =
@@ -1600,6 +1616,24 @@ let restrict_range_lower c1 (Typ_aux (typ_aux, l) as typ) =
range_typ (nconstant (max c1 c2)) nexp
| _ -> typ
+exception Not_a_constraint;;
+
+let rec assert_nexp (E_aux (exp_aux, l)) =
+ match exp_aux with
+ | E_sizeof nexp -> nexp
+ | E_lit (L_aux (L_num n, _)) -> nconstant n
+ | _ -> raise Not_a_constraint
+
+let rec assert_constraint (E_aux (exp_aux, l)) =
+ match exp_aux with
+ | E_app_infix (x, op, y) when string_of_id op = "|" ->
+ nc_or (assert_constraint x) (assert_constraint y)
+ | E_app_infix (x, op, y) when string_of_id op = "&" ->
+ nc_and (assert_constraint x) (assert_constraint y)
+ | E_app_infix (x, op, y) when string_of_id op = "==" ->
+ nc_eq (assert_nexp x) (assert_nexp y)
+ | _ -> nc_true
+
type flow_constraint =
| Flow_lteq of int
| Flow_gteq of int
@@ -1725,7 +1759,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
| E_block exps, _ ->
begin
let rec check_block l env exps typ = match exps with
- | [] -> typ_error l "Empty block found"
+ | [] -> typ_equality l env typ unit_typ; []
| [exp] -> [crule check_exp env exp typ]
| (E_aux (E_assign (lexp, bind), _) :: exps) ->
let texp, env = bind_assignment env lexp bind in
@@ -1734,6 +1768,14 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
typ_print ("Adding constraint " ^ string_of_n_constraint nc ^ " for assert");
let inferred_exp = irule infer_exp env exp in
inferred_exp :: check_block l (Env.add_constraint nc env) exps typ
+ | ((E_aux (E_assert (const_expr, assert_msg), _) as exp) :: exps) ->
+ begin
+ try
+ let nc = assert_constraint const_expr in
+ check_block l (Env.add_constraint nc env) exps typ
+ with
+ | Not_a_constraint -> check_block l env exps typ
+ end
| (exp :: exps) ->
let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in
texp :: check_block l env exps typ
@@ -1797,7 +1839,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
begin
let (start, len, ord, vtyp) = destructure_vec_typ l env typ in
let checked_items = List.map (fun i -> crule check_exp env i vtyp) vec in
- match len with
+ match nexp_simp len with
| Nexp_aux (Nexp_constant lenc, _) ->
if List.length vec = lenc then annot_exp (E_vector checked_items) typ
else typ_error l "List length didn't match" (* FIXME: improve error message *)
@@ -1932,10 +1974,17 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
annot_pat (P_list pats) typ, env
| _ -> typ_error l "Cannot match list pattern against non-list type"
end
+ | P_tup [] ->
+ begin
+ match Env.expand_synonyms env typ with
+ | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" ->
+ annot_pat (P_tup []) typ, env
+ | _ -> typ_error l "Cannot match unit pattern against non-unit type"
+ end
| P_tup pats ->
begin
- match typ_aux with
- | Typ_tup typs ->
+ match Env.expand_synonyms env typ with
+ | Typ_aux (Typ_tup typs, _) ->
let tpats, env =
try List.fold_left2 bind_tuple_pat ([], env) pats typs with
| Invalid_argument _ -> typ_error l "Tuple pattern and tuple type have different length"
@@ -2040,24 +2089,27 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as
let infer_flexp = function
| LEXP_id v ->
begin match Env.lookup_id v env with
- | Register typ -> typ, LEXP_id v
- | _ -> typ_error l "l-expression field is not a register"
+ | Register typ -> typ, LEXP_id v, true
+ | Local (Mutable, typ) -> typ, LEXP_id v, false
+ | _ -> typ_error l "l-expression field is not a register or a local mutable type"
end
| LEXP_vector (LEXP_aux (LEXP_id v, _), exp) ->
begin
(* Check: is this ok if the vector is immutable? *)
- let is_immutable, vtyp = match Env.lookup_id v env with
+ let is_immutable, vtyp, is_register = match Env.lookup_id v env with
| Unbound -> typ_error l "Cannot assign to element of unbound vector"
| Enum _ -> typ_error l "Cannot vector assign to enumeration element"
- | Local (Immutable, vtyp) -> true, vtyp
- | Local (Mutable, vtyp) | Register vtyp -> false, vtyp
+ | Local (Immutable, vtyp) -> true, vtyp, false
+ | Local (Mutable, vtyp) -> false, vtyp, false
+ | Register vtyp -> false, vtyp, true
in
let access = infer_exp (Env.enable_casts env) (E_aux (E_app (mk_id "vector_access", [E_aux (E_id v, (l, ())); exp]), (l, ()))) in
let E_aux (E_app (_, [_; inferred_exp]), _) = access in
- typ_of access, LEXP_vector (annot_lexp (LEXP_id v) vtyp, inferred_exp)
+ typ_of access, LEXP_vector (annot_lexp (LEXP_id v) vtyp, inferred_exp), is_register
end
in
- let regtyp, inferred_flexp = infer_flexp flexp in
+ let regtyp, inferred_flexp, is_register = infer_flexp flexp in
+ let eff = if is_register then mk_effect [BE_wreg] else no_effect in
typ_debug ("REGTYP: " ^ string_of_typ regtyp ^ " / " ^ string_of_typ (Env.expand_synonyms env regtyp));
match Env.expand_synonyms env regtyp with
| Typ_aux (Typ_id regtyp_id, _) when Env.is_regtyp regtyp_id env ->
@@ -2074,13 +2126,13 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as
| _, _ -> typ_error l "Not implemented this register field type yet..."
in
let checked_exp = crule check_exp env exp vec_typ in
- annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp (mk_effect [BE_wreg]), field)) vec_typ) checked_exp, env
+ annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp eff, field)) vec_typ) checked_exp, env
| Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env ->
let (typq, Typ_aux (Typ_fn (rectyp_q, field_typ, _), _)) = Env.get_accessor field env in
let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q regtyp with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
let field_typ' = subst_unifiers unifiers field_typ in
let checked_exp = crule check_exp env exp field_typ' in
- annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp (mk_effect [BE_wreg]), field)) field_typ') checked_exp, env
+ annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp eff, field)) field_typ') checked_exp, env
| _ -> typ_error l "Field l-expression has invalid type"
end
| LEXP_memory (f, xs) ->
@@ -2254,7 +2306,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
(* Accessing a field of a record *)
| Typ_aux (Typ_id rectyp, _) as typ when Env.is_record rectyp env ->
begin
- let inferred_acc, _ = infer_funapp' l (Env.no_casts env) field (Env.get_accessor field env) [strip_exp inferred_exp] None in
+ let inferred_acc, _ = infer_funapp' l (Env.no_casts env) field (Env.get_accessor rectyp field env) [strip_exp inferred_exp] None in
match inferred_acc with
| E_aux (E_app (field, [inferred_exp]) ,_) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc)
| _ -> assert false (* Unreachable *)
diff --git a/src/type_check.mli b/src/type_check.mli
index 647feaaa..a2b8a10c 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -87,7 +87,7 @@ module Env : sig
val is_record : id -> t -> bool
- val get_accessor : id -> t -> typquant * typ
+ val get_accessor : id -> id -> t -> typquant * typ
(* If the environment is checking a function, then this will get the
expected return type of the function. It's useful for checking or
@@ -105,6 +105,8 @@ module Env : sig
won't throw any exceptions. *)
val lookup_id : id -> t -> lvar
+ val is_union_constructor : id -> t -> bool
+
(* Return a fresh kind identifier that doesn't exist in the environment *)
val fresh_kid : t -> kid