這期內(nèi)容當(dāng)中小編將會給大家?guī)碛嘘P(guān)如何在Pytorch 中使用retain_graph,文章內(nèi)容豐富且以專業(yè)的角度為大家分析和敘述,閱讀完這篇文章希望大家可以有所收獲。
成都創(chuàng)新互聯(lián)公司專注為客戶提供全方位的互聯(lián)網(wǎng)綜合服務(wù),包含不限于網(wǎng)站設(shè)計(jì)、成都網(wǎng)站制作、南陵網(wǎng)絡(luò)推廣、微信小程序定制開發(fā)、南陵網(wǎng)絡(luò)營銷、南陵企業(yè)策劃、南陵品牌公關(guān)、搜索引擎seo、人物專訪、企業(yè)宣傳片、企業(yè)代運(yùn)營等,從售前售中售后,我們都將竭誠為您服務(wù),您的肯定,是我們大的嘉獎;成都創(chuàng)新互聯(lián)公司為所有大學(xué)生創(chuàng)業(yè)者提供南陵建站搭建服務(wù),24小時(shí)服務(wù)熱線:18982081108,官方網(wǎng)址:muchs.cn用法分析
在查看SRGAN源碼時(shí)有如下?lián)p失函數(shù),其中設(shè)置了retain_graph=True,其作用是什么?
############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = Variable(data) if torch.cuda.is_available(): z = z.cuda() fake_img = netG(z) netD.zero_grad() real_out = netD(real_img).mean() fake_out = netD(fake_img).mean() d_loss = 1 - real_out + fake_out d_loss.backward(retain_graph=True) ##### optimizerD.step() ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### netG.zero_grad() g_loss = generator_criterion(fake_out, fake_img, real_img) g_loss.backward() optimizerG.step() fake_img = netG(z) fake_out = netD(fake_img).mean() g_loss = generator_criterion(fake_out, fake_img, real_img) running_results['g_loss'] += g_loss.data[0] * batch_size d_loss = 1 - real_out + fake_out running_results['d_loss'] += d_loss.data[0] * batch_size running_results['d_score'] += real_out.data[0] * batch_size running_results['g_score'] += fake_out.data[0] * batch_size
在更新D網(wǎng)絡(luò)時(shí)的loss反向傳播過程中使用了retain_graph=True,目的為是為保留該過程中計(jì)算的梯度,后續(xù)G網(wǎng)絡(luò)更新時(shí)使用;
其實(shí)retain_graph這個(gè)參數(shù)在平常中我們是用不到的,但是在特殊的情況下我們會用到它,
如下代碼:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward() output2.backward()
輸出如下錯(cuò)誤信息:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-19-8ad6b0658906> in <module>() ----> 1 output1.backward() 2 output2.backward() D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph) 91 products. Defaults to ``False``. 92 """ ---> 93 torch.autograd.backward(self, gradient, retain_graph, create_graph) 94 95 def register_hook(self, hook): D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables) 88 Variable._execution_engine.run_backward( 89 tensors, grad_tensors, retain_graph, create_graph, ---> 90 allow_unreachable=True) # allow_unreachable flag 91 92 RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
修改成如下正確:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward(retain_graph=True) output2.backward()
# 假如你有兩個(gè)Loss,先執(zhí)行第一個(gè)的backward,再執(zhí)行第二個(gè)backward loss1.backward(retain_graph=True) loss2.backward() # 執(zhí)行完這個(gè)后,所有中間變量都會被釋放,以便下一次的循環(huán) optimizer.step() # 更新參數(shù)
Variable 類源代碼
class Variable(_C._VariableBase): """ Attributes: data: 任意類型的封裝好的張量。 grad: 保存與data類型和位置相匹配的梯度,此屬性難以分配并且不能重新分配。 requires_grad: 標(biāo)記變量是否已經(jīng)由一個(gè)需要調(diào)用到此變量的子圖創(chuàng)建的bool值。只能在葉子變量上進(jìn)行修改。 volatile: 標(biāo)記變量是否能在推理模式下應(yīng)用(如不保存歷史記錄)的bool值。只能在葉變量上更改。 is_leaf: 標(biāo)記變量是否是圖葉子(如由用戶創(chuàng)建的變量)的bool值. grad_fn: Gradient function graph trace. Parameters: data (any tensor class): 要包裝的張量. requires_grad (bool): bool型的標(biāo)記值. **Keyword only.** volatile (bool): bool型的標(biāo)記值. **Keyword only.** """ def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None): """計(jì)算關(guān)于當(dāng)前圖葉子變量的梯度,圖使用鏈?zhǔn)椒▌t導(dǎo)致分化 如果Variable是一個(gè)標(biāo)量(例如它包含一個(gè)單元素?cái)?shù)據(jù)),你無需對backward()指定任何參數(shù) 如果變量不是標(biāo)量(包含多個(gè)元素?cái)?shù)據(jù)的矢量)且需要梯度,函數(shù)需要額外的梯度; 需要指定一個(gè)和tensor的形狀匹配的grad_output參數(shù)(y在指定方向投影對x的導(dǎo)數(shù)); 可以是一個(gè)類型和位置相匹配且包含與自身相關(guān)的不同函數(shù)梯度的張量。 函數(shù)在葉子上累積梯度,調(diào)用前需要對該葉子進(jìn)行清零。 Arguments: grad_variables (Tensor, Variable or None): 變量的梯度,如果是一個(gè)張量,除非“create_graph”是True,否則會自動轉(zhuǎn)換成volatile型的變量。 可以為標(biāo)量變量或不需要grad的值指定None值。如果None值可接受,則此參數(shù)可選。 retain_graph (bool, optional): 如果為False,用來計(jì)算梯度的圖將被釋放。 在幾乎所有情況下,將此選項(xiàng)設(shè)置為True不是必需的,通??梢砸愿行У姆绞浇鉀Q。 默認(rèn)值為create_graph的值。 create_graph (bool, optional): 為True時(shí),會構(gòu)造一個(gè)導(dǎo)數(shù)的圖,用來計(jì)算出更高階導(dǎo)數(shù)結(jié)果。 默認(rèn)為False,除非``gradient``是一個(gè)volatile變量。 """ torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables) def register_hook(self, hook): """Registers a backward hook. 每當(dāng)與variable相關(guān)的梯度被計(jì)算時(shí)調(diào)用hook,hook的申明:hook(grad)->Variable or None 不能對hook的參數(shù)進(jìn)行修改,但可以選擇性地返回一個(gè)新的梯度以用在`grad`的相應(yīng)位置。 函數(shù)返回一個(gè)handle,其``handle.remove()``方法用于將hook從模塊中移除。 Example: >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.Tensor([1, 1, 1])) >>> v.grad.data 2 2 2 [torch.FloatTensor of size 3] >>> h.remove() # removes the hook """ if self.volatile: raise RuntimeError("cannot register a hook on a volatile variable") if not self.requires_grad: raise RuntimeError("cannot register a hook on a variable that " "doesn't require gradient") if self._backward_hooks is None: self._backward_hooks = OrderedDict() if self.grad_fn is not None: self.grad_fn._register_hook_dict(self) handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle def reinforce(self, reward): """Registers a reward obtained as a result of a stochastic process. 區(qū)分隨機(jī)節(jié)點(diǎn)需要為他們提供reward值。如果圖表中包含任何的隨機(jī)操作,都應(yīng)該在其輸出上調(diào)用此函數(shù),否則會出現(xiàn)錯(cuò)誤。 Parameters: reward(Tensor): 帶有每個(gè)元素獎賞的張量,必須與Variable數(shù)據(jù)的設(shè)備位置和形狀相匹配。 """ if not isinstance(self.grad_fn, StochasticFunction): raise RuntimeError("reinforce() can be only called on outputs " "of stochastic functions") self.grad_fn._reinforce(reward) def detach(self): """返回一個(gè)從當(dāng)前圖分離出來的心變量。 結(jié)果不需要梯度,如果輸入是volatile,則輸出也是volatile。 .. 注意:: 返回變量使用與原始變量相同的數(shù)據(jù)張量,并且可以看到其中任何一個(gè)的就地修改,并且可能會觸發(fā)正確性檢查中的錯(cuò)誤。 """ result = NoGrad()(self) # this is needed, because it merges version counters result._grad_fn = None return result def detach_(self): """從創(chuàng)建它的圖中分離出變量并作為該圖的一個(gè)葉子""" self._grad_fn = None self.requires_grad = False def retain_grad(self): """Enables .grad attribute for non-leaf Variables.""" if self.grad_fn is None: # no-op for leaves return if not self.requires_grad: raise RuntimeError("can't retain_grad on Variable that has requires_grad=False") if hasattr(self, 'retains_grad'): return weak_self = weakref.ref(self) def retain_grad_hook(grad): var = weak_self() if var is None: return if var._grad is None: var._grad = grad.clone() else: var._grad = var._grad + grad self.register_hook(retain_grad_hook) self.retains_grad = True
上述就是小編為大家分享的如何在Pytorch 中使用retain_graph了,如果剛好有類似的疑惑,不妨參照上述分析進(jìn)行理解。如果想知道更多相關(guān)知識,歡迎關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道。
文章題目:如何在Pytorch中使用retain_graph-創(chuàng)新互聯(lián)
網(wǎng)站地址:http://muchs.cn/article26/degjcg.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供Google、虛擬主機(jī)、建站公司、網(wǎng)站內(nèi)鏈、面包屑導(dǎo)航、網(wǎng)站建設(shè)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容