diff options
Diffstat (limited to 'src/ast_util.ml')
| -rw-r--r-- | src/ast_util.ml | 133 |
1 files changed, 129 insertions, 4 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 9966742e..9490366f 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -1082,7 +1082,7 @@ let rec tyvars_of_nexp (Nexp_aux (nexp,_)) = | Nexp_neg n -> tyvars_of_nexp n | Nexp_app (_, nexps) -> List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_nexp nexps) -let rec tyvars_of_nc (NC_aux (nc, _)) = +let rec tyvars_of_constraint (NC_aux (nc, _)) = match nc with | NC_equal (nexp1, nexp2) | NC_bounded_ge (nexp1, nexp2) @@ -1092,7 +1092,7 @@ let rec tyvars_of_nc (NC_aux (nc, _)) = | NC_set (kid, _) -> KidSet.singleton kid | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> - KidSet.union (tyvars_of_nc nc1) (tyvars_of_nc nc2) + KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) | NC_app (id, nexps) -> List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_nexp nexps) | NC_true @@ -1112,7 +1112,7 @@ let rec tyvars_of_typ (Typ_aux (t,_)) = List.fold_left (fun s ta -> KidSet.union s (tyvars_of_typ_arg ta)) KidSet.empty tas | Typ_exist (kids, nc, t) -> - let s = KidSet.union (tyvars_of_typ t) (tyvars_of_nc nc) in + let s = KidSet.union (tyvars_of_typ t) (tyvars_of_constraint nc) in List.fold_left (fun s k -> KidSet.remove k s) s kids and tyvars_of_typ_arg (Typ_arg_aux (ta,_)) = match ta with @@ -1123,7 +1123,7 @@ and tyvars_of_typ_arg (Typ_arg_aux (ta,_)) = let tyvars_of_quant_item (QI_aux (qi, _)) = match qi with | QI_id (KOpt_aux ((KOpt_none kid | KOpt_kind (_, kid)), _)) -> KidSet.singleton kid - | QI_const nc -> tyvars_of_nc nc + | QI_const nc -> tyvars_of_constraint nc let is_kid_generated kid = String.contains (string_of_kid kid) '#' @@ -1488,3 +1488,128 @@ and locate_fexps : 'a. l -> 'a fexps -> 'a fexps = fun l (FES_aux (FES_Fexps (fe and locate_fexp : 'a. l -> 'a fexp -> 'a fexp = fun l (FE_aux (FE_Fexp (id, exp), (_, annot))) -> FE_aux (FE_Fexp (locate_id l id, locate l exp), (l, annot)) + +(**************************************************************************) +(* 1. Substitutions *) +(**************************************************************************) + +let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l) +and nexp_subst_aux sv subst = function + | Nexp_id v -> Nexp_id v + | Nexp_var kid -> if Kid.compare kid sv = 0 then subst else Nexp_var kid + | Nexp_constant c -> Nexp_constant c + | Nexp_times (nexp1, nexp2) -> Nexp_times (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2) + | Nexp_sum (nexp1, nexp2) -> Nexp_sum (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2) + | Nexp_minus (nexp1, nexp2) -> Nexp_minus (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2) + | Nexp_app (id, nexps) -> Nexp_app (id, List.map (nexp_subst sv subst) nexps) + | Nexp_exp nexp -> Nexp_exp (nexp_subst sv subst nexp) + | Nexp_neg nexp -> Nexp_neg (nexp_subst sv subst nexp) + +let rec nexp_set_to_or l subst = function + | [] -> raise (Reporting_basic.err_unreachable l __POS__ "Empty set in constraint") + | [int] -> NC_equal (subst, nconstant int) + | (int :: ints) -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints)) + +let rec nc_subst_nexp sv subst (NC_aux (nc, l)) = NC_aux (nc_subst_nexp_aux l sv subst nc, l) +and nc_subst_nexp_aux l sv subst = function + | NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) + | NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2) + | NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2) + | NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) + | NC_set (kid, ints) as set_nc -> + if Kid.compare kid sv = 0 + then nexp_set_to_or l (mk_nexp subst) ints + else set_nc + | NC_or (nc1, nc2) -> NC_or (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2) + | NC_and (nc1, nc2) -> NC_and (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2) + | NC_app (id, nexps) -> NC_app (id, List.map (nexp_subst sv subst) nexps) + | NC_false -> NC_false + | NC_true -> NC_true + +let rec typ_subst_nexp sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_nexp_aux sv subst typ, l) +and typ_subst_nexp_aux sv subst = function + | Typ_internal_unknown -> Typ_internal_unknown + | Typ_id v -> Typ_id v + | Typ_var kid -> Typ_var kid + | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_nexp sv subst) arg_typs, typ_subst_nexp sv subst ret_typ, effs) + | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_nexp sv subst typ1, typ_subst_nexp sv subst typ2) + | Typ_tup typs -> Typ_tup (List.map (typ_subst_nexp sv subst) typs) + | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_nexp sv subst) args) + | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ) + | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc_subst_nexp sv subst nc, typ_subst_nexp sv subst typ) +and typ_subst_arg_nexp sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_nexp_aux sv subst arg, l) +and typ_subst_arg_nexp_aux sv subst = function + | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv subst nexp) + | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_nexp sv subst typ) + | Typ_arg_order ord -> Typ_arg_order ord + +let rec typ_subst_typ sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_typ_aux sv subst typ, l) +and typ_subst_typ_aux sv subst = function + | Typ_internal_unknown -> Typ_internal_unknown + | Typ_id v -> Typ_id v + | Typ_var kid -> if Kid.compare kid sv = 0 then subst else Typ_var kid + | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_typ sv subst) arg_typs, typ_subst_typ sv subst ret_typ, effs) + | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_typ sv subst typ1, typ_subst_typ sv subst typ2) + | Typ_tup typs -> Typ_tup (List.map (typ_subst_typ sv subst) typs) + | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_typ sv subst) args) + | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, typ_subst_typ sv subst typ) +and typ_subst_arg_typ sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_typ_aux sv subst arg, l) +and typ_subst_arg_typ_aux sv subst = function + | Typ_arg_nexp nexp -> Typ_arg_nexp nexp + | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_typ sv subst typ) + | Typ_arg_order ord -> Typ_arg_order ord + +let order_subst_aux sv subst = function + | Ord_var kid -> if Kid.compare kid sv = 0 then subst else Ord_var kid + | Ord_inc -> Ord_inc + | Ord_dec -> Ord_dec + +let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l) + +let rec typ_subst_order sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_order_aux sv subst typ, l) +and typ_subst_order_aux sv subst = function + | Typ_internal_unknown -> Typ_internal_unknown + | Typ_id v -> Typ_id v + | Typ_var kid -> Typ_var kid + | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_order sv subst) arg_typs, typ_subst_order sv subst ret_typ, effs) + | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_order sv subst typ1, typ_subst_order sv subst typ2) + | Typ_tup typs -> Typ_tup (List.map (typ_subst_order sv subst) typs) + | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_order sv subst) args) + | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, typ_subst_order sv subst typ) +and typ_subst_arg_order sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_order_aux sv subst arg, l) +and typ_subst_arg_order_aux sv subst = function + | Typ_arg_nexp nexp -> Typ_arg_nexp nexp + | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_order sv subst typ) + | Typ_arg_order ord -> Typ_arg_order (order_subst sv subst ord) + +let rec typ_subst_kid sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_kid_aux sv subst typ, l) +and typ_subst_kid_aux sv subst = function + | Typ_internal_unknown -> Typ_internal_unknown + | Typ_id v -> Typ_id v + | Typ_var kid -> if Kid.compare kid sv = 0 then Typ_var subst else Typ_var kid + | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_kid sv subst) arg_typs, typ_subst_kid sv subst ret_typ, effs) + | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_kid sv subst typ1, typ_subst_kid sv subst typ2) + | Typ_tup typs -> Typ_tup (List.map (typ_subst_kid sv subst) typs) + | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_kid sv subst) args) + | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ) + | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc_subst_nexp sv (Nexp_var subst) nc, typ_subst_kid sv subst typ) +and typ_subst_arg_kid sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_kid_aux sv subst arg, l) +and typ_subst_arg_kid_aux sv subst = function + | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv (Nexp_var subst) nexp) + | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_kid sv subst typ) + | Typ_arg_order ord -> Typ_arg_order (order_subst sv (Ord_var subst) ord) + +let quant_item_subst_kid_aux sv subst = function + | QI_id (KOpt_aux (KOpt_none kid, l)) as qid -> + if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_none subst, l)) else qid + | QI_id (KOpt_aux (KOpt_kind (k, kid), l)) as qid -> + if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_kind (k, subst), l)) else qid + | QI_const nc -> QI_const (nc_subst_nexp sv (Nexp_var subst) nc) + +let quant_item_subst_kid sv subst (QI_aux (quant, l)) = QI_aux (quant_item_subst_kid_aux sv subst quant, l) + +let typquant_subst_kid_aux sv subst = function + | TypQ_tq quants -> TypQ_tq (List.map (quant_item_subst_kid sv subst) quants) + | TypQ_no_forall -> TypQ_no_forall + +let typquant_subst_kid sv subst (TypQ_aux (typq, l)) = TypQ_aux (typquant_subst_kid_aux sv subst typq, l) |
