怎么在Pytorch中只導(dǎo)入部分模型參數(shù)?針對(duì)這個(gè)問(wèn)題,這篇文章詳細(xì)介紹了相對(duì)應(yīng)的分析和解答,希望可以幫助更多想解決這個(gè)問(wèn)題的小伙伴找到更簡(jiǎn)單易行的方法。
創(chuàng)新互聯(lián)成立于2013年,先為肇源等服務(wù)建站,肇源等地企業(yè),進(jìn)行企業(yè)商務(wù)咨詢服務(wù)。為肇源企業(yè)網(wǎng)站制作PC+手機(jī)+微官網(wǎng)三網(wǎng)同步一站式服務(wù)解決您的所有建站問(wèn)題。pytorch的優(yōu)點(diǎn)1.PyTorch是相當(dāng)簡(jiǎn)潔且高效快速的框架;2.設(shè)計(jì)追求最少的封裝;3.設(shè)計(jì)符合人類思維,它讓用戶盡可能地專注于實(shí)現(xiàn)自己的想法;4.與google的Tensorflow類似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開(kāi)發(fā)更新;5.PyTorch作者親自維護(hù)的論壇 供用戶交流和求教問(wèn)題6.入門簡(jiǎn)單
import torch as t from torch.nn import Module from torch import nn from torch.nn import functional as F class Net(Module): def __init__(self): super(Net,self).__init__() self.conv1 = nn.Conv2d(3,32,3,1) self.conv2 = nn.Conv2d(32,3,3,1) self.w = nn.Parameter(t.randn(3,10)) for p in self.children(): nn.init.xavier_normal_(p.weight.data) nn.init.constant_(p.bias.data, 0) def forward(self, x): out = self.conv1(x) out = self.conv2(x) out = F.avg_pool2d(out,(out.shape[2],out.shape[3])) out = F.linear(out,weight=self.w) return out
然后我們保存這個(gè)網(wǎng)絡(luò)的初始值。
model = Net() t.save(model.state_dict(),'xxx.pth')
現(xiàn)在我們將Net修改一下,多加幾個(gè)卷積層,但并不加入到forward中,僅僅出于少些幾行的目的。
import torch as t from torch.nn import Module from torch import nn from torch.nn import functional as F class Net(Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, 1) self.conv2 = nn.Conv2d(32, 3, 3, 1) self.conv3 = nn.Conv2d(3,64,3,1) self.conv4 = nn.Conv2d(64,32,3,1) for p in self.children(): nn.init.xavier_normal_(p.weight.data) nn.init.constant_(p.bias.data, 0) self.w = nn.Parameter(t.randn(3, 10)) def forward(self, x): out = self.conv1(x) out = self.conv2(x) out = F.avg_pool2d(out, (out.shape[2], out.shape[3])) out = F.linear(out, weight=self.w) return out
我們現(xiàn)在試著導(dǎo)入之前保存的模型參數(shù)。
path = 'xxx.pth' model = Net() model.load_state_dict(t.load(path)) ''' RuntimeError: Error(s) in loading state_dict for Net: Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". '''
出現(xiàn)了沒(méi)有在模型文件中找到error中的關(guān)鍵字的錯(cuò)誤。
現(xiàn)在我們這樣導(dǎo)入模型
path = 'xxx.pth' model = Net() save_model = t.load(path) model_dict = model.state_dict() state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()} print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias']) model_dict.update(state_dict) model.load_state_dict(model_dict)
看看上面的代碼,很容易弄明白。其中model_dict.update的作用是更新代碼中搭建的模型參數(shù)字典。為啥更新我其實(shí)并不清楚,但這一步驟是必須的,否則還會(huì)報(bào)錯(cuò)。
為了弄清楚為什么要更新model_dict,我們不妨分別輸出state_dict和model_dict的關(guān)鍵值看一看。
for k in state_dict.keys(): print(k) ''' w conv1.weight conv1.bias conv2.weight conv2.bias ''' for k in model_dict.keys(): print(k) ''' w conv1.weight conv1.bias conv2.weight conv2.bias conv3.weight conv3.bias conv4.weight conv4.bias '''
關(guān)于怎么在Pytorch中只導(dǎo)入部分模型參數(shù)問(wèn)題的解答就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,如果你還有很多疑惑沒(méi)有解開(kāi),可以關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道了解更多相關(guān)知識(shí)。
網(wǎng)頁(yè)題目:怎么在Pytorch中只導(dǎo)入部分模型參數(shù)-創(chuàng)新互聯(lián)
路徑分享:http://muchs.cn/article20/cdgpjo.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供標(biāo)簽優(yōu)化、品牌網(wǎng)站制作、關(guān)鍵詞優(yōu)化、微信公眾號(hào)、定制網(wǎng)站、企業(yè)網(wǎng)站制作
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容