diff options
| author | Pierre-Marie Pédrot | 2019-04-23 21:35:20 +0200 |
|---|---|---|
| committer | Pierre-Marie Pédrot | 2019-04-23 21:35:20 +0200 |
| commit | 75c5264aa687480c66a6765d64246b5ebd2c0d54 (patch) | |
| tree | 756baf747199c1f88c601c514887adabd9a05c5f /kernel/nativecode.ml | |
| parent | 9834f23fe9bc8a659ed36c426d557e94179476b0 (diff) | |
| parent | 8b886a0a201444b7eb782f3fa0dc52a7b6fe8837 (diff) | |
Merge PR #9962: [native compiler] Encoding of constructors based on tags
Ack-by: maximedenes
Reviewed-by: ppedrot
Diffstat (limited to 'kernel/nativecode.ml')
| -rw-r--r-- | kernel/nativecode.ml | 245 |
1 files changed, 148 insertions, 97 deletions
diff --git a/kernel/nativecode.ml b/kernel/nativecode.ml index d7ec2ecf72..3f791dfc22 100644 --- a/kernel/nativecode.ml +++ b/kernel/nativecode.ml @@ -377,8 +377,8 @@ type mllambda = | MLif of mllambda * mllambda * mllambda | MLmatch of annot_sw * mllambda * mllambda * mllam_branches (* argument, prefix, accu branch, branches *) - | MLconstruct of string * constructor * mllambda array - (* prefix, constructor name, arguments *) + | MLconstruct of string * inductive * int * mllambda array + (* prefix, inductive name, tag, arguments *) | MLint of int | MLuint of Uint63.t | MLsetref of string * mllambda @@ -386,7 +386,11 @@ type mllambda = | MLarray of mllambda array | MLisaccu of string * inductive * mllambda -and mllam_branches = ((constructor * lname option array) list * mllambda) array +and 'a mllam_pattern = + | ConstPattern of int + | NonConstPattern of tag * 'a array + +and mllam_branches = (lname option mllam_pattern list * mllambda) array let push_lnames n env lns = snd (Array.fold_left (fun (i,r) x -> (i+1, LNmap.add x i r)) (n,env) lns) @@ -439,9 +443,10 @@ let rec eq_mllambda gn1 gn2 n env1 env2 t1 t2 = eq_mllambda gn1 gn2 n env1 env2 c1 c2 && eq_mllambda gn1 gn2 n env1 env2 accu1 accu2 && eq_mllam_branches gn1 gn2 n env1 env2 br1 br2 - | MLconstruct (pf1, cs1, args1), MLconstruct (pf2, cs2, args2) -> + | MLconstruct (pf1, ind1, tag1, args1), MLconstruct (pf2, ind2, tag2, args2) -> String.equal pf1 pf2 && - eq_constructor cs1 cs2 && + eq_ind ind1 ind2 && + Int.equal tag1 tag2 && Array.equal (eq_mllambda gn1 gn2 n env1 env2) args1 args2 | MLint i1, MLint i2 -> Int.equal i1 i2 @@ -474,15 +479,22 @@ and eq_letrec gn1 gn2 n env1 env2 defs1 defs2 = (* we require here that patterns have the same order, which may be too strong *) and eq_mllam_branches gn1 gn2 n env1 env2 br1 br2 = - let eq_cargs (cs1, args1) (cs2, args2) body1 body2 = + let eq_cargs args1 args2 body1 body2 = Int.equal (Array.length args1) (Array.length args2) && - eq_constructor cs1 cs2 && let env1 = opush_lnames n env1 args1 in let env2 = opush_lnames n env2 args2 in eq_mllambda gn1 gn2 (n + Array.length args1) env1 env2 body1 body2 in - let eq_branch (ptl1,body1) (ptl2,body2) = - List.equal (fun pt1 pt2 -> eq_cargs pt1 pt2 body1 body2) ptl1 ptl2 + let eq_pattern pat1 pat2 body1 body2 = + match pat1, pat2 with + | ConstPattern tag1, ConstPattern tag2 -> + Int.equal tag1 tag2 && eq_mllambda gn1 gn2 n env1 env2 body1 body2 + | NonConstPattern (tag1,args1), NonConstPattern (tag2,args2) -> + Int.equal tag1 tag2 && eq_cargs args1 args2 body1 body2 + | (ConstPattern _ | NonConstPattern _), _ -> false + in + let eq_branch (patl1,body1) (patl2,body2) = + List.equal (fun pt1 pt2 -> eq_pattern pt1 pt2 body1 body2) patl1 patl2 in Array.equal eq_branch br1 br2 @@ -518,10 +530,11 @@ let rec hash_mllambda gn n env t = let hc = hash_mllambda gn n env c in let haccu = hash_mllambda gn n env accu in combinesmall 9 (hash_mllam_branches gn n env (combine3 hannot hc haccu) br) - | MLconstruct (pf, cs, args) -> + | MLconstruct (pf, ind, tag, args) -> let hpf = String.hash pf in - let hcs = constructor_hash cs in - combinesmall 10 (hash_mllambda_array gn n env (combine hpf hcs) args) + let hcs = ind_hash ind in + let htag = Int.hash tag in + combinesmall 10 (hash_mllambda_array gn n env (combine3 hpf hcs htag) args) | MLint i -> combinesmall 11 i | MLuint i -> @@ -551,15 +564,18 @@ and hash_mllambda_array gn n env init arr = Array.fold_left (fun acc t -> combine (hash_mllambda gn n env t) acc) init arr and hash_mllam_branches gn n env init br = - let hash_cargs (cs, args) body = + let hash_cargs args body = let nargs = Array.length args in - let hcs = constructor_hash cs in let env = opush_lnames n env args in let hbody = hash_mllambda gn (n + nargs) env body in - combine3 nargs hcs hbody + combine nargs hbody + in + let hash_pattern pat body = match pat with + | ConstPattern i -> combinesmall 1 (Int.hash i) + | NonConstPattern (tag,args) -> combinesmall 2 (combine (Int.hash tag) (hash_cargs args body)) in let hash_branch acc (ptl,body) = - List.fold_left (fun acc t -> combine (hash_cargs t body) acc) acc ptl + List.fold_left (fun acc t -> combine (hash_pattern t body) acc) acc ptl in Array.fold_left hash_branch init br @@ -589,17 +605,20 @@ let fv_lam l = | MLmatch(_,a,p,bs) -> let fv = aux a bind (aux p bind fv) in let fv_bs (cargs, body) fv = - let bind = - List.fold_right (fun (_,args) bind -> - Array.fold_right - (fun o bind -> match o with - | Some l -> LNset.add l bind - | _ -> bind) args bind) - cargs bind in - aux body bind fv in + let bind = + List.fold_right (fun pat bind -> + match pat with + | ConstPattern _ -> bind + | NonConstPattern(_,args) -> + Array.fold_right + (fun o bind -> match o with + | Some l -> LNset.add l bind + | _ -> bind) args bind) + cargs bind in + aux body bind fv in Array.fold_right fv_bs bs fv - (* argument, accu branch, branches *) - | MLconstruct (_,_,p) -> + (* argument, accu branch, branches *) + | MLconstruct (_,_,_,p) -> Array.fold_right (fun a fv -> aux a bind fv) p fv | MLsetref(_,l) -> aux l bind fv | MLsequence(l1,l2) -> aux l1 bind (aux l2 bind fv) @@ -647,8 +666,8 @@ type global = | Gletcase of gname * lname array * annot_sw * mllambda * mllambda * mllam_branches | Gopen of string - | Gtype of inductive * int array - (* ind name, arities of constructors *) + | Gtype of inductive * (tag * int) array + (* ind name, tag and arities of constructors *) | Gcomment of string (* Alpha-equivalence on globals *) @@ -673,7 +692,8 @@ let eq_global g1 g2 = eq_mllambda gn1 gn2 (Array.length lns1) env1 env2 t1 t2 | Gopen s1, Gopen s2 -> String.equal s1 s2 | Gtype (ind1, arr1), Gtype (ind2, arr2) -> - eq_ind ind1 ind2 && Array.equal Int.equal arr1 arr2 + eq_ind ind1 ind2 && + Array.equal (fun (tag1,ar1) (tag2,ar2) -> Int.equal tag1 tag2 && Int.equal ar1 ar2) arr1 arr2 | Gcomment s1, Gcomment s2 -> String.equal s1 s2 | _, _ -> false @@ -700,7 +720,10 @@ let hash_global g = combinesmall 4 (combine nlns (hash_mllambda gn nlns env t)) | Gopen s -> combinesmall 5 (String.hash s) | Gtype (ind, arr) -> - combinesmall 6 (combine (ind_hash ind) (Array.fold_left combine 0 arr)) + let hash_aux acc (tag,ar) = + combine3 acc (Int.hash tag) (Int.hash ar) + in + combinesmall 6 (combine (ind_hash ind) (Array.fold_left hash_aux 0 arr)) | Gcomment s -> combinesmall 7 (String.hash s) let global_stack = ref ([] : global list) @@ -907,26 +930,33 @@ let get_proj_code i = [|MLglobal symbols_tbl_name; MLint i|]) type rlist = - | Rnil - | Rcons of (constructor * lname option array) list ref * LNset.t * mllambda * rlist' + | Rnil + | Rcons of lname option mllam_pattern list ref * LNset.t * mllambda * rlist' and rlist' = rlist ref -let rm_params fv params = - Array.map (fun l -> if LNset.mem l fv then Some l else None) params +let rm_params fv params = + Array.map (fun l -> if LNset.mem l fv then Some l else None) params -let rec insert cargs body rl = +let rec insert pat body rl = match !rl with | Rnil -> let fv = fv_lam body in - let (c,params) = cargs in - let params = rm_params fv params in - rl:= Rcons(ref [(c,params)], fv, body, ref Rnil) + begin match pat with + | ConstPattern _ as p -> + rl:= Rcons(ref [p], fv, body, ref Rnil) + | NonConstPattern (tag,args) -> + let args = rm_params fv args in + rl:= Rcons(ref [NonConstPattern (tag,args)], fv, body, ref Rnil) + end | Rcons(l,fv,body',rl) -> - if eq_mllambda body body' then - let (c,params) = cargs in - let params = rm_params fv params in - l := (c,params)::!l - else insert cargs body rl + if eq_mllambda body body' then + match pat with + | ConstPattern _ as p -> + l := p::!l + | NonConstPattern (tag,args) -> + let args = rm_params fv args in + l := NonConstPattern (tag,args)::!l + else insert pat body rl let rec to_list rl = match !rl with @@ -935,7 +965,7 @@ let rec to_list rl = let merge_branches t = let newt = ref Rnil in - Array.iter (fun (c,args,body) -> insert (c,args) body newt) t; + Array.iter (fun (pat,body) -> insert pat body newt) t; Array.of_list (to_list newt) let app_prim p args = MLapp(MLprimitive p, args) @@ -1092,14 +1122,19 @@ let ml_of_instance instance u = let a_uid = fresh_lname Anonymous in let la_uid = MLlocal a_uid in (* compilation of branches *) - let ml_br (c,params, body) = - let lnames, env_c = push_rels env_c params in - (c, lnames, ml_of_lam env_c l body) + let nbconst = Array.length bs.constant_branches in + let nbtotal = nbconst + Array.length bs.nonconstant_branches in + let br = Array.init nbtotal (fun i -> if i < Array.length bs.constant_branches then + (ConstPattern i, ml_of_lam env_c l bs.constant_branches.(i)) + else + let (params, body) = bs.nonconstant_branches.(i-nbconst) in + let lnames, env_c = push_rels env_c params in + (NonConstPattern (i-nbconst+1,lnames), ml_of_lam env_c l body) + ) in - let bs = Array.map ml_br bs in let cn = fresh_gcase l in (* Compilation of accu branch *) - let pred = MLapp(MLglobal pn, fv_args env_c pfvn pfvr) in + let pred = MLapp(MLglobal pn, fv_args env_c pfvn pfvr) in let (fvn, fvr) = !(env_c.env_named), !(env_c.env_urel) in let cn_fv = mkMLapp (MLglobal cn) (fv_args env_c fvn fvr) in (* remark : the call to fv_args does not add free variables in env_c *) @@ -1112,7 +1147,7 @@ let ml_of_instance instance u = (* let body = MLlam([|a_uid|], MLmatch(annot, la_uid, accu, bs)) in let case = generalize_fv env_c body in *) let cn = push_global_case cn (Array.append (fv_params env_c) [|a_uid|]) - annot la_uid accu (merge_branches bs) + annot la_uid accu (merge_branches br) in (* Final result *) let arg = ml_of_lam env l a in @@ -1272,9 +1307,11 @@ let ml_of_instance instance u = (lname, paramsi, body) in MLletrec(Array.mapi mkrec lf, lf_args.(start)) *) - | Lmakeblock (prefix,(cn,_u),_,args) -> + | Lint tag -> MLapp(MLprimitive Mk_int, [|MLint tag|]) + + | Lmakeblock (prefix,cn,tag,args) -> let args = Array.map (ml_of_lam env l) args in - MLconstruct(prefix,cn,args) + MLconstruct(prefix,cn,tag,args) | Luint i -> MLapp(MLprimitive Mk_uint, [|MLuint i|]) | Lval v -> let i = push_symbol (SymbValue v) in get_value_code i @@ -1337,7 +1374,7 @@ let subst s l = | MLmatch(annot,a,accu,bs) -> let auxb (cargs,body) = (cargs,aux body) in MLmatch(annot,a,aux accu, Array.map auxb bs) - | MLconstruct(prefix,c,args) -> MLconstruct(prefix,c,Array.map aux args) + | MLconstruct(prefix,c,tag,args) -> MLconstruct(prefix,c,tag,Array.map aux args) | MLsetref(s,l1) -> MLsetref(s,aux l1) | MLsequence(l1,l2) -> MLsequence(aux l1, aux l2) | MLarray arr -> MLarray (Array.map aux arr) @@ -1446,8 +1483,8 @@ let optimize gdef l = | MLmatch(annot,a,accu,bs) -> let opt_b (cargs,body) = (cargs,optimize s body) in MLmatch(annot, optimize s a, subst s accu, Array.map opt_b bs) - | MLconstruct(prefix,c,args) -> - MLconstruct(prefix,c,Array.map (optimize s) args) + | MLconstruct(prefix,c,tag,args) -> + MLconstruct(prefix,c,tag,Array.map (optimize s) args) | MLsetref(r,l) -> MLsetref(r, optimize s l) | MLsequence(l1,l2) -> MLsequence(optimize s l1, optimize s l2) | MLarray arr -> MLarray (Array.map (optimize s) arr) @@ -1520,6 +1557,7 @@ let string_of_kn kn = let string_of_con c = string_of_kn (Constant.user c) let string_of_mind mind = string_of_kn (MutInd.user mind) +let string_of_ind (mind,i) = string_of_kn (MutInd.user mind) ^ "_" ^ string_of_int i let string_of_gname g = match g with @@ -1557,10 +1595,13 @@ let pp_ldecls fmt ids = Format.fprintf fmt " (%a : Nativevalues.t)" pp_lname ids.(i) done -let string_of_construct prefix ((mind,i),j) = - let id = Format.sprintf "Construct_%s_%i_%i" (string_of_mind mind) i (j-1) in - prefix ^ id - +let string_of_construct prefix ~constant ind tag = + let base = if constant then "Int" else "Construct" in + Format.sprintf "%s%s_%s_%i" prefix base (string_of_ind ind) tag + +let string_of_accu_construct prefix ind = + Format.sprintf "%sAccu_%s" prefix (string_of_ind ind) + let pp_int fmt i = if i < 0 then Format.fprintf fmt "(%i)" i else Format.fprintf fmt "%i" i @@ -1586,16 +1627,16 @@ let pp_mllam fmt l = Format.fprintf fmt "@[(if %a then@\n %a@\nelse@\n %a)@]" pp_mllam t pp_mllam l1 pp_mllam l2 | MLmatch (annot, c, accu_br, br) -> - let mind,i = annot.asw_ind in + let ind = annot.asw_ind in let prefix = annot.asw_prefix in - let accu = Format.sprintf "%sAccu_%s_%i" prefix (string_of_mind mind) i in - Format.fprintf fmt - "@[begin match Obj.magic (%a) with@\n| %s _ ->@\n %a@\n%aend@]" - pp_mllam c accu pp_mllam accu_br (pp_branches prefix) br - - | MLconstruct(prefix,c,args) -> + let accu = string_of_accu_construct prefix ind in + Format.fprintf fmt + "@[begin match Obj.magic (%a) with@\n| %s _ ->@\n %a@\n%aend@]" + pp_mllam c accu pp_mllam accu_br (pp_branches prefix ind) br + + | MLconstruct(prefix,ind,tag,args) -> Format.fprintf fmt "@[(Obj.magic (%s%a) : Nativevalues.t)@]" - (string_of_construct prefix c) pp_cargs args + (string_of_construct prefix ~constant:false ind tag) pp_cargs args | MLint i -> pp_int fmt i | MLuint i -> Format.fprintf fmt "(%s)" (Uint63.compile i) | MLsetref (s, body) -> @@ -1612,8 +1653,8 @@ let pp_mllam fmt l = pp_mllam fmt arr.(len-1) end; Format.fprintf fmt "|]@]" - | MLisaccu (prefix, (mind, i), c) -> - let accu = Format.sprintf "%sAccu_%s_%i" prefix (string_of_mind mind) i in + | MLisaccu (prefix, ind, c) -> + let accu = string_of_accu_construct prefix ind in Format.fprintf fmt "@[begin match Obj.magic (%a) with@\n| %s _ ->@\n true@\n| _ ->@\n false@\nend@]" pp_mllam c accu @@ -1636,7 +1677,7 @@ let pp_mllam fmt l = | MLprimitive (Mk_prod | Mk_sort) (* FIXME: why this special case? *) | MLlam _ | MLletrec _ | MLlet _ | MLapp _ | MLif _ -> Format.fprintf fmt "(%a)" pp_mllam l - | MLconstruct(_,_,args) when Array.length args > 0 -> + | MLconstruct(_,_,_,args) when Array.length args > 0 -> Format.fprintf fmt "(%a)" pp_mllam l | _ -> pp_mllam fmt l @@ -1675,19 +1716,23 @@ let pp_mllam fmt l = done in Format.fprintf fmt "(%a)" aux params - and pp_branches prefix fmt bs = + and pp_branches prefix ind fmt bs = let pp_branch (cargs,body) = - let pp_c fmt (cn,args) = - Format.fprintf fmt "| %s%a " - (string_of_construct prefix cn) pp_cparams args in - let rec pp_cargs fmt cargs = - match cargs with - | [] -> () - | cargs::cargs' -> - Format.fprintf fmt "%a%a" pp_c cargs pp_cargs cargs' in - Format.fprintf fmt "%a ->@\n %a@\n" - pp_cargs cargs pp_mllam body + let pp_pat fmt = function + | ConstPattern i -> + Format.fprintf fmt "| %s " + (string_of_construct prefix ~constant:true ind i) + | NonConstPattern (tag,args) -> + Format.fprintf fmt "| %s%a " + (string_of_construct prefix ~constant:false ind tag) pp_cparams args in + let rec pp_pats fmt pats = + match pats with + | [] -> () + | pat::pats -> + Format.fprintf fmt "%a%a" pp_pat pat pp_pats pats in + Format.fprintf fmt "%a ->@\n %a@\n" pp_pats cargs pp_mllam body + in Array.iter pp_branch bs and pp_primitive fmt = function @@ -1761,19 +1806,24 @@ let pp_global fmt g = pp_mllam c | Gopen s -> Format.fprintf fmt "@[open %s@]@." s - | Gtype ((mind, i), lar) -> - let l = string_of_mind mind in - let rec aux s ar = - if Int.equal ar 0 then s else aux (s^" * Nativevalues.t") (ar-1) in - let pp_const_sig i fmt j ar = - let sig_str = if ar > 0 then aux "of Nativevalues.t" (ar-1) else "" in - Format.fprintf fmt " | Construct_%s_%i_%i %s@\n" l i j sig_str - in - let pp_const_sigs i fmt lar = - Format.fprintf fmt " | Accu_%s_%i of Nativevalues.t@\n" l i; - Array.iteri (pp_const_sig i fmt) lar - in - Format.fprintf fmt "@[type ind_%s_%i =@\n%a@]@\n@." l i (pp_const_sigs i) lar + | Gtype (ind, lar) -> + let rec aux s arity = + if Int.equal arity 0 then s else aux (s^" * Nativevalues.t") (arity-1) in + let pp_const_sig fmt (tag,arity) = + if arity > 0 then + let sig_str = aux "of Nativevalues.t" (arity-1) in + let cstr = string_of_construct "" ~constant:false ind tag in + Format.fprintf fmt " | %s %s@\n" cstr sig_str + else + let sig_str = if arity > 0 then aux "of Nativevalues.t" (arity-1) else "" in + let cstr = string_of_construct "" ~constant:true ind tag in + Format.fprintf fmt " | %s %s@\n" cstr sig_str + in + let pp_const_sigs fmt lar = + Format.fprintf fmt " | %s of Nativevalues.t@\n" (string_of_accu_construct "" ind); + Array.iter (pp_const_sig fmt) lar + in + Format.fprintf fmt "@[type ind_%s =@\n%a@]@\n@." (string_of_ind ind) pp_const_sigs lar | Gtblfixtype (g, params, t) -> Format.fprintf fmt "@[let %a %a =@\n %a@]@\n@." pp_gname g pp_ldecls params pp_array t @@ -1910,7 +1960,7 @@ let compile_mind mb mind stack = (** Generate data for every block *) let f i stack ob = let ind = (mind, i) in - let gtype = Gtype(ind, Array.map snd ob.mind_reloc_tbl) in + let gtype = Gtype(ind, ob.mind_reloc_tbl) in let j = push_symbol (SymbInd ind) in let name = Gind ("", ind) in let accu = @@ -1933,7 +1983,8 @@ let compile_mind mb mind stack = asw_reloc = tbl; asw_finite = true } in let c_uid = fresh_lname Anonymous in let cf_uid = fresh_lname Anonymous in - let _, arity = tbl.(0) in + let tag, arity = tbl.(0) in + assert (arity > 0); let ci_uid = fresh_lname Anonymous in let cargs = Array.init arity (fun i -> if Int.equal i proj_arg then Some ci_uid else None) @@ -1941,7 +1992,7 @@ let compile_mind mb mind stack = let i = push_symbol (SymbProj (ind, proj_arg)) in let accu = MLapp (MLprimitive Cast_accu, [|MLlocal cf_uid|]) in let accu_br = MLapp (MLprimitive Mk_proj, [|get_proj_code i;accu|]) in - let code = MLmatch(asw,MLlocal cf_uid,accu_br,[|[((ind,1),cargs)],MLlocal ci_uid|]) in + let code = MLmatch(asw,MLlocal cf_uid,accu_br,[|[NonConstPattern (tag,cargs)],MLlocal ci_uid|]) in let code = MLlet(cf_uid, mkForceCofix "" ind (MLlocal c_uid), code) in let gn = Gproj ("", ind, proj_arg) in Glet (gn, mkMLlam [|c_uid|] code) :: acc |
