diff options
Diffstat (limited to 'src/interpreter.ml')
| -rw-r--r-- | src/interpreter.ml | 233 |
1 files changed, 195 insertions, 38 deletions
diff --git a/src/interpreter.ml b/src/interpreter.ml index 3356b9dc..4fd75094 100644 --- a/src/interpreter.ml +++ b/src/interpreter.ml @@ -51,9 +51,22 @@ open Ast open Ast_util open Value -(* open Type_check *) -type state = St +type gstate = + { registers : value Bindings.t } + +type lstate = + { locals : value Bindings.t } + +type state = lstate * gstate + +let initial_gstate = + { registers = Bindings.empty } + +let initial_lstate = + { locals = Bindings.empty } + +let initial_state = initial_lstate, initial_gstate let value_of_lit (L_aux (l_aux, _)) = match l_aux with @@ -63,6 +76,7 @@ let value_of_lit (L_aux (l_aux, _)) = | L_true -> V_bool true | L_false -> V_bool false | L_string str -> V_string str + | L_num n -> V_int n | _ -> failwith "Unimplemented value_of_lit" (* TODO *) let is_value = function @@ -87,23 +101,23 @@ let value_of_exp = function (**************************************************************************) type 'a response = - | Final of value + | Early_return of value | Exception of value | Assertion_failed of string | Call of id * value list * (value -> 'a) | Gets of (state -> 'a) - | Puts of state * 'a + | Puts of state * (unit -> 'a) and 'a monad = | Pure of 'a | Yield of ('a monad response) let map_response f = function - | Final v -> Final v + | Early_return v -> Early_return v | Exception v -> Exception v | Assertion_failed str -> Assertion_failed str | Gets g -> Gets (fun s -> f (g s)) - | Puts (s, x) -> Puts (s, f x) + | Puts (s, cont) -> Puts (s, fun () -> f (cont ())) | Call (id, vals, cont) -> Call (id, vals, fun v -> f (cont v)) let rec liftM f = function @@ -142,9 +156,9 @@ let gets : state monad = Yield (Gets (fun s -> Pure s)) let puts (s : state) : unit monad = - Yield (Puts (s, Pure ())) + Yield (Puts (s, fun () -> Pure ())) -let final v = Yield (Final v) +let early_return v = Yield (Early_return v) let assertion_failed msg = Yield (Assertion_failed msg) @@ -156,9 +170,22 @@ let rec subst id value (E_aux (e_aux, annot) as exp) = | E_block exps -> E_block (List.map (subst id value) exps) | E_nondet exps -> E_nondet (List.map (subst id value) exps) | E_id id' -> if Id.compare id id' = 0 then unaux_exp (exp_of_value value) else E_id id' + | E_lit lit -> E_lit lit | E_cast (typ, exp) -> E_cast (typ, subst id value exp) | E_app (fn, exps) -> E_app (fn, List.map (subst id value) exps) - | _ -> assert false (* TODO *) + | E_app_infix (exp1, op, exp2) -> E_app_infix (subst id value exp1, op, subst id value exp2) + | E_tuple exps -> E_tuple (List.map (subst id value) exps) + | E_assign (lexp, exp) -> E_assign (lexp, subst id value exp) (* Shadowing... *) + | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) -> + (* TODO: Fix shadowing *) + E_let (LB_aux (LB_val (pat, subst id value bind), lb_annot), subst id value body) + | E_if (cond, then_exp, else_exp) -> + E_if (subst id value cond, subst id value then_exp, subst id value else_exp) + | E_vector exps -> E_vector (List.map (subst id value) exps) + | E_return exp -> E_return (subst id value exp) + | E_assert (exp1, exp2) -> E_assert (subst id value exp1, subst id value exp2) + | E_internal_value v -> E_internal_value v + | _ -> failwith ("subst " ^ string_of_exp exp) in wrap e_aux @@ -167,7 +194,12 @@ let rec subst id value (E_aux (e_aux, annot) as exp) = (* 2. Expression Evaluation *) (**************************************************************************) -let rec step (E_aux (e_aux, annot)) = +let unit_exp = E_lit (L_aux (L_unit, Parse_ast.Unknown)) + +let is_value_fexp (FE_aux (FE_Fexp (id, exp), _)) = is_value exp +let value_of_fexp (FE_aux (FE_Fexp (id, exp), _)) = (string_of_id id, value_of_exp exp) + +let rec step (E_aux (e_aux, annot) as orig_exp) = let wrap e_aux' = return (E_aux (e_aux', annot)) in match e_aux with | E_block [] -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) @@ -183,11 +215,38 @@ let rec step (E_aux (e_aux, annot)) = | E_if (exp, then_exp, else_exp) -> step exp >>= fun exp' -> wrap (E_if (exp', then_exp, else_exp)) - | E_assert (exp, msg) when is_true exp -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) + | E_assert (exp, msg) when is_true exp -> wrap unit_exp | E_assert (exp, msg) when is_false exp -> assertion_failed "FIXME" | E_assert (exp, msg) -> step exp >>= fun exp' -> wrap (E_assert (exp', msg)) + | E_vector exps -> + let evaluated, unevaluated = Util.take_drop is_value exps in + begin + match unevaluated with + | exp :: exps -> + step exp >>= fun exp' -> wrap (E_vector (evaluated @ exp' :: exps)) + | [] -> return (exp_of_value (V_vector (List.map value_of_exp evaluated))) + end + + (* Special rules for short circuting boolean operators *) + | E_app (id, [x; y]) when (string_of_id id = "and_bool" || string_of_id id = "or_bool") && not (is_value x) -> + step x >>= fun x' -> wrap (E_app (id, [x'; y])) + | E_app (id, [x; y]) when string_of_id id = "and_bool" && is_false x -> + return (exp_of_value (V_bool false)) + | E_app (id, [x; y]) when string_of_id id = "or_bool" && is_true x -> + return (exp_of_value (V_bool true)) + + | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) when not (is_value bind) -> + step bind >>= fun bind' -> wrap (E_let (LB_aux (LB_val (pat, bind'), lb_annot), body)) + | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) -> + let matched, bindings = pattern_match pat (value_of_exp bind) in + if matched then + return (List.fold_left (fun body (id, v) -> subst id v body) body (Bindings.bindings bindings)) + else + failwith "Match failure" + + (* otherwise left-to-right evaluation order for function applications *) | E_app (id, exps) -> let evaluated, unevaluated = Util.take_drop is_value exps in begin @@ -199,7 +258,8 @@ let rec step (E_aux (e_aux, annot)) = return (exp_of_value (V_ctor (string_of_id id, List.map value_of_exp evaluated))) | [] when Env.is_extern id (env_of_annot annot) "interpreter" -> begin - let primop = StringMap.find (Env.get_extern id (env_of_annot annot) "interpreter") primops in + let extern = Env.get_extern id (env_of_annot annot) "interpreter" in + let primop = try StringMap.find extern primops with Not_found -> failwith ("No primop " ^ extern) in return (exp_of_value (primop (List.map value_of_exp evaluated))) end | [] -> liftM exp_of_value (call id (List.map value_of_exp evaluated)) @@ -211,7 +271,7 @@ let rec step (E_aux (e_aux, annot)) = | E_app_infix (x, id, y) -> step x >>= fun x' -> wrap (E_app_infix (x', id, y)) - | E_return exp when is_value exp -> final (value_of_exp exp) + | E_return exp when is_value exp -> early_return (value_of_exp exp) | E_return exp -> step exp >>= fun exp' -> wrap (E_return exp') | E_tuple exps -> @@ -238,6 +298,76 @@ let rec step (E_aux (e_aux, annot)) = | E_throw exp when is_value exp -> throw (value_of_exp exp) | E_throw exp -> step exp >>= fun exp' -> wrap (E_throw exp') + | E_id id -> + begin + let open Type_check in + gets >>= fun (lstate, gstate) -> + match Env.lookup_id id (env_of_annot annot) with + | Register _ -> + let exp = + try exp_of_value (Bindings.find id gstate.registers) with + | Not_found -> + let exp = mk_exp (E_app (mk_id ("undefined_" ^ string_of_typ (typ_of orig_exp)), [mk_exp (E_lit (mk_lit (L_unit)))])) in + Type_check.check_exp (env_of_annot annot) exp (typ_of orig_exp) + in + return exp + | Local (Mutable, _) -> return (exp_of_value (Bindings.find id lstate.locals)) + | _ -> failwith "id" + end + + | E_record (FES_aux (FES_Fexps (fexps, flag), fes_annot)) -> + let evaluated, unevaluated = Util.take_drop is_value_fexp fexps in + begin + match unevaluated with + | FE_aux (FE_Fexp (id, exp), fe_annot) :: fexps -> + step exp >>= fun exp' -> + wrap (E_record (FES_aux (FES_Fexps (evaluated @ FE_aux (FE_Fexp (id, exp'), fe_annot) :: fexps, flag), fes_annot))) + | [] -> + List.map value_of_fexp fexps + |> List.fold_left (fun record (field, v) -> StringMap.add field v record) StringMap.empty + |> (fun record -> V_record record) + |> exp_of_value + |> return + end + + | E_record_update (exp, fexps) when not (is_value exp) -> + step exp >>= fun exp' -> wrap (E_record_update (exp', fexps)) + | E_record_update (record, FES_aux (FES_Fexps (fexps, flag), fes_annot)) -> + let evaluated, unevaluated = Util.take_drop is_value_fexp fexps in + begin + match unevaluated with + | FE_aux (FE_Fexp (id, exp), fe_annot) :: fexps -> + step exp >>= fun exp' -> + wrap (E_record_update (record, FES_aux (FES_Fexps (evaluated @ FE_aux (FE_Fexp (id, exp'), fe_annot) :: fexps, flag), fes_annot))) + | [] -> + List.map value_of_fexp fexps + |> List.fold_left (fun record (field, v) -> StringMap.add field v record) (coerce_record (value_of_exp record)) + |> (fun record -> V_record record) + |> exp_of_value + |> return + end + + | E_assign (lexp, exp) when not (is_value exp) -> step exp >>= fun exp' -> wrap (E_assign (lexp, exp')) + | E_assign (LEXP_aux (LEXP_memory (id, args), _), exp) -> wrap (E_app (id, args @ [exp])) + | E_assign (LEXP_aux (LEXP_field (lexp, id), _), exp) -> + let open Type_check in + let lexp_exp = infer_exp (env_of_annot annot) (exp_of_lexp (strip_lexp lexp)) in + let ul = (Parse_ast.Unknown, None) in + let exp' = E_aux (E_record_update (lexp_exp, FES_aux (FES_Fexps ([FE_aux (FE_Fexp (id, exp), ul)], false), ul)), ul) in + wrap (E_assign (lexp, exp')) + | E_assign (LEXP_aux (LEXP_id id, _), exp) | E_assign (LEXP_aux (LEXP_cast (_, id), _), exp) -> + begin + let open Type_check in + gets >>= fun (lstate, gstate) -> + match Env.lookup_id id (env_of_annot annot) with + | Register _ -> + puts (lstate, { gstate with registers = Bindings.add id (value_of_exp exp) gstate.registers }) >> wrap unit_exp + | Local (Mutable, _) | Unbound -> + puts ({ lstate with locals = Bindings.add id (value_of_exp exp) lstate.locals }, gstate) >> wrap unit_exp + | _ -> failwith "Assign" + end + | E_assign _ -> assert false + | E_try (exp, pexps) when is_value exp -> return exp | E_try (exp, pexps) -> begin @@ -249,7 +379,7 @@ let rec step (E_aux (e_aux, annot)) = | E_sizeof _ | E_constraint _ -> assert false (* Must be re-written before interpreting *) - | _ -> assert false (* TODO *) + | _ -> failwith ("Unimplemented " ^ string_of_exp orig_exp) and combine _ v1 v2 = match (v1, v2) with @@ -258,6 +388,22 @@ and combine _ v1 v2 = | None, Some v2 -> Some v2 | Some v1, Some v2 -> failwith "Pattern binds same identifier twice!" +and exp_of_lexp (LEXP_aux (lexp_aux, _) as lexp) = + match lexp_aux with + | LEXP_id id -> mk_exp (E_id id) + | LEXP_memory (f, args) -> mk_exp (E_app (f, args)) + | LEXP_cast (typ, id) -> mk_exp (E_cast (typ, mk_exp (E_id id))) + | LEXP_tup lexps -> mk_exp (E_tuple (List.map exp_of_lexp lexps)) + | LEXP_vector (lexp, exp) -> mk_exp (E_vector_access (exp_of_lexp lexp, exp)) + | LEXP_vector_range (lexp, exp1, exp2) -> mk_exp (E_vector_subrange (exp_of_lexp lexp, exp1, exp2)) + | LEXP_field (lexp, id) -> mk_exp (E_field (exp_of_lexp lexp, id)) + +and lexp_assign (LEXP_aux (lexp_aux, _) as lexp) value = + print_endline ("Assigning: " ^ string_of_lexp lexp ^ " to " ^ string_of_value value |> Util.yellow |> Util.clear); + match lexp_aux with + | LEXP_id id -> Bindings.singleton id value + | _ -> failwith "Unhandled lexp_assign" + and pattern_match (P_aux (p_aux, _) as pat) value = print_endline ("Matching: " ^ string_of_pat pat ^ " with " ^ string_of_value value |> Util.yellow |> Util.clear); match p_aux with @@ -294,27 +440,38 @@ let rec get_fundef id (Defs defs) = | (DEF_fundef fdef) :: _ when Id.compare id (id_of_fundef fdef) = 0 -> fdef | _ :: defs -> get_fundef id (Defs defs) -let rec untilM p f x = - if p x then - return x - else - f (return x) >>= fun x' -> untilM p f x' - -type trace = - | Done of value - | Step of (Type_check.tannot exp) monad * (value -> (Type_check.tannot exp) monad) list - -let rec eval_exp ast m = - match m with - | Pure v when is_value v -> Done (value_of_exp v) - | Pure exp' -> - Pretty_print_sail2.pretty_sail stdout (Pretty_print_sail2.doc_exp exp'); - print_newline (); - Step (step exp', []) - | Yield (Call (id, vals, cont)) -> - print_endline ("Calling " ^ string_of_id id |> Util.cyan |> Util.clear); - let arg = if List.length vals != 1 then tuple_value vals else List.hd vals in - let body = exp_of_fundef (get_fundef id ast) arg in - Step (return body, [cont]) - | _ -> assert false - +let stack_cont (_, _, cont) = cont +let stack_string (str, _, _) = str +let stack_state (_, lstate, _) = lstate + +type frame = + | Done of state * value + | Step of string * state * (Type_check.tannot exp) monad * (string * lstate * (value -> (Type_check.tannot exp) monad)) list + +let rec eval_frame ast = function + | Done (state, v) -> Done (state, v) + | Step (out, state, m, stack) -> + match (m, stack) with + | Pure v, [] when is_value v -> Done (state, value_of_exp v) + | Pure v, (head :: stack') when is_value v -> + print_endline ("Returning value: " ^ string_of_value (value_of_exp v) |> Util.cyan |> Util.clear); + Step (stack_string head, (stack_state head, snd state), stack_cont head (value_of_exp v), stack') + | Pure exp', _ -> + let out' = Pretty_print_sail2.to_string (Pretty_print_sail2.doc_exp exp') in + Step (out', state, step exp', stack) + | Yield (Call(id, vals, cont)), _ -> + print_endline ("Calling " ^ string_of_id id |> Util.cyan |> Util.clear); + let arg = if List.length vals != 1 then tuple_value vals else List.hd vals in + let body = exp_of_fundef (get_fundef id ast) arg in + Step ("", (initial_lstate, snd state), return body, (out, fst state, cont) :: stack) + | Yield (Gets cont), _ -> + eval_frame ast (Step (out, state, cont state, stack)) + | Yield (Puts (state', cont)), _ -> + eval_frame ast (Step (out, state', cont (), stack)) + | Yield (Early_return v), [] -> Done (state, v) + | Yield (Early_return v), (head :: stack') -> + print_endline ("Returning value: " ^ string_of_value v |> Util.cyan |> Util.clear); + Step (stack_string head, (stack_state head, snd state), stack_cont head v, stack') + | Yield (Assertion_failed msg), _ -> + failwith msg + | _ -> assert false |
