【迁移学习论文四】Multi-Adversarial Domain Adaptation论文原理及复现工作
Multi-Adversarial Domain Adaptation 多对抗域适应
前言
- 好久没有更新了,所以这周开始记录下来,也好督促自己。
- 记录本人预备研究生阶段相关迁移学习论文的原理阐述以及复现工作。
问题
跨域混淆或错误对齐
文章介绍
这篇文章于2018年发表在AAAI,作者是清华大学龙明盛老师的学生。作者提到,域适应存在两个技术挑战:
- 通过最大限度地匹配跨域数据分布的多模式结构来增强正迁移;
- 通过防止跨域分布中模式的错误对齐来减轻负迁移。
在这些挑战的激励下,作者提出一种多对抗域自适应(MADA)方法,它捕获多模式结构,以支持基于多个域鉴别器的不同数据分布的细粒度对齐。与以前的方法相比,一个关键的改进是能够同时促进相关数据的正迁移和减轻不相关数据的负迁移。利用线性时间内的反向传播计算梯度,通过随机梯度下降实现自适应。
模型结构
标签分类器
- 从图中可以看到,源域样本首先经过 G f G_f Gf?层提取到相关特征,然后送入标签分类器 G y G_y Gy?得到分类标签
- 然后使用交叉熵损失计算分类损失。标签分类器几乎是所有模型必备的一项。这里不过多赘述。
局部域分类器
-
我们可以看到上面的蓝色线条,这就是局部域分类器。
- 首先 G f G_f Gf?提取到目标域和源域的特征,然后经过GRL后送入局部域分类器。这里产生一个问题,目标域样本的标签我们是不知道的,只知道源域数据标签,那我们如何知道哪个样本应该送入哪个局部域分类器呢?
- 让经过GRL之后的特征再与各个类别的标签分类概率相乘,然后送入相应类别的局部域分类器,在这个类别的分类概率越高就意味着这个局部域分类器对你的关注度就应该越高。
-
对抗学习过程是一个双人博弈
- 第一个参与者是经过训练的域鉴别器 G d G_d Gd?,用于区分源域和目标域
- 第二个参与者是同时经过微调的特征提取器 G f G_f Gf?,用于混淆域鉴别器
损失函数
优点
本文提出的多对抗域自适应网络实现了细粒度自适应,每个数据点 x i x_i xi?仅由相关的域鉴别器根据其概率 y i y_i yi?进行匹配。这种细粒度的适应可能带来三个好处。
- 避免了将每个点只分配给一个域鉴别器的困难,避免了对目标域数据的不准确。
- 避免了负迁移,因为每个点只对齐到最相关的类,而不相关的类被概率过滤掉,不会包含在相应的域判别器中,从而避免了不同分布中判别结构的错误对齐。
- 用概率加权数据点训练多域鉴别器,自然学习到具有不同参数的多个域鉴别器;这些域具有不同参数的鉴别器促进每个实例的正迁移。
代码
# 前向传播方法
def forward(self, x):
# 计算 lambda(lbda)
lbda = self.get_lambda_p(self.get_p()) if self.mode == 'Train' else 0
# 提取特征
features = self.backbone(x) # 通过特征提取器获取特征
features = features.reshape(features.size(0), -1) # 重塑特征形状为二维
# 类别分类器得到类别预测结果
class_logits = self.class_classifier(features)
class_predictions = F.softmax(class_logits, dim=1) # 对类别 logits 进行 softmax 得到概率
# 对特征进行反转(领域自适应)
reverse_features = GRL.apply(features, lbda)
# 对每个类别使用独立的领域分类器进行域分类
domain_logits = []
for class_idx in range(self.num_classes):
weighted_reverse_features = class_predictions[:, class_idx].unsqueeze(1) * reverse_features
# 域分类器对加权的反转特征进行域分类
domain_logits.append(
self.domain_classifiers[class_idx](weighted_reverse_features).cuda()
)
return class_logits, domain_logits # 返回类别 logits 和域 logits
-
前向传播(forward):
- 输入数据 x 经过特征提取器(self.backbone)得到特征表示。
- 特征表示经过类别分类器得到类别的预测结果(class_logits)。
- 特征表示经过领域分类器,分别针对每个类别(num_classes)进行域分类。
-
多个域分类器: 对每个类别都有一个独立的领域分类器,以便更好地适应不同类别在不同域中的分布情况。
总结
本文提出了一种新的多对抗域自适应方法来实现有效的深度迁移学习。与以往的领域对抗自适应方法只匹配域间的特征分布而不利用复杂的多模结构不同,该方法进一步利用判别结构,在多对抗自适应框架中实现细粒度分布对齐,同时促进正迁移和规避负迁移。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!