Pytorch:自定義網(wǎng)絡層實例-創(chuàng)新互聯(lián)

自定義Autograd函數(shù)

成都創(chuàng)新互聯(lián)公司成立于2013年,是專業(yè)互聯(lián)網(wǎng)技術服務公司,擁有項目成都網(wǎng)站建設、成都網(wǎng)站制作網(wǎng)站策劃,項目實施與項目整合能力。我們以讓每一個夢想脫穎而出為使命,1280元衛(wèi)東做網(wǎng)站,已為上家服務,為衛(wèi)東各地企業(yè)和個人服務,聯(lián)系電話:13518219792

對于淺層的網(wǎng)絡,我們可以手動的書寫前向傳播和反向傳播過程。但是當網(wǎng)絡變得很大時,特別是在做深度學習時,網(wǎng)絡結構變得復雜。前向傳播和反向傳播也隨之變得復雜,手動書寫這兩個過程就會存在很大的困難。幸運地是在pytorch中存在了自動微分的包,可以用來解決該問題。在使用自動求導的時候,網(wǎng)絡的前向傳播會定義一個計算圖(computational graph),圖中的節(jié)點是張量(tensor),兩個節(jié)點之間的邊對應了兩個張量之間變換關系的函數(shù)。有了計算圖的存在,張量的梯度計算也變得容易了些。例如, x是一個張量,其屬性 x.requires_grad = True,那么 x.grad就是一個保存這個張量x的梯度的一些標量值。

最基礎的自動求導操作在底層就是作用在兩個張量上。前向傳播函數(shù)是從輸入張量到輸出張量的計算過程;反向傳播是輸入輸出張量的梯度(一些標量)并輸出輸入張量的梯度(一些標量)。在pytorch中我們可以很容易地定義自己的自動求導操作,通過繼承torch.autograd.Function并定義forward和backward函數(shù)。

forward(): 前向傳播操作??梢暂斎肴我舛嗟膮?shù),任意的python對象都可以。

backward():反向傳播(梯度公式)。輸出的梯度個數(shù)需要與所使用的張量個數(shù)保持一致,且返回的順序也要對應起來。

# Inherit from Function
class LinearFunction(Function):

  # Note that both forward and backward are @staticmethods
  @staticmethod
  # bias is an optional argument
  def forward(ctx, input, weight, bias=None):
    # ctx在這里類似self,ctx的屬性可以在backward中調(diào)用
    ctx.save_for_backward(input, weight, bias)
    output = input.mm(weight.t())
    if bias is not None:
      output += bias.unsqueeze(0).expand_as(output)
    return output

  # This function has only a single output, so it gets only one gradient
  @staticmethod
  def backward(ctx, grad_output):
    # This is a pattern that is very convenient - at the top of backward
    # unpack saved_tensors and initialize all gradients w.r.t. inputs to
    # None. Thanks to the fact that additional trailing Nones are
    # ignored, the return statement is simple even when the function has
    # optional inputs.
    input, weight, bias = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None

    # These needs_input_grad checks are optional and there only to
    # improve efficiency. If you want to make your code simpler, you can
    # skip them. Returning gradients for inputs that don't require it is
    # not an error.
    if ctx.needs_input_grad[0]:
      grad_input = grad_output.mm(weight)
    if ctx.needs_input_grad[1]:
      grad_weight = grad_output.t().mm(input)
    if bias is not None and ctx.needs_input_grad[2]:
      grad_bias = grad_output.sum(0).squeeze(0)

    return grad_input, grad_weight, grad_bias

#調(diào)用自定義的自動求導函數(shù)
linear = LinearFunction.apply(*args) #前向傳播
linear.backward()#反向傳播
linear.grad_fn.apply(*args)#反向傳播

另外有需要云服務器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務器15元起步,三天無理由+7*72小時售后在線,公司持有idc許可證,提供“云服務器、裸金屬服務器、高防服務器、香港服務器、美國服務器、虛擬主機、免備案服務器”等云主機租用服務以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡單易用、服務可用性高、性價比高”等特點與優(yōu)勢,專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應用場景需求。

網(wǎng)頁題目:Pytorch:自定義網(wǎng)絡層實例-創(chuàng)新互聯(lián)
網(wǎng)頁URL:http://muchs.cn/article22/cdsocc.html

成都網(wǎng)站建設公司_創(chuàng)新互聯(lián),為您提供微信公眾號Google、微信小程序、服務器托管軟件開發(fā)、外貿(mào)網(wǎng)站建設

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉載內(nèi)容為主,如果涉及侵權請盡快告知,我們將會在第一時間刪除。文章觀點不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉載,或轉載時需注明來源: 創(chuàng)新互聯(lián)

營銷型網(wǎng)站建設