PyTorch学习笔记:使用不同设备(CPU/GPU)对模型进行保存和加载

本篇其实与PyTorch学习笔记:使用state_dict来保存和加载模型是高度关联的,之所以单独拎出来写,主要是想突出它的重要性。

首先来描述一个本人实际遇到的问题:

首先在GPU服务器上训练了一个ResNet34的模型,然后将该模型在本人PC机(没有GPU)上进行推理,模型加载代码如下:

# load model weights
weights_path = "./resNet34.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)

# 直接加载模型
model.load_state_dict(torch.load(weights_path))

结果运行时出现如下错误:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

这就引出了今天的问题,当保存模型的设备和加载模型的设备不同时,更直白一点就是,训练模型的设备和推理(或者继续训练)的设备不一致时,我们就需要在load_state_dict()函数中引入一个参数map_location来指定需要将不同设备上保存的模型映射到另一个设备上。 

具体包括4种情况:

1. CPU上保存,CPU上加载

这种情况是最简单的,可不使用map_location参数,也可不适用model.to(device)。

# 保存模型
torch.save(model.state_dict(), PATH)

device = torch.device('cpu')
#加载模型
model = resnet34(num_classes=5)
# load model weights
weights_path = "./resNet34.pth"
model.load_state_dict(torch.load(weights_path))

model.eval()

2. GPU上保存,GPU上加载

这种情况也比较简单,由于保存和加载设备都是GPU,因此可省略map_location参数,但需要使用model.to(device)将模型的参数张量转化为CUDA张量,否则会出现如下报错:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same.

这种情况下,代码可写为:

# 保存模型
torch.save(model.state_dict(), PATH)

device = torch.device('cuda')
#加载模型
model = resnet34(num_classes=5)
# load model weights
weights_path = "./resNet34.pth"
model.load_state_dict(torch.load(weights_path))
model.to(device)

model.eval()

3. CPU上保存,GPU上加载

模型保存设备和加载设备不一致,需要借助map_location参数,当然model.to(device)也是需要的。

# 保存模型
torch.save(model.state_dict(), PATH)

device = torch.device('cuda')
#加载模型
model = resnet34(num_classes=5)
# load model weights
weights_path = "./resNet34.pth"
model.load_state_dict(torch.load(weights_path), map_location="cuda:0")
model.to(device)

model.eval()

4. GPU上保存,CPU上加载

模型保存设备和加载设备不一致,需要借助map_location参数,但model.to(device)可不使用:

# 保存模型
torch.save(model.state_dict(), PATH)

device = torch.device('cpu')
#加载模型
model = resnet34(num_classes=5)
# load model weights
weights_path = "./resNet34.pth"
model.load_state_dict(torch.load(weights_path), map_location=device)

model.eval()

好了,以上4种情况介绍完了,但如果我们平时工作的时候要根据不同的设备来写不同的代码未免也太麻烦。所以最好有一套通用的代码来cover各种情况,仍然拿我的工程来举例:

# 如果有GPU,不管有几块,我们只指定第一块使用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = resnet34(num_classes=5)
weights_path = "./resNet34.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)

model.eval()
......

归纳起来也就是,先获取本机可用的设备,有GPU当然就首选GPU啦,后面的代码,不管是否相同设备,map_location和model.to(device)都用上就好了。

你可能感兴趣的