diff options
| author | Alasdair Armstrong | 2017-07-17 18:46:00 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2017-07-17 18:46:00 +0100 |
| commit | 1090d8667193e3bc56bfc7a0d028566b36ad3b96 (patch) | |
| tree | bd7c9d69120df927f4ae3ddf6477345f0c009f4f | |
| parent | 6c75d9386a9c179969c22baf1231f1bd7b9a60a3 (diff) | |
Added pattern guards to sail
Introduces a when keyword for case statements, as the Pat_when constructor for pexp's in the AST. This allows us to write things like:
typedef T = const union { int C1; int C2 }
function int test ((int) x, (T) y) =
switch y {
case (C1(z)) when z == 0 -> 0
case (C1(z)) when z != 0 -> x quot z
case (C2(z)) -> z
}
this should make translation from ASL's patterns much more straightforward
| -rw-r--r-- | editors/sail-mode.el | 4 | ||||
| -rw-r--r-- | src/ast.ml | 3 | ||||
| -rw-r--r-- | src/ast_util.ml | 10 | ||||
| -rw-r--r-- | src/initial_check.ml | 2 | ||||
| -rw-r--r-- | src/lexer.mll | 1 | ||||
| -rw-r--r-- | src/parse_ast.ml | 1 | ||||
| -rw-r--r-- | src/parser.mly | 4 | ||||
| -rw-r--r-- | src/pretty_print_sail.ml | 8 | ||||
| -rw-r--r-- | src/type_check_new.ml | 46 | ||||
| -rw-r--r-- | test/typecheck/pass/guards.sail | 22 |
10 files changed, 81 insertions, 20 deletions
diff --git a/editors/sail-mode.el b/editors/sail-mode.el index e1da042b..b667b084 100644 --- a/editors/sail-mode.el +++ b/editors/sail-mode.el @@ -641,7 +641,7 @@ though, if you also have `sail-electric-indent' on." (mapc 'sail-define-abbrev '("scattered" "function" "typedef" "let" "default" "val" "register" "alias" "union" "member" "clause" "extern" "cast" "effect" - "rec" "and" "switch" "case" "exit" "foreach" "from" "else" + "rec" "and" "switch" "case" "when" "exit" "foreach" "from" "else" "to" "end" "downto" "in" "then" "with" "if" "nondet" "as" "undefined" "const" "struct" "IN" "deinfix" "return" "sizeof")) (setq abbrevs-changed nil)) @@ -711,7 +711,7 @@ Based on Tuareg mode. See Tuareg mode for usage" `(("\\<\\(extern\\|cast\\|overload\\|deinfix\\|function\\|scattered\\|clause\\|effect\\|default\\|struct\\|const\\|union\\|val\\|typedef\\|in\\|let\\|rec\\|and\\|end\\|register\\|alias\\|member\\|enumerate\\)\\>" 0 sail-font-lock-governing-face nil nil) ("\\<\\(false\\|true\\|bitzero\\|bitone\\|0x[:xdigit:]\\|[:digit:]\\)\\>" 0 font-lock-constant-face nil nil) - ("\\<\\(as\\|downto\\|else\\|foreach\\|if\\|t\\(hen\\|o\\)\\|when\\|switch\\|with\\|case\\|exit\\|sizeof\\|nondet\\|from\\|by\\|return\\)\\>" + ("\\<\\(as\\|downto\\|else\\|foreach\\|if\\|t\\(hen\\|o\\)\\|when\\|switch\\|with\\|case\\|when\\|exit\\|sizeof\\|nondet\\|from\\|by\\|return\\)\\>" 0 font-lock-keyword-face nil nil) ("\\<\\(clause\\)\\>[ \t\n]*\\(\\(\\w\\|[_ \t()*,]\\)+\\)" 2 font-lock-variable-name-face keep nil) @@ -364,7 +364,8 @@ and 'a opt_default = Def_val_aux of 'a opt_default_aux * 'a annot and 'a pexp_aux = (* Pattern match *) - Pat_exp of 'a pat * 'a exp + Pat_exp of 'a pat * 'a exp +| Pat_when of 'a pat * 'a exp * 'a exp and 'a pexp = Pat_aux of 'a pexp_aux * 'a annot diff --git a/src/ast_util.ml b/src/ast_util.ml index 9e2b7a2d..fc59fcf2 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -93,7 +93,10 @@ and map_opt_default_annot_aux f = function | Def_val_dec exp -> Def_val_dec (map_exp_annot f exp) and map_fexps_annot f (FES_aux (FES_Fexps (fexps, b), annot)) = FES_aux (FES_Fexps (List.map (map_fexp_annot f) fexps, b), f annot) and map_fexp_annot f (FE_aux (FE_Fexp (id, exp), annot)) = FE_aux (FE_Fexp (id, map_exp_annot f exp), f annot) -and map_pexp_annot f (Pat_aux (Pat_exp (pat, exp), annot)) = Pat_aux (Pat_exp (map_pat_annot f pat, map_exp_annot f exp), f annot) +and map_pexp_annot f (Pat_aux (pexp, annot)) = Pat_aux (map_pexp_annot_aux f pexp, f annot) +and map_pexp_annot_aux f = function + | Pat_exp (pat, exp) -> Pat_exp (map_pat_annot f pat, map_exp_annot f exp) + | Pat_when (pat, guard, exp) -> Pat_when (map_pat_annot f pat, map_exp_annot f guard, map_exp_annot f exp) and map_pat_annot f (P_aux (pat, annot)) = P_aux (map_pat_annot_aux f pat, f annot) and map_pat_annot_aux f = function | P_lit lit -> P_lit lit @@ -278,7 +281,10 @@ let rec string_of_exp (E_aux (exp, _)) = ^ ") { " ^ string_of_exp body | _ -> "INTERNAL" -and string_of_pexp (Pat_aux (Pat_exp (pat, exp), _)) = string_of_pat pat ^ " -> " ^ string_of_exp exp +and string_of_pexp (Pat_aux (pexp, _)) = + match pexp with + | Pat_exp (pat, exp) -> string_of_pat pat ^ " -> " ^ string_of_exp exp + | Pat_when (pat, guard, exp) -> string_of_pat pat ^ " when " ^ string_of_exp guard ^ " -> " ^ string_of_exp exp and string_of_pat (P_aux (pat, l)) = match pat with | P_lit lit -> string_of_lit lit diff --git a/src/initial_check.ml b/src/initial_check.ml index df8fb678..89dad186 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -525,6 +525,8 @@ and to_ast_lexp (k_env : kind Envmap.t) (def_ord : order) (Parse_ast.E_aux(exp,l and to_ast_case (k_env : kind Envmap.t) (def_ord : order) (Parse_ast.Pat_aux(pex,l) : Parse_ast.pexp) : tannot pexp = match pex with | Parse_ast.Pat_exp(pat,exp) -> Pat_aux(Pat_exp(to_ast_pat k_env def_ord pat, to_ast_exp k_env def_ord exp),(l,NoTyp)) + | Parse_ast.Pat_when(pat,guard,exp) -> + Pat_aux (Pat_when (to_ast_pat k_env def_ord pat, to_ast_exp k_env def_ord guard, to_ast_exp k_env def_ord exp), (l, NoTyp)) and to_ast_fexps (fail_on_error:bool) (k_env:kind Envmap.t) (def_ord:order) (exps : Parse_ast.exp list) : tannot fexps option = match exps with diff --git a/src/lexer.mll b/src/lexer.mll index 99965e20..36aefb1d 100644 --- a/src/lexer.mll +++ b/src/lexer.mll @@ -106,6 +106,7 @@ let kw_table = ("undefined", (fun x -> Undefined)); ("union", (fun x -> Union)); ("with", (fun x -> With)); + ("when", (fun x -> When)); ("val", (fun x -> Val)); ("div", (fun x -> Div_)); diff --git a/src/parse_ast.ml b/src/parse_ast.ml index 8b52b2ab..4529aa8f 100644 --- a/src/parse_ast.ml +++ b/src/parse_ast.ml @@ -305,6 +305,7 @@ and opt_default = and pexp_aux = (* Pattern match *) Pat_exp of pat * exp + | Pat_when of pat * exp * exp and pexp = Pat_aux of pexp_aux * l diff --git a/src/parser.mly b/src/parser.mly index 8e61a0ac..b64e0de6 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -131,7 +131,7 @@ let make_vector_sugar order_set is_inc typ typ1 = %token And Alias As Assert Bitzero Bitone Bits By Case Clause Const Dec Def Default Deinfix Effect EFFECT End %token Enumerate Else Exit Extern False Forall Foreach Overload Function_ If_ In IN Inc Let_ Member Nat NatNum Order Cast %token Pure Rec Register Return Scattered Sizeof Struct Switch Then True TwoStarStar Type TYPE Typedef -%token Undefined Union With Val +%token Undefined Union With When Val %token Barr Depend Rreg Wreg Rmem Rmemt Wmem Wmv Wmvt Eamem Exmem Undef Unspec Nondet Escape @@ -962,6 +962,8 @@ case_exps: patsexp: | atomic_pat MinusGt exp { peloc (Pat_exp($1,$3)) } + | atomic_pat When exp MinusGt exp + { peloc (Pat_when ($1, $3, $5)) } letbind: | Let_ atomic_pat Eq exp diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 6826087a..a484bd1f 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -364,8 +364,12 @@ let doc_exp, doc_let = and doc_fexp (FE_aux(FE_Fexp(id,e),_)) = doc_op equals (doc_id id) (exp e) - and doc_case (Pat_aux(Pat_exp(pat,e),_)) = - doc_op arrow (separate space [string "case"; doc_atomic_pat pat]) (group (exp e)) + and doc_case (Pat_aux (pexp, _)) = + match pexp with + | Pat_exp(pat, e) -> + doc_op arrow (separate space [string "case"; doc_atomic_pat pat]) (group (exp e)) + | Pat_when(pat, guard, e) -> + doc_op arrow (separate space [string "case"; doc_atomic_pat pat; string "when"; exp guard]) (group (exp e)) (* lexps are parsed as eq_exp - we need to duplicate the precedence * structure for them *) diff --git a/src/type_check_new.ml b/src/type_check_new.ml index d78a8874..094d11de 100644 --- a/src/type_check_new.ml +++ b/src/type_check_new.ml @@ -1560,9 +1560,14 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ end | E_case (exp, cases), _ -> let inferred_exp = irule infer_exp env exp in - let check_case (Pat_aux (Pat_exp (pat, case), (l, _))) typ = - let tpat, env = bind_pat env pat (typ_of inferred_exp) in - Pat_aux (Pat_exp (tpat, crule check_exp env case typ), (l, None)) + let check_case pat typ = match pat with + | Pat_aux (Pat_exp (pat, case), (l, _)) -> + let tpat, env = bind_pat env pat (typ_of inferred_exp) in + Pat_aux (Pat_exp (tpat, crule check_exp env case typ), (l, None)) + | Pat_aux (Pat_when (pat, guard, case), (l, _)) -> + let tpat, env = bind_pat env pat (typ_of inferred_exp) in + let checked_guard = check_exp env guard bool_typ in + Pat_aux (Pat_when (tpat, checked_guard, crule check_exp env case typ), (l, None)) in annot_exp (E_case (inferred_exp, List.map (fun case -> check_case case typ) cases)) typ | E_let (LB_aux (letbind, (let_loc, _)), exp), _ -> @@ -2286,15 +2291,32 @@ and propagate_exp_effect_aux = function | exp_aux -> typ_error Parse_ast.Unknown ("Unimplemented: Cannot propagate effect in expression " ^ string_of_exp (E_aux (exp_aux, (Parse_ast.Unknown, None)))) -and propagate_pexp_effect (Pat_aux (Pat_exp (pat, exp), (l, annot))) = - let propagated_pat = propagate_pat_effect pat in - let propagated_exp = propagate_exp_effect exp in - let propagated_eff = union_effects (effect_of_pat propagated_pat) (effect_of propagated_exp) in - match annot with - | Some (typq, typ, eff) -> - Pat_aux (Pat_exp (propagated_pat, propagated_exp), (l, Some (typq, typ, union_effects eff propagated_eff))), - union_effects eff propagated_eff - | None -> Pat_aux (Pat_exp (propagated_pat, propagated_exp), (l, None)), propagated_eff +and propagate_pexp_effect = function + | Pat_aux (Pat_exp (pat, exp), (l, annot)) -> + begin + let propagated_pat = propagate_pat_effect pat in + let propagated_exp = propagate_exp_effect exp in + let propagated_eff = union_effects (effect_of_pat propagated_pat) (effect_of propagated_exp) in + match annot with + | Some (typq, typ, eff) -> + Pat_aux (Pat_exp (propagated_pat, propagated_exp), (l, Some (typq, typ, union_effects eff propagated_eff))), + union_effects eff propagated_eff + | None -> Pat_aux (Pat_exp (propagated_pat, propagated_exp), (l, None)), propagated_eff + end + | Pat_aux (Pat_when (pat, guard, exp), (l, annot)) -> + begin + let propagated_pat = propagate_pat_effect pat in + let propagated_guard = propagate_exp_effect guard in + let propagated_exp = propagate_exp_effect exp in + let propagated_eff = union_effects (effect_of_pat propagated_pat) + (union_effects (effect_of propagated_guard) (effect_of propagated_exp)) + in + match annot with + | Some (typq, typ, eff) -> + Pat_aux (Pat_when (propagated_pat, propagated_guard, propagated_exp), (l, Some (typq, typ, union_effects eff propagated_eff))), + union_effects eff propagated_eff + | None -> Pat_aux (Pat_when (propagated_pat, propagated_guard, propagated_exp), (l, None)), propagated_eff + end and propagate_pat_effect (P_aux (pat, annot)) = let propagated_pat, eff = propagate_pat_effect_aux pat in diff --git a/test/typecheck/pass/guards.sail b/test/typecheck/pass/guards.sail new file mode 100644 index 00000000..2811c428 --- /dev/null +++ b/test/typecheck/pass/guards.sail @@ -0,0 +1,22 @@ + +val (int, int) -> int effect pure add_int +overload (deinfix +) [add_int] + +val forall Type 'a. ('a, 'a) -> bool effect pure eq +val forall Type 'a. ('a, 'a) -> bool effect pure neq + +overload (deinfix ==) [eq] +overload (deinfix !=) [neq] + +val (int, int) -> int effect pure quotient + +overload (deinfix quot) [quotient] + +typedef T = const union { int C1; int C2 } + +function int test ((int) x, (T) y) = + switch y { + case (C1(z)) when z == 0 -> 0 + case (C1(z)) when z != 0 -> x quot z + case (C2(z)) -> z + } |
