MixMatch A Holistic Approach to Semi-Supervised Learning
MixMatch: A Holistic Approach to Semi-Supervised Learning
1. 摘要
事实证明,半监督学习是一种强大的范式,可以利用未标记的数据来减轻对大型标记数据集的依赖。在这项工作中,我们统一了目前半监督学习的主流方法,产生了一种新的算法,MixMatch,它为数据增强的未标记的例子猜测低熵标签,并使用MixUp混合标记和未标记的数据。MixMatch在许多数据集和标记的数据量上都获得了最先进的结果,而且幅度很大。例如,在有250个标签的CIFAR-10上,我们将错误率降低了4倍(从38%到11%),在STL-10上降低了2倍。我们还展示了MixMatch是如何帮助实现差异化隐私的显著更好的准确性-隐私权衡的。最后,我们进行了一项消融研究,以区分MixMatch的哪些部分对其成功最重要。
1. 主要内容
2.1 主要贡献
本文介绍了MixMatch,一种SSL算法,它引入了一个单一的损失,优雅地统一了这些半监督学习的主流方法:
通过实验,我们表明MixMatch在所有标准图像基准上获得了最先进的结果,并将CIFAR-10的错误率降低了4倍
我们在一项消融研究中进一步表明,MixMatch大于其各部分的总和
我们表明,MixMatch对不同的隐私学习很有用,使PATE框架中的学生获得新的最先进结果,同时加强隐私保证和准确性。
2.2 方法
MixMatch中使用的标签猜测过程示意图。随机数据增强被应用于一个未标记的图像K次,每个增强的图像被送入分类器。然后,通过调整分布的温度,对这K个预测的平均值进行 “锐化”。完整描述见下算法。
给定一batch带有onehot标签(代表L个可能标签中的一个)的有标签的例子X和一批同等大小的无标签的例子U,MixMatch产生一批经过处理的有标签的例子X′和一批带有 “猜测的 “标签的无标签的例子U′。然后,U′和X′被用于计算单独的有标签和无标签的损失项。更正式地说,半监督学习的综合损失L被定义为
其中H(p, q)是分布p和q之间的交叉熵,而T、K、α和λU是下面描述的超参数。
2.2.1 数据增强
在许多SSL方法中,我们在有标签和无标签的数据上都使用数据增强。对于标签数据X批次中的每个xb,我们生成一个转换版本ˆ xb = Augment(xb) (算法1,第3行)。对于无标签数据U中的每个ub,我们生成K个增强版本ˆub,k = Augment(ub), k∈ (1, . . , K)(算法1,第5行)。我们使用这些单独的增强值为每个ub生成一个 “猜测的标签 “qb,具体过程我们将在下面的小节中描述。
2.2.2 标签猜测
对于U中的每个未标记的例子,MixMatch使用模型的预测结果对该例子的标签产生一个 “猜测”。这个猜测后来被用于无监督的损失项。为此,我们通过以下方式计算模型在ub的所有K个增量中预测的类别分布的平均值
锐化:
在生成标签猜测时,我们在半监督学习中熵最小化的成功启发下,执行了一个额外的步骤。鉴于对增量的平均预测 ̄ qb,我们应用一个锐化函数来减少标签分布的熵。在实践中,对于锐化函数,我们使用调整该分类分布的 “温度 “的常见方法,它被定义为以下操作
其中p是一些输入的分类分布(特别是在MixMatch中,p是对增 ̄ qb的平均类别预测,如算法1第8行所示),T是一个超参数。当T→0时,Sharpen(p, T)的输出将接近Dirac(”onehot”)分布。由于我们以后将使用qb = Sharpen( ̄ qb, T )作为模型预测的目标,以增加ub,降低温度可以鼓励模型产生较低熵的预测。
2.2.3 Mixup
我们使用MixUp进行半监督学习,与过去SSL的工作不同,我们将有标签的例子和无标签的例子与标签猜测混合在一起。为了与我们单独的损失项兼容,我们定义了一个稍加修改的MixUp版本。对于一对有相应标签概率的两个例子(x1,p1),(x2,p2),我们通过以下方式计算(x′,p′):
其中α是一个超参数。考虑到有标签和无标签的例子在同一批次中被串联起来,我们需要保留批次的顺序来适当地计算各个损失成分。这可以通过上述第二个公式来实现,该公式确保x′比x2更接近x1。为了应用MixUp,我们首先收集所有带标签的增强实例和所有带猜测标签的无标签实例,并将其放入
(算法1,第10-11行)。然后,我们将这些集合合并,并将结果洗牌,形成W,作为MixUp的数据源(算法1,第12行)。对于X中的每一个例子-标签对,我们计算MixUp( ˆ Xi, Wi),并将结果添加到X ′集合中(算法1,第13行)。对于i∈(1, . . , | ˆ U|),我们计算U ′ i = MixUp( ˆ Ui, Wi+| ˆ X |),有意使用在构建X ′时没有使用的W的剩余部分(算法1,第14行)。总而言之,MixMatch将X转化为X′,这是一个经过数据增强和MixUp(可能与未标记的例子混合)的标记例子的集合。同样地,U也被转化为U′,这是一个对每个无标签的例子进行多次增强的集合,并有相应的标签猜测。
2.2.4 损失函数
考虑到我们处理过的批次X′和U′,我们使用公式(3)到(5)中所示的标准半监督损失。公式(5)结合了标签和来自X′的模型预测之间的典型交叉熵损失和来自U′的预测和猜测的标签的平方L2损失。我们在公式(4)中使用这种L2损失(多类Brier得分),因为与交叉熵不同,它是有界的,对不正确的预测不太敏感。由于这个原因,它经常被用作SSL中的无标签数据损失,以及预测不确定性的测量。我们不通过计算猜测的标签来传播梯度,这是标准的做法 。
2.2.5 超参数
由于MixMatch结合了多种利用无标签数据的机制,它引入了各种超参数–特别是锐化温度T、无标签增强的数量K、MixUp中Beta的α参数和无监督损失权重λU。 在实践中,具有许多超参数的半监督学习方法可能会有问题,因为交叉验证在小的验证集中是很困难的。然而,我们在实践中发现,MixMatch的大多数超参数可以固定,不需要在每个实验或每个数据集的基础上进行调整。具体来说,对于所有的实验,我们设定T=0.5,K=2。此外,我们只在每个数据集的基础上改变α和λU;我们发现α=0.75和λU=100是调整的良好起点。在所有的实验中,我们按照通常的做法,在训练的前16,000步中线性地将λU提升到其最大值。
2.3 实验
首先,我们在四个标准基准数据集上评估MixMatch的有效性。CIFAR-10和CIFAR-100,SVHN,以及STL-10。在前三个数据集上评估半监督学习的标准做法是将大部分数据集作为未标记的数据,并使用一小部分作为标记的数据。STL-10是一个专门为SSL设计的数据集,有5,000张标注的图像和100,000张未标注的图像,这些图像的分布与标注的数据略有不同。
在CIFAR-10上,MixMatch与基线方法在不同数量的标签下的错误率比较。”监督 “指的是用所有50000个训练实例进行训练,没有未标记的数据。在有250个标签的情况下,MixMatch的错误率与有4000个标签的次佳方法的性能相当。
3. 启发
本文的核心思想是mix,即混合有标签数据、无标签数据(数据增强),本文提到其中的增强、混合、标签猜测等部分,融合了许多SSL方法中的思想和特点,最后也取得了很好的效果。但是我在自己的实验中如果混合很多不同的方法、tricks却不一定能够达到更好的效果,甚至有时候会有相反的效果,所以我认为需要对方法的特点和数据、任务的特点有一个明确的认识,才能够合理地结合它们,达到像本文一样的融合更佳的效果。而由于现在认识不足,常常觉得怎么组合都能够说得通,很多时候是实验结果得出结论后,再由结论去反推原因,逻辑上还是不够流畅。