33. CV练习: 验证码识别
Hi, 你好。我是茶桁。
上一节课,我给大家留了个作业,内容是对验证码进行识别。
咱们的再把练习内容重复看一下:
-
练习内容: 训练一个模型,对验证码中的字符进行分类识别,并最终完成验证码识别的任务。
-
数据集: 数据集内包含0-9以及A-Z一共36个字符,训练集中每个字符有50张图片,验证集中每个字符有10张图片,验证码数据集是由随机去除的4个字符图片拼接而成。
-
需要的相关知识:
- 数据读取
- 使用torch搭建、训练、验证模型
- 模型预测于图片切分
好,让我们来看看具体的,我们该怎么完成这个练习。
问题分析
首先,我们需要一步步的确定我们的问题。第一个问题,肯定是要先建立字符对照表,第二个问题,要定义一个datasets
和一个dataloader
。 第三个问题,是需要定义网络结构。 第四个问题,就是定义模型训练函数。 最后,就是验证我们的训练结果。
先从第一个问题分析,我们可以通过遍历字典,将每一对键值反转,并存储于新的字典中。我们可以按如下方式去做:
new_dict = {v:k for k, v in old_dict.items()}
那么第二个问题就简单了,在opencv-python中,可以使用image = cv2.medianBlur(image, kernel_size)
进行中值滤波。
第三个问题,在torch中,卷积于前连接层的定义方法可以是:
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
fc = nn.Linear(in_features, out_features, bias)
具体的,还可以参看PyTorch的相关手册。 用这两个方法进行组合,就可以定义出一个卷积神经网络结构。
接下来第四个问题,要定义模型训练函数。那么torch框架的模型训练过程会包含清空梯度、前向传播、计算损失、计算梯度、更新权重等操作。
- 清空梯度:目的是消除step与step之间的干扰,即每次都只用一个batch的数据损失计算梯度并更新权重。一般可以放在最前或最后;
- 前向传播:使用一个batch的数据跑一边前向传播的过程,生成模型输出结果;
- 计算损失:使用定义好的损失函数、模型输出结果以及label计算单个batch的损失值;
- 计算梯度:根据损失值,计算模型所有权中在本次优化中所需的梯度值;
- 更新权重:使用计算好的梯度值,更新所有权重的值。
模拟一下单词流程代码,样例如下:
>>> optimizer.zero_grad() # 清空梯度(也可以放在最后一行)
>>> output = model(data) # 前向传播
>>> loss = loss_fn(output, target) # 计算损失
>>> optimizer.step() # 更新权重
好,步骤分析完了,接着咱们来进行实现,这次我们所用到的代码库如下,将其都引入进来:
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pickle
import matplotlib.pyplot as plt
from PIL import Image
然后我们来看一下数据集,上一课末尾我给过数据集了,有需要的自己去找一下。先将需要用到的数据路径都定义出来:
train_data_dir = 'data/train_data.bin'
val_data_dir = 'data/val_data.bin'
verification_code_dir = 'data/verification_code_data.bin'
这个数据集保存在二进制文件中,我们需要定义一个函数,读取二进制文件中的图片:
def load_file(file_name):
with open(file_name, mode='rb') as f:
result = pickle.load(f)
return result
来让我们查看一下数据集:
train_data = load_file(train_data_dir)
img_test = list()
for i in range(1,1800,50):
img_test.append(train_data[i][1])
plt.figure()
for i in range(1,37):
plt.subplot(6,6,i)
plt.imshow(img_test[i-1])
plt.xticks([])
plt.yticks([])
plt.show()
放大其中单张图:
plt.imshow(train_data[500][1])
plt.xticks([])
plt.yticks([])
plt.show()
观察这张图片,我们可以看到字符图片中含有大量的噪声,而噪声会对模型预测结果产生不良影响,因此我们可以在数据预处理时,使用特定的滤波器,消除图片噪声。
简单观察可知,刚才定义字符字典中,键与值都没有重复项,因此可以将字典中的键与值进行反转,以便我们用值查找键(将模型预测结果转换成可读字符)。将字典中的键与值进行反转(例:dict={'A':10,'B':11}
反转后得到new_dict={10:'A',11:'B'}
char_dict = {'0':0,'1':1,'2':2,'3':3,'4':4,'5':5,'6':6,'7':7,'8':8,'9':9,\
'A':10,'B':11,'C':12,'D':13,'E':14,'F':15,'G':16,'H':17,'I':18,'J':19,'K':20,'L':21,'M':22,\
'N':23,'O':24,'P':25,'Q':26,'R':27,'S':28,'T':29,'U':30,'V':31,'W':32,'X':33,'Y':34,'Z':35}
new_char_dict = {v : k for k, v in char_dict.items()}
然后我们就要定义datasets
和dataloader
了, 我们需要使用torch.utils.data.Dataset
作为父类,定义自己的datasets,以便规范自己的数据集。
class MyDataset(Dataset):
def __init__(self, file_name, transforms):
self.file_name = file_name # 文件名称
self.image_label_arr = load_file(self.file_name) # 读入二进制文件
self.transforms = transforms # 图片转换器
def __getitem__(self, index):
label, img = self.image_label_arr[index]
img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY) # 将图片转为灰度图
img = cv2.medianBlur(img, 5) # 使用中值模糊除去图片噪音
img = self.transforms(img) # 对图片进行转换
return img, char_dict[label[0]]
def __len__(self):
return len(self.image_label_arr)
接着我们来定义transform
和dataloader
:
transform = transforms.Compose([transforms.ToPILImage(),
transforms.Resize([28,28]), # 将图片尺寸调整为28*28
transforms.ToTensor(), # 将图片转为tensor
transforms.Normalize(mean=[0.5],std=[0.5])]) # 进行归一化处理
train_datasets = MyDataset(train_data_dir, transform)
train_loader = DataLoader(dataset=train_datasets,batch_size=32,shuffle=True)
val_datasets = MyDataset(val_data_dir, transform)
val_loader = DataLoader(dataset=val_datasets,batch_size=32,shuffle=True)
在数据准备好之后,我们需要定义一个简单的卷积神经网络,神经网络的输入是[batchsize,chanel(1),w(28),h(28)]
,输出是36个分类。
我们的神经网络将使用2个卷积层搭配2个全连接层,这四层的参数设置如下表所示(未标注的直接使用默认参数即可):
conv1: in_chanel=1, out_chanel=10, kernel_size=5
conv2: in_chanel=10, out_chanel=20, kernel_size=3
fc1: in_feature=2000, out_feature=500
fc2: in_feature=500, out_feature=36
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.conv2 = nn.Conv2d(10, 20, 3)
self.fc1 = nn.Linear(20 * 10 * 10, 500)
self.fc2 = nn.Linear(500, 36)
def forward(self, x):
# inputsize:[b,1,28,28]
in_size = x.size(0) # b
out= self.conv1(x) # inputsize:[b,1,28,28] -> outputsize:[b,10,24,24]
out = F.relu(out)
out = F.max_pool2d(out, 2, 2) # inputsize:[b,10,24,24] -> outputsize:[b,10,12,12]
out = self.conv2(out) # inputsize:[b,10,12,12] -> outputsize:[b,20,10,10]
out = F.relu(out)
out = out.view(in_size, -1) # inputsize:[b,20,10,10] -> outputsize:[b,2000]
out = self.fc1(out) # inputsize:[b,2000] -> outputsize:[b,500]
out = F.relu(out)
out = self.fc2(out) # inputsize:[b,500] -> outputsize:[b,36]
out = F.log_softmax(out, dim = 1)
return out
接着,我们来定义模型训练函数,需要实现下面四项操作:
- 清空梯度
- 前向传播
- 计算梯度
- 更新权重
def train(model, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
再来定义模型测试函数:
def test(model, test_loader):
model.eval()
test_loss =0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += F.nll_loss(output, target, reduction = 'sum')
pred = output.max(1, keepdim = True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print("\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%) \n".format(
test_loss, correct, len(test_loader.dataset),
100.* correct / len(test_loader.dataset)))
接着是定义模型及优化器,我们将刚刚搭建好的模型结构定义为model,并选择使用Adam优化器。
model = ConvNet()
optimizer = optim.Adam(model.parameters())
终于可以来进行模型训练与测试了。我们可以先设置epochs数为3,进行模型训练,看看模型精度是多少,是否满足验证码识别的要求。如果模型精度不够,你还可以尝试调整epochs数,重新进行训练。
EPOCHS = 3
for epoch in range(1, EPOCHS + 1):
train(model, train_loader, optimizer, epoch)
test(model, val_loader)
---
Train Epoch: 1 [288/1800 (16%)] Loss: 3.454732
Train Epoch: 1 [608/1800 (33%)] Loss: 2.911864
Train Epoch: 1 [928/1800 (51%)] Loss: 1.960211
Train Epoch: 1 [1248/1800 (68%)] Loss: 0.972134
Train Epoch: 1 [1568/1800 (86%)] Loss: 0.420529
Test set: Average loss: 0.3054, Accuracy: 335/360 (93%)
Train Epoch: 2 [288/1800 (16%)] Loss: 0.094786
Train Epoch: 2 [608/1800 (33%)] Loss: 0.120601
Train Epoch: 2 [928/1800 (51%)] Loss: 0.073640
Train Epoch: 2 [1248/1800 (68%)] Loss: 0.058856
Train Epoch: 2 [1568/1800 (86%)] Loss: 0.007260
Test set: Average loss: 0.0139, Accuracy: 359/360 (100%)
Train Epoch: 3 [288/1800 (16%)] Loss: 0.002425
Train Epoch: 3 [608/1800 (33%)] Loss: 0.004629
Train Epoch: 3 [928/1800 (51%)] Loss: 0.005880
Train Epoch: 3 [1248/1800 (68%)] Loss: 0.008300
Train Epoch: 3 [1568/1800 (86%)] Loss: 0.004973
Test set: Average loss: 0.0039, Accuracy: 360/360 (100%)
训练结果看起来相当不错。那么下面我们自然就是要开始进行验证码识别了,我们需要先导入验证码数据集:
verification_code_data = load_file(verification_code_dir)
下面我们随便选一张图(图3),看看这个验证码长什么样。
image = verification_code_data[3]
IMG = Image.fromarray(cv2.cvtColor(image.copy(), cv2.COLOR_BGR2RGB))
plt.imshow(IMG)
再来看看中值滤波能对验证码图片产生什么效果。
# 中值滤波效果
img = cv2.medianBlur(image.copy(), 5)
plt.imshow(img)
好,最后来让我们看看识别的实际情况如何:
# 查看实际识别结果
IMAGES = list()
NUMS = list()
for img in verification_code_data:
IMAGES.append(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
image_1 = img[:, :80]
image_2 = img[:, 80:160]
image_3 = img[:, 160:240]
image_4 = img[:, 240:320]
img_list = [image_1, image_2, image_3, image_4]
nums = []
for one_img in img_list:
one_img = transform(one_img)
one_img = one_img.unsqueeze(0)
output = model(one_img)
nums.append(new_char_dict[torch.argmax(output).item()])
NUMS.append('Verification_code: '+''.join(nums))
plt.figure(figsize=(20, 20))
plt.subplots_adjust(wspace=0.2, hspace=0.5)
for i in range(1, 11):
plt.subplot(5,2,i)
plt.title(NUMS[i-1], fontsize=25, color='red')
plt.imshow(IMAGES[i-1])
plt.xticks([])
plt.yticks([])
plt.show()
相当令人满意的结果。
本次练习我主要是讲解思路,大家看完课程之后要多加练习,自己反复敲打代码,去琢磨其中的一些逻辑关系,要结合我们之前课程中讲过的知识。
还有就是,关于一些第三方库,大家要习惯于自行查看手册。这才是正确的学习打开方式。
好,那这个练习就讲到这里,大家下来记得多练习。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!