官方教程:SAVING AND LOADING MODELS

torch.save() 实现对网络结构和模型参数的保存。有两种保存方式:一是保存整个神经网络的结构信息和模型参数信息,save 的对象是网络 net;二是只保存神经网络的训练模型参数,save 的对象是 net.state_dict()。

torch.save(net1, '7-net.pth')  # 保存整个神经网络的结构和模型参数
torch.save(net1.state_dict(), '7-net_params.pth')  # 只保存神经网络的模型参数

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过 torch.load('.pth') 直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先导入对应的网络,通过 net.load_state_dict(torch.load('.pth')) 完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

保存和加载整个模型

torch.save(model_object, ‘model.pkl’)
model = torch.load(‘model.pkl’)

仅保存和加载模型参数

torch.save(model_object.state_dict(), ‘params.pkl’)
model_object.load_state_dict(torch.load(‘params.pkl’))

发表评论

电子邮件地址不会被公开。 必填项已用*标注