batchnorm2d參數(shù)torch_Pytorch自由載入部分模型參數(shù)并凍結(jié)的示例分析

這篇文章給大家介紹batchnorm2d參數(shù) torch_Pytorch自由載入部分模型參數(shù)并凍結(jié)的示例分析,內(nèi)容非常詳細(xì),感興趣的小伙伴們可以參考借鑒,希望對(duì)大家能有所幫助。

專(zhuān)業(yè)領(lǐng)域包括成都網(wǎng)站設(shè)計(jì)、成都網(wǎng)站建設(shè)、成都做商城網(wǎng)站、微信營(yíng)銷(xiāo)、系統(tǒng)平臺(tái)開(kāi)發(fā), 與其他網(wǎng)站設(shè)計(jì)及系統(tǒng)開(kāi)發(fā)公司不同,創(chuàng)新互聯(lián)的整合解決方案結(jié)合了幫做網(wǎng)絡(luò)品牌建設(shè)經(jīng)驗(yàn)和互聯(lián)網(wǎng)整合營(yíng)銷(xiāo)的理念,并將策略和執(zhí)行緊密結(jié)合,為客戶提供全網(wǎng)互聯(lián)網(wǎng)整合方案。

Pytorch的load方法和load_state_dict方法只能較為固定的讀入?yún)?shù)文件,他們要求讀入的state_dict的key和Model.state_dict()的key對(duì)應(yīng)相等。

而我們?cè)谶M(jìn)行遷移學(xué)習(xí)的過(guò)程中也許只需要使用某個(gè)預(yù)訓(xùn)練網(wǎng)絡(luò)的一部分,把多個(gè)網(wǎng)絡(luò)拼和成一個(gè)網(wǎng)絡(luò),或者為了得到中間層的輸出而分離預(yù)訓(xùn)練模型中的Sequential 等等,這些情況下。傳統(tǒng)的load方法就不是很有效了。

例如,我們想利用Mobilenet的前7個(gè)卷積并把這幾層凍結(jié),后面的部分接別的結(jié)構(gòu),或者改寫(xiě)成FCN結(jié)構(gòu),傳統(tǒng)的方法就不奏效了。

最普適的方法是:構(gòu)建一個(gè)字典,使得字典的keys和我們自己創(chuàng)建的網(wǎng)絡(luò)相同,我們?cè)購(gòu)母鞣N預(yù)訓(xùn)練網(wǎng)絡(luò)把想要的參數(shù)對(duì)著新的keys填進(jìn)去就可以有一個(gè)新的state_dict了,這樣我們就可以load這個(gè)新的state_dict,目前只能想到這個(gè)方法應(yīng)對(duì)較為復(fù)雜的網(wǎng)絡(luò)變換。

網(wǎng)上查“載入部分模型”,“凍結(jié)部分模型”一般都是只改個(gè)FC,根本沒(méi)有用,初學(xué)的時(shí)候自己寫(xiě)state_dict也踩了一些坑,發(fā)出來(lái)記錄一下。


一.載入部分預(yù)訓(xùn)練參數(shù)

我們先看看Mobilenet的結(jié)構(gòu)

( 來(lái)源github,附帶預(yù)訓(xùn)練模型mobilenet_sgd_rmsprop_69.526.tar)

class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()


def conv_bn(inp, oup, stride):

return nn.Sequential(

nn.Conv2d(inp, oup, 3, stride, 1, bias=False),

nn.BatchNorm2d(oup),

nn.ReLU(inplace=True)

)


def conv_dw(inp, oup, stride):

return nn.Sequential(

nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),

nn.BatchNorm2d(inp),

nn.ReLU(inplace=True),


nn.Conv2d(inp, oup, 1, 1, 0, bias=False),

nn.BatchNorm2d(oup),

nn.ReLU(inplace=True),

)


self.model = nn.Sequential(

conv_bn( 3, 32, 2),

conv_dw( 32, 64, 1),

conv_dw( 64, 128, 2),

conv_dw(128, 128, 1),

conv_dw(128, 256, 2),

conv_dw(256, 256, 1),

conv_dw(256, 512, 2),

conv_dw(512, 512, 1),

conv_dw(512, 512, 1),

conv_dw(512, 512, 1),

conv_dw(512, 512, 1),

conv_dw(512, 512, 1),

conv_dw(512, 1024, 2),

conv_dw(1024, 1024, 1),

nn.AvgPool2d(7),

)

self.fc = nn.Linear(1024, 1000)


def forward(self, x):

x = self.model(x)

x = x.view(-1, 1024)

x = self.fc(x)

return x

我們只需要前7層卷積,并且為了方便日后concate操作,我們把Sequential拆開(kāi),成為下面的樣子

class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()


def conv_bn(inp, oup, stride):

return nn.Sequential(

nn.Conv2d(inp, oup, 3, stride, 1, bias=False),

nn.BatchNorm2d(oup),

nn.ReLU(inplace=True)

)


def conv_dw(inp, oup, stride):

return nn.Sequential(

nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),

nn.BatchNorm2d(inp),

nn.ReLU(inplace=True),


nn.Conv2d(inp, oup, 1, 1, 0, bias=False),

nn.BatchNorm2d(oup),

nn.ReLU(inplace=True),

)


self.conv1 = conv_bn( 3, 32, 2)

self.conv2 = conv_dw( 32, 64, 1)

self.conv3 = conv_dw( 64, 128, 2)

self.conv4 = conv_dw(128, 128, 1)

self.conv5 = conv_dw(128, 256, 2)

self.conv6 = conv_dw(256, 256, 1)

self.conv7 = conv_dw(256, 512, 2)


# 原來(lái)這些不要了

# 可以自己接后面的結(jié)構(gòu)

'''

self.features = nn.Sequential(

conv_dw(512, 512, 1),

conv_dw(512, 512, 1),

conv_dw(512, 512, 1),

conv_dw(512, 512, 1),

conv_dw(512, 512, 1),

conv_dw(512, 1024, 2),

conv_dw(1024, 1024, 1),

nn.AvgPool2d(7),)


self.fc = nn.Linear(1024, 1000)

'''


def forward(self, x):

x1 = self.conv1(x)

x2 = self.conv2(x1)

x3 = self.conv3(x2)

x4 = self.conv4(x3)

x5 = self.conv5(x4)

x6 = self.conv6(x5)

x7 = self.conv7(x6)

#x8 = self.features(x7)

#out = self.fc

return (x1,x2,x3,x4,x4,x6,x7)

我們更具改過(guò)的結(jié)構(gòu)創(chuàng)建一個(gè)net,看看他的state_dict和我們預(yù)訓(xùn)練文件的state_dict有啥區(qū)別

net = Net()

#我的電腦沒(méi)有GPU,他的參數(shù)是GPU訓(xùn)練的cudatensor,于是要下面這樣轉(zhuǎn)換一下

dict_trained = torch.load("mobilenet_sgd_rmsprop_69.526.tar",map_location=lambda storage, loc: storage)["state_dict"]

dict_new = net.state_dict().copy()


new_list = list (net.state_dict().keys() )

trained_list = list (dict_trained.keys() )

print("new_state_dict size: {} trained state_dict size: {}".format(len(new_list),len(trained_list)) )

print("New state_dict first 10th parameters names")

print(new_list[:10])

print("trained state_dict first 10th parameters names")

print(trained_list[:10])


print(type(dict_new))

print(type(dict_trained))

得到輸出如下:

我們截?cái)嘁话胫?,參?shù)由137變成65了,前十個(gè)參數(shù)看出,名字變了但是順序其實(shí)沒(méi)變。state_dict的數(shù)據(jù)類(lèi)型是Odict,可以按照dict的操作方法操作。

new_state_dict size: 65 trained state_dict size: 137
New state_dict first 10th parameters names
['conv1.0.weight', 'conv1.1.weight', 'conv1.1.bias', 'conv1.1.running_mean', 'conv1.1.running_var', 'conv2.0.weight', 'conv2.1.weight', 'conv2.1.bias', 'conv2.1.running_mean', 'conv2.1.running_var']
trained state_dict first 10th parameters names
['module.model.0.0.weight', 'module.model.0.1.weight', 'module.model.0.1.bias', 'module.model.0.1.running_mean', 'module.model.0.1.running_var', 'module.model.1.0.weight', 'module.model.1.1.weight', 'module.model.1.1.bias', 'module.model.1.1.running_mean', 'module.model.1.1.running_var']
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>

我們看出只要構(gòu)建一個(gè)字典,使得字典的keys和我們自己創(chuàng)建的網(wǎng)絡(luò)相同,我們?cè)趶母鞣N預(yù)訓(xùn)練網(wǎng)絡(luò)把想要的參數(shù)對(duì)著新的keys填進(jìn)去就可以有一個(gè)新的state_dict了,這樣我們就可以load這個(gè)新的state_dict,這是最普適的方法適用于所有的網(wǎng)絡(luò)變化。

for i in range(65):

dict_new[ new_list[i] ] = dict_trained[ trained_list[i] ]


net.load_state_dict(dict_new)

還有別的情況,比如我們只是在后面加了一些層,沒(méi)有改變?cè)瓉?lái)網(wǎng)絡(luò)層的名字和結(jié)構(gòu),可以用下面的簡(jiǎn)便方法:

loaded_dict = {k: loaded_dict[k] for k, _ in model.state_dict()}

二.凍結(jié)這幾層參數(shù)

方法很多,這里用和上面方法對(duì)應(yīng)的凍結(jié)方法

發(fā)現(xiàn)之前的凍結(jié)有問(wèn)題,還是建議看一下

https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088

或者

https://discuss.pytorch.org/t/correct-way-to-freeze-layers/26714

或者

對(duì)應(yīng)的,在訓(xùn)練時(shí)候,optimizer里面只能更新requires_grad = True的參數(shù),于是

optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, net.parameters(),lr) )

關(guān)于batchnorm2d參數(shù) torch_Pytorch自由載入部分模型參數(shù)并凍結(jié)的示例分析就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,可以學(xué)到更多知識(shí)。如果覺(jué)得文章不錯(cuò),可以把它分享出去讓更多的人看到。

文章名稱:batchnorm2d參數(shù)torch_Pytorch自由載入部分模型參數(shù)并凍結(jié)的示例分析
分享地址:http://muchs.cn/article34/pidppe.html

成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供建站公司、全網(wǎng)營(yíng)銷(xiāo)推廣、軟件開(kāi)發(fā)、品牌網(wǎng)站設(shè)計(jì)App開(kāi)發(fā)、做網(wǎng)站

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)

h5響應(yīng)式網(wǎng)站建設(shè)