图像分割Unet算法及其Pytorch实现

2024-01-01 09:47:39

简介

UNet是一种用于图像分割的神经网络,由于这个算法前后两个部分在处理上比较对称,类似一个U形,如下图所示,故称之为Unet,论文链接:U-Net: Convolutional Networks for Biomedical Image Segmentation,全文仅8页。

在这里插入图片描述

从此图可以看出,左边的基础操作是两次 3 × 3 3\times3 3×3卷积后池化,连续4次,图像从 572 × 572 572\times572 572×572变成 32 × 32 32\times32 32×32。右侧则调转过来,以两次 3 × 3 3\times3 3×3卷积核一个 2 × 2 2\times2 2×2上采样卷积作为一组,再来四次,最后恢复成 388 × 388 388\times388 388×388的图像。

实现

整理一下上图,其计算顺序依次是

  1. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2池化
  2. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2池化
  3. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2池化
  4. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2池化
  5. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2上采样,拼接4的结果
  6. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2上采样,拼接3的结果
  7. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2上采样,拼接2的结果
  8. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 2 × 2 2\times2 2×2上采样,拼接1的结果
  9. 3 × 3 3\times3 3×3卷积-> 3 × 3 3\times3 3×3卷积-> 1 × 1 1\times1 1×1卷积

由于两次 3 × 3 3\times3 3×3卷积累计出现多次,故而先将其封装成类,便于后续调用

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, inSize, outSize):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inSize, outSize, kernel_size=3, padding=1),
            nn.BatchNorm2d(outSize),
            nn.ReLU(inplace=True),
            nn.Conv2d(outSize, outSize, kernel_size=3, padding=1),
            nn.BatchNorm2d(outSize),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

然后分别实现其降采样、上采样以及最终的输出过程,其中降采样没什么好说的,就是两次卷积一次池化,最终输出的 1 × 1 1\times1 1×1卷积当然就更简单了,二者一并实现如下

class Down(nn.Module):
    def __init__(self, inSize, outSize):
        super().__init__()
        self.conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(inSize, outSize))

    def forward(self, x):
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, inSize, outSize):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(inSize, outSize, 1)

    def forward(self, x):
        return self.conv(x)

上采样过程相对来说复杂一点,多了一个拼接操作,故而其forward函数中,除了需要输入被卷积的数据之外,还要输入U形中,与之对应的那部分计算结果

class Up(nn.Module):
    def __init__(self, inSize, outSize):
        super().__init__()

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)
        self.conv = DoubleConv(inSize, outSize)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

最后,将这几个组分拼接成一个UNet

class UNet(nn.Module):
    def __init__(self, nChannel, nClass):
        super(UNet, self).__init__()
        self.inc = DoubleConv(nChannel, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, nClass)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.outc(x)

数据集

在具体训练之前,需要准备数据集,其中图像存放在image文件夹中,标签存放在label文件夹中,同名的图像和标签文件一一对应。

from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset

class ImgData(Dataset):
    def __init__(self, data_path):
        self.path = data_path
        self.imgForder = os.path.join(data_path, "image")

    # 加载图像
    def loadImg(self, path):
        img = np.array(Image.open(path))
        return img.reshape(1, *img.shape)

    # 根据index读取图片
    def __getitem__(self, index):
        pImg = os.path.join(self.path, f"image\{index}.png")
        pLabel = os.path.join(self.path, f"label\{index}.png")
        image = self.loadImg(pImg)
        label = self.loadImg(pLabel)
        # 数据标签归一化
        if label.max() > 1:
            label = label / 255
        # 随机翻转图像,增加训练样本
        flipCode = np.random.randint(3)
        if flipCode!=0:
            image = np.flip(image, flipCode).copy()
            label = np.flip(label, flipCode).copy()
        return image, label

    def __len__(self):
        # 返回训练集大小
        return len(os.listdir(self.imgForder))

训练

接下来就是激动人心的训练过程了,UNet采用RMSprop优化算法和BCEWithLogits损失函数,训练函数如下

from torch.utils.data import DataLoader
from torch import optim
import torch.nn as nn

def train(net, device, path, epochs=40, bSize=1, lr=0.00001):
    igmData = ImgData(path)
    train_loader = DataLoader(igmData, bSize, shuffle=True)
    # 优化算法
    optimizer = optim.RMSprop(net.parameters(),
            lr=lr, weight_decay=1e-8, momentum=0.9)

    criterion = nn.BCEWithLogitsLoss()      # 损失函数
    bestLoss = float('inf')                # 最佳loss,初始化为无穷大

    # 训练epochs次
    for epoch in range(epochs):
        net.train()     # 训练模式
        for image, label in train_loader:
            optimizer.zero_grad()
            # 将数据拷贝到device中
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)

            pred = net(image)   # 使用网络参数,输出预测结果
            loss = criterion(pred, label)   # 计算损失
            # 保存loss最小的网络参数
            if loss < bestLoss:
                bestLoss = loss
                torch.save(net.state_dict(), 'best_model.pth')

            loss.backward() # 更新参数
            optimizer.step()

        print(epoch, 'Loss/train', loss.item())

接下来调用训练函数,经过40次训练之后,得到51MB的best_model.pth模型文件,此即最佳测试结果

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(1, 1)
net.to(device=device)
path = "train/"
train(net, device, path)

预测

所谓预测,无非是重新做一次训练,而且不及损失,只需保存被神经网络处理之后的结果即可,下面是预测一张图像的函数,其输入net即为我们训练好的网络,device为设备。

def predictOne(net, device, pRead, pSave):
    img = Image.open(pRead)
    img = np.array(img)
    img = img.reshape(1, 1, *img.shape)

    img = torch.from_numpy(img)
    img = img.to(device=device, dtype=torch.float32)

    pred = net(img)     # 预测
    pred[pred >= 0.5] = 255
    pred[pred < 0.5] = 0

    pred = np.array(pred.data.cpu()[0])[0]
    img = Image.fromarray(pred.astype(np.uint8))
    img.save(pSave)

最后,批量处理预测数据集,test和predict分别是存放测试文件和预测图像的文件夹。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(1, 1)
net.to(device=device)
net.load_state_dict(torch.load('best_model.pth', map_location=device))

net.eval()      # 测试模式
fs = os.listdir('test')
for f in fs:
    pRead = os.path.join('test', f)
    pSave = os.path.join("predict",f)
    predictOne(net, device, pRead, pSave)

预测结果如下,左侧为图像,右侧为标签。

在这里插入图片描述

文章来源:https://blog.csdn.net/m0_37816922/article/details/135130464
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。