推荐答案
在 PyTorch 中,保存和加载模型通常使用 torch.save()
和 torch.load()
函数。以下是保存和加载模型的推荐方法:
保存模型
# 保存模型的状态字典 torch.save(model.state_dict(), 'model.pth')
加载模型
# 加载模型的状态字典 model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load('model.pth')) model.eval() # 设置为评估模式
本题详细解读
保存模型
在 PyTorch 中,模型的参数(权重和偏置)存储在 state_dict
中,这是一个 Python 字典对象。torch.save()
函数可以将这个字典保存到文件中。通常,我们只保存模型的 state_dict
,而不是整个模型对象,因为这样可以更灵活地加载模型。
torch.save(model.state_dict(), 'model.pth')
加载模型
加载模型时,首先需要实例化一个与保存时相同的模型结构,然后使用 torch.load()
加载保存的 state_dict
,并将其加载到模型中。加载后,通常需要调用 model.eval()
将模型设置为评估模式,这会影响到某些层(如 Dropout 和 BatchNorm)的行为。
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load('model.pth')) model.eval()
保存和加载整个模型
虽然不推荐,但也可以保存整个模型对象,包括模型的结构和参数。这种方法在加载时不需要重新定义模型结构,但可能会导致代码的兼容性问题。
# 保存整个模型 torch.save(model, 'model.pth') # 加载整个模型 model = torch.load('model.pth') model.eval()
注意事项
- 模型结构一致性:加载模型时,必须确保模型结构与保存时一致。
- 设备兼容性:如果模型是在 GPU 上训练的,加载时可能需要将其映射到 CPU 或其他设备上。
- 评估模式:加载模型后,务必调用
model.eval()
以确保模型在评估模式下运行。
通过以上方法,可以有效地保存和加载 PyTorch 模型,确保模型的训练成果得以保留和复用。