summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast.ml2
-rw-r--r--src/ast_util.ml5
-rw-r--r--src/ast_util.mli1
-rw-r--r--src/initial_check.ml2
-rw-r--r--src/parse_ast.ml2
-rw-r--r--src/parser.mly4
-rw-r--r--src/parser2.mly10
-rw-r--r--src/pretty_print_lem.ml2
-rw-r--r--src/pretty_print_lem_ast.ml2
-rw-r--r--src/pretty_print_sail.ml5
-rw-r--r--src/rewriter.ml20
-rw-r--r--src/rewriter.mli2
-rw-r--r--src/type_check.ml34
13 files changed, 54 insertions, 37 deletions
diff --git a/src/ast.ml b/src/ast.ml
index c4b225f1..526314ca 100644
--- a/src/ast.ml
+++ b/src/ast.ml
@@ -260,7 +260,7 @@ type
| P_as of 'a pat * id (* named pattern *)
| P_typ of typ * 'a pat (* typed pattern *)
| P_id of id (* identifier *)
- | P_var of kid (* bind identifier and type variable *)
+ | P_var of 'a pat * kid (* bind pattern to type variable *)
| P_app of id * ('a pat) list (* union constructor pattern *)
| P_record of ('a fpat) list * bool (* struct pattern *)
| P_vector of ('a pat) list (* vector pattern *)
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 61633236..200c63fe 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -61,6 +61,7 @@ let mk_exp exp_aux = E_aux (exp_aux, no_annot)
let unaux_exp (E_aux (exp_aux, _)) = exp_aux
let mk_pat pat_aux = P_aux (pat_aux, no_annot)
+let unaux_pat (P_aux (pat_aux, _)) = pat_aux
let mk_lexp lexp_aux = LEXP_aux (lexp_aux, no_annot)
@@ -269,7 +270,7 @@ and map_pat_annot_aux f = function
| P_as (pat, id) -> P_as (map_pat_annot f pat, id)
| P_typ (typ, pat) -> P_typ (typ, map_pat_annot f pat)
| P_id id -> P_id id
- | P_var kid -> P_var kid
+ | P_var (pat, kid) -> P_var (map_pat_annot f pat, kid)
| P_app (id, pats) -> P_app (id, List.map (map_pat_annot f) pats)
| P_record (fpats, b) -> P_record (List.map (map_fpat_annot f) fpats, b)
| P_tup pats -> P_tup (List.map (map_pat_annot f) pats)
@@ -502,7 +503,7 @@ and string_of_pat (P_aux (pat, l)) =
| P_lit lit -> string_of_lit lit
| P_wild -> "_"
| P_id v -> string_of_id v
- | P_var kid -> string_of_kid kid
+ | P_var (pat, kid) -> string_of_pat pat ^ " as " ^ string_of_kid kid
| P_typ (typ, pat) -> "(" ^ string_of_typ typ ^ ") " ^ string_of_pat pat
| P_tup pats -> "(" ^ string_of_list ", " string_of_pat pats ^ ")"
| P_app (f, pats) -> string_of_id f ^ "(" ^ string_of_list ", " string_of_pat pats ^ ")"
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 33d65ede..18223f4a 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -67,6 +67,7 @@ val mk_fexps : (unit fexp) list -> unit fexps
val mk_letbind : unit pat -> unit exp -> unit letbind
val unaux_exp : 'a exp -> 'a exp_aux
+val unaux_pat : 'a pat -> 'a pat_aux
val inc_ord : order
val dec_ord : order
diff --git a/src/initial_check.ml b/src/initial_check.ml
index 9f5fd4e6..42f4e1dd 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -418,7 +418,7 @@ let rec to_ast_pat (k_env : kind Envmap.t) (def_ord : order) (Parse_ast.P_aux(pa
| Parse_ast.P_as(pat,id) -> P_as(to_ast_pat k_env def_ord pat,to_ast_id id)
| Parse_ast.P_typ(typ,pat) -> P_typ(to_ast_typ k_env def_ord typ,to_ast_pat k_env def_ord pat)
| Parse_ast.P_id(id) -> P_id(to_ast_id id)
- | Parse_ast.P_var kid -> P_var (to_ast_var kid)
+ | Parse_ast.P_var (pat, kid) -> P_var (to_ast_pat k_env def_ord pat, to_ast_var kid)
| Parse_ast.P_app(id,pats) ->
if pats = []
then P_id (to_ast_id id)
diff --git a/src/parse_ast.ml b/src/parse_ast.ml
index 73f75919..c7365d03 100644
--- a/src/parse_ast.ml
+++ b/src/parse_ast.ml
@@ -234,7 +234,7 @@ pat_aux = (* Pattern *)
| P_as of pat * id (* named pattern *)
| P_typ of atyp * pat (* typed pattern *)
| P_id of id (* identifier *)
- | P_var of kid
+ | P_var of pat * kid (* bind pat to type variable *)
| P_app of id * (pat) list (* union constructor pattern *)
| P_record of (fpat) list * bool (* struct pattern *)
| P_vector of (pat) list (* vector pattern *)
diff --git a/src/parser.mly b/src/parser.mly
index dd0981fc..0508f116 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -49,6 +49,8 @@ open Parse_ast
let loc () = Range(Parsing.symbol_start_pos(),Parsing.symbol_end_pos())
let locn m n = Range(Parsing.rhs_start_pos m,Parsing.rhs_end_pos n)
+let id_of_kid (Kid_aux (Var str, l)) = Id_aux (Id str, l)
+
let idl i = Id_aux(i, loc())
let string_of_id = function
@@ -506,7 +508,7 @@ atomic_pat:
| id
{ ploc (P_app($1,[])) }
| tyvar
- { ploc (P_var $1) }
+ { ploc (P_var (ploc (P_id (id_of_kid $1)), $1)) }
| Lcurly fpats Rcurly
{ ploc (P_record((fst $2, snd $2))) }
| Lsquare comma_pats Rsquare
diff --git a/src/parser2.mly b/src/parser2.mly
index c00287e0..a735c3f5 100644
--- a/src/parser2.mly
+++ b/src/parser2.mly
@@ -51,6 +51,8 @@ let loc n m = Range (m, n)
let mk_id i n m = Id_aux (i, loc m n)
let mk_kid str n m = Kid_aux (Var str, loc n m)
+let id_of_kid (Kid_aux (Var str, l)) = Id_aux (Id str, l)
+
let deinfix (Id_aux (Id v, l)) = Id_aux (DeIid v, l)
let mk_effect e n m = BE_aux (e, loc n m)
@@ -604,7 +606,7 @@ atomic_pat:
| id
{ mk_pat (P_id $1) $startpos $endpos }
| kid
- { mk_pat (P_var $1) $startpos $endpos }
+ { mk_pat (P_var (mk_pat (P_id (id_of_kid $1)) $startpos $endpos, $1)) $startpos $endpos }
| id Lparen pat_list Rparen
{ mk_pat (P_app ($1, $3)) $startpos $endpos }
| pat Colon typ
@@ -943,6 +945,12 @@ let_def:
val_spec_def:
| Val id Colon typschm
{ mk_vs (VS_val_spec ($4, $2, None, false)) $startpos $endpos }
+ | Val Cast id Colon typschm
+ { mk_vs (VS_val_spec ($5, $3, None, true)) $startpos $endpos }
+ | Val id Eq String Colon typschm
+ { mk_vs (VS_val_spec ($6, $2, Some $4, false)) $startpos $endpos }
+ | Val Cast id Eq String Colon typschm
+ { mk_vs (VS_val_spec ($7, $3, Some $5, true)) $startpos $endpos }
register_def:
| Register id Colon typ
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index d40381b9..29f4cbf2 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -323,7 +323,7 @@ let rec doc_pat_lem sequential mwords apat_needed (P_aux (p,(l,annot)) as pa) =
begin match id with
| Id_aux (Id "None",_) -> string "Nothing" (* workaround temporary issue *)
| _ -> doc_id_lem id end
- | P_var kid -> doc_var_lem kid
+ | P_var(p,kid) -> parens (separate space [doc_pat_lem sequential mwords true p; string "as"; doc_var_lem kid])
| P_as(p,id) -> parens (separate space [doc_pat_lem sequential mwords true p; string "as"; doc_id_lem id])
| P_typ(typ,p) ->
let doc_p = doc_pat_lem sequential mwords true p in
diff --git a/src/pretty_print_lem_ast.ml b/src/pretty_print_lem_ast.ml
index a7631843..2c9cf83c 100644
--- a/src/pretty_print_lem_ast.ml
+++ b/src/pretty_print_lem_ast.ml
@@ -328,7 +328,7 @@ let rec pp_format_pat_lem (P_aux(p,(l,annot))) =
| P_lit(lit) -> "(P_lit " ^ pp_format_lit_lem lit ^ ")"
| P_wild -> "P_wild"
| P_id(id) -> "(P_id " ^ pp_format_id_lem id ^ ")"
- | P_var(kid) -> "(P_var " ^ pp_format_var_lem kid ^ ")"
+ | P_var(pat,kid) -> "(P_var " ^ pp_format_pat_lem pat ^ " " ^ pp_format_var_lem kid ^ ")"
| P_as(pat,id) -> "(P_as " ^ pp_format_pat_lem pat ^ " " ^ pp_format_id_lem id ^ ")"
| P_typ(typ,pat) -> "(P_typ " ^ pp_format_typ_lem typ ^ " " ^ pp_format_pat_lem pat ^ ")"
| P_app(id,pats) -> "(P_app " ^ pp_format_id_lem id ^ " [" ^
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 72218bc3..6a099712 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -41,6 +41,7 @@
(**************************************************************************)
open Ast
+open Ast_util
open PPrint
open Pretty_print_common
@@ -109,7 +110,9 @@ let doc_pat, doc_atomic_pat =
| P_lit lit -> doc_lit lit
| P_wild -> underscore
| P_id id -> doc_id id
- | P_var kid -> doc_var kid
+ | P_var (P_aux (P_id id, _), kid) when Id.compare (id_of_kid kid) id == 0 ->
+ doc_var kid
+ | P_var(p,kid) -> parens (separate space [pat p; string "as"; doc_var kid])
| P_as(p,id) -> parens (separate space [pat p; string "as"; doc_id id])
| P_typ(typ,p) -> separate space [parens (doc_typ typ); atomic_pat p]
| P_app(id,[]) -> doc_id id
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 5c6aa0d8..df264d1e 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -609,7 +609,7 @@ type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg =
; p_as : 'pat * id -> 'pat_aux
; p_typ : Ast.typ * 'pat -> 'pat_aux
; p_id : id -> 'pat_aux
- ; p_var : kid -> 'pat_aux
+ ; p_var : 'pat * kid -> 'pat_aux
; p_app : id * 'pat list -> 'pat_aux
; p_record : 'fpat list * bool -> 'pat_aux
; p_vector : 'pat list -> 'pat_aux
@@ -627,8 +627,8 @@ let rec fold_pat_aux (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a pat
| P_lit lit -> alg.p_lit lit
| P_wild -> alg.p_wild
| P_id id -> alg.p_id id
- | P_var kid -> alg.p_var kid
- | P_as (p,id) -> alg.p_as (fold_pat alg p,id)
+ | P_var (p, kid) -> alg.p_var (fold_pat alg p, kid)
+ | P_as (p,id) -> alg.p_as (fold_pat alg p, id)
| P_typ (typ,p) -> alg.p_typ (typ,fold_pat alg p)
| P_app (id,ps) -> alg.p_app (id,List.map (fold_pat alg) ps)
| P_record (ps,b) -> alg.p_record (List.map (fold_fpat alg) ps, b)
@@ -656,7 +656,7 @@ let id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg =
; p_as = (fun (pat,id) -> P_as (pat,id))
; p_typ = (fun (typ,pat) -> P_typ (typ,pat))
; p_id = (fun id -> P_id id)
- ; p_var = (fun kid -> P_var kid)
+ ; p_var = (fun (pat,kid) -> P_var (pat,kid))
; p_app = (fun (id,ps) -> P_app (id,ps))
; p_record = (fun (ps,b) -> P_record (ps,b))
; p_vector = (fun ps -> P_vector ps)
@@ -883,7 +883,7 @@ let compute_pat_alg bot join =
; p_as = (fun ((v,pat),id) -> (v, P_as (pat,id)))
; p_typ = (fun (typ,(v,pat)) -> (v, P_typ (typ,pat)))
; p_id = (fun id -> (bot, P_id id))
- ; p_var = (fun kid -> (bot, P_var kid))
+ ; p_var = (fun ((v,pat),kid) -> (v, P_var (pat,kid)))
; p_app = (fun (id,ps) -> split_join (fun ps -> P_app (id,ps)) ps)
; p_record = (fun (ps,b) -> split_join (fun ps -> P_record (ps,b)) ps)
; p_vector = split_join (fun ps -> P_vector ps)
@@ -1331,7 +1331,7 @@ let remove_vector_concat_pat pat =
; p_wild = P_wild
; p_as = (fun (pat,id) -> P_as (pat true,id))
; p_id = (fun id -> P_id id)
- ; p_var = (fun kid -> P_var kid)
+ ; p_var = (fun (pat,kid) -> P_var (pat true,kid))
; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps))
; p_record = (fun (fpats,b) -> P_record (fpats, b))
; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps))
@@ -1461,7 +1461,7 @@ let remove_vector_concat_pat pat =
; p_as = (fun ((pat,decls),id) -> (P_as (pat,id),decls))
; p_typ = (fun (typ,(pat,decls)) -> (P_typ (typ,pat),decls))
; p_id = (fun id -> (P_id id,[]))
- ; p_var = (fun kid -> (P_var kid, []))
+ ; p_var = (fun ((pat,decls),kid) -> (P_var (pat,kid),decls))
; p_app = (fun (id,ps) -> let (ps,decls) = List.split ps in
(P_app (id,ps),List.flatten decls))
; p_record = (fun (ps,b) -> let (ps,decls) = List.split ps in
@@ -1816,7 +1816,7 @@ let remove_bitvector_pat pat =
; p_wild = P_wild
; p_as = (fun (pat,id) -> P_as (pat true,id))
; p_id = (fun id -> P_id id)
- ; p_var = (fun kid -> P_var kid)
+ ; p_var = (fun (pat,kid) -> P_var (pat true,kid))
; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps))
; p_record = (fun (fpats,b) -> P_record (fpats, b))
; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps))
@@ -1976,7 +1976,7 @@ let remove_bitvector_pat pat =
; p_as = (fun ((pat,gdls),id) -> (P_as (pat,id), gdls))
; p_typ = (fun (typ,(pat,gdls)) -> (P_typ (typ,pat), gdls))
; p_id = (fun id -> (P_id id, (None, (fun b -> b), [])))
- ; p_var = (fun kid -> (P_var kid, (None, (fun b -> b), [])))
+ ; p_var = (fun ((pat,gdls),kid) -> (P_var (pat,kid), gdls))
; p_app = (fun (id,ps) -> let (ps,gdls) = List.split ps in
(P_app (id,ps), flatten_guards_decls gdls))
; p_record = (fun (ps,b) -> let (ps,gdls) = List.split ps in
@@ -2499,7 +2499,7 @@ let rewrite_simple_types (Defs defs) =
let simple_pat = {
id_pat_alg with
p_typ = (fun (typ, pat) -> P_typ (simple_typ typ, pat));
- p_var = (fun kid -> P_id (id_of_kid kid));
+ p_var = (fun (pat, kid) -> unaux_pat pat);
p_vector = (fun pats -> P_list pats)
} in
let simple_exp = {
diff --git a/src/rewriter.mli b/src/rewriter.mli
index 32974bd0..2bf00b06 100644
--- a/src/rewriter.mli
+++ b/src/rewriter.mli
@@ -66,7 +66,7 @@ type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg =
; p_as : 'pat * id -> 'pat_aux
; p_typ : Ast.typ * 'pat -> 'pat_aux
; p_id : id -> 'pat_aux
- ; p_var : kid -> 'pat_aux
+ ; p_var : 'pat * kid -> 'pat_aux
; p_app : id * 'pat list -> 'pat_aux
; p_record : 'fpat list * bool -> 'pat_aux
; p_vector : 'pat list -> 'pat_aux
diff --git a/src/type_check.ml b/src/type_check.ml
index 33c6ff6d..b585c85e 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -2107,22 +2107,22 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
| Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m)
end
end
- | P_var kid ->
+ | P_var (pat, kid) ->
+ let typ = Env.expand_synonyms env typ in
begin
- let v = id_of_kid kid in
- match Env.lookup_id v env with
- | Local (Immutable, _) | Unbound ->
- begin
- match destruct_exist env typ with
- | Some ([kid'], nc, typ) ->
- let env = Env.add_typ_var kid BK_nat env in
- let env = Env.add_constraint (nc_subst_nexp kid' (Nexp_var kid) nc) env in
- let env = Env.add_local v (Immutable, typ_subst_nexp kid' (Nexp_var kid) typ) env in
- annot_pat (P_var kid) typ, env
- | Some _ -> typ_error l ("Cannot bind type variable pattern against multiple argument existential")
- | None _ -> typ_error l ("Cannot bind type variable against non existential type")
- end
- | _ -> typ_error l ("Bad type identifer pattern: " ^ string_of_pat pat)
+ match destruct_exist env typ, typ with
+ | Some ([kid'], nc, ex_typ), _ ->
+ let env = Env.add_typ_var kid BK_nat env in
+ let ex_typ = typ_subst_nexp kid' (Nexp_var kid) ex_typ in
+ let env = Env.add_constraint (nc_subst_nexp kid' (Nexp_var kid) nc) env in
+ let typed_pat, env = bind_pat env pat ex_typ in
+ annot_pat (P_var (typed_pat, kid)) typ, env
+ | Some _, _ -> typ_error l ("Cannot bind type variable pattern against multiple argument existential")
+ | None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "int") == 0 ->
+ let env = Env.add_typ_var kid BK_nat env in
+ let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in
+ annot_pat (P_var (typed_pat, kid)) typ, env
+ | None, _ -> typ_error l ("Cannot bind type variable against non existential type")
end
| P_wild -> annot_pat P_wild typ, env
| P_cons (hd_pat, tl_pat) ->
@@ -2940,7 +2940,9 @@ and propagate_pat_effect_aux = function
let p_pat = propagate_pat_effect pat in
P_typ (typ, p_pat), effect_of_pat p_pat
| P_id id -> P_id id, no_effect
- | P_var kid -> P_var kid, no_effect
+ | P_var (pat, kid) ->
+ let p_pat = propagate_pat_effect pat in
+ P_var (p_pat, kid), effect_of_pat p_pat
| P_app (id, pats) ->
let p_pats = List.map propagate_pat_effect pats in
P_app (id, p_pats), collect_effects_pat p_pats