summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair Armstrong2017-07-17 18:46:00 +0100
committerAlasdair Armstrong2017-07-17 18:46:00 +0100
commit1090d8667193e3bc56bfc7a0d028566b36ad3b96 (patch)
treebd7c9d69120df927f4ae3ddf6477345f0c009f4f
parent6c75d9386a9c179969c22baf1231f1bd7b9a60a3 (diff)
Added pattern guards to sail
Introduces a when keyword for case statements, as the Pat_when constructor for pexp's in the AST. This allows us to write things like: typedef T = const union { int C1; int C2 } function int test ((int) x, (T) y) = switch y { case (C1(z)) when z == 0 -> 0 case (C1(z)) when z != 0 -> x quot z case (C2(z)) -> z } this should make translation from ASL's patterns much more straightforward
-rw-r--r--editors/sail-mode.el4
-rw-r--r--src/ast.ml3
-rw-r--r--src/ast_util.ml10
-rw-r--r--src/initial_check.ml2
-rw-r--r--src/lexer.mll1
-rw-r--r--src/parse_ast.ml1
-rw-r--r--src/parser.mly4
-rw-r--r--src/pretty_print_sail.ml8
-rw-r--r--src/type_check_new.ml46
-rw-r--r--test/typecheck/pass/guards.sail22
10 files changed, 81 insertions, 20 deletions
diff --git a/editors/sail-mode.el b/editors/sail-mode.el
index e1da042b..b667b084 100644
--- a/editors/sail-mode.el
+++ b/editors/sail-mode.el
@@ -641,7 +641,7 @@ though, if you also have `sail-electric-indent' on."
(mapc 'sail-define-abbrev
'("scattered" "function" "typedef" "let" "default" "val" "register"
"alias" "union" "member" "clause" "extern" "cast" "effect"
- "rec" "and" "switch" "case" "exit" "foreach" "from" "else"
+ "rec" "and" "switch" "case" "when" "exit" "foreach" "from" "else"
"to" "end" "downto" "in" "then" "with" "if" "nondet" "as"
"undefined" "const" "struct" "IN" "deinfix" "return" "sizeof"))
(setq abbrevs-changed nil))
@@ -711,7 +711,7 @@ Based on Tuareg mode. See Tuareg mode for usage"
`(("\\<\\(extern\\|cast\\|overload\\|deinfix\\|function\\|scattered\\|clause\\|effect\\|default\\|struct\\|const\\|union\\|val\\|typedef\\|in\\|let\\|rec\\|and\\|end\\|register\\|alias\\|member\\|enumerate\\)\\>"
0 sail-font-lock-governing-face nil nil)
("\\<\\(false\\|true\\|bitzero\\|bitone\\|0x[:xdigit:]\\|[:digit:]\\)\\>" 0 font-lock-constant-face nil nil)
- ("\\<\\(as\\|downto\\|else\\|foreach\\|if\\|t\\(hen\\|o\\)\\|when\\|switch\\|with\\|case\\|exit\\|sizeof\\|nondet\\|from\\|by\\|return\\)\\>"
+ ("\\<\\(as\\|downto\\|else\\|foreach\\|if\\|t\\(hen\\|o\\)\\|when\\|switch\\|with\\|case\\|when\\|exit\\|sizeof\\|nondet\\|from\\|by\\|return\\)\\>"
0 font-lock-keyword-face nil nil)
("\\<\\(clause\\)\\>[ \t\n]*\\(\\(\\w\\|[_ \t()*,]\\)+\\)"
2 font-lock-variable-name-face keep nil)
diff --git a/src/ast.ml b/src/ast.ml
index 0d418afc..5444e580 100644
--- a/src/ast.ml
+++ b/src/ast.ml
@@ -364,7 +364,8 @@ and 'a opt_default =
Def_val_aux of 'a opt_default_aux * 'a annot
and 'a pexp_aux = (* Pattern match *)
- Pat_exp of 'a pat * 'a exp
+ Pat_exp of 'a pat * 'a exp
+| Pat_when of 'a pat * 'a exp * 'a exp
and 'a pexp =
Pat_aux of 'a pexp_aux * 'a annot
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 9e2b7a2d..fc59fcf2 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -93,7 +93,10 @@ and map_opt_default_annot_aux f = function
| Def_val_dec exp -> Def_val_dec (map_exp_annot f exp)
and map_fexps_annot f (FES_aux (FES_Fexps (fexps, b), annot)) = FES_aux (FES_Fexps (List.map (map_fexp_annot f) fexps, b), f annot)
and map_fexp_annot f (FE_aux (FE_Fexp (id, exp), annot)) = FE_aux (FE_Fexp (id, map_exp_annot f exp), f annot)
-and map_pexp_annot f (Pat_aux (Pat_exp (pat, exp), annot)) = Pat_aux (Pat_exp (map_pat_annot f pat, map_exp_annot f exp), f annot)
+and map_pexp_annot f (Pat_aux (pexp, annot)) = Pat_aux (map_pexp_annot_aux f pexp, f annot)
+and map_pexp_annot_aux f = function
+ | Pat_exp (pat, exp) -> Pat_exp (map_pat_annot f pat, map_exp_annot f exp)
+ | Pat_when (pat, guard, exp) -> Pat_when (map_pat_annot f pat, map_exp_annot f guard, map_exp_annot f exp)
and map_pat_annot f (P_aux (pat, annot)) = P_aux (map_pat_annot_aux f pat, f annot)
and map_pat_annot_aux f = function
| P_lit lit -> P_lit lit
@@ -278,7 +281,10 @@ let rec string_of_exp (E_aux (exp, _)) =
^ ") { "
^ string_of_exp body
| _ -> "INTERNAL"
-and string_of_pexp (Pat_aux (Pat_exp (pat, exp), _)) = string_of_pat pat ^ " -> " ^ string_of_exp exp
+and string_of_pexp (Pat_aux (pexp, _)) =
+ match pexp with
+ | Pat_exp (pat, exp) -> string_of_pat pat ^ " -> " ^ string_of_exp exp
+ | Pat_when (pat, guard, exp) -> string_of_pat pat ^ " when " ^ string_of_exp guard ^ " -> " ^ string_of_exp exp
and string_of_pat (P_aux (pat, l)) =
match pat with
| P_lit lit -> string_of_lit lit
diff --git a/src/initial_check.ml b/src/initial_check.ml
index df8fb678..89dad186 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -525,6 +525,8 @@ and to_ast_lexp (k_env : kind Envmap.t) (def_ord : order) (Parse_ast.E_aux(exp,l
and to_ast_case (k_env : kind Envmap.t) (def_ord : order) (Parse_ast.Pat_aux(pex,l) : Parse_ast.pexp) : tannot pexp =
match pex with
| Parse_ast.Pat_exp(pat,exp) -> Pat_aux(Pat_exp(to_ast_pat k_env def_ord pat, to_ast_exp k_env def_ord exp),(l,NoTyp))
+ | Parse_ast.Pat_when(pat,guard,exp) ->
+ Pat_aux (Pat_when (to_ast_pat k_env def_ord pat, to_ast_exp k_env def_ord guard, to_ast_exp k_env def_ord exp), (l, NoTyp))
and to_ast_fexps (fail_on_error:bool) (k_env:kind Envmap.t) (def_ord:order) (exps : Parse_ast.exp list) : tannot fexps option =
match exps with
diff --git a/src/lexer.mll b/src/lexer.mll
index 99965e20..36aefb1d 100644
--- a/src/lexer.mll
+++ b/src/lexer.mll
@@ -106,6 +106,7 @@ let kw_table =
("undefined", (fun x -> Undefined));
("union", (fun x -> Union));
("with", (fun x -> With));
+ ("when", (fun x -> When));
("val", (fun x -> Val));
("div", (fun x -> Div_));
diff --git a/src/parse_ast.ml b/src/parse_ast.ml
index 8b52b2ab..4529aa8f 100644
--- a/src/parse_ast.ml
+++ b/src/parse_ast.ml
@@ -305,6 +305,7 @@ and opt_default =
and pexp_aux = (* Pattern match *)
Pat_exp of pat * exp
+ | Pat_when of pat * exp * exp
and pexp =
Pat_aux of pexp_aux * l
diff --git a/src/parser.mly b/src/parser.mly
index 8e61a0ac..b64e0de6 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -131,7 +131,7 @@ let make_vector_sugar order_set is_inc typ typ1 =
%token And Alias As Assert Bitzero Bitone Bits By Case Clause Const Dec Def Default Deinfix Effect EFFECT End
%token Enumerate Else Exit Extern False Forall Foreach Overload Function_ If_ In IN Inc Let_ Member Nat NatNum Order Cast
%token Pure Rec Register Return Scattered Sizeof Struct Switch Then True TwoStarStar Type TYPE Typedef
-%token Undefined Union With Val
+%token Undefined Union With When Val
%token Barr Depend Rreg Wreg Rmem Rmemt Wmem Wmv Wmvt Eamem Exmem Undef Unspec Nondet Escape
@@ -962,6 +962,8 @@ case_exps:
patsexp:
| atomic_pat MinusGt exp
{ peloc (Pat_exp($1,$3)) }
+ | atomic_pat When exp MinusGt exp
+ { peloc (Pat_when ($1, $3, $5)) }
letbind:
| Let_ atomic_pat Eq exp
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 6826087a..a484bd1f 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -364,8 +364,12 @@ let doc_exp, doc_let =
and doc_fexp (FE_aux(FE_Fexp(id,e),_)) = doc_op equals (doc_id id) (exp e)
- and doc_case (Pat_aux(Pat_exp(pat,e),_)) =
- doc_op arrow (separate space [string "case"; doc_atomic_pat pat]) (group (exp e))
+ and doc_case (Pat_aux (pexp, _)) =
+ match pexp with
+ | Pat_exp(pat, e) ->
+ doc_op arrow (separate space [string "case"; doc_atomic_pat pat]) (group (exp e))
+ | Pat_when(pat, guard, e) ->
+ doc_op arrow (separate space [string "case"; doc_atomic_pat pat; string "when"; exp guard]) (group (exp e))
(* lexps are parsed as eq_exp - we need to duplicate the precedence
* structure for them *)
diff --git a/src/type_check_new.ml b/src/type_check_new.ml
index d78a8874..094d11de 100644
--- a/src/type_check_new.ml
+++ b/src/type_check_new.ml
@@ -1560,9 +1560,14 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
end
| E_case (exp, cases), _ ->
let inferred_exp = irule infer_exp env exp in
- let check_case (Pat_aux (Pat_exp (pat, case), (l, _))) typ =
- let tpat, env = bind_pat env pat (typ_of inferred_exp) in
- Pat_aux (Pat_exp (tpat, crule check_exp env case typ), (l, None))
+ let check_case pat typ = match pat with
+ | Pat_aux (Pat_exp (pat, case), (l, _)) ->
+ let tpat, env = bind_pat env pat (typ_of inferred_exp) in
+ Pat_aux (Pat_exp (tpat, crule check_exp env case typ), (l, None))
+ | Pat_aux (Pat_when (pat, guard, case), (l, _)) ->
+ let tpat, env = bind_pat env pat (typ_of inferred_exp) in
+ let checked_guard = check_exp env guard bool_typ in
+ Pat_aux (Pat_when (tpat, checked_guard, crule check_exp env case typ), (l, None))
in
annot_exp (E_case (inferred_exp, List.map (fun case -> check_case case typ) cases)) typ
| E_let (LB_aux (letbind, (let_loc, _)), exp), _ ->
@@ -2286,15 +2291,32 @@ and propagate_exp_effect_aux = function
| exp_aux -> typ_error Parse_ast.Unknown ("Unimplemented: Cannot propagate effect in expression "
^ string_of_exp (E_aux (exp_aux, (Parse_ast.Unknown, None))))
-and propagate_pexp_effect (Pat_aux (Pat_exp (pat, exp), (l, annot))) =
- let propagated_pat = propagate_pat_effect pat in
- let propagated_exp = propagate_exp_effect exp in
- let propagated_eff = union_effects (effect_of_pat propagated_pat) (effect_of propagated_exp) in
- match annot with
- | Some (typq, typ, eff) ->
- Pat_aux (Pat_exp (propagated_pat, propagated_exp), (l, Some (typq, typ, union_effects eff propagated_eff))),
- union_effects eff propagated_eff
- | None -> Pat_aux (Pat_exp (propagated_pat, propagated_exp), (l, None)), propagated_eff
+and propagate_pexp_effect = function
+ | Pat_aux (Pat_exp (pat, exp), (l, annot)) ->
+ begin
+ let propagated_pat = propagate_pat_effect pat in
+ let propagated_exp = propagate_exp_effect exp in
+ let propagated_eff = union_effects (effect_of_pat propagated_pat) (effect_of propagated_exp) in
+ match annot with
+ | Some (typq, typ, eff) ->
+ Pat_aux (Pat_exp (propagated_pat, propagated_exp), (l, Some (typq, typ, union_effects eff propagated_eff))),
+ union_effects eff propagated_eff
+ | None -> Pat_aux (Pat_exp (propagated_pat, propagated_exp), (l, None)), propagated_eff
+ end
+ | Pat_aux (Pat_when (pat, guard, exp), (l, annot)) ->
+ begin
+ let propagated_pat = propagate_pat_effect pat in
+ let propagated_guard = propagate_exp_effect guard in
+ let propagated_exp = propagate_exp_effect exp in
+ let propagated_eff = union_effects (effect_of_pat propagated_pat)
+ (union_effects (effect_of propagated_guard) (effect_of propagated_exp))
+ in
+ match annot with
+ | Some (typq, typ, eff) ->
+ Pat_aux (Pat_when (propagated_pat, propagated_guard, propagated_exp), (l, Some (typq, typ, union_effects eff propagated_eff))),
+ union_effects eff propagated_eff
+ | None -> Pat_aux (Pat_when (propagated_pat, propagated_guard, propagated_exp), (l, None)), propagated_eff
+ end
and propagate_pat_effect (P_aux (pat, annot)) =
let propagated_pat, eff = propagate_pat_effect_aux pat in
diff --git a/test/typecheck/pass/guards.sail b/test/typecheck/pass/guards.sail
new file mode 100644
index 00000000..2811c428
--- /dev/null
+++ b/test/typecheck/pass/guards.sail
@@ -0,0 +1,22 @@
+
+val (int, int) -> int effect pure add_int
+overload (deinfix +) [add_int]
+
+val forall Type 'a. ('a, 'a) -> bool effect pure eq
+val forall Type 'a. ('a, 'a) -> bool effect pure neq
+
+overload (deinfix ==) [eq]
+overload (deinfix !=) [neq]
+
+val (int, int) -> int effect pure quotient
+
+overload (deinfix quot) [quotient]
+
+typedef T = const union { int C1; int C2 }
+
+function int test ((int) x, (T) y) =
+ switch y {
+ case (C1(z)) when z == 0 -> 0
+ case (C1(z)) when z != 0 -> x quot z
+ case (C2(z)) -> z
+ }