分享

【科普】半监督学习的概述与思考,及其在联邦学习场景下的应用

 我爱计算机视觉 2022-03-03

在现实世界中,数据往往存在各种各样的问题,例如:图片分类模型对标注数据的依赖性很强、标注图片数据难以获取、大量未标注数据存在、针对某个场景的数据量过小…等等问题。

在联邦学习中,由于数据的非独立同分布特性(Non-IID)导致了每个客户端(数据拥有者)自身的数据可能存在噪声、标注不完全、数据量不够等等情况,同时我们从隐私安全的方面考虑到只要使用数据,就有可能存在隐私泄露的风险,因此有人思考到:能否只从每个数据拥有方抽取一小部分数据(含有标签)放到客户端,然后再添加大量无标注数据来帮助模型进行训练呢?

这就不得不提到半监督学习,半监督学习是指训练集同时包含带标签的样本数据以及未标记的样本数据,在不需要人工干预的情况下,让模型可以自动利用未标记样本数据来提升自己的学习性能。

进一步而言,本质就是模型利用已标注数据从未标注数据中提取信息用于自身训练,同时有些情况下如果标注数据很多,那么再利用未标注数据可以提升模型的泛化能力。如下图所示,半监督学习可以在标签不知道的情况下让模型也能正确完成任务。

图1:半监督学习example(链接1)

以往的联邦学习工作往往专注于监督学习任务的研究,即要求所有的数据都必须包含相应的标签。但是在联邦学习的现实场景下,本地客户端所包含的数据常常大部分甚至全部都是没有相应的标签的。因此,我们需要结合半监督学习的技术来改进算法从而解决联邦学习中存在的一些问题。

接下来我将从近期的一些论文从半监督学习与联邦半监督学习两个方面进行概述,并加以总结和思考。

 半监督学习

伪标签(Pseudo Label):伪标签是半监督学习的一个基本思路,即模型在标注数据上进行训练然后对未标注数据进行预测,得到的预测结果作为未标记数据的标签。

模型通过伪标签的方式,可以将未标注样本拉向与其最相邻的类,然后再训练时(有标签数据+无标签数据一起训练)这就相当于约束了模型对无标签的搜索空间,但是也有一个缺点就是这个伪标签的预测是不分对错的,其仅仅只是提高了模型对该样本的置信度。

同时我们也可以从熵的角度去思考,强迫模型对未标记数据做出预测,这就代表熵降低了,模型偏向于低熵预测,通过最小化熵将模型预测拉向当前最邻近的类别。伪标签方法带来了若干好处:

1)增强模型的泛化能力和鲁棒性;

2)模型充分利用了无标注数据。

最小化熵(Entropy Minimization):同时我们也可以从概率的角度去思考,熵代表事件的混乱程度,如果一个事件分布越均衡那么其熵越高。强迫模型对未标记数据做出预测,这就代表熵降低了,模型偏向于低熵预测,通过最小化熵将模型预测拉向当前最邻近的类别。强迫模型如何作出低熵预测呢?

实现方法其实很简单,就在损失函数中增加一项loss:最小P_model(y|x)对应的熵即可。如下图公式所示:

标签锐化(Sharpen Label):标签锐化也是在半监督学习方面常用的一个技巧,模型预测某个样本得出结果logit,这是一个矩阵,矩阵的每一列代表该数据样本所对应类别的概率。为了防止噪音等因素干扰,使用标签锐化操作,其实就是一个基于softmax函数的放大最高概率的方法,特别地,当锐化因子T→0的时候,趋于one-hot。具体如下:

一致性正则化(Consistency Regularization)& 数据增强(Data Augmented):一致性正则化(Consistency Regularization)依赖于模型在输入同一图像的扰动版本时应该输出相似预测的假设,用于未标记数据。

我们定义数据为x,某种弱数据增强方法(例如翻转、平移等)为a()函数,某种强数据增强方法(例如RandAugment、CTAugment)为A()函数,定义模型为f(),一致性正则化具体而言有如下方式:

1)f(x)应当与f(a(x))具有相同的预测;

2)f(a(x))应当与f(A(x))具有相同的预测;

此时模型f()应当是两个结构相同但是Dropout不同的模型,同时我们往往在带标记数据上采用CE loss交叉熵损失,在无标记数据上采用MSE loss。有人发现,不同的数据增强会给模型带来推理速度、时间开销以及batch_size受限等方面问题,因此提出Temporal Ensembling(链接2)带时序的方法。

具体而言将当前时间步和上一时间步得到了未标记数据增强后的数据,分别看成经过两次不同数据增强的数据,然后进行一致性正则化,同时添加参数来控制当前和上一次数据的所占比例,例如:f(x)= f(pre(x)*0.3+next(x)*0.7)。具体如下图所示:

图2:Temporal Ensembling流程(链接2)

经过上述前置知识后,接下来我们通过Mixup→MixMatch→ReMixMatch→FixMatch→FlexMatch顺序论文来看看半监督学习方法。

MixupMixup论文(链接3)主要提出了一种随机融合不同类别数据的策略方法,后来被广泛应用于半监督学习领域来重构建数据集。具体而言xi,yi,xj,yj分别是从数据集中随机挑选的两对数据,然后通过如下方式进行组合,产生新的数据(x,y):

图3:Mixup实现方式(链接3)

其中Beta(a,a)代表贝塔分布,公式具体如下所示:

图4:Beta分布公式(链接3)

MixMatch:MixMatch论文(链接4)通过“集百家之长”,集合多种步骤运用于半监督学习,并且取得不错效果。主要步骤如下:

1)对于一个batch内的数据,针对有标签数据进行一次数据增强,针对无标签数据进行K次数据增强;

2)对K次数据增强后的无标签数据用模型进行预测,然后对所有的预测结果取平均并锐化处理,然后将其作为无标签数据的伪标签;

3)混合打乱所有数据并进行Mixup操作,我们将所有数据记为W,将有标签数据记为X,无标签数据记为U,那么融合后的有标签数据为X’= Mixup(W,X),U’= Mixup(W,U);

4)针对X’进行CE loss计算,针对U’进行MSE loss计算。如下图所示:

图5:MixMatch操作(链接4)

这里我们对MixMatch进行一个小总结:

1)MixMatch采用了一致性正则化方法,在数据增强阶段对图像进行随机翻转于裁剪;

2)MixMatch采用了最小化熵的思想,对预测结果进行锐化处理,从而最小化无标签数据的分类熵;

3)MixMatch在训练阶段还采用了正则化手段进行优化;

4)MixMatch结合Mixup方法对数据进行重构建。

ReMixMatchReMixMatch(链接5)主要是针对MixMatch论文的改进工作。主要改进点在于:

1)由于MixMatch的标签猜测可能存在噪声和不一致的情况(初期模型预测率较低)基于此,作者提出利用有标签数据的标签分布,对无标签数据的预测结果进行对齐,称为Distribution Alignment。如下图左图所示,蓝色是对当前无标签数据的标签预测结果q,绿色是一个运行平均版本(average之后)的无标签数据预测结果p',黄色是有标签数据的标签分布p,对齐之后的标签预测为:q' = Normalize(q*(p/p'));

2)作者提出一个假设:对样本进行简单增强(比如翻转和裁切)之后的预测结果,要比多次复杂变换更加可靠和稳定,称为Augmentation Anchor。因此,对于同一张图片,首先进行弱增强,得到预测结果q',然后对同一张图片进行复杂的强增强。弱增强和强增强共同使用一个标签猜测q'进行Mixup和模型训练,即使用弱增强得到的预测结果作为强增强数据的标签来进行训练。

图6:ReMixMatch示意图(链接5)

FixMatch一种对现有SSL方法具有显著简化的算法(链接6)。首先,FixMatch对无标签数据进行弱增强然后预测结果,将预测结果作为该无标签数据的伪标签,将数据本身进行强增强作为数据。

进一步而言,无标签数据弱增强后的预测结果只有其置信度超过某个阈值才会被保留。FixMatch损失分为针对有标签数据的交叉熵损失L_s和针对无标签数据的交叉熵损失L_u,总损失为loss=L_s+参数*L_u。通过消融实验得知:阈值最好选择为0.95。具体FixMatch算法如下图所示。

图7:FixMatch算法(链接6)

FlexMatch这是针对上文所述的FixMatch算法的改进版本(链接7),由于FixMatch算法一直使用一个固定阈值,这会带来一些缺点:

1)模型从不同的类中进行训练比较困难:因为固定阈值的存在,使得容易被模型识别的类(简单类)具有较高的置信度。而难以被模型识别的类(困难类)具有较低的置信度,因此再一个批次数据中,可能会有较多的简单类参与到训练,而困难类却很少,这会导致模型训练有难度;

2)模型训练初期,大多数无标签样本都是较低的置信度,如果恰好有超过阈值的预测结果也很有可能是错误的预测或噪声数据在进行干扰,这有可能导致模型朝着错误的方向收敛(收敛缓慢)的问题,同时训练初期大量的无标签数据得不到训练(低利用率)问题。

因此FlexMatch提出一个自适应的阈值想法,即每个类的阈值会随着训练时间的变化而变化,一个很自然的想法就是用模型在当前时间步的精确度来调整该类的阈值,但是由于缺少验证集数据以及计算效率的昂贵性,这无疑是困难的。

但是作者又提出一个假设:利用当前时间步无标签数据中某个类别C被选取数据(大于阈值)的数量来近似精确度,然后通过归一化方法最终得到比率,然后乘上初始阈值就是当前时间步下类别C的阈值了。具体如下图所示,作者还使用了一系列trick来增强FlexMatch算法,例如:对得到比率进行进一步的非线性处理等。

图8:FlexMatch算法(链接7)

 SSL算法小总结

综上其实MixMatch思想就是对数据弱增广+平均和锐化处理,再结合Mixup重新融合数据集;ReMixMatch则是模型将强增广分布与弱增广对齐后的分布拉近相对之间的距离;FixMatch则是模型对弱增广与强增广之间分布直接拉近距离,类似于对比学习,个人感觉这样其实就是把无标签数据约束到一个个簇里面,模型更好识别。

在联邦学习场景下,数据不足/非独立同分布情况下,加入半监督学习又会碰撞出怎样的火花呢?后期我们将从联邦半监督学习方面进行概述,欢迎关注。

参考链接

[1] https://github.com/yassouali/awesome-semi-supervised-learning

[2] Laine S, Aila T. Temporal ensembling for semi-supervised learning[J]. arXiv preprint arXiv:1610.02242, 2016.

[3] Zhang H, Cisse M, Dauphin Y N, et al. mixup: Beyond empirical risk minimization[J]. arXiv preprint arXiv:1710.09412, 2017.

[4] Berthelot, David, et al. "Mixmatch: A holistic approach to semi-supervised learning." Advances in Neural Information Processing Systems 32 (2019).

[5] Berthelot, David, et al. "Remixmatch: Semi-supervised learning with distribution alignment and augmentation anchoring." arXiv preprint arXiv:1911.09785 (2019).

[6] Sohn K, Berthelot D, Carlini N, et al. Fixmatch: Simplifying semi-supervised learning with consistency and confidence[J]. Advances in Neural Information Processing Systems, 2020, 33: 596-608.

[7] Zhang B, Wang Y, Hou W, et al. Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling[J]. Advances in Neural Information Processing Systems, 2021, 34.

END

    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多