动手学深度学习(六)——线性回归网络

文章目录

    • pytorch学习(六)——线性回归网络
      • 一、导入相关模块儿(主要是torch相关的内容)
      • 二、生成数据(数据+噪声)
      • 三、构建网络模型(继承tensor.nn中的模型类)
      • 四、定义模型(网络模型、代价函数与优化器)
      • 五、查看模型参数(并非必要)
      • 六、模型训练
      • 七、绘图查看

pytorch学习(六)——线性回归网络

说明:这篇博客主要是pytorch实现线性回归模型,可以非常清晰地看清在pytorch之中进行模型训练的步骤和方式
相关代码下载地址:https://download.csdn.net/download/jerry_liufeng/13095216

一、导入相关模块儿(主要是torch相关的内容)

import numpy as np
import matplotlib.pyplot as plt
from torch import optim
from torch.autograd import Variable
import torch
from torch import nn

二、生成数据(数据+噪声)

x_data = np.random.rand(100) #范围是0-1
noise = np.random.normal(0,0.01,x_data.shape) # normal获得高斯分布的随机数
y_data = x_data*0.1+0.2+noise

plt.scatter(x_data,y_data)
plt.show()

动手学深度学习(六)——线性回归网络_第1张图片
将数据变为二维数据

x_data = x_data.reshape(-1,1) # -1表示自动匹配最大行,1表示1列
y_data = y_data.reshape(-1,1)

将numpy格式的数据变为tensor中的变量

# 将numpy数据变为tensor数据
x_data = torch.FloatTensor(x_data)
y_data = torch.FloatTensor(y_data)

# 将数据变为tensor中的变量
inputs = Variable(x_data) 
target = Variable(y_data)

三、构建网络模型(继承tensor.nn中的模型类)

# 构建神经网络模型
# 一般将网络中具有科学系参数的层放在初始化__init__()之中
class LinearRegression(nn.Module):
    def __init__(self):
        # 定义网络结构
        super(LinearRegression,self).__init__() # 初始化父类
        self.fc = nn.Linear(1,1)
        
    def forward(self,x):
        #定义网络计算
        out = self.fc(x)
        return out

四、定义模型(网络模型、代价函数与优化器)

# 定义模型
model = LinearRegression()

# 定义代价函数-均方根代价函数
mse_loss = nn.MSELoss()

# 定义优化器—随机梯度下降法
optimizer = optim.SGD(model.parameters(),lr=0.1)

五、查看模型参数(并非必要)

# 查看模型参数
for name,parameter in model.named_parameters():
    print('name:{},param:{}'.format(name,parameter))

在这里插入图片描述

六、模型训练

# 模型训练
for i in range(1001):
    out = model(inputs)
    #计算loss
    loss =mse_loss(out,target)
    # 梯度清零
    optimizer.zero_grad()
    # 计算梯度
    loss.backward()
    # 修改权值
    optimizer.step()
    if i%200==0:
        print(i,loss.item())

七、绘图查看

# 绘图
y_pred = model(inputs)
plt.scatter(x_data,y_data)
plt.plot(x_data,y_pred.data.numpy(),'r-',lw=3) # 注意y_pred为tensor的变量类型,需要取出数据然后变为numpy类型
plt.show()

动手学深度学习(六)——线性回归网络_第2张图片

你可能感兴趣的