全卷积网络
全卷积网络
全卷积网络就是图像到图像的变换,一般用于语义分割,图像恢复啥的
我们使用Resnet18来进行,最后平均池化和全连接层我们不需要
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
# 使用ImageNet数据集上预训练的Resnet-18来提取图像特征,Resnet-18的最后几层包括全局平均池化和全连接层,我们不需要
pretrained_net = torchvision.models.resnet18(pretrained=True)
list(pretrained_net.children())[-3:]
创建全卷积网络,抛掉最后两层
net = nn.Sequential(*list(pretrained_net.children())[:-2])
给定高320宽480, net前向传播为1/32
X = torch.rand(size=(1, 3, 320, 480))
net(X).shape
因此,我们需要利用转置卷积将shape转为原来的320*480
在此之前,我们需要将通道数改为类别数,这样每个通道相同位置上的点就代表其对应类别的概率
# 最后先用1*1卷积层将输出通道数换成类别数,因为这样子通道上相同位置的点就代表为这个类别的概率
# 然后再通过反卷积增加高和宽的size
# 如果步幅为s,填充为s/2,卷积核size为2s*2s,那么输入的高和宽为放大s倍
num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))
对于转置卷积的初始化,我们有特殊的做法
初始化卷积层,我们有时候需要将图像放大,用上采样,双线性插值是常用的上采样方法之一,它也常用于初始化卷积层
假设我们将输出图像坐标(x, y) 映射到(x’, y’)上,例如,根据输入和输出的尺寸来映射
然后我们需要在输入图像上找到与(x’, y’)最近的四个像素
输出图像坐标(x,y)上的像素依据输入图像上的zz这四个像素和(x’, y’)距离来计算
了解一下原理就行,代码层面先不深究
# 可以通过转置卷积层实现,我们不讨论算法的原理,就给出一个函数(我也没了解)
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = (torch.arange(kernel_size).reshape(-1, 1),
torch.arange(kernel_size).reshape(1, -1))
filt = (1 - torch.abs(og[0] - center) / factor) * \
(1 - torch.abs(og[1] - center) / factor)
weight = torch.zeros((in_channels, out_channels,
kernel_size, kernel_size))
weight[range(in_channels), range(out_channels), :, :] = filt
return weight
假设需要放大两倍,使用该初始化技术
# 假设我们需要把图像放大两倍,使用bilinear_kernel初始化
conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2, bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4))
用一张图片来感受一下
img = torchvision.transforms.ToTensor()(d2l.Image.open('catdog.png'))
X = img.unsqueeze(0)
Y = conv_trans(X)
out_img = Y[0].permute(1, 2, 0).detach()
d2l.set_figsize()
print(img.permute(1, 2, 0).shape)
d2l.plt.imshow(img.permute(1, 2, 0))
print(out_img.shape)
d2l.plt.imshow(out_img)
可以看到跟原图片是没什么变化的(原图片前面的章节中有提到)
初始化之前的全卷积网络的参数
# 我们就使用bilinear_kernel来初始化W
# 对于1*1卷积层,使用Xavier初始化参数
W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W)
读取数据, 这里我修改了内置函数的num_worker,要不然会报错
# 读取数据集
batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size) # 我修改了内置函数的num_worker = 0
定义Loss和训练,这里由于我电脑性能问题,没有训练成功,应该代码是可以运行的
def loss(inputs, targets): # loss和之前的没什么区别,只不过是基于通道维度的
return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)
num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
最后是预测
def predict(img): # 先标准化
X = test_iter.dataset.normalize_image(img).unsqueeze(0)
pred = net(X.to(device[0])).argmax(dim=1)
return pred.reshape(pred.shape[1], pred.shape[2])
# 可视化类别的颜色
def label2image(pred):
coloarmap = torch.tensor(d2l.VOC_COLORMAP, device = devices[0])
X = pred.long()
return colormap[X, :]
测试集中的图像大小各异,由于模型使用了步幅为32的转置卷积层,因此当输入图像的高或宽无法被32整除时,
转置卷积层输出的高或宽会与输入图像的尺寸有偏差。 为了解决这个问题,我们可以在图像中截取多块高和宽为32的整数倍的矩形区域,
并分别对这些区域中的像素做前向传播。 请注意,这些区域的并集需要完整覆盖输入图像。
当一个像素被多个区域所覆盖时,它在不同区域前向传播中转置卷积层输出的平均值可以作为softmax运算的输入,从而预测类别
我们只读取几张较大的测试图像,并从图像的左上角开始截取形状为320*480的区域用于预测。
对于这些测试图像,我们逐一打印它们截取的区域,再打印预测结果,最后打印标注的类别
voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):
crop_rect = (0, 0, 320, 480)
X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)
pred = label2image(predict(X))
imgs += [X.permute(1, 2, 0), pred.cpu(), torchvision.transforms.functional.crop(test_images[i], *crop_rect).permute(1, 2, 0)]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scales=2) # 连续的是图像、预测、标签 这样加起来后,得到的是连续的图像,连续的预测,连续的标签
参考:https://zh.d2l.ai/chapter_computer-vision/fcn.html
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!