econml.cate_interpreter.SingleTreeCateInterpreter

class econml.cate_interpreter.SingleTreeCateInterpreter(*, include_model_uncertainty=False, uncertainty_level=0.05, uncertainty_only_on_leaves=True, splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0)[source]

基类: econml.cate_interpreter._interpreters._SingleTreeInterpreter

一个用于解释 CATE 估计器估计的效应的解释器

参数
  • include_model_uncertainty (bool, default False) – 构建简化版 CATE 模型时是否包含置信区间信息。如果设置为 True,则 CATE 估计器需要支持 const_marginal_ate_inference 方法。

  • uncertainty_level (double, default 0.05) – 用于构建简化模型时使用的置信区间的 uncertainty level。如果 value=alpha,则将构建一个多任务决策树,使得叶节点中的所有样本具有相似的目标预测以及相似的 alpha 置信区间。

  • uncertainty_only_on_leaves (bool, default True) – 不确定性信息是否仅显示在叶节点上。如果为 False,则解释可能会稍慢,特别是对于推断方法计算成本较高的 CATE 模型而言。

  • splitter (str, default “best”) – 用于选择每个节点分裂的策略。支持的策略有:“best” 表示选择最佳分裂,“random” 表示选择最佳随机分裂。

  • max_depth (int, optional) – 树的最大深度。如果为 None,则节点会持续扩展,直到所有叶节点都是纯的,或者所有叶节点包含的样本少于 min_samples_split。

  • min_samples_split (int, float, default 2) – 分裂内部节点所需的最小样本数

    • 如果为 int,则将其视为最小样本数。

    • 如果为 float,则表示一个分数比例,且 ceil(min_samples_split * n_samples) 为每次分裂所需的最小样本数。

  • min_samples_leaf (int, float, default 1) – 叶节点所需的最小样本数。仅当分裂点在左右分支中分别至少留下 min_samples_leaf 个训练样本时,才会考虑该分裂点。这可能有助于平滑模型,尤其是在回归任务中。

    • 如果为 int,则将其视为最小样本数。

    • 如果为 float,则表示一个分数比例,且 ceil(min_samples_leaf * n_samples) 为每个节点所需的最小样本数。

  • min_weight_fraction_leaf (float, default 0.) – 叶节点所需的总权重(所有输入样本的总权重)的最小加权分数。未提供 sample_weight 时,样本权重相等。

  • max_features (int, float, {“auto”, “sqrt”, “log2”}, or None, default None) – 寻找最佳分裂时考虑的特征数量

    • 如果为 int,则在每次分裂时考虑 max_features 个特征。

    • 如果为 float,则表示一个分数比例,且在每次分裂时考虑 int(max_features * n_features) 个特征。

    • 如果为 “auto”,则 max_features=n_features

    • 如果为 “sqrt”,则 max_features=sqrt(n_features)

    • 如果为 “log2”,则 max_features=log2(n_features)

    • 如果为 None,则 max_features=n_features

    注意: 寻找分裂点的过程不会停止,直到找到至少一个有效的节点样本分区,即使这实际上需要检查超过 max_features 个特征。

  • random_state (int, RandomState instance, or None, default None) – 如果为 int,则 random_state 为随机数生成器使用的种子;如果为 RandomState 实例,则 random_state 为随机数生成器;如果为 None,则随机数生成器为 np.random 使用的 RandomState 实例。

  • max_leaf_nodes (int, optional) – 以最佳优先方式生成具有 max_leaf_nodes 个叶节点的树。最佳节点定义为杂质的相对减少量。如果为 None,则叶节点数量不受限制。

  • min_impurity_decrease (float, default 0.) – 如果此分裂导致的杂质减少量大于或等于此值,则节点将被分裂。

    加权杂质减少量的计算公式如下

    N_t / N * (impurity - N_t_R / N_t * right_impurity
                        - N_t_L / N_t * left_impurity)
    

    其中 N 是样本总数,N_t 是当前节点的样本数,N_t_L 是左子节点的样本数,N_t_R 是右子节点的样本数。如果传递了 sample_weight,则 N, N_t, N_t_RN_t_L 都指加权总和。

__init__(*, include_model_uncertainty=False, uncertainty_level=0.05, uncertainty_only_on_leaves=True, splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0)[source]

方法

__init__(*[, include_model_uncertainty, ...])

export_graphviz([out_file, feature_names, ...])

导出表示学习到的树模型的 graphviz dot 文件

interpret(cate_estimator, X)

解释将 CATE 估计器应用于一组特征时的异质性

plot([ax, title, feature_names, ...])

将策略树导出到 matplotlib

render(out_file[, format, view, ...])

将树渲染到文件

属性

node_dict_

tree_model_

export_graphviz(out_file=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)

导出表示学习到的树模型的 graphviz dot 文件

参数
  • out_file (file object or str, optional) – 输出文件的句柄或名称。如果为 None,结果将作为字符串返回。

  • feature_names (list of str, optional) – 每个特征的名称。

  • treatment_names (list of str, optional) – 每个处理的名称

  • max_depth (int, optional) – 要绘制的最大树深度

  • filled (bool, default False) – 设置为 True 时,节点将着色以指示分类任务中的多数类别、回归任务中的值极端性或多输出任务中的节点纯度。

  • leaves_parallel (bool, default True) – 设置为 True 时,将所有叶节点绘制在树的底部。

  • rotate (bool, default False) – 设置为 True 时,将树的方向设置为从左到右,而不是从上到下。

  • rounded (bool, default True) – 设置为 True 时,绘制圆角节点框,并使用 Helvetica 字体而非 Times-Roman。

  • special_characters (bool, default False) – 设置为 False 时,为兼容 PostScript 而忽略特殊字符。

  • precision (int, default 3) – 每个节点的 impurity、threshold 和 value 属性中浮点数的精度位数。

interpret(cate_estimator, X)[source]

解释将 CATE 估计器应用于一组特征时的异质性

参数
  • cate_estimator (LinearCateEstimator) – 要解释的已拟合估计器

  • X (array_like) – 用于解释估计器的特征;必须与用于拟合估计器的特征在形状上兼容

返回值

self

返回类型

对象实例

plot(ax=None, title=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, rounded=True, precision=3, fontsize=None)

将策略树导出到 matplotlib

参数
  • ax (matplotlib.axes.Axes, optional) – 要绘制的坐标轴

  • title (str, optional) – 要打印在页面顶部的最终图形的标题。

  • feature_names (list of str, optional) – 每个特征的名称。

  • treatment_names (list of str, optional) – 每个处理的名称

  • max_depth (int, optional) – 要绘制的最大树深度

  • filled (bool, default False) – 设置为 True 时,节点将着色以指示分类任务中的多数类别、回归任务中的值极端性或多输出任务中的节点纯度。

  • rounded (bool, default True) – 设置为 True 时,绘制圆角节点框,并使用 Helvetica 字体而非 Times-Roman。

  • precision (int, default 3) – 每个节点的 impurity、threshold 和 value 属性中浮点数的精度位数。

  • fontsize (int, optional) – 文本的字体大小

render(out_file, format='pdf', view=True, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)

将树渲染到文件

参数
  • out_file (要保存的文件名)

  • format (str, default ‘pdf’) – 要渲染的文件格式;必须受 graphviz 支持

  • view (bool, default True) – 是否使用默认应用程序打开渲染结果。

  • feature_names (list of str, optional) – 每个特征的名称。

  • treatment_names (list of str, optional) – 每个处理的名称

  • max_depth (int, optional) – 要绘制的最大树深度

  • filled (bool, default False) – 设置为 True 时,节点将着色以指示分类任务中的多数类别、回归任务中的值极端性或多输出任务中的节点纯度。

  • leaves_parallel (bool, default True) – 设置为 True 时,将所有叶节点绘制在树的底部。

  • rotate (bool, default False) – 设置为 True 时,将树的方向设置为从左到右,而不是从上到下。

  • rounded (bool, default True) – 设置为 True 时,绘制圆角节点框,并使用 Helvetica 字体而非 Times-Roman。

  • special_characters (bool, default False) – 设置为 False 时,为兼容 PostScript 而忽略特殊字符。

  • precision (int, default 3) – 每个节点的 impurity、threshold 和 value 属性中浮点数的精度位数。