【论文笔记】Distilling the Knowledge in a Neural Network
Abstract
几乎任何机器学习算法性能提升的一个非常简单的方法是在相同数据上训练多个不同的模型,然后对它们的预测结果进行平均。
不幸的是,使用整个模型集合进行预测繁琐,可能会因为计算成本过高而难以部署给大量用户,尤其是如果各个模型是庞大的神经网络时。
研究表明,可以将集合中的知识压缩成一个单一模型,这样更容易部署,而我们则进一步使用不同的压缩技术发展了这种方法。
本文在MNIST数据集上取得了令人惊讶的结果,并展示了通过将多个模型的知识融合到一个单一模型中,能够显著提升一个广泛使用的商业系统的声学模型。
本文还引入了一个新型的集成模型,包括一个或多个全模型和许多专业模型,这些模型学习区分全模型容易混淆的细粒度类别。
与专家混合模型不同的是,这些专业模型可以快速并行训练。
1 Introduction
在大规模机器学习中,通常在训练阶段和部署阶段使用非常相似的模型,尽管它们有着非常不同的需求:对于语音和物体识别等任务,训练必须从非常大、高度冗余的数据集中提取结构,但它不需要实时操作,并且可以利用大量的计算资源。然而,部署到大量用户中则对延迟和计算资源有着更严格的要求。
如果能够从数据中提取结构更容易,我们应该愿意训练非常庞大的模型。
一旦庞大的模型训练完成,我们可以使用一种称为“蒸馏”的不同训练方式,将庞大模型的知识转移至更适合部署的小型模型。
一个概念上的障碍可能阻止了对这种非常有前景的方法进行更多的研究,这是因为我们倾向于将训练模型中的知识与学习到的参数数值等同起来,这使得我们很难看到如何改变模型的形式却保持相同的知识。
以往的观念中将训练模型的知识与学习到的参数等同,这很难看到如何改变模型的形式但保持相同的知识。
对知识更抽象的观点是:它是从输入向量到输出向量的学习映射,不受特定实例化的限制。
对于复杂的模型,想要区分大量类别,一个常规的训练是最大化正确答案的平均对数概率,但学习的副作用是训练模型会为所有的错误答案分配概率,即使这些概率非常小,其中一些可能也比其他大很多。这些错误答案的相对概率是繁琐模型泛化能力的关键。
训练模型最好具有良好的泛化能力,但这需要有关正确泛化方式的信息。
如果复杂模型具有良好的泛化能力,是因为它是多个不同模型的平均值,则以相同方式训练的小模型通常在测试数据上会表现得更好。
转移泛化能力的一种明显办法是将复杂模型产生的类别概率作为小模型训练的软目标(soft target),转移阶段我们可以使用相同的训练集或单独的“转移集”(“transfer” set)。
当复杂模型是简单模型的大型集合时,可以使用各自预测分布的算术或几何平均值作为软目标。
软目标具有高熵时,每个训练案例提供的信息比硬目标多得多,在训练案例上的梯度变化要小得多,因此小模型通常可以在比原始庞大模型少得多的数据上进行训练,提高学习效率。
有时复杂模型以高置信度产生正确答案,错误答案的概率非常低,接近于0,因此对交叉熵损失函数的影响非常小。
过往研究中有用logit(最终的softmax输入)而不是softmax产生的概率训练小模型,得以绕过这个问题,并且最小化了复杂模型产生的logit和小模型产生的logit之间的平方距离。
更通用的解决方案称为“蒸馏”(distillation),将最终softmax的温度提高,直到庞大模型产生适当软化(soft)的目标,然后在训练小模型以匹配这些目标时,使用相同的温度。
事实证明,匹配复杂模型的logit实际上是蒸馏的一种特殊情况。
2 Distillation
神经网络通常使用softmax输出层来产生类别概率,该层为每个类别计算的logit
z
i
z_i
zi?转换为概率
q
i
q_i
qi?,通过比较
z
i
z_i
zi?和其他的logits完成转换。
q
i
=
exp
?
(
z
i
/
T
)
Σ
j
exp
?
(
z
i
/
T
)
\begin{align} q_i=\frac{\exp(z_i/T)}{\Sigma_j \exp(z_i/T)}\tag{1} \end{align}
qi?=Σj?exp(zi?/T)exp(zi?/T)??(1)?
其中
T
T
T是温度,在Sec. 1中有介绍。
T
T
T越高,产生的softmax值越soft。
在最简单的蒸馏形式中,知识通过在转移集上对蒸馏模型进行训练传输,使用复杂模型在softmax中使用较高温度产生的软目标分布,针对转移集中的每个案例都生成一个软目标分布。在训练蒸馏模型时使用相同的高温度,但在蒸馏模型训练完之后,温度 T T T恢复为1。
当所有或部分transfer set的正确标签已知时,通过训练蒸馏模型产生正确标签,可以显著提高这种方法的效果。
一种方法是使用正确标签来修改软目标,但我们发现更好的方法是简单地使用两种不同目标函数的加权平均。
- 第一个目标函数是使用软目标的交叉熵,这个交叉熵是使用在蒸馏模型的 softmax 中产生软目标时与从庞大模型生成软目标时使用的相同高温度计算的。
- 第二个目标函数是使用正确标签的交叉熵。这是使用蒸馏模型的 softmax 中完全相同的 logit 进行计算的,但是温度设置为 1。
通常通过在第二个目标函数上使用相当低的权重来获得最佳结果。
由于软目标产生的梯度大小为 1 T 2 \frac{1}{T^2} T21? ,因此在使用硬目标和软目标时将其乘以 T 2 T^2 T2非常重要。这确保了更改了用于蒸馏的温度时,硬目标和软目标的相对贡献大致保持不变。
2.1 Matching logits is a special case of distillation
对于蒸馏模型的每个logit和
z
i
z_i
zi?,transfer set的每个案例都提供一个交叉熵梯度,
d
C
/
d
z
i
dC/dz_i
dC/dzi?。
如果复杂模型有logit
v
i
v_i
vi?,产生软目标概率
p
i
p_i
pi?,转移训练在温度
T
T
T下进行,这个梯度可以写作:
?
C
?
z
i
=
1
T
(
q
i
?
p
i
)
=
1
T
(
e
z
i
/
T
Σ
j
e
z
j
/
T
?
e
v
i
/
T
Σ
j
e
v
j
/
T
)
\begin{align} \frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{e^{z_i/T}}{\Sigma_j e^{z_j/T}}-\frac{e^{v_i/T}}{\Sigma_j e^{v_j/T}})\tag{2} \end{align}
?zi??C?=T1?(qi??pi?)=T1?(Σj?ezj?/Tezi?/T??Σj?evj?/Tevi?/T?)?(2)?
证明:
C
=
C
i
+
∑
j
≠
i
C
j
?
C
?
z
i
=
?
(
C
i
+
∑
j
≠
i
C
j
)
?
z
i
=
?
C
i
?
z
i
+
?
(
∑
j
≠
i
C
j
)
?
z
i
\begin{aligned} C&=C_i+\sum_{j\neq i}C_j \\ \frac{\partial C}{\partial z_i}&=\frac{\partial (C_i+\sum_{j\neq i}C_j)}{\partial z_i} \\ &= \frac{\partial C_i}{\partial z_i}+\frac{\partial (\sum_{j\neq i}C_j)}{\partial z_i} \end{aligned}
C?zi??C??=Ci?+j=i∑?Cj?=?zi??(Ci?+∑j=i?Cj?)?=?zi??Ci??+?zi??(∑j=i?Cj?)??
分别求
?
C
i
?
z
i
,
?
(
Σ
j
≠
i
C
j
)
?
z
i
\frac{\partial C_i}{\partial z_i}, \frac{\partial(\Sigma_{j\neq i}C_j)}{\partial z_i}
?zi??Ci??,?zi??(Σj=i?Cj?)?:
?
C
i
?
z
i
=
?
(
?
p
i
log
?
q
i
)
?
z
i
=
?
p
i
?
log
?
exp
?
(
z
i
/
T
)
Σ
j
exp
?
(
z
j
/
T
)
?
z
i
=
?
p
i
?
(
(
z
i
/
T
)
?
log
?
(
Σ
j
exp
?
(
z
j
/
T
)
)
)
?
z
i
=
?
p
i
(
1
T
?
log
?
(
exp
?
(
z
i
/
T
)
+
c
)
?
z
i
)
=
?
p
i
(
1
T
?
A
)
\begin{aligned} \frac{\partial C_i}{\partial z_i}&= \frac{\partial(-p_i\log q_i)}{\partial z_i}\\ &=-p_i \frac{\partial\log \frac{\exp(z_i/T)}{\Sigma_j \exp(z_j/T)}}{\partial z_i}=-p_i\frac{\partial( (z_i/T)-\log(\Sigma_j \exp(z_j/T)))}{\partial z_i} \\ &=-p_i(\frac{1}{T}-\frac{\log(\exp(z_i/T)+c)}{\partial z_i}) \\ &=-p_i(\frac{1}{T}-A) \end{aligned}
?zi??Ci???=?zi??(?pi?logqi?)?=?pi??zi??logΣj?exp(zj?/T)exp(zi?/T)??=?pi??zi??((zi?/T)?log(Σj?exp(zj?/T)))?=?pi?(T1???zi?log(exp(zi?/T)+c)?)=?pi?(T1??A)?
其中
c
=
Σ
j
≠
i
exp
?
(
z
j
/
T
)
c=\Sigma_{j\neq i} \exp(z_j/T)
c=Σj=i?exp(zj?/T),对
z
i
z_i
zi?求偏导时所有项均为常数,
A = 1 exp ? ( z i / T ) + c × exp ? ( z i T ) × 1 T = 1 T × exp ? ( z i / T ) Σ j exp ? ( z j / T ) = 1 T × q i \begin{aligned} A&=\frac{1}{\exp(z_i/T)+c}\times \exp(\frac{z_i}{T})\times\frac{1}{T} \\ &=\frac{1}{T}\times \frac{\exp(z_i/T)}{\Sigma_j \exp(z_j/T)} \\ &=\frac{1}{T}\times q_i \end{aligned} A?=exp(zi?/T)+c1?×exp(Tzi??)×T1?=T1?×Σj?exp(zj?/T)exp(zi?/T)?=T1?×qi??
另一边:
?
(
∑
j
≠
i
C
j
)
?
z
i
=
?
?
(
Σ
j
≠
i
p
j
log
?
q
j
)
?
z
i
=
?
?
(
Σ
j
≠
i
p
j
log
?
exp
?
(
z
j
/
T
)
exp
?
(
z
i
/
T
)
+
c
)
?
z
i
=
?
?
Σ
j
≠
i
p
j
(
z
j
/
T
?
log
?
(
exp
?
(
z
i
/
T
)
+
c
)
)
?
z
i
=
?
Σ
j
≠
i
p
j
(
log
?
(
exp
?
(
z
i
/
T
)
+
c
)
)
?
z
i
=
?
log
?
(
exp
?
(
z
i
/
T
)
+
c
)
Σ
j
≠
i
p
j
?
z
i
=
(
1
?
p
i
)
×
?
log
?
(
exp
?
(
z
i
/
T
)
+
c
)
?
z
i
=
(
1
?
p
i
)
×
A
\begin{aligned} \frac{\partial (\sum_{j\neq i}C_j)}{\partial z_i}&=-\frac{\partial(\Sigma_{j\neq i}p_j\log q_j)}{\partial z_i} \\ &=-\frac{\partial(\Sigma_{j\neq i} p_j\log \frac{\exp(z_j/T)}{\exp(z_i/T)+c})}{\partial z_i} \\ &=-\frac{\partial\Sigma_{j\neq i} p_j(z_j/T-\log(\exp(z_i/T)+c))}{\partial z_i} \\ &=\frac{\partial\Sigma_{j\neq i} p_j(\log(\exp(z_i/T)+c))}{\partial z_i} \\ &=\frac{\partial \log(\exp(z_i/T)+c)\Sigma_{j\neq i} p_j}{\partial z_i} \\ &=(1-p_i)\times\frac{\partial \log(\exp(z_i/T)+c)}{\partial z_i} \\ &=(1-p_i)\times A \end{aligned}
?zi??(∑j=i?Cj?)??=??zi??(Σj=i?pj?logqj?)?=??zi??(Σj=i?pj?logexp(zi?/T)+cexp(zj?/T)?)?=??zi??Σj=i?pj?(zj?/T?log(exp(zi?/T)+c))?=?zi??Σj=i?pj?(log(exp(zi?/T)+c))?=?zi??log(exp(zi?/T)+c)Σj=i?pj??=(1?pi?)×?zi??log(exp(zi?/T)+c)?=(1?pi?)×A?
带入
?
C
?
z
i
\frac{\partial C}{\partial z_i}
?zi??C?,得到:
?
C
?
z
i
=
?
C
i
?
z
i
+
?
(
∑
j
≠
i
C
j
)
?
z
i
=
?
p
i
(
1
T
?
A
)
+
(
1
?
p
i
)
×
A
=
?
p
i
T
+
A
=
?
p
i
T
+
q
i
T
=
?
1
T
(
q
i
?
p
i
)
\begin{aligned} \frac{\partial C}{\partial z_i} &= \frac{\partial C_i}{\partial z_i}+\frac{\partial (\sum_{j\neq i}C_j)}{\partial z_i} \\ &= -p_i(\frac{1}{T}-A)+(1-p_i)\times A \\ &=-\frac{p_i}{T}+A \\ &=-\frac{p_i}{T}+\frac{q_i}{T} \\ &=-\frac{1}{T}(q_i-p_i) \end{aligned}
?zi??C??=?zi??Ci??+?zi??(∑j=i?Cj?)?=?pi?(T1??A)+(1?pi?)×A=?Tpi??+A=?Tpi??+Tqi??=?T1?(qi??pi?)?
证毕。
如果
T
T
T与logit的大小相比较高,则可以近似(放缩法):
?
C
?
z
i
≈
1
T
(
1
+
z
i
/
T
N
+
Σ
j
z
j
/
T
?
1
+
v
i
/
T
N
+
Σ
j
v
j
/
T
)
\begin{align} \frac{\partial C}{\partial z_i}\approx\frac{1}{T}(\frac{1+z_i/T}{N+\Sigma_j z_j/T}-\frac{ 1+v_i/T}{N+\Sigma_j v_j/T})\tag{3} \end{align}
?zi??C?≈T1?(N+Σj?zj?/T1+zi?/T??N+Σj?vj?/T1+vi?/T?)?(3)?
如果假设每个转移情况的logit都是0均值的,则
Σ
j
z
j
=
Σ
j
v
j
=
0
\Sigma_j z_j=\Sigma_j v_j=0
Σj?zj?=Σj?vj?=0,那么Eq.3可以简化为:
?
C
?
z
i
≈
1
N
T
2
(
z
i
?
v
i
)
\begin{align} \frac{\partial C}{\partial z_i}\approx\frac{1}{NT^2}(z_i-v_i)\tag{4} \end{align}
?zi??C?≈NT21?(zi??vi?)?(4)?
因此,高温极限下,若logit是0均值的,则蒸馏相当于最小化
1
/
2
(
z
i
?
v
i
)
2
1/2(z_i-v_i)^2
1/2(zi??vi?)2。
低温条件下,蒸馏对于与平均值相比更负的logits匹配关注较少,这可能是有利的,因为这些logits在训练复杂模型时几乎完全不受成本函数的约束,可能导致很多噪点。
另一方面,非常负的logits可能会带有关复杂模型获得知识的有用信息,这些影响中哪一个是主导,这是一个经验问题。
当蒸馏模型太小而无法捕获繁琐模型中的所有知识时,中间温度效果最好。这表明忽略大的负对数可能会有益。
3 Preliminary experiments on MNIST
为了探明蒸馏的效果,在60000个案例上训练了一个大型神经网络,它具有两个隐藏层,每个隐藏层包含1200个ReLU单元(rectified linear hidden units)。该网络使用dropout和权重约束进行了强烈正则化。Dropout可以被视为训练共享权重的指数级大模型集合的一种方法。此外,输入图像在任意方向上的扰动不超过2个像素。这个网络取得了67个测试错误。
一个更小的网络,有2个包含800个ReLU单元的隐藏层,没有正则化,取得了146个错误。
如果将这个较小的网络仅通过增大与大型网络在
T
=
20
T=20
T=20时产生的软目标匹配的额外任务进行正则化,取得了76个错误。
这表明,soft target可以将大量的知识转移到蒸馏模型当中,即使是经过各种变换后的训练数据中学到的知识,尽管transfer set一般不包含任何转换。
当蒸馏网络的两个隐藏层中每层有300个活更多ReLU单元时,所有高于8的温度都产生了相似的效果。但是每层ReLU数被大幅减少到30时,温度在2.5到4范围内的效果显著优于更高或更低的温度。
还尝试了在转移集中删除所有数字3的示例。对于蒸馏模型,3是一个从没见过的神秘数字。
蒸馏模型只有206个测试错误,其中133个出现在测试集中1010个数字3上。大多数错误是由于对于数字3类别的学习偏差太低导致的。
如果将这种偏差增加 3.5 倍(这会优化测试集的整体性能),蒸馏模型产生了 109 个错误,其中有 14 个是数字 3。因此,通过正确的偏差,尽管在训练过程中从未见过数字 3,蒸馏模型能够正确识别 98.6% 的测试集中的数字 3。如果转移集仅包含训练集中的数字 7 和 8,蒸馏模型产生了 47.3% 的测试错误,但是当数字 7 和 8 的偏差减少 7.6 以优化测试性能时,错误率降至 13.2%。
4 Experiments on speech recognition
(这部分对音频识别不很了解,需要以后知识丰富后重新学习)
这一节,调查了用于自动语音识别 (ASR) 的深度神经网络 (DNN) 声学模型集成的影响,展示了本文提出的蒸馏策略实现了将模型集成蒸馏为单一模型的预期效果,该单一模型的性能明显优于直接从相同训练数据中学习的相同大小的模型。
目前最先进的自动语音识别 (ASR) 系统使用深度神经网络 (DNN) 将从波形中导出的(短暂的)时间上下文特征映射到隐马尔可夫模型 (HMM) 离散状态的概率分布。更具体地说,DNN 在每个时间点上产生一个三音素状态集群的概率分布,然后解码器找到通过 HMM 状态的路径,这个路径在使用高概率状态和在语言模型下产生概率高的转录之间找到最佳平衡。
虽然可以(也希望)以训练DNN的方式使得解码器(因此也包括语言模型)考虑到所有可能路径,但通常训练DNN以通过(局部地)最小化网络所作预测与通过强制与每个观察值的真实状态序列的标签之间的交叉熵进行逐帧分类:
θ
=
arg
?
max
?
θ
′
P
(
h
t
∣
s
t
;
θ
′
)
\mathbf{\theta}=\arg\max_{\mathbf{\theta}^{'}} P(h_t|s_t; \theta^{'})
θ=argθ′max?P(ht?∣st?;θ′)
其中,
θ
\theta
θ是我们声学模型
P
P
P的参数,该模型将时间
t
t
t的声学观察
s
t
s_t
st?映射到一个概率
P
(
h
t
∣
s
t
;
θ
′
)
P(h_t|s_t; \theta^{'})
P(ht?∣st?;θ′),表示“正确”的 HMM 状态
h
t
h_t
ht?的概率,这个状态由与正确单词序列的强制对齐确定。该模型采用分布式随机梯度下降方法进行训练。
我们采用的架构包含8个隐藏层,每个隐藏层包含2560个修正线性单元,以及一个最终的softmax层,具有14000个标签(HMM目标 h t h_t ht?)。输入是26帧的40个Mel频率滤波器组成的系数,每帧间隔10毫秒,我们预测第21帧的HMM状态。总参数数量约为8500万。这是Android语音搜索使用的声学模型的略微过时版本,应被视为一个非常强大的基准。为训练DNN声学模型,我们使用了约2000小时的英语口语数据,产生了约7亿个训练样本。该系统在我们的开发集上实现了58.9%的帧精度和10.9%的词错误率(WER)。
表1:帧分类准确率和词错误率显示,蒸馏后的单一模型的表现大致与用于生成软目标的10个模型的平均预测结果相当。
5 Training ensembles of specialists on very big datasets
训练一个模型集合是利用并行计算的非常简单的方法,通常对集合在测试时需要太多计算资源的反对意见可以通过使用蒸馏来解决。然而,对于集合还存在另一个重要的反对意见:如果单个模型是大型神经网络,且数据集非常庞大,在训练时所需的计算量过大,尽管很容易并行化。
这一节举例说明这样一个数据集,并展示了如何学习专家模型,每个模型专注于不同的易混淆类别子集,从而减少学习集合所需的总计算量。专家模型专注于进行细粒度区分的主要问题是它们很容易出现过拟合,我们描述了如何通过使用软目标来预防这种过拟合。
5.1 The JFT dataset
JFT 是谷歌内部的数据集,包含1亿张带有15,000个标签的标记图像。在我们进行此项工作时,谷歌对 JFT 的基准模型是一种深度卷积神经网络,经过约六个月的训练,使用了大量核心的异步随机梯度下降。这次训练使用了两种类型的并行方式。
异步随机梯度下降(Asynchronous stochastic gradient descent):一种梯度下降的变种,用于训练深度神经网络。在 ASGD 中,多个工作进程可以独立地更新模型参数,而不需要等待其他进程。每个工作进程都会计算一部分数据的梯度,并且可以在没有全局同步的情况下更新模型参数。
这种方法通常用于分布式系统中,其中多个计算节点可以并行地处理不同的数据批次,并且不需要等待其他节点完成它们的计算。ASGD 可以提高训练效率,并允许更快地收敛到较好的模型。然而,由于异步更新,ASGD 可能导致一定程度上的参数不稳定性,因为不同的节点在不同的时间更新参数。因此,对于某些应用场景,同步的方法可能更可靠,因为它们确保所有节点在更新参数时是同步的。
第一种方式是,神经网络有许多副本运行在不同的核心集上,处理来自训练集的不同小批量数据。每个副本计算其当前小批量的平均梯度,并将该梯度发送到分片参数服务器,该服务器发送回参数的新值。 这些新值反映了参数服务器自上次向副本发送参数以来接收到的所有梯度。
第二种方式是,通过将神经元的不同子集分配到每个核心上,将每个副本分布在多个核心上。
集合训练是第三种可以包裹在前两种类型周围的并行方式,但前提是有更多的核心可用。等待数年来训练一组模型并不是一个选择,因此我们需要一种更快的方法来改进基准模型。
5.2 Specialist Models
当类别数量非常庞大时,让复杂模型成为一个包含一种在所有数据上训练的通用模型和许多“专家”模型的集成,每个专家模型都是在高度丰富于某个类别子集(比如不同类型的蘑菇)的数据上训练的。这种专家模型的softmax可以通过将其不关心的所有类别合并成一个单一的“垃圾桶”类别而大幅减小。
为了减少过拟合并分享学习较低级特征检测器的工作,每个专家模型都使用通用模型的权重进行初始化。然后,通过训练专家模型,其一半示例来自其特殊子集,另一半随机抽样自其余训练集。训练后,我们可以通过将“垃圾桶”类别的 logit 增加专家类别过采样比例的对数来校正训练集的偏差。
5.3 Assigning classes to specialists
为了得到专家模型的对象类别分组,我们决定专注于全网络经常混淆的类别。尽管我们可以计算混淆矩阵并将其用作查找这些类别的方法,但我们选择了一种更简单的方法,不需要使用真实标签来构建这些集群。
特别是,我们对通用模型预测的协方差矩阵应用了聚类算法,这样经常一起被预测的一组类别 S m S^m Sm 将被用作我们其中一个专家模型 m m m的目标。我们对协方差矩阵的列应用了K-means算法的在线版本,并获得了合理的聚类结果(如表2所示)。我们尝试了几种聚类算法,其产生了类似的结果。
表2:协方差矩阵聚类算法计算出的集群示例类别
5.4 Performing inference with ensembles of specialists
在研究专家模型被提炼后的效果之前,我们想要看看包含专家的整体模型表现如何。除了专家模型外,我们始终有一个通用模型,以便处理我们没有专家的类别,并决定要使用哪些专家。
对于输入图像x,我们通过两步进行最高置信度的分类:
步骤1:对于每个测试案例,根据通用模型找到最可能的
n
n
n个类别。我们将这组类别称为
k
k
k。在我们的实验中,我们使用了
n
=
1
n=1
n=1。
步骤2:然后,我们考虑所有专家模型
m
m
m,其中其易混淆类别
S
m
S^m
Sm 与
k
k
k 有非空交集,并将这组模型称为活动专家集
A
k
A_k
Ak?(请注意,这个集合可能为空)。然后,找到最小化全类别概率分布
q
q
q 的公式:
K
L
(
p
g
,
q
)
+
∑
m
∈
A
k
K
L
(
p
m
,
q
)
\begin{align} KL(\mathbf{p}^{\mathbf{g}},\mathbf{q})+\sum_{m∈A_k}KL(\mathbf{p}^m,\mathbf{q})\tag{5} \end{align}
KL(pg,q)+m∈Ak?∑?KL(pm,q)?(5)?
其中
K
L
KL
KL代表KL散度,
p
m
,
p
q
\mathbf{p}^m,\mathbf{p}^{\mathbf{q}}
pm,pq表示专家模型或通用模型的概率分布。分布
p
m
\mathbf{p}^m
pm是
m
m
m的所有专家类别加上一个单独的垃圾箱类别,因此计算其与完整
q
\mathbf{q}
q分布的KL散度时,要先对
q
\mathbf{q}
q分布赋予的所有概率求和,这些概率对应
m
m
m垃圾箱类别中的所有类别。
Eq.5没有一个通用的闭式解,尽管当所有模型为每个类别产生单个概率时,解决方案要么是算术平均值,要么是几何平均值,这取决于我们使用的是 K L ( p , q ) KL(\mathbf{p}, \mathbf{q}) KL(p,q)还是 K L ( q , p ) KL(\mathbf{q}, \mathbf{p}) KL(q,p)。我们对q进行参数化,即 q = s o f t m a x ( z ) ( T = 1 ) \mathbf{q} = softmax(z)(T = 1) q=softmax(z)(T=1),并使用梯度下降来优化等式5中的logits z。请注意,这种优化必须针对每个图像进行。
5.5 Results
从经过训练的baseline全网络开始,专家们的训练速度非常快(与 JFT 训练的几周相比,仅需几天)。而且,所有专家都完全独立地进行训练。表3显示了baseline和baseline与专家模型结合的绝对测试准确度。使用61个专家模型,整体测试准确度相对提高了4.4%。我们还报告了条件测试准确度,即只考虑属于专家类别的样本,并将我们的预测限制在该类别的子集中。
表3:JFT数据集上的分类top-1准确度
对于我们的JFT专家实验,我们训练了61个专家模型,每个模型有300个类别(加上垃圾桶类别)。因为专家的类别集不是不相交的,我们经常有多个专家涵盖特定的图像类别。表4显示了测试集示例的数量,使用专家时第1位置的正确示例数量的变化,以及按照涵盖该类别的专家数量分解的JFT数据集的top1准确度的相对百分比改进。我们鼓励一般趋势是,当我们有更多专家涵盖特定类别时,准确度提高更大,因为训练独立的专家模型非常容易并行化。
表4:JFT测试集中每个正确类别所涵盖的专家模型数量与top1准确度的改进。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!