1、有了已經(jīng)訓練好的模型參數(shù),對這個模型的某些層做了改變,如何利用這些訓練好的模型參數(shù)繼續(xù)訓練:
1
2
3
|
pretrained_params = torch.load( 'Pretrained_Model' ) model = The_New_Model(xxx) model.load_state_dict(pretrained_params.state_dict(), strict = False ) |
strict=False 使得預訓練模型參數(shù)中和新模型對應(yīng)上的參數(shù)會被載入,對應(yīng)不上或沒有的參數(shù)被拋棄。
2、如果載入的這些參數(shù)中,有些參數(shù)不要求被更新,即固定不變,不參與訓練,需要手動設(shè)置這些參數(shù)的梯度屬性為Fasle,并且在optimizer傳參時篩選掉這些參數(shù):
1
2
3
4
5
6
7
8
|
# 載入預訓練模型參數(shù)后... for name, value in model.named_parameters(): if name 滿足某些條件: value.requires_grad = False # setup optimizer params = filter ( lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adam(params, lr = 1e - 4 ) |
將滿足條件的參數(shù)的 requires_grad 屬性設(shè)置為False, 同時 filter 函數(shù)將模型中屬性 requires_grad = True 的參數(shù)帥選出來,傳到優(yōu)化器(以Adam為例)中,只有這些參數(shù)會被求導數(shù)和更新。
3、如果載入的這些參數(shù)中,所有參數(shù)都更新,但要求一些參數(shù)和另一些參數(shù)的更新速度(學習率learning rate)不一樣,最好知道這些參數(shù)的名稱都有什么:
1
2
3
4
5
|
# 載入預訓練模型參數(shù)后... for name, value in model.named_parameters(): print (name) # 或 print (model.state_dict().keys()) |
假設(shè)該模型中有encoder,viewer和decoder兩部分,參數(shù)名稱分別是:
1
2
3
4
5
6
|
'encoder.visual_emb.0.weight' , 'encoder.visual_emb.0.bias' , 'viewer.bd.Wsi' , 'viewer.bd.bias' , 'decoder.core.layer_0.weight_ih' , 'decoder.core.layer_0.weight_hh' , |
假設(shè)要求encode、viewer的學習率為1e-6, decoder的學習率為1e-4,那么在將參數(shù)傳入優(yōu)化器時:
1
2
3
4
5
6
|
ignored_params = list ( map ( id , model.decoder.parameters())) base_params = filter ( lambda p: id (p) not in ignored_params, model.parameters()) optimizer = torch.optim.Adam([{ 'params' :base_params, 'lr' : 1e - 6 }, { 'params' :model.decoder.parameters()} ], lr = 1e - 4 , momentum = 0.9 ) |
代碼的結(jié)果是除decoder參數(shù)的learning_rate=1e-4 外,其他參數(shù)的額learning_rate=1e-6。
在傳入optimizer時,和一般的傳參方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,參數(shù)部分用了一個list, list的每個元素有params和lr兩個鍵值。如果沒有 lr則應(yīng)用Adam的lr屬性。Adam的屬性除了lr, 其他都是參數(shù)所共有的(比如momentum)。
以上這篇pytorch載入預訓練模型后,實現(xiàn)訓練指定層就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持服務(wù)器之家。
參考:
pytorch官方文檔
原文鏈接:https://blog.csdn.net/weixin_36049506/article/details/89522860