小編給大家分享一下PyTorch中加載數(shù)據(jù)集的示例分析,希望大家閱讀完這篇文章之后都有所收獲,下面讓我們一起去探討吧!
創(chuàng)新互聯(lián)長(zhǎng)期為上1000+客戶提供的網(wǎng)站建設(shè)服務(wù),團(tuán)隊(duì)從業(yè)經(jīng)驗(yàn)10年,關(guān)注不同地域、不同群體,并針對(duì)不同對(duì)象提供差異化的產(chǎn)品和服務(wù);打造開放共贏平臺(tái),與合作伙伴共同營(yíng)造健康的互聯(lián)網(wǎng)生態(tài)環(huán)境。為隴縣企業(yè)提供專業(yè)的成都做網(wǎng)站、網(wǎng)站制作,隴縣網(wǎng)站改版等技術(shù)服務(wù)。擁有十載豐富建站經(jīng)驗(yàn)和眾多成功案例,為您定制開發(fā)。數(shù)據(jù)預(yù)處理在解決深度學(xué)習(xí)問(wèn)題的過(guò)程中,往往需要花費(fèi)大量的時(shí)間和精力。 數(shù)據(jù)處理的質(zhì)量對(duì)訓(xùn)練神經(jīng)網(wǎng)絡(luò)來(lái)說(shuō)十分重要,良好的數(shù)據(jù)處理不僅會(huì)加速模型訓(xùn)練, 更會(huì)提高模型性能。為解決這一問(wèn)題,PyTorch提供了幾個(gè)高效便捷的工具, 以便使用者進(jìn)行數(shù)據(jù)處理或增強(qiáng)等操作,同時(shí)可通過(guò)并行化加速數(shù)據(jù)加載。
數(shù)據(jù)集存放大致有以下兩種方式:
(1)所有數(shù)據(jù)集放在一個(gè)目錄下,文件名上附有標(biāo)簽名,數(shù)據(jù)集存放格式如下: root/cat_dog/cat.01.jpg
root/cat_dog/cat.02.jpg
........................
root/cat_dog/dog.01.jpg
root/cat_dog/dog.02.jpg
......................
(2)不同類別的數(shù)據(jù)集放在不同目錄下,目錄名就是標(biāo)簽,數(shù)據(jù)集存放格式如下:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
................
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
..................
1.1 對(duì)第1種數(shù)據(jù)集的處理步驟
(1)生成包含各文件名的列表(List)
(2)定義Dataset的一個(gè)子類,該子類需要繼承Dataset類,查看Dataset類的源碼
(3)重寫父類Dataset中的兩個(gè)魔法方法: 一個(gè)是: __lent__(self),其功能是len(Dataset),返回Dataset的樣本數(shù)。 另一個(gè)是__getitem__(self,index),其功能假設(shè)索引為i,使Dataset[i]返回第i個(gè)樣本。
(4)使用torch.utils.data.DataLoader加載數(shù)據(jù)集Dataset.
1.2 實(shí)例詳解
以下以cat-dog數(shù)據(jù)集為例,說(shuō)明如何實(shí)現(xiàn)自定義數(shù)據(jù)集的加載。
1.2.1 數(shù)據(jù)集結(jié)構(gòu)
所有數(shù)據(jù)集在cat-dog目錄下:
.\cat_dog\cat.01.jpg
.\cat_dog\cat.02.jpg
.\cat_dog\cat.03.jpg
....................
.\cat_dog\dog.01.jpg
.\cat_dog\dog.02.jpg
....................
1.2.2 導(dǎo)入需要用到的模塊
from torch.utils.data import DataLoader,Dataset from skimage import io,transform import matplotlib.pyplot as plt import oimport torch from torchvision import transforms, utils from PIL import Image import pandas as pd import numpy as np #過(guò)濾警告信息 import warnings warnings.filterwarnings("ignore")
1.2.3定義加載自定義數(shù)據(jù)的類
class MyDataset(Dataset): #繼承Dataset def __init__(self, path_dir, transform=None): #初始化一些屬性 self.path_dir = path_dir #文件路徑,如'.\data\cat-dog' self.transform = transform #對(duì)圖形進(jìn)行處理,如標(biāo)準(zhǔn)化、截取、轉(zhuǎn)換等 self.images = os.listdir(self.path_dir)#把路徑下的所有文件放在一個(gè)列表中 def __len__(self):#返回整個(gè)數(shù)據(jù)集的大小 return len(self.images) def __getitem__(self,index):#根據(jù)索引index返回圖像及標(biāo)簽 image_index = self.images[index]#根據(jù)索引獲取圖像文件名稱 img_path = os.path.join(self.path_dir, image_index)#獲取圖像的路徑或目錄 img = Image.open(img_path).convert('RGB')# 讀取圖像 # 根據(jù)目錄名稱獲取圖像標(biāo)簽(cat或dog) label = img_path.split('\\')[-1].split('.')[0] #把字符轉(zhuǎn)換為數(shù)字cat-0,dog-1 label = 1 if 'dog' in label else 0 if self.transform is not None: img = self.transform(img) return img,label
1.2.4 實(shí)例化類
dataset = MyDataset('.\data\cat-dog',transform=None) img, label = dataset[0] #將啟動(dòng)魔法方法__getitem__(0) print(type(img)) <class 'PIL.Image.Image'>
1.2.5 查看圖像形狀
i=1
for img, label in dataset:
if i
img的形狀(500, 374),label的值0
img的形狀(300, 280),label的值0
img的形狀(489, 499),label的值0
img的形狀(431, 410),label的值0
img的形狀(300, 224),label的值0
從上面返回樣本的形狀來(lái)看:
(1)每張圖片的大小不一樣,如果需要取batch訓(xùn)練的神經(jīng)網(wǎng)絡(luò)來(lái)說(shuō)很不友好。
(2)返回樣本的數(shù)值較大,未歸一化至[-1, 1]
為此需要對(duì)img進(jìn)行轉(zhuǎn)換,如何轉(zhuǎn)換?只要使用torchvision中的transforms即可
1.2.6 對(duì)圖像數(shù)據(jù)進(jìn)行處理
這里使用torchvision中的transforms模塊
from torchvision import transforms as T transform = T.Compose([ T.Resize(224), # 縮放圖片(Image),保持長(zhǎng)寬比不變,最短邊為224像素 T.CenterCrop(224), # 從圖片中間切出224*224的圖片 T.ToTensor(), # 將圖片(Image)轉(zhuǎn)成Tensor,歸一化至[0, 1] T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 標(biāo)準(zhǔn)化至[-1, 1],規(guī)定均值和標(biāo)準(zhǔn)差 ])
1.2.7查看處理后的數(shù)據(jù)
dataset = MyDataset('.\data\cat-dog',transform=transform) for img, label in dataset: print("圖像img的形狀{},標(biāo)簽label的值{}".format(img.shape, label)) print("圖像數(shù)據(jù)預(yù)處理后:\n",img) break
圖像img的形狀torch.Size([3, 224, 224]),標(biāo)簽label的值0
圖像數(shù)據(jù)預(yù)處理后:
tensor([[[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],
[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],
[ 0.9059, 0.9137, 0.9137, ..., 0.9529, 0.9529, 0.9529],
...,
[-0.4824, -0.5294, -0.5373, ..., -0.9216, -0.9294, -0.9451],
[-0.4980, -0.5529, -0.5608, ..., -0.9294, -0.9373, -0.9529],
[-0.4980, -0.5529, -0.5686, ..., -0.9529, -0.9608, -0.9608]],
[[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],
[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],
[ 0.5686, 0.5765, 0.5765, ..., 0.8039, 0.7961, 0.7961],
...,
[-0.6078, -0.6471, -0.6549, ..., -0.9137, -0.9216, -0.9373],
[-0.6157, -0.6706, -0.6784, ..., -0.9216, -0.9294, -0.9451],
[-0.6157, -0.6706, -0.6863, ..., -0.9451, -0.9529, -0.9529]],
[[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],
[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],
[-0.0510, -0.0431, -0.0431, ..., 0.2157, 0.2235, 0.2235],
...,
[-0.9529, -0.9843, -0.9922, ..., -0.9529, -0.9608, -0.9765],
[-0.9686, -0.9922, -1.0000, ..., -0.9608, -0.9686, -0.9843],
[-0.9686, -0.9922, -1.0000, ..., -0.9843, -0.9922, -0.9922]]])
由此可知,數(shù)據(jù)已標(biāo)準(zhǔn)化、規(guī)范化。
1.2.8對(duì)數(shù)據(jù)集進(jìn)行批量加載
使用DataLoader模塊,對(duì)數(shù)據(jù)集dataset進(jìn)行批量加載
#使用DataLoader加載數(shù)據(jù) dataloader = DataLoader(dataset,batch_size=4,shuffle=True) for batch_datas, batch_labels in dataloader: print(batch_datas.size(),batch_labels.size()) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([4, 3, 224, 224]) torch.Size([4]) torch.Size([2, 3, 224, 224]) torch.Size([2])
1.2.9隨機(jī)查看一個(gè)批次的圖像
import torchvision import matplotlib.pyplot as plt import numpy as np %matplotlib inline # 顯示圖像 def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # 隨機(jī)獲取部分訓(xùn)練數(shù)據(jù) dataiter = iter(dataloader) images, labels = dataiter.next() # 顯示圖像 imshow(torchvision.utils.make_grid(images)) # 打印標(biāo)簽 print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小貓" for j in range(4)]))
2 對(duì)第2種數(shù)據(jù)集的處理
處理這種情況比較簡(jiǎn)單,可分為2步:
(1)使用datasets.ImageFolder讀取、處理圖像。
(2)使用.data.DataLoader批量加載數(shù)據(jù)集,示例如下:
import torch from torchvision import transforms, datasets data_transform = transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) hymenoptera_dataset = datasets.ImageFolder(root='.\catdog\train', transform=data_transform) dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
看完了這篇文章,相信你對(duì)“PyTorch中加載數(shù)據(jù)集的示例分析”有了一定的了解,如果想了解更多相關(guān)知識(shí),歡迎關(guān)注創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司行業(yè)資訊頻道,感謝各位的閱讀!
另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無(wú)理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、高防服務(wù)器、香港服務(wù)器、美國(guó)服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡(jiǎn)單易用、服務(wù)可用性高、性價(jià)比高”等特點(diǎn)與優(yōu)勢(shì),專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場(chǎng)景需求。
文章題目:PyTorch中加載數(shù)據(jù)集的示例分析-創(chuàng)新互聯(lián)
網(wǎng)址分享:http://muchs.cn/article38/ceocpp.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供網(wǎng)站排名、網(wǎng)站設(shè)計(jì)、定制開發(fā)、網(wǎng)站策劃、響應(yīng)式網(wǎng)站、移動(dòng)網(wǎng)站建設(shè)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容