econml.cate_interpreter.SingleTreePolicyInterpreter
- class econml.cate_interpreter.SingleTreePolicyInterpreter(*, include_model_uncertainty=False, uncertainty_level=0.05, uncertainty_only_on_leaves=True, risk_level=None, risk_seeking=False, max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, min_balancedness_tol=0.45, min_impurity_decrease=0.0, random_state=None)[source]
基类:
econml.cate_interpreter._interpreters._SingleTreeInterpreter
基于 CATE 估计的策略解释器
- 参数
include_model_uncertainty (bool, default False) – 构建简化的 CATE 模型时是否包含置信区间信息。如果设置为 True,则 CATE 估计器需要支持 const_marginal_ate_inference 方法。
uncertainty_level (double, default 0.05) – 用于构建和简化模型中使用的置信区间的未确定性水平。如果 value=alpha,则将构建一个多任务决策树,使得叶节点中的所有样本具有相似的目标预测和相似的 alpha 置信区间。
uncertainty_only_on_leaves (bool, default True) – 未确定性信息是否只显示在叶节点上。如果为 False,则解释可能会稍慢,特别是对于计算成本高昂的推断方法的 CATE 模型。
risk_level (float, optional) – 如果为 None,则每个点的 CATE 点估计值将用作处理效应。如果为任何浮点数 alpha 且 risk_seeking=False (默认),则将使用 CATE 的 alpha 置信区间的下限。否则,如果 risk_seeking=True,则将使用 alpha 置信区间的上限。
risk_seeking (bool, default False,) – 是否对样本点的效应估计使用乐观或悲观的值。仅当 risk_level 不为 None 时使用。
max_depth (int, optional) – 树的最大深度。如果为 None,则节点将一直展开,直到所有叶节点都是纯的,或者直到所有叶节点包含少于 min_samples_split 个样本。
min_samples_split (int, float, default 2) – 分割内部节点所需的最小样本数
如果为 int,则将 min_samples_split 视为最小样本数。
如果为 float,则 min_samples_split 是一个分数,ceil(min_samples_split * n_samples) 是每次分割所需的最小样本数。
min_samples_leaf (int, float, default 1) – 叶节点所需的最小样本数。只有当分割点在左右分支中都留下至少
min_samples_leaf
个训练样本时,才会考虑任何深度的分割点。这可能有助于平滑模型,尤其是在回归中。如果为 int,则将 min_samples_leaf 视为最小样本数。
如果为 float,则 min_samples_leaf 是一个分数,ceil(min_samples_leaf * n_samples) 是每个节点所需的最小样本数。
min_weight_fraction_leaf (float, default 0.) – 叶节点所需的总权重(所有输入样本的)的最小加权分数。未提供 sample_weight 时,样本具有相同的权重。
max_features (int, float, {“auto”, “sqrt”, “log2”}, 或 None, default None) – 寻找最佳分割时考虑的特征数量
如果为 int,则每次分割考虑 max_features 个特征。
如果为 float,则 max_features 是一个分数,每次分割考虑 int(max_features * n_features) 个特征。
如果为 “auto”,则 max_features=n_features。
如果为 “sqrt”,则 max_features=sqrt(n_features)。
如果为 “log2”,则 max_features=log2(n_features)。
如果为 None,则 max_features=n_features。
注意:寻找分割的过程不会停止,直到找到至少一个有效的节点样本划分,即使这需要实际检查超过
max_features
个特征。min_balancedness_tol (float in [0, .5], default .45) – 我们能容忍的分割不平衡程度。这强制要求每次分割在分割的两侧至少留下 (.5 - min_balancedness_tol) 分数的样本;或者在 sample_weight 不为 None 时,留下总样本权重的相应分数。默认值确保父节点权重的至少 5% 落入分割的每一侧。将其设置为 0.0 表示不强制平衡,设置为 .5 表示完美平衡的分割。为了使正式推断理论有效,这必须是任何大于零的有界常数。
min_impurity_decrease (float, default 0.) – 如果此分割引起的杂质减少量大于或等于此值,则节点将被分割。
加权杂质减少方程如下
N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity)
其中
N
是样本总数,N_t
是当前节点的样本数,N_t_L
是左子节点的样本数,N_t_R
是右子节点的样本数。N
、N_t
、N_t_R
和N_t_L
都指加权和,如果传递了sample_weight
。random_state (int, RandomState 实例, 或 None, default None) – 如果为 int,则 random_state 是随机数生成器使用的种子;如果为 RandomState 实例,则 random_state 是随机数生成器;如果为 None,则随机数生成器是 np.random 使用的 RandomState 实例。
- tree_model_
表示学习到的策略的策略树模型;仅在调用
interpret()
后可用。- 类型
- policy_value_
将学习到的策略应用于与
interpret()
一起使用的样本后的值- 类型
- always_treat_value_
将始终处理所有单位的策略应用于与
interpret()
一起使用的样本后的值- 类型
- __init__(*, include_model_uncertainty=False, uncertainty_level=0.05, uncertainty_only_on_leaves=True, risk_level=None, risk_seeking=False, max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, min_balancedness_tol=0.45, min_impurity_decrease=0.0, random_state=None)[source]
方法
__init__
(*[, include_model_uncertainty, ...])export_graphviz
([out_file, feature_names, ...])导出表示学习到的树模型的 graphviz dot 文件
interpret
(cate_estimator, X[, ...])将基于线性 CATE 估计器的策略应用于一组特征时进行解释
plot
([ax, title, feature_names, ...])将策略树导出到 matplotlib
render
(out_file[, format, view, ...])将树渲染到文件
treat
(X)使用通过调用
interpret()
学习到的策略模型,为一组单位分配处理属性
node_dict_
- export_graphviz(out_file=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)
导出表示学习到的树模型的 graphviz dot 文件
- 参数
out_file (文件对象或 str, optional) – 输出文件的句柄或名称。如果为
None
,结果将作为字符串返回。feature_names (str 列表, optional) – 每个特征的名称。
treatment_names (str 列表, optional) – 每个处理的名称
max_depth (int, optional) – 要绘制的最大树深度
filled (bool, default False) – 设置为
True
时,为节点着色以指示分类的多数类别、回归的值的极端性或多输出的节点纯度。leaves_parallel (bool, default True) – 设置为
True
时,将所有叶节点绘制在树的底部。rotate (bool, default False) – 设置为
True
时,将树的方向从自上而下改为从左到右。rounded (bool, default True) – 设置为
True
时,绘制圆角节点框并使用 Helvetica 字体而不是 Times-Roman。special_characters (bool, default False) – 设置为
False
时,为 PostScript 兼容性忽略特殊字符。precision (int, default 3) – 每个节点的 impurity、threshold 和 value 属性中浮点数的精度位数。
- interpret(cate_estimator, X, sample_treatment_costs=None)[source]
将基于线性 CATE 估计器的策略应用于一组特征时进行解释
- 参数
cate_estimator (
LinearCateEstimator
) – 要解释的已拟合估计器X (array_like) – 用于解释估计器的特征;形状必须与用于拟合估计器的特征兼容
sample_treatment_costs (array_like, optional) – 处理的成本。可以是标量,或维度为 (n_samples, n_treatments),或者如果 T 是向量,则维度为 (n_samples,)
- 返回
self
- 返回类型
object 实例
- plot(ax=None, title=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, rounded=True, precision=3, fontsize=None)
将策略树导出到 matplotlib
- 参数
ax (
matplotlib.axes.Axes
, optional) – 要绘制的轴title (str, optional) – 要打印在页面顶部的最终图形标题。
feature_names (str 列表, optional) – 每个特征的名称。
treatment_names (str 列表, optional) – 每个处理的名称
max_depth (int, optional) – 要绘制的最大树深度
filled (bool, default False) – 设置为
True
时,为节点着色以指示分类的多数类别、回归的值的极端性或多输出的节点纯度。rounded (bool, default True) – 设置为
True
时,绘制圆角节点框并使用 Helvetica 字体而不是 Times-Roman。precision (int, default 3) – 每个节点的 impurity、threshold 和 value 属性中浮点数的精度位数。
fontsize (int, optional) – 文本字体大小
- render(out_file, format='pdf', view=True, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)
将树渲染到文件
- 参数
out_file (要保存到的文件名)
format (str, default ‘pdf’) – 要渲染到的文件格式;必须由 graphviz 支持
view (bool, default True) – 是否使用默认应用程序打开渲染结果。
feature_names (str 列表, optional) – 每个特征的名称。
treatment_names (str 列表, optional) – 每个处理的名称
max_depth (int, optional) – 要绘制的最大树深度
filled (bool, default False) – 设置为
True
时,为节点着色以指示分类的多数类别、回归的值的极端性或多输出的节点纯度。leaves_parallel (bool, default True) – 设置为
True
时,将所有叶节点绘制在树的底部。rotate (bool, default False) – 设置为
True
时,将树的方向从自上而下改为从左到右。rounded (bool, default True) – 设置为
True
时,绘制圆角节点框并使用 Helvetica 字体而不是 Times-Roman。special_characters (bool, default False) – 设置为
False
时,为 PostScript 兼容性忽略特殊字符。precision (int, default 3) – 每个节点的 impurity、threshold 和 value 属性中浮点数的精度位数。
- treat(X)[source]
使用通过调用
interpret()
学习到的策略模型,为一组单位分配处理- 参数
X (array_like) – 要处理的单位的特征;形状必须与解释期间使用的特征兼容
- 返回
T – 解释器学习到的策略所隐含的处理,其中处理 0 表示未处理,处理 1 表示第一个处理,依此类推。
- 返回类型
array_like