一区二区三区在线-一区二区三区亚洲视频-一区二区三区亚洲-一区二区三区午夜-一区二区三区四区在线视频-一区二区三区四区在线免费观看

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - pytorch 狀態字典:state_dict使用詳解

pytorch 狀態字典:state_dict使用詳解

2020-04-16 12:58wzg2016 Python

今天小編就為大家分享一篇pytorch 狀態字典:state_dict使用詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

pytorch 中的 state_dict 是一個簡單的python的字典對象,將每一層與它的對應參數建立映射關系.(如model的每一層的weights及偏置等等)

(注意,只有那些參數可以訓練的layer才會被保存到模型的state_dict中,如卷積層,線性層等等)

優化器對象Optimizer也有一個state_dict,它包含了優化器的狀態以及被使用的超參數(如lr, momentum,weight_decay等)

備注:

1) state_dict是在定義了model或optimizer之后pytorch自動生成的,可以直接調用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

?
1
torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自動具備的函數,可以直接調用

?
1
2
3
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因為,只有在執行該命令后,"dropout層"及"batch normalization層"才會進入 evalution 模態. 而在"訓練(training)模態"與"評估(evalution)模態"下,這兩層有不同的表現形式.

模態字典(state_dict)的保存(model是一個網絡結構類的對象)

1.1)僅保存學習到的參數,用以下命令

?
1
torch.save(model.state_dict(), PATH)

1.2)加載model.state_dict,用以下命令

?
1
2
3
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

備注:model.load_state_dict的操作對象是 一個具體的對象,而不能是文件名

2.1)保存整個model的狀態,用以下命令

?
1
torch.save(model,PATH)

2.2)加載整個model的狀態,用以下命令:

?
1
2
3
4
5
  # Model class must be defined somewhere
 
model = torch.load(PATH)
 
model.eval()

state_dict 是一個python的字典格式,以字典的格式存儲,然后以字典的格式被加載,而且只加載key匹配的項

如何僅加載某一層的訓練的到的參數(某一層的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

?
1
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加載模型參數后,如何設置某層某參數的"是否需要訓練"(param.requires_grad)

?
1
2
for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作對象是tensor.

疑問:能否直接對某個層直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:經測試,不可以.model.conv1 沒有requires_grad屬性.

全部測試代碼:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 僅僅加載某一層的參數
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上這篇pytorch 狀態字典:state_dict使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://blog.csdn.net/Strive_For_Future/article/details/83240081

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 美女视频黄a| 天堂a免费视频在线观看 | 日本96在线精品视频免费观看 | 免费黄色片网站 | 福利一区福利二区 | 久久精品国产色蜜蜜麻豆国语版 | 咪咪爱小说 | 欧美有码 | 国内精品中文字幕 | 帅小伙和警官同性3p | 大肚孕妇的高h辣文 | 香蕉国产人午夜视频在线 | 好姑娘在线完整版视频 | 特黄特色大片免费高清视频 | 精品国语对白精品自拍视 | 大桥未久aⅴ一区二区 | 欧美腐剧mm在线观看 | 美女精品永久福利在线 | bt天堂在线最新版在线 | 日本高清在线看 | 国内老司机精品视频在线播出 | 亚洲国产精品综合久久一线 | 成年人免费观看的视频 | 国产成人cao在线 | 91久久偷偷做嫩草影院免费 | futa百合高肉全h | 亚洲 制服 欧美 中文字幕 | 国产精品成人扳一级aa毛片 | 亚洲第一在线 | 日本成熟 | 欧美一区二区三区免费不卡 | 免费全看男女拍拍拍的视频 | 国产成人激情视频 | 俄罗斯三级完整版在线观看 | 鬼惨笑小说 | 高清一级做a爱免费视 | 亚洲精品国产乱码AV在线观看 | 青青草国产精品久久久久 | 92精品国产成人观看免费 | 日本艳鉧动漫1~6完整版在 | 亚洲精彩视频在线观看 |