损失函数与反向传播
2024-01-07 18:08:05
nn.L1Loss()
import torch
from torch import nn
inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)
inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))
loss = nn.L1Loss()
result = loss(inputs, targets)
print(result) # tensor(0.6667)
reduction=‘sum’(默认是mean)
loss = nn.L1Loss(reduction='sum')
MSE
loss_mse = nn.MSELoss()
result_mse = loss_mse(inputs, targets)
CrossEntropyLoss_交叉熵损失
x = torch.tensor([0.1, 0.2, 0.3])
y = torch.tensor([1])
x = torch.reshape(x, (1,3))
loss_cross = nn.CrossEntropyLoss()
result_cross = loss_cross(x, y)
print(result_cross)
使用之前的网络
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
)
def forward(self, x):
x = self.model1(x)
return x
tudui = Tudui()
for data in dataloader:
imgs, targets = data
outputs = tudui(imgs)
print(outputs)
print(targets)
可以正确输出,且是分类问题,因此使用交叉熵
tudui = Tudui()
loss = nn.CrossEntropyLoss()
for data in dataloader:
imgs, targets = data
outputs = tudui(imgs)
result_loss = loss(outputs, targets)
print(result_loss)
1.计算实际输出和目标之间的差距
2.为我们更新输出提供一定的依据(反向传播)grad
查看梯度
执行反向传播后
文章来源:https://blog.csdn.net/weixin_51788042/article/details/135440339
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!