diff options
| author | Alasdair Armstrong | 2017-12-15 21:27:27 +0000 |
|---|---|---|
| committer | Alasdair Armstrong | 2017-12-15 21:27:27 +0000 |
| commit | 810dca66a6328fd635f5992256bc24960dcc3899 (patch) | |
| tree | b55fed28dd33deebf3c5a97190b0d0366c00ad98 /src | |
| parent | 2162c6586b8024789875c2e619b09ba8348e72e0 (diff) | |
Experimenting with interactive mode
Diffstat (limited to 'src')
| -rw-r--r-- | src/interpreter.ml | 193 | ||||
| -rw-r--r-- | src/isail.ml | 26 | ||||
| -rw-r--r-- | src/pretty_print_sail2.ml | 5 | ||||
| -rw-r--r-- | src/process_file.ml | 1 | ||||
| -rw-r--r-- | src/process_file.mli | 1 | ||||
| -rw-r--r-- | src/rewrites.ml | 6 | ||||
| -rw-r--r-- | src/rewrites.mli | 3 | ||||
| -rw-r--r-- | src/sail.ml | 2 | ||||
| -rw-r--r-- | src/sail_lib.ml | 477 | ||||
| -rw-r--r-- | src/util.ml | 1 | ||||
| -rw-r--r-- | src/util.mli | 1 | ||||
| -rw-r--r-- | src/value.ml | 203 |
12 files changed, 841 insertions, 78 deletions
diff --git a/src/interpreter.ml b/src/interpreter.ml index 3356b9dc..d0dd6f48 100644 --- a/src/interpreter.ml +++ b/src/interpreter.ml @@ -51,9 +51,16 @@ open Ast open Ast_util open Value -(* open Type_check *) -type state = St +type state = + { registers : value Bindings.t; + locals : value Bindings.t + } + +let initial_state = + { registers = Bindings.empty; + locals = Bindings.empty + } let value_of_lit (L_aux (l_aux, _)) = match l_aux with @@ -63,6 +70,7 @@ let value_of_lit (L_aux (l_aux, _)) = | L_true -> V_bool true | L_false -> V_bool false | L_string str -> V_string str + | L_num n -> V_int n | _ -> failwith "Unimplemented value_of_lit" (* TODO *) let is_value = function @@ -87,23 +95,23 @@ let value_of_exp = function (**************************************************************************) type 'a response = - | Final of value + | Early_return of value | Exception of value | Assertion_failed of string | Call of id * value list * (value -> 'a) | Gets of (state -> 'a) - | Puts of state * 'a + | Puts of state * (unit -> 'a) and 'a monad = | Pure of 'a | Yield of ('a monad response) let map_response f = function - | Final v -> Final v + | Early_return v -> Early_return v | Exception v -> Exception v | Assertion_failed str -> Assertion_failed str | Gets g -> Gets (fun s -> f (g s)) - | Puts (s, x) -> Puts (s, f x) + | Puts (s, cont) -> Puts (s, fun () -> f (cont ())) | Call (id, vals, cont) -> Call (id, vals, fun v -> f (cont v)) let rec liftM f = function @@ -142,9 +150,9 @@ let gets : state monad = Yield (Gets (fun s -> Pure s)) let puts (s : state) : unit monad = - Yield (Puts (s, Pure ())) + Yield (Puts (s, fun () -> Pure ())) -let final v = Yield (Final v) +let early_return v = Yield (Early_return v) let assertion_failed msg = Yield (Assertion_failed msg) @@ -156,9 +164,15 @@ let rec subst id value (E_aux (e_aux, annot) as exp) = | E_block exps -> E_block (List.map (subst id value) exps) | E_nondet exps -> E_nondet (List.map (subst id value) exps) | E_id id' -> if Id.compare id id' = 0 then unaux_exp (exp_of_value value) else E_id id' + | E_lit lit -> E_lit lit | E_cast (typ, exp) -> E_cast (typ, subst id value exp) | E_app (fn, exps) -> E_app (fn, List.map (subst id value) exps) - | _ -> assert false (* TODO *) + | E_app_infix (exp1, op, exp2) -> E_app_infix (subst id value exp1, op, subst id value exp2) + | E_vector exps -> E_vector (List.map (subst id value) exps) + | E_return exp -> E_return (subst id value exp) + | E_assert (exp1, exp2) -> E_assert (subst id value exp1, subst id value exp2) + | E_internal_value v -> E_internal_value v + | _ -> failwith ("subst " ^ string_of_exp exp) in wrap e_aux @@ -167,7 +181,12 @@ let rec subst id value (E_aux (e_aux, annot) as exp) = (* 2. Expression Evaluation *) (**************************************************************************) -let rec step (E_aux (e_aux, annot)) = +let unit_exp = E_lit (L_aux (L_unit, Parse_ast.Unknown)) + +let is_value_fexp (FE_aux (FE_Fexp (id, exp), _)) = is_value exp +let value_of_fexp (FE_aux (FE_Fexp (id, exp), _)) = (string_of_id id, value_of_exp exp) + +let rec step (E_aux (e_aux, annot) as orig_exp) = let wrap e_aux' = return (E_aux (e_aux', annot)) in match e_aux with | E_block [] -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) @@ -183,11 +202,20 @@ let rec step (E_aux (e_aux, annot)) = | E_if (exp, then_exp, else_exp) -> step exp >>= fun exp' -> wrap (E_if (exp', then_exp, else_exp)) - | E_assert (exp, msg) when is_true exp -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) + | E_assert (exp, msg) when is_true exp -> wrap unit_exp | E_assert (exp, msg) when is_false exp -> assertion_failed "FIXME" | E_assert (exp, msg) -> step exp >>= fun exp' -> wrap (E_assert (exp', msg)) + | E_vector exps -> + let evaluated, unevaluated = Util.take_drop is_value exps in + begin + match unevaluated with + | exp :: exps -> + step exp >>= fun exp' -> wrap (E_vector (evaluated @ exp' :: exps)) + | [] -> return (exp_of_value (V_vector (List.map value_of_exp evaluated))) + end + | E_app (id, exps) -> let evaluated, unevaluated = Util.take_drop is_value exps in begin @@ -199,7 +227,8 @@ let rec step (E_aux (e_aux, annot)) = return (exp_of_value (V_ctor (string_of_id id, List.map value_of_exp evaluated))) | [] when Env.is_extern id (env_of_annot annot) "interpreter" -> begin - let primop = StringMap.find (Env.get_extern id (env_of_annot annot) "interpreter") primops in + let extern = Env.get_extern id (env_of_annot annot) "interpreter" in + let primop = try StringMap.find extern primops with Not_found -> failwith ("No primop " ^ extern) in return (exp_of_value (primop (List.map value_of_exp evaluated))) end | [] -> liftM exp_of_value (call id (List.map value_of_exp evaluated)) @@ -211,7 +240,7 @@ let rec step (E_aux (e_aux, annot)) = | E_app_infix (x, id, y) -> step x >>= fun x' -> wrap (E_app_infix (x', id, y)) - | E_return exp when is_value exp -> final (value_of_exp exp) + | E_return exp when is_value exp -> early_return (value_of_exp exp) | E_return exp -> step exp >>= fun exp' -> wrap (E_return exp') | E_tuple exps -> @@ -238,6 +267,73 @@ let rec step (E_aux (e_aux, annot)) = | E_throw exp when is_value exp -> throw (value_of_exp exp) | E_throw exp -> step exp >>= fun exp' -> wrap (E_throw exp') + | E_id id -> + begin + let open Type_check in + gets >>= fun state -> + match Env.lookup_id id (env_of_annot annot) with + | Register _ -> + let exp = + try exp_of_value (Bindings.find id state.registers) with + | Not_found -> + let exp = mk_exp (E_app (mk_id ("undefined_" ^ string_of_typ (typ_of orig_exp)), [mk_exp (E_lit (mk_lit (L_unit)))])) in + Type_check.check_exp (env_of_annot annot) exp (typ_of orig_exp) + in + return exp + | _ -> failwith "id" + end + + | E_record (FES_aux (FES_Fexps (fexps, flag), fes_annot)) -> + let evaluated, unevaluated = Util.take_drop is_value_fexp fexps in + begin + match unevaluated with + | FE_aux (FE_Fexp (id, exp), fe_annot) :: fexps -> + step exp >>= fun exp' -> + wrap (E_record (FES_aux (FES_Fexps (evaluated @ FE_aux (FE_Fexp (id, exp'), fe_annot) :: fexps, flag), fes_annot))) + | [] -> + List.map value_of_fexp fexps + |> List.fold_left (fun record (field, v) -> StringMap.add field v record) StringMap.empty + |> (fun record -> V_record record) + |> exp_of_value + |> return + end + + | E_record_update (exp, fexps) when not (is_value exp) -> + step exp >>= fun exp' -> wrap (E_record_update (exp', fexps)) + | E_record_update (record, FES_aux (FES_Fexps (fexps, flag), fes_annot)) -> + let evaluated, unevaluated = Util.take_drop is_value_fexp fexps in + begin + match unevaluated with + | FE_aux (FE_Fexp (id, exp), fe_annot) :: fexps -> + step exp >>= fun exp' -> + wrap (E_record_update (record, FES_aux (FES_Fexps (evaluated @ FE_aux (FE_Fexp (id, exp'), fe_annot) :: fexps, flag), fes_annot))) + | [] -> + List.map value_of_fexp fexps + |> List.fold_left (fun record (field, v) -> StringMap.add field v record) (coerce_record (value_of_exp record)) + |> (fun record -> V_record record) + |> exp_of_value + |> return + end + + | E_assign (lexp, exp) when not (is_value exp) -> step exp >>= fun exp' -> wrap (E_assign (lexp, exp')) + | E_assign (LEXP_aux (LEXP_memory (id, args), _), exp) -> wrap (E_app (id, args @ [exp])) + | E_assign (LEXP_aux (LEXP_field (lexp, id), _), exp) -> + let open Type_check in + let lexp_exp = infer_exp (env_of_annot annot) (exp_of_lexp (strip_lexp lexp)) in + let ul = (Parse_ast.Unknown, None) in + let exp' = E_aux (E_record_update (lexp_exp, FES_aux (FES_Fexps ([FE_aux (FE_Fexp (id, exp), ul)], false), ul)), ul) in + wrap (E_assign (lexp, exp')) + | E_assign (LEXP_aux (LEXP_id id, _), exp) -> + begin + let open Type_check in + gets >>= fun state -> + match Env.lookup_id id (env_of_annot annot) with + | Register _ -> + puts { state with registers = Bindings.add id (value_of_exp exp) state.registers } >> wrap unit_exp + | _ -> failwith "Assign" + end + | E_assign _ -> assert false + | E_try (exp, pexps) when is_value exp -> return exp | E_try (exp, pexps) -> begin @@ -249,7 +345,7 @@ let rec step (E_aux (e_aux, annot)) = | E_sizeof _ | E_constraint _ -> assert false (* Must be re-written before interpreting *) - | _ -> assert false (* TODO *) + | _ -> failwith ("Unimplemented " ^ string_of_exp orig_exp) and combine _ v1 v2 = match (v1, v2) with @@ -258,6 +354,22 @@ and combine _ v1 v2 = | None, Some v2 -> Some v2 | Some v1, Some v2 -> failwith "Pattern binds same identifier twice!" +and exp_of_lexp (LEXP_aux (lexp_aux, _) as lexp) = + match lexp_aux with + | LEXP_id id -> mk_exp (E_id id) + | LEXP_memory (f, args) -> mk_exp (E_app (f, args)) + | LEXP_cast (typ, id) -> mk_exp (E_cast (typ, mk_exp (E_id id))) + | LEXP_tup lexps -> mk_exp (E_tuple (List.map exp_of_lexp lexps)) + | LEXP_vector (lexp, exp) -> mk_exp (E_vector_access (exp_of_lexp lexp, exp)) + | LEXP_vector_range (lexp, exp1, exp2) -> mk_exp (E_vector_subrange (exp_of_lexp lexp, exp1, exp2)) + | LEXP_field (lexp, id) -> mk_exp (E_field (exp_of_lexp lexp, id)) + +and lexp_assign (LEXP_aux (lexp_aux, _) as lexp) value = + print_endline ("Assigning: " ^ string_of_lexp lexp ^ " to " ^ string_of_value value |> Util.yellow |> Util.clear); + match lexp_aux with + | LEXP_id id -> Bindings.singleton id value + | _ -> failwith "Unhandled lexp_assign" + and pattern_match (P_aux (p_aux, _) as pat) value = print_endline ("Matching: " ^ string_of_pat pat ^ " with " ^ string_of_value value |> Util.yellow |> Util.clear); match p_aux with @@ -294,27 +406,32 @@ let rec get_fundef id (Defs defs) = | (DEF_fundef fdef) :: _ when Id.compare id (id_of_fundef fdef) = 0 -> fdef | _ :: defs -> get_fundef id (Defs defs) -let rec untilM p f x = - if p x then - return x - else - f (return x) >>= fun x' -> untilM p f x' - -type trace = - | Done of value - | Step of (Type_check.tannot exp) monad * (value -> (Type_check.tannot exp) monad) list - -let rec eval_exp ast m = - match m with - | Pure v when is_value v -> Done (value_of_exp v) - | Pure exp' -> - Pretty_print_sail2.pretty_sail stdout (Pretty_print_sail2.doc_exp exp'); - print_newline (); - Step (step exp', []) - | Yield (Call (id, vals, cont)) -> - print_endline ("Calling " ^ string_of_id id |> Util.cyan |> Util.clear); - let arg = if List.length vals != 1 then tuple_value vals else List.hd vals in - let body = exp_of_fundef (get_fundef id ast) arg in - Step (return body, [cont]) - | _ -> assert false - +type frame = + | Done of state * value + | Step of string * state * (Type_check.tannot exp) monad * (string * (value -> (Type_check.tannot exp) monad)) list + +let rec eval_frame ast = function + | Done (state, v) -> Done (state, v) + | Step (out, state, m, stack) -> + match (m, stack) with + | Pure v, [] when is_value v -> Done (state, value_of_exp v) + | Pure v, (head :: stack') when is_value v -> + print_endline ("Returning value: " ^ string_of_value (value_of_exp v) |> Util.cyan |> Util.clear); + Step (fst head, state, snd head (value_of_exp v), stack') + | Pure exp', _ -> + let out' = Pretty_print_sail2.to_string (Pretty_print_sail2.doc_exp exp') in + Step (out', state, step exp', stack) + | Yield (Call(id, vals, cont)), _ -> + print_endline ("Calling " ^ string_of_id id |> Util.cyan |> Util.clear); + let arg = if List.length vals != 1 then tuple_value vals else List.hd vals in + let body = exp_of_fundef (get_fundef id ast) arg in + Step ("", state, return body, (out, cont) :: stack) + | Yield (Gets cont), _ -> + eval_frame ast (Step (out, state, cont state, stack)) + | Yield (Puts (state', cont)), _ -> + eval_frame ast (Step (out, state', cont (), stack)) + | Yield (Early_return v), [] -> Done (state, v) + | Yield (Early_return v), (head :: stack') -> + print_endline ("Returning value: " ^ string_of_value v |> Util.cyan |> Util.clear); + Step (fst head, state, snd head v, stack') + | _ -> assert false diff --git a/src/isail.ml b/src/isail.ml index e1860451..b15e9d87 100644 --- a/src/isail.ml +++ b/src/isail.ml @@ -56,7 +56,7 @@ open Interpreter open Pretty_print_sail2 type mode = - | Evaluation of trace + | Evaluation of frame | Normal let current_mode = ref Normal @@ -127,26 +127,20 @@ let handle_input input = | _ -> print_endline ("Unrecognised command " ^ input) else if input <> "" then let exp = Type_check.infer_exp !interactive_env (Initial_check.exp_of_string Ast_util.dec_ord input) in - current_mode := Evaluation (eval_exp !interactive_ast (return exp)) + current_mode := Evaluation (eval_frame !interactive_ast (Step ("", initial_state, return exp, []))) else () end - | Evaluation trace -> + | Evaluation frame -> begin - match trace with - | Done v -> + match frame with + | Done (_, v) -> print_endline ("Result = " ^ Value.string_of_value v); current_mode := Normal - | Step (exp, stack) -> - let next = match eval_exp !interactive_ast exp with - | Step (exp', stack') -> Evaluation (Step (exp', stack' @ stack)) - | Done v when stack = [] -> - print_endline ("Result = " ^ Value.string_of_value v); - Normal - | Done v -> - print_endline ("Returning: " ^ Value.string_of_value v |> Util.cyan |> Util.clear); - Evaluation (Step (List.hd stack v, List.tl stack)) - in - current_mode := next + | Step (out, _, _, stack) -> + let sep = "-----------------------------------------------------" |> Util.blue |> Util.clear in + List.map fst stack |> List.rev |> List.iter (fun code -> print_endline code; print_endline sep); + print_endline out; + current_mode := Evaluation (eval_frame !interactive_ast frame) end diff --git a/src/pretty_print_sail2.ml b/src/pretty_print_sail2.ml index 71fcd587..7f91bbe5 100644 --- a/src/pretty_print_sail2.ml +++ b/src/pretty_print_sail2.ml @@ -531,3 +531,8 @@ let doc_defs (Defs(defs)) = let pp_defs f d = ToChannel.pretty 1. 80 f (doc_defs d) let pretty_sail f doc = ToChannel.pretty 1. 120 f doc + +let to_string doc = + let b = Buffer.create 120 in + ToBuffer.pretty 1. 120 b doc; + Buffer.contents b diff --git a/src/process_file.ml b/src/process_file.ml index 68e5786e..68b08fd4 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -237,4 +237,5 @@ let rewrite_undefined = rewrite [("undefined", fun x -> Rewrites.rewrite_undefin let rewrite_ast_lem = rewrite Rewrites.rewrite_defs_lem let rewrite_ast_ocaml = rewrite Rewrites.rewrite_defs_ocaml let rewrite_ast_sil = rewrite Rewrites.rewrite_defs_sil +let rewrite_ast_interpreter = rewrite Rewrites.rewrite_defs_interpreter let rewrite_ast_check = rewrite Rewrites.rewrite_defs_check diff --git a/src/process_file.mli b/src/process_file.mli index f99bdf54..5477af86 100644 --- a/src/process_file.mli +++ b/src/process_file.mli @@ -57,6 +57,7 @@ val rewrite_undefined: Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs val rewrite_ast_lem : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs val rewrite_ast_ocaml : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs val rewrite_ast_sil : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs +val rewrite_ast_interpreter : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs val rewrite_ast_check : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs val load_file_no_check : Ast.order -> string -> unit Ast.defs diff --git a/src/rewrites.ml b/src/rewrites.ml index 8c0526fe..f735aef6 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2970,6 +2970,12 @@ let rewrite_defs_ocaml = [ (* ("separate_numbs", rewrite_defs_separate_numbs) *) ] +let rewrite_defs_interpreter = [ + ("constraint", rewrite_constraint); + ("trivial_sizeof", rewrite_trivial_sizeof); + ("sizeof", rewrite_sizeof); + ] + let rewrite_defs_sil = [ ("top_sort_defs", top_sort_defs); ("tuple_vector_assignments", rewrite_tuple_vector_assignments); diff --git a/src/rewrites.mli b/src/rewrites.mli index ce24a4c4..db82f679 100644 --- a/src/rewrites.mli +++ b/src/rewrites.mli @@ -61,6 +61,9 @@ val rewrite_defs_ocaml : (string * (tannot defs -> tannot defs)) list val rewrite_defs_lem : (string * (tannot defs -> tannot defs)) list (* Perform rewrites to sail intermediate language *) +val rewrite_defs_interpreter : (string * (tannot defs -> tannot defs)) list + +(* Perform rewrites to sail intermediate language *) val rewrite_defs_sil : (string * (tannot defs -> tannot defs)) list (* This is a special rewriter pass that checks AST invariants without diff --git a/src/sail.ml b/src/sail.ml index 519ec916..d4e2526e 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -216,7 +216,7 @@ let main() = begin (if !(opt_interactive) then - (interactive_ast := ast; interactive_env := type_envs) + (interactive_ast := Process_file.rewrite_ast_interpreter ast; interactive_env := type_envs) else ()); (if !(opt_sanity) then diff --git a/src/sail_lib.ml b/src/sail_lib.ml new file mode 100644 index 00000000..b4dffba9 --- /dev/null +++ b/src/sail_lib.ml @@ -0,0 +1,477 @@ +module Big_int = Nat_big_num + +type 'a return = { return : 'b . 'a -> 'b } + +let opt_trace = ref false + +let trace_depth = ref 0 +let random = ref false + +let sail_call (type t) (f : _ -> t) = + let module M = + struct exception Return of t end + in + let return = { return = (fun x -> raise (M.Return x)) } in + try + f return + with M.Return x -> x + +let trace str = + if !opt_trace + then + begin + if !trace_depth < 0 then trace_depth := 0 else (); + prerr_endline (String.make (!trace_depth * 2) ' ' ^ str) + end + else () + +let trace_write name str = + trace ("Write: " ^ name ^ " " ^ str) + +let trace_read name str = + trace ("Read: " ^ name ^ " " ^ str) + +let sail_trace_call (type t) (name : string) (in_string : string) (string_of_out : t -> string) (f : _ -> t) = + let module M = + struct exception Return of t end + in + let return = { return = (fun x -> raise (M.Return x)) } in + trace ("Call: " ^ name ^ " " ^ in_string); + incr trace_depth; + let result = try f return with M.Return x -> x in + decr trace_depth; + trace ("Return: " ^ string_of_out result); + result + +let trace_call str = + trace str; incr trace_depth + +type bit = B0 | B1 + +let and_bit = function + | B1, B1 -> B1 + | _, _ -> B0 + +let or_bit = function + | B0, B0 -> B0 + | _, _ -> B1 + +let xor_bit = function + | B1, B0 -> B1 + | B0, B1 -> B1 + | _, _ -> B0 + +let and_vec (xs, ys) = + assert (List.length xs = List.length ys); + List.map2 (fun x y -> and_bit (x, y)) xs ys + +let and_bool (b1, b2) = b1 && b2 + +let or_vec (xs, ys) = + assert (List.length xs = List.length ys); + List.map2 (fun x y -> or_bit (x, y)) xs ys + +let or_bool (b1, b2) = b1 || b2 + +let xor_vec (xs, ys) = + assert (List.length xs = List.length ys); + List.map2 (fun x y -> xor_bit (x, y)) xs ys + +let xor_bool (b1, b2) = (b1 || b2) && (b1 != b2) + +let undefined_bit () = + if !random + then (if Random.bool () then B0 else B1) + else B0 + +let undefined_bool () = + if !random then Random.bool () else false + +let rec undefined_vector (start_index, len, item) = + if Big_int.equal len Big_int.zero + then [] + else item :: undefined_vector (start_index, Big_int.sub len (Big_int.of_int 1), item) + +let undefined_string () = "" + +let undefined_unit () = () + +let undefined_int () = + if !random then Big_int.of_int (Random.int 0xFFFF) else Big_int.zero + +let undefined_nat () = Big_int.zero + +let undefined_range (lo, hi) = lo + +let internal_pick list = + if !random + then List.nth list (Random.int (List.length list)) + else List.nth list 0 + +let eq_int (n, m) = Big_int.equal n m + +let rec drop n xs = + match n, xs with + | 0, xs -> xs + | n, [] -> [] + | n, (x :: xs) -> drop (n -1) xs + +let rec take n xs = + match n, xs with + | 0, xs -> [] + | n, (x :: xs) -> x :: take (n - 1) xs + | n, [] -> [] + +let subrange (list, n, m) = + let n = Big_int.to_int n in + let m = Big_int.to_int m in + List.rev (take (n - (m - 1)) (drop m (List.rev list))) + +let slice (list, n, m) = + let n = Big_int.to_int n in + let m = Big_int.to_int m in + List.rev (take m (drop n (List.rev list))) + +let eq_list (xs, ys) = List.for_all2 (fun x y -> x = y) xs ys + +let access (xs, n) = List.nth (List.rev xs) (Big_int.to_int n) + +let append (xs, ys) = xs @ ys + +let update (xs, n, x) = + let n = (List.length xs - Big_int.to_int n) - 1 in + take n xs @ [x] @ drop (n + 1) xs + +let update_subrange (xs, n, m, ys) = + let rec aux xs o = function + | [] -> xs + | (y :: ys) -> aux (update (xs, o, y)) (Big_int.sub o (Big_int.of_int 1)) ys + in + aux xs n ys + + +let length xs = Big_int.of_int (List.length xs) + +let big_int_of_bit = function + | B0 -> Big_int.zero + | B1 -> (Big_int.of_int 1) + +let uint xs = + let uint_bit x (n, pos) = + Big_int.add n (Big_int.mul (Big_int.pow_int_positive 2 pos) (big_int_of_bit x)), pos + 1 + in + fst (List.fold_right uint_bit xs (Big_int.zero, 0)) + +let sint = function + | [] -> Big_int.zero + | [msb] -> Big_int.negate (big_int_of_bit msb) + | msb :: xs -> + let msb_pos = List.length xs in + let complement = + Big_int.negate (Big_int.mul (Big_int.pow_int_positive 2 msb_pos) (big_int_of_bit msb)) + in + Big_int.add complement (uint xs) + +let add (x, y) = Big_int.add x y +let sub (x, y) = Big_int.sub x y +let mult (x, y) = Big_int.mul x y +let quotient (x, y) = Big_int.div x y +let modulus (x, y) = Big_int.modulus x y + +let add_bit_with_carry (x, y, carry) = + match x, y, carry with + | B0, B0, B0 -> B0, B0 + | B0, B1, B0 -> B1, B0 + | B1, B0, B0 -> B1, B0 + | B1, B1, B0 -> B0, B1 + | B0, B0, B1 -> B1, B0 + | B0, B1, B1 -> B0, B1 + | B1, B0, B1 -> B0, B1 + | B1, B1, B1 -> B1, B1 + +let sub_bit_with_carry (x, y, carry) = + match x, y, carry with + | B0, B0, B0 -> B0, B0 + | B0, B1, B0 -> B0, B1 + | B1, B0, B0 -> B1, B0 + | B1, B1, B0 -> B0, B0 + | B0, B0, B1 -> B1, B0 + | B0, B1, B1 -> B0, B0 + | B1, B0, B1 -> B1, B1 + | B1, B1, B1 -> B1, B0 + +let not_bit = function + | B0 -> B1 + | B1 -> B0 + +let not_vec xs = List.map not_bit xs + +let add_vec_carry (xs, ys) = + assert (List.length xs = List.length ys); + let (carry, result) = + List.fold_right2 (fun x y (c, result) -> let (z, c) = add_bit_with_carry (x, y, c) in (c, z :: result)) xs ys (B0, []) + in + carry, result + +let add_vec (xs, ys) = snd (add_vec_carry (xs, ys)) + +let rec replicate_bits (bits, n) = + if Big_int.less_equal n Big_int.zero + then [] + else bits @ replicate_bits (bits, Big_int.sub n (Big_int.of_int 1)) + +let identity x = x + +let rec bits_of_big_int bit n = + if not (Big_int.equal bit Big_int.zero) + then + begin + if Big_int.greater (Big_int.div n bit) Big_int.zero + then B1 :: bits_of_big_int (Big_int.div bit (Big_int.of_int 2)) (Big_int.sub n bit) + else B0 :: bits_of_big_int (Big_int.div bit (Big_int.of_int 2)) n + end + else [] + +let add_vec_int (v, n) = + let n_bits = bits_of_big_int (Big_int.pow_int_positive 2 (List.length v - 1)) n in + add_vec(v, n_bits) + +let sub_vec (xs, ys) = add_vec (xs, add_vec_int (not_vec ys, (Big_int.of_int 1))) + +let sub_vec_int (v, n) = + let n_bits = bits_of_big_int (Big_int.pow_int_positive 2 (List.length v - 1)) n in + sub_vec(v, n_bits) + +let get_slice_int (n, m, o) = + let bits = bits_of_big_int (Big_int.pow_int_positive 2 (Big_int.add n o |> Big_int.to_int)) (Big_int.abs m) in + let bits = + if Big_int.less m Big_int.zero + then sub_vec (List.map (fun _ -> B0) bits, bits) + else bits + in + let slice = List.rev (take (Big_int.to_int n) (drop (Big_int.to_int o) (List.rev bits))) in + assert (Big_int.equal (Big_int.of_int (List.length slice)) n); + slice + +let hex_char = function + | '0' -> [B0; B0; B0; B0] + | '1' -> [B0; B0; B0; B1] + | '2' -> [B0; B0; B1; B0] + | '3' -> [B0; B0; B1; B1] + | '4' -> [B0; B1; B0; B0] + | '5' -> [B0; B1; B0; B1] + | '6' -> [B0; B1; B1; B0] + | '7' -> [B0; B1; B1; B1] + | '8' -> [B1; B0; B0; B0] + | '9' -> [B1; B0; B0; B1] + | 'A' | 'a' -> [B1; B0; B1; B0] + | 'B' | 'b' -> [B1; B0; B1; B1] + | 'C' | 'c' -> [B1; B1; B0; B0] + | 'D' | 'd' -> [B1; B1; B0; B1] + | 'E' | 'e' -> [B1; B1; B1; B0] + | 'F' | 'f' -> [B1; B1; B1; B1] + +let list_of_string s = + let rec aux i acc = + if i < 0 then acc + else aux (i-1) (s.[i] :: acc) + in aux (String.length s - 1) [] + +let bits_of_string str = + List.concat (List.map hex_char (list_of_string str)) + +let concat_str (str1, str2) = str1 ^ str2 + +let rec break n = function + | [] -> [] + | (_ :: _ as xs) -> [take n xs] @ break n (drop n xs) + +let string_of_bit = function + | B0 -> "0" + | B1 -> "1" + +let string_of_hex = function + | [B0; B0; B0; B0] -> "0" + | [B0; B0; B0; B1] -> "1" + | [B0; B0; B1; B0] -> "2" + | [B0; B0; B1; B1] -> "3" + | [B0; B1; B0; B0] -> "4" + | [B0; B1; B0; B1] -> "5" + | [B0; B1; B1; B0] -> "6" + | [B0; B1; B1; B1] -> "7" + | [B1; B0; B0; B0] -> "8" + | [B1; B0; B0; B1] -> "9" + | [B1; B0; B1; B0] -> "A" + | [B1; B0; B1; B1] -> "B" + | [B1; B1; B0; B0] -> "C" + | [B1; B1; B0; B1] -> "D" + | [B1; B1; B1; B0] -> "E" + | [B1; B1; B1; B1] -> "F" + +let string_of_bits bits = + if List.length bits mod 4 == 0 + then "0x" ^ String.concat "" (List.map string_of_hex (break 4 bits)) + else "0b" ^ String.concat "" (List.map string_of_bit bits) + +let hex_slice (str, n, m) = + let bits = List.concat (List.map hex_char (list_of_string (String.sub str 2 (String.length str - 2)))) in + let padding = replicate_bits([B0], n) in + let bits = padding @ bits in + let slice = List.rev (take (Big_int.to_int n) (drop (Big_int.to_int m) (List.rev bits))) in + slice + +let putchar n = + print_char (char_of_int (Big_int.to_int n)); + flush stdout + +let rec bits_of_int bit n = + if bit <> 0 + then + begin + if n / bit > 0 + then B1 :: bits_of_int (bit / 2) (n - bit) + else B0 :: bits_of_int (bit / 2) n + end + else [] + +let byte_of_int n = bits_of_int 128 n + +module BigIntHash = + struct + type t = Big_int.num + let equal i j = Big_int.equal i j + let hash i = Hashtbl.hash i + end + +module RAM = Hashtbl.Make(BigIntHash) + +let ram : int RAM.t = RAM.create 256 + +let write_ram' (addr_size, data_size, hex_ram, addr, data) = + let data = List.map (fun byte -> Big_int.to_int (uint byte)) (break 8 data) in + let rec write_byte i byte = + trace (Printf.sprintf "Store: %s 0x%02X" (Big_int.to_string (Big_int.add addr (Big_int.of_int i))) byte); + RAM.add ram (Big_int.add addr (Big_int.of_int i)) byte + in + List.iteri write_byte (List.rev data) + +let write_ram (addr_size, data_size, hex_ram, addr, data) = + write_ram' (addr_size, data_size, hex_ram, uint addr, data) + +let wram addr byte = + RAM.add ram addr byte + +let read_ram (addr_size, data_size, hex_ram, addr) = + let addr = uint addr in + let rec read_byte i = + if Big_int.equal i Big_int.zero + then [] + else + begin + let loc = Big_int.sub (Big_int.add addr i) (Big_int.of_int 1) in + let byte = try RAM.find ram loc with Not_found -> 0 in + trace (Printf.sprintf "Load: %s 0x%02X" (Big_int.to_string loc) byte); + byte_of_int byte @ read_byte (Big_int.sub i (Big_int.of_int 1)) + end + in + read_byte data_size + +let rec reverse_endianness bits = + if List.length bits <= 8 then bits else + reverse_endianness (drop 8 bits) @ (take 8 bits) + +(* FIXME: Casts can't be externed *) +let zcast_unit_vec x = [x] + +let shl_int (n, m) = Big_int.shift_left n (Big_int.to_int m) +let shr_int (n, m) = Big_int.shift_right n (Big_int.to_int m) + +let debug (str1, n, str2, v) = prerr_endline (str1 ^ Big_int.to_string n ^ str2 ^ string_of_bits v) + +let eq_string (str1, str2) = String.compare str1 str2 == 0 + +let lt_int (x, y) = Big_int.less x y + +let set_slice (out_len, slice_len, out, n, slice) = + let out = update_subrange(out, Big_int.add n (Big_int.of_int (List.length slice - 1)), n, slice) in + assert (List.length out = Big_int.to_int out_len); + out + +let set_slice_int (_, _, _, _) = assert false + +(* +let eq_real (x, y) = Num.eq_num x y +let lt_real (x, y) = Num.lt_num x y +let gt_real (x, y) = Num.gt_num x y +let lteq_real (x, y) = Num.le_num x y +let gteq_real (x, y) = Num.ge_num x y + +let round_down x = Num.big_int_of_num (Num.floor_num x) +let round_up x = Num.big_int_of_num (Num.ceiling_num x) +let quotient_real (x, y) = Num.div_num x y +let mult_real (x, y) = Num.mult_num x y +let real_power (x, y) = Num.power_num x (Num.num_of_big_int y) +let add_real (x, y) = Num.add_num x y +let sub_real (x, y) = Num.sub_num x y + +let abs_real x = Num.abs_num x + *) + +let lt (x, y) = Big_int.less x y +let gt (x, y) = Big_int.greater x y +let lteq (x, y) = Big_int.less_equal x y +let gteq (x, y) = Big_int.greater_equal x y + +let pow2 x = Big_int.pow_int x 2 + +let max_int (x, y) = Big_int.max x y +let min_int (x, y) = Big_int.min x y +let abs_int x = Big_int.abs x + +(* +let undefined_real () = Num.num_of_int 0 + +let real_of_string str = + try + let point = String.index str '.' in + let whole = Num.num_of_string (String.sub str 0 point) in + let frac_str = String.sub str (point + 1) (String.length str - (point + 1)) in + let frac = Num.div_num (Num.num_of_string frac_str) (Num.num_of_big_int (Big_int.pow_int_positive 10 (String.length frac_str))) in + Num.add_num whole frac + with + | Not_found -> Num.num_of_string str + +(* Not a very good sqrt implementation *) +let sqrt_real x = real_of_string (string_of_float (sqrt (Num.float_of_num x))) + *) + +let print_int (str, x) = + print_endline (str ^ Big_int.to_string x) + +let print_bits (str, xs) = + print_endline (str ^ string_of_bits xs) + +let reg_deref r = !r + +let string_of_zbit = function + | B0 -> "0" + | B1 -> "1" +let string_of_znat n = Big_int.to_string n +let string_of_zint n = Big_int.to_string n +let string_of_zunit () = "()" +let string_of_zbool = function + | true -> "true" + | false -> "false" +(* let string_of_zreal r = Num.string_of_num r *) +let string_of_zstring str = "\"" ^ String.escaped str ^ "\"" + +let rec string_of_list sep string_of = function + | [] -> "" + | [x] -> string_of x + | x::ls -> (string_of x) ^ sep ^ (string_of_list sep string_of ls) + +let zero_extend (vec, n) = + let m = Big_int.to_int n in + if m <= List.length vec + then take m vec + else replicate_bits ([B0], Big_int.of_int (m - List.length vec)) @ vec diff --git a/src/util.ml b/src/util.ml index bd083a8b..51ed8926 100644 --- a/src/util.ml +++ b/src/util.ml @@ -392,4 +392,5 @@ let green str = termcode 92 ^ str let yellow str = termcode 93 ^ str let red str = termcode 91 ^ str let cyan str = termcode 96 ^ str +let blue str = termcode 94 ^ str let clear str = str ^ termcode 0 diff --git a/src/util.mli b/src/util.mli index bdf6e594..39bc8a19 100644 --- a/src/util.mli +++ b/src/util.mli @@ -238,4 +238,5 @@ val green : string -> string val red : string -> string val yellow : string -> string val cyan : string -> string +val blue : string -> string val clear : string -> string diff --git a/src/value.ml b/src/value.ml index f49b230c..e42f68cb 100644 --- a/src/value.ml +++ b/src/value.ml @@ -50,33 +50,19 @@ module Big_int = Nat_big_num -type bit = B0 | B1 +module StringMap = Map.Make(String) type value = | V_vector of value list | V_list of value list | V_int of Big_int.num | V_bool of bool - | V_bit of bit + | V_bit of Sail_lib.bit | V_tuple of value list | V_unit | V_string of string | V_ctor of string * value list - -let rec string_of_value = function - | V_vector _ -> "VEC" - | V_bool true -> "true" - | V_bool false -> "false" - | V_bit B0 -> "bitzero" - | V_bit B1 -> "bitone" - | V_int n -> Big_int.to_string n - | V_tuple vals -> "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" - | V_list vals -> "[" ^ Util.string_of_list ", " string_of_value vals ^ "]" - | V_unit -> "()" - | V_string str -> "\"" ^ str ^ "\"" - | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" - -let eq_value v1 v2 = string_of_value v1 = string_of_value v2 + | V_record of value StringMap.t let coerce_bit = function | V_bit b -> b @@ -90,6 +76,10 @@ let coerce_bool = function | V_bool b -> b | _ -> assert false +let coerce_record = function + | V_record record -> record + | _ -> assert false + let and_bool = function | [v1; v2] -> V_bool (coerce_bool v1 && coerce_bool v2) | _ -> assert false @@ -98,12 +88,14 @@ let or_bool = function | [v1; v2] -> V_bool (coerce_bool v1 || coerce_bool v2) | _ -> assert false -let print = function - | [v] -> print_endline (string_of_value v |> Util.red |> Util.clear); V_unit - | _ -> assert false - let tuple_value (vs : value list) : value = V_tuple vs +let mk_vector (bits : Sail_lib.bit list) : value = V_vector (List.map (fun bit -> V_bit bit) bits) + +let coerce_bit = function + | V_bit b -> b + | _ -> assert false + let coerce_tuple = function | V_tuple vs -> vs | _ -> assert false @@ -111,6 +103,11 @@ let coerce_tuple = function let coerce_listlike = function | V_tuple vs -> vs | V_list vs -> vs + | V_unit -> [] + | _ -> assert false + +let coerce_int = function + | V_int i -> i | _ -> assert false let coerce_cons = function @@ -118,9 +115,140 @@ let coerce_cons = function | V_list [] -> None | _ -> assert false +let coerce_gv = function + | V_vector vs -> vs + | _ -> assert false + +let coerce_bv = function + | V_vector vs -> List.map coerce_bit vs + | _ -> assert false + +let coerce_string = function + | V_string str -> str + | _ -> assert false + let unit_value = V_unit -module StringMap = Map.Make(String) +let value_eq_int = function + | [v1; v2] -> V_bool (Sail_lib.eq_int (coerce_int v1, coerce_int v2)) + | _ -> failwith "value eq_int" + +let value_lteq = function + | [v1; v2] -> V_bool (Sail_lib.lteq (coerce_int v1, coerce_int v2)) + | _ -> failwith "value lteq" + +let value_gteq = function + | [v1; v2] -> V_bool (Sail_lib.gteq (coerce_int v1, coerce_int v2)) + | _ -> failwith "value gteq" + +let value_lt = function + | [v1; v2] -> V_bool (Sail_lib.lt (coerce_int v1, coerce_int v2)) + | _ -> failwith "value lt" + +let value_gt = function + | [v1; v2] -> V_bool (Sail_lib.gt (coerce_int v1, coerce_int v2)) + | _ -> failwith "value gt" + +let value_eq_list = function + | [v1; v2] -> V_bool (Sail_lib.eq_list (coerce_bv v1, coerce_bv v2)) + | _ -> failwith "value eq_list" + +let value_eq_string = function + | [v1; v2] -> V_bool (Sail_lib.eq_string (coerce_string v1, coerce_string v2)) + | _ -> failwith "value eq_string" + +let value_length = function + | [v] -> V_int (coerce_gv v |> List.length |> Big_int.of_int) + | _ -> failwith "value length" + +let value_subrange = function + | [v1; v2; v3] -> mk_vector (Sail_lib.subrange (coerce_bv v1, coerce_int v2, coerce_int v3)) + | _ -> failwith "value subrange" + +let value_access = function + | [v1; v2] -> Sail_lib.access (coerce_gv v1, coerce_int v2) + | _ -> failwith "value access" + +let value_update = function + | [v1; v2; v3] -> V_vector (Sail_lib.update (coerce_gv v1, coerce_int v2, v3)) + | _ -> failwith "value update" + +let value_update_subrange = function + | [v1; v2; v3; v4] -> mk_vector (Sail_lib.update_subrange (coerce_bv v1, coerce_int v2, coerce_int v3, coerce_bv v4)) + | _ -> failwith "value update_subrange" + +let value_append = function + | [v1; v2] -> V_vector (coerce_gv v1 @ coerce_gv v2) + | _ -> failwith "value append" + +let value_not = function + | [v] -> V_bool (not (coerce_bool v)) + | _ -> failwith "value not" + +let value_not_vec = function + | [v] -> mk_vector (Sail_lib.not_vec (coerce_bv v)) + | _ -> failwith "value not_vec" + +let value_and_vec = function + | [v1; v2] -> mk_vector (Sail_lib.and_vec (coerce_bv v1, coerce_bv v2)) + | _ -> failwith "value not_vec" + +let value_or_vec = function + | [v1; v2] -> mk_vector (Sail_lib.or_vec (coerce_bv v1, coerce_bv v2)) + | _ -> failwith "value not_vec" + +let value_uint = function + | [v] -> V_int (Sail_lib.uint (coerce_bv v)) + | _ -> failwith "value uint" + +let value_sint = function + | [v] -> V_int (Sail_lib.sint (coerce_bv v)) + | _ -> failwith "value sint" + +let value_get_slice_int = function + | [v1; v2; v3] -> mk_vector (Sail_lib.get_slice_int (coerce_int v1, coerce_int v2, coerce_int v3)) + | _ -> failwith "value get_slice_int" + +let value_add = function + | [v1; v2] -> V_int (Sail_lib.add (coerce_int v1, coerce_int v2)) + | _ -> failwith "value add" + +let value_sub = function + | [v1; v2] -> V_int (Sail_lib.sub (coerce_int v1, coerce_int v2)) + | _ -> failwith "value sub" + +let value_replicate_bits = function + | [v1; v2] -> mk_vector (Sail_lib.replicate_bits (coerce_bv v1, coerce_int v2)) + | _ -> failwith "value replicate_bits" + +let rec string_of_value = function + | V_vector vs -> Sail_lib.string_of_bits (List.map coerce_bit vs) + | V_bool true -> "true" + | V_bool false -> "false" + | V_bit B0 -> "bitzero" + | V_bit B1 -> "bitone" + | V_int n -> Big_int.to_string n + | V_tuple vals -> "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" + | V_list vals -> "[" ^ Util.string_of_list ", " string_of_value vals ^ "]" + | V_unit -> "()" + | V_string str -> "\"" ^ str ^ "\"" + | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" + | V_record record -> + "{" ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) ^ "}" + +let eq_value v1 v2 = string_of_value v1 = string_of_value v2 + +let value_eq_anything = function + | [v1; v2] -> V_bool (eq_value v1 v2) + | _ -> failwith "value eq_anything" + +let value_print = function + | [v] -> print_endline (string_of_value v |> Util.red |> Util.clear); V_unit + | _ -> assert false + +let value_undefined_vector = function + | [v1; v2; v3] -> V_vector (Sail_lib.undefined_vector (coerce_int v1, coerce_int v2, v3)) + | _ -> failwith "value undefined_vector" let primops = List.fold_left @@ -128,5 +256,34 @@ let primops = StringMap.empty [ ("and_bool", and_bool); ("or_bool", or_bool); - ("print_endline", print); + ("print_endline", value_print); + ("prerr_endline", value_print); + ("string_of_bits", fun vs -> V_string (string_of_value (List.hd vs))); + ("eq_int", value_eq_int); + ("lteq", value_lteq); + ("gteq", value_gteq); + ("lt", value_lt); + ("gt", value_gt); + ("eq_list", value_eq_list); + ("eq_string", value_eq_string); + ("eq_anything", value_eq_anything); + ("length", value_length); + ("subrange", value_subrange); + ("access", value_access); + ("update", value_update); + ("update_subrange", value_update_subrange); + ("append", value_append); + ("not", value_not); + ("not_vec", value_not_vec); + ("and_vec", value_and_vec); + ("or_vec", value_or_vec); + ("uint", value_uint); + ("sint", value_sint); + ("get_slice_int", value_get_slice_int); + ("add", value_add); + ("sub", value_sub); + ("undefined_bit", fun _ -> V_bit Sail_lib.B0); + ("undefined_vector", value_undefined_vector); + ("replicate_bits", value_replicate_bits); + ("Elf_loader.elf_entry", fun _ -> V_int (Big_int.of_int 0)); ] |
