可解释性

我们的软件包提供了多种可解释性工具,以更好地理解最终模型的 CATE。

树解释器

树解释器提供了易于展示的总结,说明了哪些关键特征解释了对干预响应性的最大差异。

SingleTreeCateInterpreter 针对您从我们任何可用的 CATE 估计器中学到的处理效应 \(\theta(X)\),在一小组您感兴趣的学习异质性的特征 \(X\) 上训练一棵浅层决策树。模型将在最大化每个叶子中处理效应差异的分割点上进行分割。最终,每个叶子将成为一个样本子组,这些样本对处理的响应方式与其他叶子不同。

例如

from econml.cate_interpreter import SingleTreeCateInterpreter
from econml.dml import LinearDML
est = LinearDML()
est.fit(y, t, X=X, W=W)
intrp = SingleTreeCateInterpreter(include_model_uncertainty=True, max_depth=2, min_samples_leaf=10)
# We interpret the CATE model's behavior based on the features used for heterogeneity
intrp.interpret(est, X)
# Plot the tree
intrp.plot(feature_names=['A', 'B', 'C'], fontsize=12)

策略解释器

策略解释器提供类似的功能,但考虑了成本。

不同于通过拟合树来学习具有不同处理效应的群体,SingleTreePolicyInterpreter 尝试将样本分割成不同的处理组。因此,在二元处理的情况下,它试图创建子组,使该组内的所有样本都具有正向效应或负向效应。因此,它试图将响应者与非响应者分开,而不是试图找到具有不同响应水平的群体。

通过这种方式,您可以构建一个可解释的个性化策略,对具有正向效应的群体进行处理,而不对具有负向效应的群体进行处理。我们的策略树在每个叶子节点提供推荐的处理。

例如

from econml.cate_interpreter import SingleTreePolicyInterpreter
# We find a tree-based treatment policy based on the CATE model
# sample_treatment_costs is the cost of treatment. Policy will treat if effect is above this cost.
intrp = SingleTreePolicyInterpreter(risk_level=None, max_depth=2, min_samples_leaf=1,min_impurity_decrease=.001)
intrp.interpret(est, X, sample_treatment_costs=0.02)
# Plot the tree
intrp.plot(feature_names=['A', 'B', 'C'], fontsize=12)

SHAP

SHAP 是一个流行的开源库,用于使用 Shapley 值方法解释黑盒机器学习模型(例如,参见 [Lundberg2017])。

类似于如何使用 SHAP 解释黑盒预测机器学习模型,我们也可以解释黑盒效应异质性模型。这种方法解释了为什么异质因果效应模型对特定人群段产生了更大或更小的效应值。哪些特征导致了这种差异?当模型被简洁描述时,这个问题很容易解决,例如线性异质性模型,其中可以简单地研究模型的系数。然而,当开始使用更具表达力的模型时,例如使用随机森林和因果森林来建模效应异质性时,这就变得困难了。SHAP 值对于理解模型从训练数据中捕捉到的效应异质性的主要因素非常有帮助。

我们的软件包提供了与 SHAP 库的无缝集成。每个 CATE 估计器都有一个方法 shap_values,它返回每个处理和结果对的估计器输出的 SHAP 值解释。然后可以使用 SHAP 库提供的丰富的可视化工具对这些值进行可视化。此外,只要可能,我们的库会针对每种类型的最终模型调用 SHAP 库中快速的专用算法,这可以大大减少计算时间。

例如

import shap
from econml.dml import LinearDML
est = LinearDML()
est.fit(y, t, X=X, W=W)
shap_values = est.shap_values(X)
# local view: explain heterogeneity for a given observation
ind=0
shap.plots.force(shap_values["Y0"]["T0"][ind], matplotlib=True)
# global view: explain heterogeneity for a sample of dataset
shap.summary_plot(shap_values['Y0']['T0'])