diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/lem_interp/interp.lem | 90 | ||||
| -rw-r--r-- | src/lem_interp/interp_inter_imp.lem | 1 | ||||
| -rw-r--r-- | src/lem_interp/run_interp.ml | 10 | ||||
| -rw-r--r-- | src/type_internal.ml | 119 | ||||
| -rw-r--r-- | src/type_internal.mli | 1 |
5 files changed, 153 insertions, 68 deletions
diff --git a/src/lem_interp/interp.lem b/src/lem_interp/interp.lem index 135fdb8e..f0770d4e 100644 --- a/src/lem_interp/interp.lem +++ b/src/lem_interp/interp.lem @@ -8,6 +8,8 @@ open import Interp_ast type tannot = maybe (t * tag * list nec * effect) +let get_exp_l (E_aux e (l,annot)) = l + val pure : effect let pure = Effect_aux(Effect_set []) Unknown @@ -667,6 +669,8 @@ let resolve_outcome to_match value_thunk action_thunk = let update_stack (Action act stack) fn = Action act (fn stack) +let debug_out e tl lm le = (Action (Debug (get_exp_l e)) (Thunk_frame e tl le lm Top),lm,le) + (*Interpret a list of expressions, tracking local state but evaluating in the same scope (i.e. not tracking env) *) let rec exp_list mode t_level build_e build_v l_env l_mem vals exps = match exps with @@ -740,9 +744,11 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = | E_if cond thn els -> resolve_outcome (interp_main mode t_level l_env l_mem cond) (fun value lm le -> - match value with - | V_lit(L_aux L_true _) -> interp_main mode t_level l_env lm thn - | _ -> interp_main mode t_level l_env lm els end) + match (value,mode.eager_eval) with + | (V_lit(L_aux L_true _),true) -> interp_main mode t_level l_env lm thn + | (V_lit(L_aux L_true _),false) -> debug_out thn t_level lm l_env + | (_,true) -> interp_main mode t_level l_env lm els + | (_,false) -> debug_out els t_level lm l_env end) (fun a -> update_stack a (add_to_top_frame (fun c -> (E_aux (E_if c thn els) (l,annot))))) | E_for id from to_ by ((Ord_aux o _) as order) exp -> let is_inc = match o with @@ -764,24 +770,27 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = if (from_num = to_num) then (Value(V_lit (L_aux L_unit l)) Tag_empty,lm,le) else - let (ftyp,ttyp,btyp) = (val_typ fval,val_typ tval,val_typ bval) in - interp_main mode t_level le lm - (E_aux - (E_block - [(E_aux (E_let - (LB_aux (LB_val_implicit - (P_aux (P_id id) (fl,val_annot ftyp)) - (E_aux (E_lit(L_aux(L_num from_num) fl)) (fl,val_annot ftyp))) - (Unknown,val_annot ftyp)) - exp) (l,annot)); - (E_aux (E_for id - (if is_inc - then (E_aux (E_lit (L_aux (L_num (from_num + by_num)) fl)) (fl,val_annot (combine_typs [ftyp;ttyp]))) - else (E_aux (E_lit (L_aux (L_num (from_num - by_num)) fl)) (fl,val_annot (combine_typs [ttyp;ftyp])))) - (E_aux (E_lit (L_aux (L_num to_num) tl)) (tl,val_annot ttyp)) - (E_aux (E_lit (L_aux (L_num by_num) bl)) (bl,val_annot btyp)) - order exp) (l,annot))]) - (l,annot)) + let (ftyp,ttyp,btyp) = (val_typ fval,val_typ tval,val_typ bval) in + let e = (E_aux + (E_block + [(E_aux (E_let + (LB_aux (LB_val_implicit + (P_aux (P_id id) (fl,val_annot ftyp)) + (E_aux (E_lit(L_aux(L_num from_num) fl)) (fl,val_annot ftyp))) + (Unknown,val_annot ftyp)) + exp) (l,annot)); + (E_aux (E_for id + (if is_inc + then (E_aux (E_lit (L_aux (L_num (from_num + by_num)) fl)) (fl,val_annot (combine_typs [ftyp;ttyp]))) + else (E_aux (E_lit (L_aux (L_num (from_num - by_num)) fl)) (fl,val_annot (combine_typs [ttyp;ftyp])))) + (E_aux (E_lit (L_aux (L_num to_num) tl)) (tl,val_annot ttyp)) + (E_aux (E_lit (L_aux (L_num by_num) bl)) (bl,val_annot btyp)) + order exp) (l,annot))]) + (l,annot)) in + if mode.eager_eval + then interp_main mode t_level le lm e + else debug_out e t_level lm le + | _ -> (Error l "internal error: by must be a number",lm,le) end) (fun a -> update_stack a (add_to_top_frame (fun b -> @@ -801,7 +810,10 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = (fun v lm le -> match find_case pats v with | Nothing -> (Error l "No matching patterns in case",lm,le) - | Just (env,exp) -> interp_main mode t_level (env++l_env) lm exp + | Just (env,exp) -> + if mode.eager_eval + then interp_main mode t_level (env++l_env) lm exp + else debug_out exp t_level lm (env++l_env) end) (fun a -> update_stack a (add_to_top_frame (fun e -> (E_aux (E_case e pats) (l,annot))))) | E_record(FES_aux (FES_Fexps fexps _) fes_annot) -> @@ -850,7 +862,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = | V_record t fexps -> match in_env fexps id with | Just v -> (Value v Tag_empty,lm,l_env) | Nothing -> (Error l "Field not found in record",lm,le) end - | _ -> (Error l "Field access of vectors not implemented",lm,le) + | _ -> (Error l "Field access of vectors not implemented",lm,le) end ) (fun a -> match (exp,a) with @@ -985,7 +997,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = | _ -> false end in exp_list mode t_level (fun es -> (E_aux (E_vector_indexed (map2 (fun i e -> (i,e)) indexes es) default) (l,annot))) (fun vals -> V_vector (List_extra.head indexes) is_inc vals) l_env l_mem [] exps - | E_block(exps) -> interp_block mode t_level l_env l_env l_mem exps + | E_block(exps) -> interp_block mode t_level l_env l_env l_mem l annot exps | E_app f args -> (match (exp_list mode t_level (fun es -> E_aux (E_app f es) (l,annot)) (fun vs -> match vs with | [] -> V_lit (L_aux L_unit l) | [v] -> v | vs -> V_tuple vs end) @@ -1000,8 +1012,8 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = | Nothing -> (Error l (String.stringAppend "No matching pattern for function " name ),l_mem,l_env) | Just(env,exp) -> - resolve_outcome (interp_main mode t_level env l_mem exp) - (fun ret lm le -> (Value ret Tag_empty, lm,l_env)) + resolve_outcome (interp_main mode t_level env emem exp) + (fun ret lm le -> (Value ret Tag_empty, l_mem,l_env)) (fun a -> update_stack a (fun stack -> (Hole_frame (id_of_string "0") (E_aux (E_id (Id_aux (Id "0") l)) (l,annot)) t_level l_env l_mem stack))) end) @@ -1013,8 +1025,8 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = | Nothing -> (Error l (String.stringAppend "No matching pattern for function " name ),l_mem,l_env) | Just(env,exp) -> - resolve_outcome (interp_main mode t_level env l_mem exp) - (fun ret lm le -> (Value ret Tag_empty, lm,l_env)) + resolve_outcome (interp_main mode t_level env emem exp) + (fun ret lm le -> (Value ret Tag_empty, l_mem,l_env)) (fun a -> update_stack a (fun stack -> (Hole_frame (id_of_string "0") (E_aux (E_id (Id_aux (Id "0") l)) (l,annot)) t_level l_env l_mem stack))) end) @@ -1081,7 +1093,10 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = (fun a -> update_stack a (add_to_top_frame (fun lft -> (E_aux (E_app_infix lft op r) (l,annot))))) | E_let (lbind : letbind tannot) exp -> match (interp_letbind mode t_level l_env l_mem lbind) with - | ((Value v tag_l,lm,le),_) -> interp_main mode t_level le lm exp + | ((Value v tag_l,lm,le),_) -> + if mode.eager_eval + then interp_main mode t_level le lm exp + else debug_out exp t_level lm le | (((Action a s as o),lm,le),Just lbuild) -> ((update_stack o (add_to_top_frame (fun e -> (E_aux (E_let (lbuild e) exp) (l,annot))))),lm,le) | (e,_) -> e end @@ -1105,14 +1120,21 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = (fun a -> update_stack a (add_to_top_frame (fun v -> (E_aux (E_assign lexp v) (l,annot))))) end - and interp_block mode t_level init_env local_env local_mem exps = +(*TODO shrink location information on recursive calls *) + and interp_block mode t_level init_env local_env local_mem l tannot exps = match exps with | [ ] -> (Value (V_lit (L_aux (L_unit) Unknown)) Tag_empty, local_mem, init_env) - | [ exp ] -> interp_main mode t_level local_env local_mem exp + | [ exp ] -> + if mode.eager_eval + then interp_main mode t_level local_env local_mem exp + else debug_out exp t_level local_mem local_env | exp:: exps -> resolve_outcome (interp_main mode t_level local_env local_mem exp) - (fun _ lm le -> interp_block mode t_level init_env le lm exps) - (fun a -> update_stack a (add_to_top_frame (fun e -> (E_aux (E_block(e::exps)) (Unknown,Nothing))))) + (fun _ lm le -> + if mode.eager_eval + then interp_block mode t_level init_env le lm l tannot exps + else debug_out (E_aux (E_block exps) (l,tannot)) t_level lm le) + (fun a -> update_stack a (add_to_top_frame (fun e -> (E_aux (E_block(e::exps)) (l,tannot))))) end and create_write_message_or_update mode t_level value l_env l_mem is_top_level ((LEXP_aux lexp (l,annot)):lexp tannot) = @@ -1330,7 +1352,7 @@ and interp_letbind mode t_level l_env l_mem (LB_aux lbind (l,annot)) = | e -> (e,Nothing) end end -(*Beef up to pick up enums as well*) +(*TODO: Beef up to pick up enums as well*) let rec to_global_letbinds (Defs defs) t_level = let (Env defs' lets regs ctors subregs) = t_level in match defs with diff --git a/src/lem_interp/interp_inter_imp.lem b/src/lem_interp/interp_inter_imp.lem index 5186222b..03def81c 100644 --- a/src/lem_interp/interp_inter_imp.lem +++ b/src/lem_interp/interp_inter_imp.lem @@ -80,6 +80,7 @@ let rec interp_to_outcome mode thunk = match List.lookup i external_functions with | Nothing -> Error ("External function not available " ^ i) | Just f -> interp_to_outcome mode (fun _ -> Interp.resume mode next_state (Just (f value))) end + | Interp.Debug l -> Internal next_state end end diff --git a/src/lem_interp/run_interp.ml b/src/lem_interp/run_interp.ml index bd8640ba..df45a373 100644 --- a/src/lem_interp/run_interp.ml +++ b/src/lem_interp/run_interp.ml @@ -95,6 +95,8 @@ let act_to_string = function (sub_to_string sub) (val_to_string value) | Call_extern (name, arg) -> sprintf "extern call %s applied to %s" name (val_to_string arg) + | Debug l -> + sprintf "debug, next step at %s" (loc_to_string l) ;; let id_compare i1 i2 = @@ -200,6 +202,7 @@ let rec perform_action ((reg, mem) as env) = function perform_action env (Write_mem (id, n, slice, V_vector(zero_big_int, true, [value]))) (* extern functions *) | Call_extern (name, arg) -> eval_external name arg, env + | Debug l -> V_lit (L_aux(L_unit,Interp_ast.Unknown)),env | _ -> assert false ;; @@ -211,6 +214,7 @@ let run ?(reg=Reg.empty) ?(mem=Mem.empty) (name, test) = + let mode = {eager_eval = false} in let rec loop env = function | Value (v, _) -> debugf "%s: returned %s\n" name (val_to_string v); true, env | Action (a, s) -> @@ -218,14 +222,14 @@ let run (*debugf "%s: suspended on action %s, with stack %s\n" name (act_to_string a) (stack_to_string s);*) let return, env' = perform_action env a in debugf "%s: action returned %s\n" name (val_to_string return); - loop env' (resume {eager_eval = true} s (Some return)) + loop env' (resume mode s (Some return)) | Error(l, e) -> debugf "%s: %s: error: %s\n" name (loc_to_string l) e; false, env in debugf "%s: starting\n" name; try Printexc.record_backtrace true; - loop (reg, mem) (interp {eager_eval = true} test entry) + loop (reg, mem) (interp mode test entry) with e -> let trace = Printexc.get_backtrace () in debugf "%s: interpretor error %s\n%s\n" name (Printexc.to_string e) trace; - false, (reg, mem) + false, (reg, mem ;; diff --git a/src/type_internal.ml b/src/type_internal.ml index b8aa0825..b4e9f61c 100644 --- a/src/type_internal.ml +++ b/src/type_internal.ml @@ -36,6 +36,7 @@ and nexp_aux = | Nadd of nexp * nexp | Nmult of nexp * nexp | N2n of nexp + | Npow of nexp * int (* nexp raised to the nat *) | Nneg of nexp (* Unary minus for representing new vector sizes after vector slicing *) | Nuvar of n_uvar and n_uvar = { nindex : int; mutable nsubst : nexp option; mutable nin : bool; } @@ -138,6 +139,7 @@ and n_to_string n = | Nadd(n1,n2) -> "("^ (n_to_string n1) ^ " + " ^ (n_to_string n2) ^")" | Nmult(n1,n2) -> "(" ^ (n_to_string n1) ^ " * " ^ (n_to_string n2) ^ ")" | N2n n -> "2**" ^ (n_to_string n) + | Npow(n, i) -> "(" ^ (n_to_string n) ^ ")**" ^ (string_of_int i) | Nneg n -> "-" ^ (n_to_string n) | Nuvar({nindex=i;nsubst=a}) -> "Nu_" ^ string_of_int i ^ "()" and e_to_string e = @@ -234,21 +236,31 @@ let rec compare_nexps n1 n2 = | Nuvar {nindex = n1}, Nuvar {nindex = n2} -> compare n1 n2 | Nuvar _ , _ -> -1 | _ , Nuvar _ -> 1 - | Nmult(_,n1),Nmult(_,n2) -> compare_nexps n1 n2 + | Nmult(n0,n1),Nmult(n2,n3) -> + (match compare_nexps n0 n2 with + | 0 -> compare_nexps n1 n3 + | a -> a) | Nmult _ , _ -> -1 | _ , Nmult _ -> 1 - | Nadd(n1,_),Nadd(n2,_) -> compare_nexps n1 n2 + | Nadd(n1,n12),Nadd(n2,n22) -> + (match compare_nexps n1 n2 with + | 0 -> compare_nexps n12 n22 + | a -> a) | Nadd _ , _ -> -1 | _ , Nadd _ -> 1 + | Npow(n1,_),Npow(n2,_)-> compare_nexps n1 n2 + | Npow _ , _ -> -1 + | _ , Npow _ -> 1 | N2n n1 , N2n n2 -> compare_nexps n1 n2 | N2n _ , _ -> -1 | _ , N2n _ -> 1 | Nneg n1 , Nneg n2 -> compare_nexps n1 n2 -let rec two_pow n = +let rec pow_i i n = match n with | 0 -> 1 - | n -> 2*(two_pow (n-1)) + | n -> i*(pow_i i (n-1)) +let two_pow = pow_i 2 (* eval an nexp as much as possible *) let rec eval_nexp n = @@ -290,6 +302,14 @@ let rec eval_nexp n = | _ -> {nexp = N2n n1'}) | Nvar _ | Nuvar _ -> n +(* predicate to determine if pushing a constant in for addition or multiplication could change the form *) +let rec contains_const n = + match n.nexp with + | Nvar _ | Nuvar _ | Npow _ | N2n _ -> false + | Nconst _ -> true + | Nneg n -> contains_const n + | Nmult(n1,n2) | Nadd(n1,n2) -> (contains_const n1) || (contains_const n2) + let rec get_var n = match n.nexp with | Nvar _ | Nuvar _ -> Some n @@ -300,7 +320,7 @@ let rec get_var n = let get_factor n = match n.nexp with | Nvar _ | Nuvar _ -> {nexp = Nconst 1} - | Nmult (n1,_) | Nneg n1 -> n1 + | Nmult (n1,_) -> n1 | _ -> assert false let increment_factor n i = @@ -335,6 +355,11 @@ let rec normalize_nexp n = | _,true -> negate n' end else n' + | Npow(n,i) -> + let n' = normalize_nexp n in + (match n'.nexp with + | Nconst n -> {nexp = Nconst (pow_i i n)} + | _ -> {nexp = Npow(n', i)}) | N2n n -> let n' = normalize_nexp n in (match n'.nexp with @@ -344,20 +369,27 @@ let rec normalize_nexp n = let n1',n2' = normalize_nexp n1, normalize_nexp n2 in (match n1'.nexp,n2'.nexp with | Nconst i1, Nconst i2 -> {nexp = Nconst (i1+i2)} - | Nconst _, Nvar _ | Nconst _, Nuvar _ | Nconst _, N2n _ | Nconst _, Nneg _ | Nconst _, Nmult _ -> {nexp = Nadd(n2',n1') } - | Nvar _, Nconst _ | Nuvar _, Nconst _ | Nmult _, Nconst _ | N2n _, Nconst _ -> {nexp = Nadd(n1',n2')} - | Nvar _, Nuvar _ | Nvar _, N2n _ -> {nexp = Nadd (n2',n1')} + | Nconst _, Nvar _ | Nconst _, Nuvar _ | Nconst _, N2n _ | Nconst _, Npow _ | Nconst _, Nneg _ | Nconst _, Nmult _ -> {nexp = Nadd(n2',n1') } + | Nvar _, Nconst _ | Nuvar _, Nconst _ | Nmult _, Nconst _ | N2n _, Nconst _ | Npow _, Nconst _-> {nexp = Nadd(n1',n2')} + | Nvar _, Nuvar _ | Nvar _, N2n _ | Nuvar _, Npow _ -> {nexp = Nadd (n2',n1')} | Nadd(n11,n12), Nadd(n21,n22) -> (match compare_nexps n11 n21 with - | -1 | 0 -> normalize_nexp {nexp = Nadd(n11, {nexp = Nadd(n12,n2')})} - | _ -> normalize_nexp {nexp = Nadd(n21, { nexp = Nadd(n22,n1') })}) - | Nadd(n11,n12), Nconst _ -> {nexp = Nadd(n11,{nexp = Nadd(n12,n2')}) } - | Nconst _, Nadd(n21,n22) -> {nexp = Nadd(n21,{nexp = Nadd(n22,n1')})} + | -1 -> {nexp = Nadd(n11, (normalize_nexp {nexp = Nadd(n12,n2')}))} + | 0 -> normalize_nexp {nexp = Nmult({nexp = Nconst 2},n1')} + | _ -> normalize_nexp {nexp = Nadd(n21, { nexp = Nadd(n22,n1') })}) + | Nadd(n11,n12), Nconst _ -> {nexp = Nadd(n11,normalize_nexp {nexp = Nadd(n12,n2')})} + | Nconst _, Nadd(n21,n22) -> {nexp = Nadd(n21,normalize_nexp {nexp = Nadd(n22,n1')})} | N2n n1, N2n n2 -> (match compare_nexps n1 n2 with - | -1 | 0 -> {nexp = Nadd (n2',n1')} - | _ -> { nexp = Nadd (n1',n2')}) - | _ -> + | -1 -> {nexp = Nadd (n2',n1')} + | 0 -> {nexp = N2n (normalize_nexp {nexp = Nadd(n1, {nexp = Nconst 1})})} + | _ -> { nexp = Nadd (n1',n2')}) + | Npow(n1,i1), Npow (n2,i2) -> + (match compare_nexps n1 n2, compare i1 i2 with + | -1,-1 | 0,-1 -> {nexp = Nadd (n2',n1')} + | 0,0 -> {nexp = Nmult ({nexp = Nconst 2},n1')} + | _ -> {nexp = Nadd (n1',n2')}) + | _ -> match get_var n1', get_var n2' with | Some(nv1),Some(nv2) -> (match compare_nexps nv1 nv2 with @@ -370,29 +402,54 @@ let rec normalize_nexp n = (match n1'.nexp,n2'.nexp with | Nconst i1, Nconst i2 -> {nexp = Nconst (i1*i2)} | Nconst 2, N2n n2 | N2n n2, Nconst 2 -> {nexp =N2n (normalize_nexp {nexp = Nadd(n2, {nexp = Nconst 1})})} - | Nconst _, Nvar _ | Nconst _, Nuvar _ | Nconst _, N2n _ | Nvar _, N2n _ -> { nexp = Nmult(n1',n2') } - | Nvar _, Nconst _ | Nuvar _, Nconst _ | N2n _, Nconst _ | Nvar _, Nmult _ | Nvar _, Nuvar _ -> { nexp = Nmult(n2',n1') } + | Nconst _, Nvar _ | Nconst _, Nuvar _ | Nconst _, N2n _ | Nconst _, Npow _ | Nvar _, N2n _ -> { nexp = Nmult(n1',n2') } + | Nvar _, Nconst _ | Nuvar _, Nconst _ | N2n _, Nconst _ | Npow _, Nconst _ | Nvar _, Nmult _ | Nvar _, Nuvar _ -> { nexp = Nmult(n2',n1') } | N2n n1, N2n n2 -> {nexp = N2n (normalize_nexp {nexp = Nadd(n1,n2)})} - | N2n _, Nvar _ | N2n _, Nuvar _ | N2n _, Nmult _ | Nuvar _, N2n _ | Nuvar _, Nmult _ -> {nexp =Nmult(n2',n1')} + | N2n _, Nvar _ | N2n _, Nuvar _ | N2n _, Nmult _ | Nuvar _, N2n _ | Nuvar _, Nmult _ -> {nexp =Nmult(n2',n1')} | Nuvar {nindex = i1}, Nuvar {nindex = i2} -> (match compare i1 i2 with - | 0 | 1 -> {nexp = Nmult(n1',n2')} + | 0 -> {nexp = Npow(n1', 2)} + | 1 -> {nexp = Nmult(n1',n2')} | _ -> {nexp = Nmult(n2',n1')}) | Nvar i1, Nvar i2 -> (match compare i1 i2 with - | 0 | 1 -> {nexp = Nmult(n1',n2')} + | 0 -> {nexp = Npow(n1', 2)} + | 1 -> {nexp = Nmult(n1',n2')} | _ -> {nexp = Nmult(n2',n1')}) - | Nconst _, Nadd(n21,n22) | Nvar _,Nadd(n21,n22) | Nuvar _,Nadd(n21,n22) | N2n _, Nadd(n21,n22) | Nmult _, Nadd(n21,n22) -> + | Npow(n1,i1),Npow(n2,i2) -> + (match compare_nexps n1 n2 with + | 0 -> {nexp = Npow(n1,(i1+i2))} + | -1 -> {nexp = Nmult(n2',n1')} + | _ -> {nexp = Nmult(n1',n2')}) + | Nconst _, Nadd(n21,n22) | Nvar _,Nadd(n21,n22) | Nuvar _,Nadd(n21,n22) | N2n _, Nadd(n21,n22) | Npow _,Nadd(n21,n22) | Nmult _, Nadd(n21,n22) -> normalize_nexp {nexp = Nadd( {nexp = Nmult(n1',n21)}, {nexp = Nmult(n1',n21)})} - | Nadd(n11,n12),Nconst _ | Nadd(n11,n12),Nvar _ | Nadd(n11,n12), Nuvar _ | Nadd(n11,n12), N2n _ | Nadd(n11,n12), Nmult _-> + | Nadd(n11,n12),Nconst _ | Nadd(n11,n12),Nvar _ | Nadd(n11,n12), Nuvar _ | Nadd(n11,n12), N2n _ | Nadd(n11,n12),Npow _ | Nadd(n11,n12), Nmult _-> normalize_nexp {nexp = Nadd( {nexp = Nmult(n11,n2')}, {nexp = Nmult(n12,n2')})} | Nadd(n11,n12),Nadd(n21,n22) -> - {nexp = Nadd( {nexp = Nmult(n11,n21)}, - {nexp = Nadd ({nexp = Nmult(n11,n22)}, - {nexp = Nadd({nexp = Nmult(n12,n21)}, - {nexp = Nmult(n12,n22)})})})} - | Nuvar _, Nvar _ | Nmult _, Nvar _| Nmult _, Nuvar _ | Nmult _, N2n _-> {nexp = Nmult (n1',n2')} + normalize_nexp {nexp = Nadd( {nexp = Nmult(n11,n21)}, + {nexp = Nadd ({nexp = Nmult(n11,n22)}, + {nexp = Nadd({nexp = Nmult(n12,n21)}, + {nexp = Nmult(n12,n22)})})})} + | Nuvar _, Nvar _ | Nmult _, N2n _-> {nexp = Nmult (n1',n2')} | Nmult(n11,n12), Nconst _ -> {nexp = Nmult({nexp = Nmult(n11,n2')},{nexp = Nmult(n12,n2')})} + | Nuvar _, Nmult(n1,n2) | Nvar _, Nmult(n1,n2) -> + (match get_var n1, get_var n2 with + | Some(nv1),Some(nv2) -> + (match compare_nexps nv1 nv2, n2.nexp with + | 0, Nuvar _ | 0, Nvar _ -> {nexp = Nmult(n1, {nexp = Npow(nv1,2)}) } + | 0, Npow(n2',i) -> {nexp = Nmult(n1, {nexp = Npow (n2',(i+1))})} + | -1, Nuvar _ | -1, Nvar _ -> {nexp = Nmult(n2',n1')} + | _,_ | _,_ -> {nexp = Nmult(normalize_nexp {nexp = Nmult(n1,n1')},n2)}) + | _ -> {nexp = Nmult(normalize_nexp {nexp = Nmult(n1,n1')},n2)}) + | Npow(n1,i),Nmult(n21,n22) -> + (match get_var n1, get_var n2 with + | Some(nv1),Some(nv2) -> + (match compare_nexps nv1 nv2,n22.nexp with + | 0, Nuvar _ | 0, Nvar _ -> {nexp = Nmult(n21,{nexp = Npow(n1,i+1)})} + | 0, Npow(_,i2) -> {nexp = Nmult(n21,{nexp=Npow(n1,i+i2)})} + | 1,Npow _ -> {nexp = Nmult(normalize_nexp {nexp = Nmult(n21,n1')},n22)} + | _ -> {nexp = Nmult(n2',n1')}) + | _ -> {nexp = Nmult(normalize_nexp {nexp = Nmult(n1',n21)},n22)}) | Nmult _ ,Nmult(n21,n22) | Nconst _, Nmult(n21,n22) -> {nexp = Nmult({nexp = Nmult(n21,n1')},{nexp = Nmult(n22,n1')})} | Nneg _,_ | _,Nneg _ -> assert false (* If things are normal, neg should be gone. *) ) @@ -997,7 +1054,7 @@ let rec nexp_eq_check n1 n2 = | _,_ -> false let nexp_eq n1 n2 = - nexp_eq_check (eval_nexp n1) (eval_nexp n2) + nexp_eq_check (normalize_nexp n1) (normalize_nexp n2) (*Is checking for structural equality amongst the types, building constraints for kind Nat. When considering two range type applications, will check for consistency instead of equality*) @@ -1305,7 +1362,7 @@ let rec simple_constraint_check in_env cs = | Eq(co,n1,n2)::cs -> let check_eq ok_to_set n1 n2 = (*let _ = Printf.printf "eq check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *) - let n1',n2' = eval_nexp n1,eval_nexp n2 in + let n1',n2' = normalize_nexp n1,normalize_nexp n2 in (*let _ = Printf.printf "finished evaled to %s, %s\n" (n_to_string n1') (n_to_string n2') in *) (match n1'.nexp,n2'.nexp with | Nconst i1, Nconst i2 -> @@ -1333,7 +1390,7 @@ let rec simple_constraint_check in_env cs = | _ -> (Eq(co,n1,n2)::(check cs))) | GtEq(co,n1,n2)::cs -> (* let _ = Printf.printf ">= check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *) - let n1',n2' = eval_nexp n1,eval_nexp n2 in + let n1',n2' = normalize_nexp n1,normalize_nexp n2 in (* let _ = Printf.printf "finished evaled to %s, %s\n" (n_to_string n1') (n_to_string n2') in *) (match n1'.nexp,n2'.nexp with | Nconst i1, Nconst i2 -> @@ -1344,7 +1401,7 @@ let rec simple_constraint_check in_env cs = | _,_ -> GtEq(co,n1',n2')::(check cs)) | LtEq(co,n1,n2)::cs -> (* let _ = Printf.printf "<= check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *) - let n1',n2' = eval_nexp n1,eval_nexp n2 in + let n1',n2' = normalize_nexp n1,normalize_nexp n2 in (* let _ = Printf.printf "finished evaled to %s, %s\n" (n_to_string n1') (n_to_string n2') in *) (match n1'.nexp,n2'.nexp with | Nconst i1, Nconst i2 -> diff --git a/src/type_internal.mli b/src/type_internal.mli index 13c27386..c64298a1 100644 --- a/src/type_internal.mli +++ b/src/type_internal.mli @@ -34,6 +34,7 @@ and nexp_aux = | Nadd of nexp * nexp | Nmult of nexp * nexp | N2n of nexp + | Npow of nexp * int | Nneg of nexp | Nuvar of n_uvar and effect = { mutable effect : effect_aux } |
