pytorch实现Unet模型

自定义数据集

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/27 12:25
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : img_segData.py
# @Software: PyCharm
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import torch.nn.functional as F
import os

class img_segData(Dataset):
	def __init__(self,img_h=256,img_w=256,path="./data/img_seg",data_file="images",label_file="profiles",
	             preprocess=True):
		'''
		数据集初始化
		:param img_h: resize图像高度
		:param img_w: resize图像宽度
		:param path: 数据集路径
		:param data_file: 数据特征值文件夹名称
		:param label_file: 数据标签文件夹名称
		:param preprocess: 是否进行数据预处理
		'''
		super(img_segData, self).__init__()
		self.file_list = os.listdir(path+"/"+data_file)
		self.data_file = data_file
		self.label_files = label_file
		self.path = path
		self.img_h = img_h
		self.img_w = img_w
		self.preprocess = preprocess
		pass
		
		
	def __len__(self):
		# 返回数据集大小
		return len(self.file_list)
		
	
	def __getitem__(self, item):
		# 返回指定索引的数据集
		img_name = self.file_list[item]
		label_name = img_name.split(".")[0]+"-profile.jpg"
		label_path = self.path+"/"+self.label_files+"/"+label_name
		img_path = self.path+"/"+self.data_file+"/"+img_name
		
		# 读取数据
		img = Image.open(img_path)
		label = Image.open(label_path)
		
		# 数据预处理
		if self.preprocess:
			trans_img = transforms.Compose([
				transforms.Resize(size=(self.img_w,self.img_h)),
				transforms.ToTensor(),
				transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
			])
			img = trans_img(img)
			trans_label = transforms.Compose([
				transforms.Resize(size=(self.img_w,self.img_h)),
				transforms.ToTensor(),
			])
			label = trans_label(label)
		return img,label



if __name__ == '__main__':
	trans_data = img_segData()
	img,label = trans_data.__getitem__(5)
	print(img.size(),label.size())
	
	# plt.imshow(img.data.numpy().transpose([1,2,0]))
	# plt.show()
	# plt.imshow(label.data.numpy().reshape(256,256))
	# plt.show()
	label = torch.where(label==1,torch.full_like(label,0),torch.full_like(label,1))
	seg = label * img
	plt.imshow(seg.data.numpy().transpose([1,2,0]))
	plt.show()

模型

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/27 13:02
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : model.py
# @Software: PyCharm

import torch
import torch.nn as nn
import torch.nn.functional as F


class conv_block(nn.Module):
	def __init__(self,ch_in,ch_out):
		super(conv_block, self).__init__()
		self.conv = nn.Sequential(
			nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
			nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True),
			nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
			nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True),
		)
		
	def forward(self,x):
		out = self.conv(x)
		return out
	
class up_block(nn.Module):
	def __init__(self,ch_in,ch_out):
		super(up_block, self).__init__()
		self.conv = nn.Sequential(
			nn.Upsample(scale_factor=2),
			nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
			nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True),
		)
		
	def forward(self,x):
		out = self.conv(x)
		return out
	
class U_Net(nn.Module):
	def __init__(self,img_ch=3,output_ch=1):
		super(U_Net, self).__init__()
		self.ndf=64
		self.Maxpool = nn.MaxPool2d(2,2)
		self.conv1 = conv_block(ch_in=img_ch,ch_out=self.ndf)
		self.conv2 = conv_block(ch_in=self.ndf,ch_out=self.ndf * 2)
		self.conv3 = conv_block(ch_in=self.ndf*2,ch_out=self.ndf*2*2)
		self.conv4 = conv_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2*2*2)
		self.conv5 = conv_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2*2*2)
		
		self.up4 = up_block(ch_in=self.ndf*2*2*2*2,ch_out=self.ndf*2*2*2)
		self.up_conv4 = conv_block(ch_in=self.ndf*2*2*2*2,ch_out=self.ndf*2*2*2)
		
		self.up3 = up_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2)
		self.up_conv3 = conv_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2)
		
		self.up2 = up_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2)
		self.up_conv2 = conv_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2)
		
		self.up1 = up_block(ch_in=self.ndf*2,ch_out=self.ndf)
		self.up_conv1 = conv_block(ch_in=self.ndf * 2, ch_out=self.ndf)
		
		self.conv1_1 = conv_block(ch_in=self.ndf,ch_out=output_ch)
		
	def forward(self,x):
		# x [none,3, 256, 256]
		x1 = self.conv1(x) # [none,3,256,256]
		
		x1_ = self.Maxpool(x1) # [none,64,128,128]
		x2 = self.conv2(x1_) # [none,128,128,128]
		
		x2_ = self.Maxpool(x2) # [none,128,64,64]
		x3 = self.conv3(x2_) # [none,256,64,64]
		
		x3_ = self.Maxpool(x3) # [none,256,32,32]
		x4 = self.conv4(x3_) # [none,512,32,32]
		
		x4_ = self.Maxpool(x4)  # [none,512,16,16]
		x5 = self.conv5(x4_)  # [none,1024,16,16]
		
		u4_ = self.up4(x5) # [none,1024,32,32]
		u4 = self.up_conv4(torch.cat([x4,u4_],dim=1)) # [none,512,32,32]
		
		u3_ = self.up3(u4) # [none,512,64,64]
		u3 = self.up_conv3(torch.cat([x3,u3_],dim=1)) # [none,256,64,64]
		
		u2_ = self.up2(u3) # [none,256,128,128]
		u2 = self.up_conv2(torch.cat([x2,u2_],dim=1)) # [none,128,128,128]
		
		u1_ = self.up1(u2) # [none,128,256,256]
		u1 = self.up_conv1(torch.cat([x1,u1_],dim=1)) # [none,64,256,256]
		
		out = self.conv1_1(u1)  # [none,1,256,256]
		out = torch.sigmoid(out)
		return out
		




if __name__ == '__main__':
	unet = U_Net()
	print(unet)

训练模型

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/27 15:34
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : img2seg.py
# @Software: PyCharm

import numpy as np
import matplotlib.pyplot as plt
import torchvision
from img_segData import img_segData
from model import U_Net
from torch.utils import data
import torch
import os
from torchvision.utils import save_image

class Trainer(object):
	def __init__(self,img_ch=3,out_ch=3,lr=0.005,
	             batch_size=16,num_epoch=60,train_set=None,
	             model_path="./model"):
		"""
		训练器初始化
		:param img_ch: 输入图片通道
		:param out_ch: 输出图片通道
		:param lr: 学习率
		:param batch_size: 批量大小
		:param num_epoch: 迭代周期
		:param train_set: 训练数据集
		:param model_path: 模型保存路径
		"""
		self.img_ch = img_ch
		self.out_ch = out_ch
		self.lr = lr
		self.batch_size = batch_size
		self.num_epoch = num_epoch
		self.model_path = model_path
		self.data_loader = data.DataLoader(dataset=train_set,
		                                   batch_size=self.batch_size,
		                                   shuffle=True,num_workers=0)
		
		# 初始化模型
		self.unet = U_Net(self.img_ch,output_ch=self.out_ch)
		self.divice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
		self.unet.to(self.divice)
		self.loss = torch.nn.BCELoss()
		self.optim = torch.optim.Adam(self.unet.parameters(),lr=self.lr,betas=(0.5,0.99))
		
	def train(self):
		
		if os.path.exists(self.model_path):
			self.unet.load_state_dict(torch.load(self.model_path+"/Unet.pkl"))
			print("模型导入成功",self.model_path+"/Unet.pkl")
		
		
		
		best_loss = 1000000
		
		for epoch in range(self.num_epoch):
			self.unet.train(True)
			epoch_loss = 0
			for i,(bx,by) in enumerate(self.data_loader):
				bx = bx.to(self.divice)
				by = by.to(self.divice)
				
				bx_gen = self.unet(bx)
				loss = self.loss(bx_gen,by)
				self.optim.zero_grad()
				loss.backward()
				self.optim.step()
				epoch_loss += loss.item()
				
			print("| epoch %d/%d | loss %f |"%(epoch,self.num_epoch,epoch_loss))
			self.save_img(save_name="epoch"+str(epoch)+".png")
			if best_loss > epoch_loss:
				best_loss = epoch_loss
				if os.path.exists(self.model_path) is False:
					os.makedirs(self.model_path)
				torch.save(self.unet.state_dict(),self.model_path+"/Unet.pkl")
				
				
	
	def save_img(self,save_path="./saved/Unet",save_name="result.png"):
		data_iter = iter(self.data_loader)
		img,labels = next(data_iter)
		self.unet.eval()
		with torch.no_grad():
			bx_gen = self.unet(img.to(self.divice))
			
		img = img.data.cpu()[:5]
		print("img.shape ===",img.shape)
		gen_label = bx_gen.data.cpu()[:5]
		labels = labels.data.cpu()[:5]
		
		gen_label = torch.where(gen_label>0.5,torch.full_like(gen_label,0),
		                        torch.full_like(gen_label,1))
		print("gen_label.shape ===",gen_label.shape)
		labels = torch.where(labels>0.5,torch.full_like(labels,0),
		                     torch.full_like(labels,1))
		
		gen_label = torch.zeros([3,256,256]) + gen_label
		seg_img = img * gen_label
		# 0黑色,255白色
		seg_img = torch.where(seg_img==0,torch.full_like(seg_img,255),seg_img)
		seg_img2 = img * labels
		seg_img2 = torch.where(seg_img==0,torch.full_like(seg_img2,255),seg_img2)
		print(seg_img2.shape)
		save_tensor = torch.cat([img,gen_label,seg_img,seg_img2],0)
		if os.path.exists(save_path) is False:
			os.makedirs(save_path)
		save_image(save_tensor,save_path+'/'+save_name,nrow=5)
		



if __name__ == '__main__':
	# 读取数据
	torch.cuda.empty_cache()
	train_data = img_segData(img_h=256,img_w=256,path="./data/img_seg",data_file="images",
	                         label_file="profiles",preprocess="True")
	# 构建模型,训练模型
	trainer = Trainer(img_ch=3,out_ch=1,lr=0.01,batch_size=16,num_epoch=50,train_set=train_data)
	trainer.train()

你可能感兴趣的