如何在pytorch中解決state_dict()的拷貝問題?很多新手對此不是很清楚,為了幫助大家解決這個(gè)難題,下面小編將為大家詳細(xì)講解,有這方面需求的人可以來學(xué)習(xí)下,希望你能有所收獲。
10多年的盈江網(wǎng)站建設(shè)經(jīng)驗(yàn),針對設(shè)計(jì)、前端、開發(fā)、售后、文案、推廣等六對一服務(wù),響應(yīng)快,48小時(shí)及時(shí)工作處理。全網(wǎng)營銷推廣的優(yōu)勢是能夠根據(jù)用戶設(shè)備顯示端的尺寸不同,自動調(diào)整盈江建站的顯示方式,使網(wǎng)站能夠適用不同顯示終端,在瀏覽器中調(diào)整網(wǎng)站的寬度,無論在任何一種瀏覽器上瀏覽網(wǎng)站,都能展現(xiàn)優(yōu)雅布局與設(shè)計(jì),從而大程度地提升瀏覽體驗(yàn)。創(chuàng)新互聯(lián)建站從事“盈江網(wǎng)站設(shè)計(jì)”,“盈江網(wǎng)站推廣”以來,每個(gè)客戶項(xiàng)目都認(rèn)真落實(shí)執(zhí)行。model.state_dict()
是淺拷貝,返回的參數(shù)仍然會隨著網(wǎng)絡(luò)的訓(xùn)練而變化。應(yīng)該使用deepcopy(model.state_dict())
,或?qū)?shù)及時(shí)序列化到硬盤。
再講故事,前幾天在做一個(gè)模型的交叉驗(yàn)證訓(xùn)練時(shí),通過model.state_dict()保存了每一組交叉驗(yàn)證模型的參數(shù),后根據(jù)效果選擇準(zhǔn)確率很好的模型load回去,結(jié)果每一次都是最后一個(gè)模型,從地址來看,每一個(gè)保存的state_dict()都具有不同的地址,但進(jìn)一步發(fā)現(xiàn)state_dict()下的各個(gè)模型參數(shù)的地址是共享的,而我又使用了in-place的方式重置模型參數(shù),進(jìn)而導(dǎo)致了上述問題。
補(bǔ)充:pytorch中state_dict的理解
在PyTorch中,state_dict是一個(gè)Python字典對象(在這個(gè)有序字典中,key是各層參數(shù)名,value是各層參數(shù)),包含模型的可學(xué)習(xí)參數(shù)(即權(quán)重和偏差,以及bn層的的參數(shù)) 優(yōu)化器對象(torch.optim)也具有state_dict,其中包含有關(guān)優(yōu)化器狀態(tài)以及所用超參數(shù)的信息。
import torch import torch.nn as nn import torchvision import numpy as np from torchsummary import summary # Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,"\t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
輸出如下:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]
我是剛接觸深度學(xué)西的小白一個(gè),希望大佬可以為我指出我的不足,此博客僅為自己的筆記?。。。?/p>
補(bǔ)充:pytorch保存模型時(shí)報(bào)錯(cuò)***object has no attribute 'state_dict'
net=BaseNet()
保存net時(shí)報(bào)錯(cuò) object has no attribute 'state_dict'
torch.save(net.state_dict(), models_dir)
原因是定義類的時(shí)候不是繼承nn.Module類,比如:
class BaseNet(object): def __init__(self):
class BaseNet(nn.Module): def __init__(self): super(BaseNet, self).__init__()
看完上述內(nèi)容是否對您有幫助呢?如果還想對相關(guān)知識有進(jìn)一步的了解或閱讀更多相關(guān)文章,請關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道,感謝您對創(chuàng)新互聯(lián)網(wǎng)站建設(shè)公司,的支持。
當(dāng)前文章:如何在pytorch中解決state_dict()的拷貝問題-創(chuàng)新互聯(lián)
網(wǎng)站鏈接:http://muchs.cn/article36/dsgdsg.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供軟件開發(fā)、品牌網(wǎng)站設(shè)計(jì)、標(biāo)簽優(yōu)化、云服務(wù)器、動態(tài)網(wǎng)站、電子商務(wù)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容