分享

如何信任机器学习模型的预测结果?(上)

 精诚至_金石开 2020-12-19

也就是说,如何确定机器学习模型预测的结果是符合常理的,进而确定所选择的机器学习模型是可信的。
关于这个问题,我将通过两个篇幅向大家介绍机器学习模型的可信性,即机器学习预测结果的解释,以及 MATLAB 如何支持对机器学习模型预测结果的解释,并通过一个例子说明在 MATLAB 中的实现过程。
首先介绍机器学习模型预测结果的解释及 MATLAB 的支持情况。



1



机器学习模型的可信性

机器学习的目的是在训练数据集中学习到一个最优的模型,并且这个最优模型对未知数据有很好的预测能力(即模型具有很好的泛化能力)。
广泛使用的衡量模型好坏的标准是模型精度,即模型在训练集上的准确度。而在追求模型准确度的过程中,往往会产生过拟合问题,在训练样本不足的情况这个问题尤为严重。模型在训练集上有很好的精度,但是在测试集上精度确很差。因此,仅仅是通过模型的精度进行模型的选择,往往会使得模型的可信度变低,尤其是在过拟合的情况下。
为了避免过拟合,提高机器学习模型的精度,正则化和交叉验证是最常采用的方法。
正则化是遵循结构化风险最小化原则,即减少模型中非零参数的数量(也就是减少变量的个数或者特征的个数),进而从模型结构上降低复杂度。尤其在训练样本不足的情况下,过多的变量或数据特征容易造成模型的过拟合。交叉验证是通过引入验证集提高模型的泛化能力。
交叉验证将训练集进行 k 份(也叫k-fold)的切分,其中 k-1 份用于训练,1 份用于验证。并且不断的交叉选择训练集和验证集,得到 k 个模型训练结果精度,使用 k 个结果的平局值作为最终的模型精度。
不论是正则化的方法,还是交叉验证的方法,选择模型的依据都是比较模型的精度。但是,模型预测结果的合理性或可信性,是决定能否将机器学习模型应用到生产环境下的重要因素。特别是,对于复杂的非线性模型,我们并不清楚为什么能够取得某种预测结果,是那些变量或那些数据特征影响模型的预测结果。这往往使得复杂机器学习模型变成一个黑盒。黑盒模型带来了可信度的问题:
我是否可以相信机器学习模型的预测结果?进而,我是否可以相信训练生成的机器学习模型?
这就引出了今天将要给大家介绍的内容:机器学习模型的可信性
这里的可信性是指:是否可以对机器学习模型的预测结果给出一个合理的解释,能够定性地呈现出模型输入数据的特征和模型输出的预测结果之间的关系。
通俗的说,我得到的预测结果是受到哪些因素的影响。领域专家往往对研究的领域具有先验知识,当对模型预测结果的解释与先验知识相匹配时,机器学习模型预测结果的可信性就会大大提升,否则就会拒绝预测结果。
如何实现对机器学习模型的预测结果进行解释呢?
Ribeiro 等人在论文[1]中提出了 LIME(Local Interpretable Model-Agnostic Explanations)。LIME 的目标是:对于一个训练好的机器学习模型,找到一个适合的可解释模型,用于解释机器学习模型的预测结果,也就是定性分析输入数据的特征与预测结果的相关性。这里提到的可解释模型是指模型本身具有自解释性,这中自解释性是由模型结构特性决定的。这类模型主要是指线性模型和决策树模型。

  • 线性模型
线性模型 f 可以简单表示为如下的数学表达:

其中,x1,x2,…,xm 表示输入数据的特征(m 维),ω12,…,ωm 为模型的系数,系数绝对值的大小表明了输入特征对结果的影响程度。因此,线性模型的可解释性是通过系数的绝对值大小表示的。
  • 决策树模型

决策树模型是利用特征对数据进行划分,而选择特征的依据就是特征值能否将数据完全分隔。
如果通过特征值可以将数据完全分隔开(相同标注的数据划分在一起),那么该特征的重要性就会升高;如果通过特征值不能分隔数据(不同标注的数据混在一起),那么这个特征的重要性就会降低。
因此,特征划分数据的能力就体现了特征对结果的影响程度。因此,决策树的可解释性是通过特征的划分能力表示的。如何选择特征进行数据划分,这就要依据于所选择的决策树算法。决策树算法包括:ID3、C4.5 和 CART。关于这些算法,有很多资料可以参考,在这里就不进行说明了。MATLAB 的决策树使用的 CART 算法,决策树既可以用于数据分类也可以用于数据回归。
LIME 正是基于线性模型或决策树模型构建用于解释复杂机器学习模型的可解释模型。对于复杂的机器学习模型,例如支持向量机(SVM),LIME 模型定性说明输入的数据特征对预测结果的影响程度。
LIME 实现对预测结果的解释的方式是:以单个预测数据为基础,通过增加扰动生成合成数据集。合成数据集包含扰动数据以及对应的预测结果。在合成数据集上找到一个可以与复杂的分类模型或回归模型近似的可解释模型,并利用可解释模型来解释原始的机器学习模型的预测结果。
LIME 主要包含三个部分:
  • Local fidelity
Local fidelity 表示 LIME 的局部保真或局部相似性。意味着 LIME 生成的可解释模型必须与复杂机器学习模型在被预测实例附近的预测结果相一致。
  • Interpretable

Interpretable 表示 LIME 的可解释性,这种解释是直接反应输入的数据特征(或变量)对预测结果的影响程度,并且这种解释会重点突出部分数据特征或输入变量而不是全部,因此说 LIME 模型可以给出比较容易理解的解释。
  • Model Agnostic

Model Agnostic 表示 LIME 的模型无关性。LIME 将被解释的原始模型视为一个黑盒(black box),并不关心原始模型的内部细节。LIME 只是通过在输入预测数据的局部范围内生成扰动数据,以及对应的预测结果,进而训练出一个局部近似的可解释模型,该模型与原始机器学习模型的预测结果一致。
上述的三个组成部分也就决定的 LIME 的工作过程。LIME 是对局部数据训练可解释模型。假设有机器学习模型 f,对于一个预测数据 xx 有 m 维或 m 个特征 :

(m 表示数据的维度或特征数)
机器学习模型 f x 的预测结果是 p

针对预测结果 p,LIME 解释过程如下:
基于概率分布,在 x 的局部范围内随机生成的多个数据样本(称之为合成数据(synthetic data)),由这些数据构成合成数据集 X’

(n 表示数据集中样本的个数,并且每个样本 x_i^'都具有 m 个维度)
对X’中的样本使用机器学习模型 f 进行预测,并生成结果 P’

即:

由此,产生了一个新的数据集 D = (X’, P’ ),LIME 在新的数据集 D 上训练一个简单的可解释模型 f’(例如,线性模型),并且,该可解释模型对于预测数据 x 的预测结果与机器学习模型 f 的预测结果近似。即:

因此,通过可解释模型,可以定性的给出数据特征对预测结果的影响程度,即参数 ω12,…,ωm 的绝对值表示特征对结果的影响程度。



2



MATLAB 对此的支持
MATLAB 2020b 版本中的统计和机器学习工具箱(Statistics and Machine Learning Toolbox)实现了对 LIME 的支持, 主要是通过 lime, fit 和 plot 三个函数。
1. lime
lime 是实现 LIME 的主函数。通过 lime 可以生成一个 LIME 的对象(如(1)所示,results 表示生成的 LIME 对象)。
 results = lime(blackbox)                 (1)
其中,blackbox是被解释的机器学习模型。lime提供了很多参数用于对LIME进行配置,包括:
  • 'QueryPoint'

QueryPoint 是LIME用于生成可解释模型的基础输入数据(也是预测数据),QueryPoint的值对应单个输入数据。当在lime函数中设置’QueryPoint‘,可以直接生成可解释模型。
'CategoricalPredictors' 
CategoricalPredictors指明输入数据的分类变量,可以是索引数组、列名向量等。
  • 'NumImportantPredictors'

可解释模型的特征数量。NumImportantPredictors 指定用于解释预测结果的特征个数
  • 'DataLocality'

指定评估概率分布的数据范围,该概率分布用于随机生成合成数据。因此DataLocality 决定数据“局部性“的范围大小。”'gloal' 表示使用机器学习模型的训练集评估概率分布,因此针对不同的输入数据(QueryPoint),生成合成数据的概率分布式一样的;'local' 表示使用输入数据的局部范围内的数据评估概率分布。使用“局部的大小”的由参数'NumNeighbours' 确定(默认值为 1500)。因此,当'DataLocality' 指定为'local' 时,需要与'NumNeighbours'配合使用,以生成概率分布的参数。
  • 'NumSyntheticData'

合成数据集的样本数量。此参数仅在'DataLocality'为'local'时有效。
  • 'SimpleModelType' 

可解释模型的类型,包括线性模型和决策树模型
  • 'Type'

被解释的‘黑盒’机器学习模型的类别,包括分类模型和回归模型
2. fit
fit 用于生成 LIME 的可解释模型。当 lime 没有指定'QueryPoint'时,需要使用 fit 函数生成可解释模型。fit 的标准形式如下
newresults = fit(results, queryPoint, numImportantPredictors)
fit 包含三个必须输入参数:
  • results,lime 函数的输出结果;

  • queryPoint,输入数据;

  • numImportantPredictors,可解释模型的特征数量。

除了必选参数,fit 也包含 'NumNeighbors','NumSyntheticData','SimpleModelType' 等在 lime 函数中出现的可选参数。并且 fit 会覆盖 lime 的对应的参数值。
3. plot
plot 用于可视化 LIME 的结果。plot 输出两种结果:水平条形图,显示线性简单模型的系数值或决策树变量(特征)的重要性值;原始机器学习模型和新生成的可解释模型的的两个预测结果。通过 plot 的结果可以直观看到那些特征或变量对原始模机器学习模型预测结果的有重要影响,并且按照重要性进行排序。
如何利用 lime、fit 和 plot 函数实现对复杂机器学习模型的解释?
下一篇中,我们将通过一个样例介绍实现过程。敬请期待!

    ◆  

参考文献
[1] Ribeiro, Marco Tulio, S. Singh, and C. Guestrin. ''Why Should I Trust You?': Explaining the Predictions of Any Classifier.' In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 1135–44. San Francisco California USA: ACM, 2016.

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多