二分类任务、多分类任务、多标签分类任务、回归任务的代码实现上的区别和联系
2024-01-01 15:45:28
区别和联系:?
在PyTorch中,不同类型任务的主要区别在于输出层的设计和损失函数的选择:
- 二分类任务:(一个图片被分类为猫或者狗,被分类成2个分类中的一种)
- 实现方式1:使用单个输出节点和
Sigmoid
激活,损失函数通常是BCELoss
。 - 实现方式2:使用两个输出节点和
Softmax
激活函数,通常使用CrossEntropyLoss
。 - 多分类任务:使用与类别数相等的输出节点和
Softmax
激活,损失函数是CrossEntropyLoss
。(一个图片被分类为猫、狗、老鼠、青蛙....,可能被分类成多个分类中的一种) - 多标签分类任务:使用与标签数相等的输出节点和
Sigmoid
激活,损失函数是BCELoss
或BCEWithLogitsLoss
。(一个图片既包含猫,还包含狗,它同时含有多个标签) - 回归任务:单个输出节点,无激活函数,损失函数通常是
MSELoss
。
备注:其中“?二分类的实现方式2 ”和“ 多分类任务 ”实现方式是一样的
代码实现上:
1、二分类任务
实现方式1:
使用单个输出节点和Sigmoid
激活,损失函数通常是BCELoss
。
class BinaryClassificationModel(nn.Module):
def __init__(self):
super(BinaryClassificationModel, self).__init__()
self.layer = nn.Linear(10, 1)
def forward(self, x):
return torch.sigmoid(self.layer(x))
criterion = nn.BCELoss()
?实现方式2:
使用两个输出节点和Softmax
激活函数,通常使用CrossEntropyLoss
。
class BinaryModelSoftmax(nn.Module):
def __init__(self):
super(BinaryModelSoftmax, self).__init__()
self.layer = nn.Linear(10, 2)
def forward(self, x):
return torch.softmax(self.layer(x), dim=1)
criterion_softmax = nn.CrossEntropyLoss()
2、多分类任务
多分类任务:使用与类别数相等的输出节点和Softmax
激活,损失函数是CrossEntropyLoss
。
class MulticlassClassificationModel(nn.Module):
def __init__(self):
super(MulticlassClassificationModel, self).__init__()
self.layer = nn.Linear(10, 3)
def forward(self, x):
return torch.softmax(self.layer(x), dim=1)
criterion = nn.CrossEntropyLoss()
3、多标签分类任务
多标签分类任务:使用与标签数相等的输出节点和Sigmoid
激活,损失函数是BCELoss
或BCEWithLogitsLoss
。
class MultiLabelClassificationModel(nn.Module):
def __init__(self):
super(MultiLabelClassificationModel, self).__init__()
self.layer = nn.Linear(10, 3)
def forward(self, x):
return torch.sigmoid(self.layer(x))
criterion = nn.BCEWithLogitsLoss()
4、回归任务
回归任务:单个输出节点,无激活函数,损失函数通常是MSELoss
。
class RegressionModel(nn.Module):
def __init__(self):
super(RegressionModel, self).__init__()
self.layer = nn.Linear(10, 1)
def forward(self, x):
return self.layer(x)
criterion = nn.MSELoss()
文章来源:https://blog.csdn.net/weixin_43135178/article/details/135324730
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!