pytorch:怎么實(shí)現(xiàn)簡(jiǎn)單的GAN-創(chuàng)新互聯(lián)

小編給大家分享一下pytorch:怎么實(shí)現(xiàn)簡(jiǎn)單的GAN,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!

成都創(chuàng)新互聯(lián)專注為客戶提供全方位的互聯(lián)網(wǎng)綜合服務(wù),包含不限于成都做網(wǎng)站、成都網(wǎng)站設(shè)計(jì)、沭陽(yáng)網(wǎng)絡(luò)推廣、微信小程序定制開(kāi)發(fā)、沭陽(yáng)網(wǎng)絡(luò)營(yíng)銷、沭陽(yáng)企業(yè)策劃、沭陽(yáng)品牌公關(guān)、搜索引擎seo、人物專訪、企業(yè)宣傳片、企業(yè)代運(yùn)營(yíng)等,從售前售中售后,我們都將竭誠(chéng)為您服務(wù),您的肯定,是我們大的嘉獎(jiǎng);成都創(chuàng)新互聯(lián)為所有大學(xué)生創(chuàng)業(yè)者提供沭陽(yáng)建站搭建服務(wù),24小時(shí)服務(wù)熱線:13518219792,官方網(wǎng)址:muchs.cn

代碼如下

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 設(shè)置畫圖的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定義畫圖工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定義一個(gè)取樣的函數(shù)
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可視化圖片效果
show_images(imgs)
 
#判別網(wǎng)絡(luò)
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成網(wǎng)絡(luò)
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判別器的 loss 就是將真實(shí)數(shù)據(jù)的得分判斷為 1,假的數(shù)據(jù)的得分判斷為 0,而生成器的 loss 就是將假的數(shù)據(jù)判斷為 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵?fù)p失函數(shù)
 
def discriminator_loss(logits_real, logits_fake): # 判別器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 來(lái)進(jìn)行訓(xùn)練,學(xué)習(xí)率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判別網(wǎng)絡(luò)
      real_data = Variable(x).view(bs, -1) # 真實(shí)數(shù)據(jù)
      logits_real = D_net(real_data) # 判別網(wǎng)絡(luò)得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù)
      logits_fake = D_net(fake_images) # 判別網(wǎng)絡(luò)得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判別器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 優(yōu)化判別網(wǎng)絡(luò)
      
      # 生成網(wǎng)絡(luò)
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù)
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成網(wǎng)絡(luò)的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 優(yōu)化生成網(wǎng)絡(luò)
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上是“pytorch:怎么實(shí)現(xiàn)簡(jiǎn)單的GAN”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道!

分享文章:pytorch:怎么實(shí)現(xiàn)簡(jiǎn)單的GAN-創(chuàng)新互聯(lián)
文章分享:http://muchs.cn/article44/degjee.html

成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供建站公司網(wǎng)站制作、商城網(wǎng)站、網(wǎng)站收錄、外貿(mào)建站網(wǎng)站建設(shè)

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)

成都網(wǎng)站建設(shè)公司