CNN定长验证码识别Pytorch版含样本直接运行

写在前面

我使用的样本例子:
样本
样本链接在最下边
训练过程第一个epoch直接就是99的成功率
拿下边的代码直接运行就完了

训练

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import models
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import numpy as np
from tqdm import tqdm
import torchvision.transforms as T

IMAGE_SHAPE = (24, 100)

transform = T.Compose([
    T.ToPILImage(),
    T.Resize(IMAGE_SHAPE),
    T.ToTensor(),
])
LABEL_MAP = [i for i in '0123456789']
# 命名不规范,请忽略。应该是验证码长度,因为是固定长度文本
Max_label_len = 5


class MyDataset(Dataset):
    def __init__(self, data_path, label_map, max_label_len):
        super(MyDataset, self).__init__()
        self.data = [(os.path.join(data_path, file), file.split('.')[0]) for file in os.listdir(data_path)]
        self.label_map = [char for char in label_map]
        self.label_map_len = len(self.label_map)
        self.max_label_len = max_label_len

    def __getitem__(self, index):
        file = self.data[index][0]
        label = self.data[index][1]
        raw_len = len(label)
        im = np.fromfile(file, dtype=np.uint8)
        im = cv2.imdecode(im, cv2.IMREAD_COLOR)
        im = transform(im)
        label = [self.label_map.index(i) for i in label]
        label = torch.as_tensor(label, dtype=torch.int64)
        label = F.one_hot(label, num_classes=len(LABEL_MAP)).float()
        return im, label, raw_len

    def __len__(self):
        return len(self.data)


class Net(nn.Module):
    """
    这里用类的原因是为了好自定义网络结构
    """

    def __init__(self):
        super(Net, self).__init__()
        self.resnet18 = models.resnet18(num_classes=Max_label_len * len(LABEL_MAP))

    def forward(self, x):
        x = self.resnet18(x)
        return x


train = DataLoader(
    dataset=MyDataset(r'../sample/train', label_map=LABEL_MAP, max_label_len=Max_label_len),
    batch_size=32, shuffle=True,
    num_workers=3)
test = DataLoader(
    dataset=MyDataset(r'../sample/test', label_map=LABEL_MAP, max_label_len=Max_label_len),
    batch_size=4, shuffle=True,
    num_workers=0)

if __name__ == '__main__':
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net()
    model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_func = nn.MSELoss()
    scheduler = StepLR(optimizer, step_size=2, gamma=0.7)

    for epoch in range(0, 100):
        # Train
        bar = tqdm(train, 'Training')
        for x, label, _ in bar:
            x, label = x.to(DEVICE), label.to(DEVICE)
            out = model(x)
            label = label.view(-1, Max_label_len * len(LABEL_MAP))
            loss = loss_func(out, label)

            # 快乐三步曲
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            lr = optimizer.param_groups[0]['lr']
            bar.set_description("Train epoch %d, loss %.4f, lr %.6f" % (
                epoch, loss.detach().cpu().numpy(), lr
            ))

        # Valid
        bar = tqdm(test, 'Training')
        correct = count = 0
        for x, label, _ in bar:
            x, label = x.to(DEVICE), label.to(DEVICE)
            out = model(x)
            label_copy = label.view(-1, Max_label_len * len(LABEL_MAP))
            loss = loss_func(out, label_copy)

            out = out.view(-1, Max_label_len, len(LABEL_MAP))  # (BATCH_SIZE, 4, 28)
            predict = torch.argmax(out, dim=2)  # (BATCH_SIZE, 4)
            label = torch.argmax(label, dim=2)

            count += x.shape[0] * Max_label_len
            correct += (predict == label).sum()

            lr = optimizer.param_groups[0]['lr']
            bar.set_description("Eval epoch %d, acc %.4f, loss %.4f, lr %.6f" % (
                epoch, correct * 1.0 / count, loss.detach().cpu().numpy(), lr
            ))

        scheduler.step(epoch)
        torch.save(model.state_dict(), "models/save_%d.model" % epoch)

测试

import torch
from torch import nn
from torchvision import models
import cv2
import numpy as np
import torchvision.transforms as T

IMAGE_SHAPE = (24, 100)

transform = T.Compose([
    T.ToPILImage(),
    T.Resize(IMAGE_SHAPE),
    T.ToTensor(),
])
LABEL_MAP = [i for i in '0123456789']
Max_label_len = 5


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.resnet18 = models.resnet18(num_classes=Max_label_len * len(LABEL_MAP))

    def forward(self, x):
        x = self.resnet18(x)
        return x

# 是否使用GPU
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = torch.device('cpu')
model = Net()
model.to(DEVICE)
# 载入训练模型
model.load_state_dict(torch.load("./models/save_9.model"))
model.eval()


def captcha(im):
    im = transform(im)
    im = im.to(DEVICE)
    im = im.unsqueeze(0)
    out = model(im)
    out = out.view(-1, Max_label_len, len(LABEL_MAP))
    predict = torch.argmax(out, dim=2)
    label = predict.cpu().detach().numpy().tolist()[0]
    return ''.join(str(i) for i in label)
# 调用方法
im = cv2.imread('path', cv2.IMREAD_COLOR)
ret = captcha(im)
print(ret)

链接: https://pan.baidu.com/s/1bEeqtlCFLqvAQcH3GqKXqQ 提取码: 7ue5

你可能感兴趣的