pytorch-minst手写字符识别实战

字符识别

文章目录

  • 字符识别
    • 1 字符识别准备
      • 1.1 minst数据集
      • 1.2 字符分类网络
      • 1.3 卷积和池化的计算公式:
      • 1.4 计算网络参数
    • 2字符识别代码
      • 2.1 网络model
      • 2.2 网络训练
      • 2.3 测试图片


1 字符识别准备

1.1 minst数据集

共有7万张图片。其中6万张训练集,1万张测试集。
每张图片是一个28*28像素点的单通道的手写数字(0~9)图片,采取黑底白字的形式。.

1.2 字符分类网络

注意:输入尺寸应该是28*28
pytorch-minst手写字符识别实战_第1张图片

1.3 卷积和池化的计算公式:

pytorch-minst手写字符识别实战_第2张图片
公式参考链接

  • kernel_size:卷积核大小
  • padding:边缘填充
  • dilation:表示感受野的扩张 默认为1
  • stride:卷积的步长

1.4 计算网络参数

卷积核大小统一为5*5

计算第一个卷积网络的参数,已知

kernel_size=5,diation=1,stride=1
H_out = 28 H_in=28
W_out = 28 W_in=28

所以只需计算出padding就好。根据网络图计算过程如下:

28 = ( 28 + 2 ∗ p a d d i n g − 1 ∗ ( 5 − 1 ) ) − 1 1 + 1 28 = \dfrac{(28+2*padding-1*(5-1)) - 1}{1} + 1 28=1(28+2padding1(51))1+1

最后计算出padding=2,同理可以计算出后面的参数

2字符识别代码

2.1 网络model

根据上一步计算的网络参数搭建网络

conv2d的参数
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1)

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 6, 5, 1, 2), #卷积
            nn.MaxPool2d(2),#池化
            nn.Conv2d(6, 16, 5, 1, 0),
            nn.MaxPool2d(2),
            nn.Flatten(),#摊平向量
            nn.Linear(400, 120),#全连接
            nn.Linear(120, 84),
            nn.Linear(84, 10)#最后输出10个类
        )

    def forward(self, input):
        output = self.model(input)
        return output

2.2 网络训练

利用tensorboard可视化训练过程

import torch
import torchvision
import time
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import *

#加载数据
train_data = torchvision.datasets.MNIST(root='.\data', train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST(root='.\data', train=False,transform=torchvision.transforms.ToTensor(),download=True)

#记录长度
length_train = len(train_data)
length_test = len(test_data)

#加载数据
train_data = DataLoader(train_data, batch_size=64)
test_data = DataLoader(test_data, batch_size=64)


#加载网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
str_class = Model()
str_class.to(device)

#调节参数
loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
learning_rate = 0.01
optimizer = torch.optim.SGD(str_class.parameters(), lr=learning_rate)
epoch = 100
total_train_step = 0


writer = SummaryWriter('./log')

for i in range(epoch):
    print('-'*5 + '第%d轮开始训练'%(i+1) + '-'*5)
    #训练网络
    start_time = time.time()
    str_class.train()
    for step, data in enumerate(train_data):
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        outputs = str_class(imgs)
        loss = loss_fn(outputs, targets)

        #反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        #输出训练信息
        if step % 100 == 0:
            end_time = time.time()
            print('训练次数:{} 训练时间: {:.2f} loss: {:.4f}'.format(step, end_time - start_time, loss))
            writer.add_scalar('train_loss/step', loss, total_train_step)

    #测试网络
    str_class.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for step, data in enumerate(test_data):
            imgs, targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = str_class(imgs)
            loss = loss_fn(outputs, targets)

            accuracy = (outputs.argmax(1) == targets).sum()
            total_test_accuracy += accuracy
            total_test_loss += loss

        #输出测试信息
        print('测试平均loss: {:.4f} 测试平均准确率 {:.4f}'.format(total_test_loss/length_test, total_test_accuracy/length_test)) #损失函数是否要除长度
        writer.add_scalar('test_loss/epoch', total_test_loss/length_test, i+1)
        writer.add_scalar('test_accuracy/epoch', total_test_accuracy/length_test, i+1)

        if (i+1) % 10 == 0:
            torch.save(str_class, './model/demo2/model_{}.pth'.format(i+1))
            print('模型已保存')

writer.close()

可视化训练过程

命令行运行如下指令
tensorboard --logdir=log
访问http://localhost:6006/

训练曲线
pytorch-minst手写字符识别实战_第3张图片

pytorch-minst手写字符识别实战_第4张图片

2.3 测试图片

可以利用画图软件构建黑底白字的图片,注意一定是黑底白字的图片,然后进行测试。

import torch
import torchvision
from PIL import Image

from model import *

str_class =torch.load('./model/demo2/decade_71.pth', map_location=torch.device('cpu'))
train_data = torchvision.datasets.MNIST('./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dict_target = train_data.class_to_idx
dict_target = [indx for indx, vale in dict_target.items()] #获得标签字典
transform_img = torchvision.transforms.Compose([torchvision.transforms.Resize((28, 28)), torchvision.transforms.ToTensor()])


def test_img(img_path):
    img = Image.open(img_path)
    img = img.convert('L')
    img = transform_img(img)
    img = torch.reshape(img, (1, 1, 28, 28))
    output = str_class(img)
    output = output.argmax(1)
    print('识别类型为{}'.format(dict_target[output]))

test_img('1.jpg')

输入图片经过网络个层后的结果
pytorch-minst手写字符识别实战_第5张图片

你可能感兴趣的