如何利用Tensorflow的隊(duì)列多線程讀取數(shù)據(jù)-創(chuàng)新互聯(lián)

這篇文章主要介紹了如何利用Tensorflow的隊(duì)列多線程讀取數(shù)據(jù),具有一定借鑒價(jià)值,感興趣的朋友可以參考下,希望大家閱讀完這篇文章之后大有收獲,下面讓小編帶著大家一起了解一下。

成都創(chuàng)新互聯(lián)公司-專業(yè)網(wǎng)站定制、快速模板網(wǎng)站建設(shè)、高性價(jià)比方城網(wǎng)站開發(fā)、企業(yè)建站全套包干低至880元,成熟完善的模板庫,直接使用。一站式方城網(wǎng)站制作公司更省心,省錢,快速模板網(wǎng)站建設(shè)找我們,業(yè)務(wù)覆蓋方城地區(qū)。費(fèi)用合理售后完善,十多年實(shí)體公司更值得信賴。

在tensorflow中,有三種方式輸入數(shù)據(jù)

1. 利用feed_dict送入numpy數(shù)組

2. 利用隊(duì)列從文件中直接讀取數(shù)據(jù)

3. 預(yù)加載數(shù)據(jù)

其中第一種方式很常用,在tensorflow的MNIST訓(xùn)練源碼中可以看到,通過feed_dict={},可以將任意數(shù)據(jù)送入tensor中。

第二種方式相比于第一種,速度更快,可以利用多線程的優(yōu)勢(shì)把數(shù)據(jù)送入隊(duì)列,再以batch的方式出隊(duì),并且在這個(gè)過程中可以很方便地對(duì)圖像進(jìn)行隨機(jī)裁剪、翻轉(zhuǎn)、改變對(duì)比度等預(yù)處理,同時(shí)可以選擇是否對(duì)數(shù)據(jù)隨機(jī)打亂,可以說是非常方便。該部分的源碼在tensorflow官方的CIFAR-10訓(xùn)練源碼中可以看到,但是對(duì)于剛學(xué)習(xí)tensorflow的人來說,比較難以理解,本篇博客就當(dāng)成我調(diào)試完成后寫的一篇總結(jié),以防自己再忘記具體細(xì)節(jié)。

讀取CIFAR-10數(shù)據(jù)集

按照第一種方式的話,CIFAR-10的讀取只需要寫一段非常簡單的代碼即可將測試集與訓(xùn)練集中的圖像分別讀?。?/p>

path = 'E:\Dataset\cifar-10\cifar-10-batches-py'
# extract train examples
num_train_examples = 50000
x_train = np.empty((num_train_examples, 32, 32, 3), dtype='uint8')
y_train = np.empty((num_train_examples), dtype='uint8')
for i in range(1, 6): 
 fpath = os.path.join(path, 'data_batch_' + str(i)) 
 (x_train[(i - 1) * 10000: i * 10000, :, :, :], y_train[(i - 1) * 10000: i * 10000])   = load_and_decode(fpath)
# extract test examples
fpath = os.path.join(path, 'test_batch')
x_test, y_test = load_and_decode(fpath)
return x_train, y_train, x_test, np.array(y_test)

其中l(wèi)oad_and_decode函數(shù)只需要按照CIFAR-10官網(wǎng)給出的方式decode就行,最終返回的x_train是一個(gè)[50000, 32, 32, 3]的ndarray,但對(duì)于ndarray來說,進(jìn)行預(yù)處理就要麻煩很多,為了取mini-SGD的batch,還自己寫了一個(gè)類,通過調(diào)用train_set.next_batch()函數(shù)來取,總而言之就是什么都要自己動(dòng)手,效率確實(shí)不高

但對(duì)于第二種方式,讀取起來就要麻煩很多,但使用起來,又快又方便

首先,把CIFAR-10的測試集文件讀取出來,生成文件名列表

path = 'E:\Dataset\cifar-10\cifar-10-batches-py'
filenames = [os.path.join(path, 'data_batch_%d' % i) for i in range(1, 6)]

有了列表以后,利用tf.train.string_input_producer函數(shù)生成一個(gè)讀取隊(duì)列

filename_queue = tf.train.string_input_producer(filenames)

接下來,我們調(diào)用read_cifar10函數(shù),得到一幅一幅的圖像,該函數(shù)的代碼如下:

def read_cifar10(filename_queue):
 label_bytes = 1
 IMAGE_SIZE = 32
 CHANNELS = 3
 image_bytes = IMAGE_SIZE*IMAGE_SIZE*3
 record_bytes = label_bytes+image_bytes

 # define a reader
 reader = tf.FixedLengthRecordReader(record_bytes)
 key, value = reader.read(filename_queue)
 record_bytes = tf.decode_raw(value, tf.uint8)

 label = tf.strided_slice(record_bytes, [0], [label_bytes])
 depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],  
            [label_bytes + image_bytes]),
        [CHANNELS, IMAGE_SIZE, IMAGE_SIZE])
 image = tf.transpose(depth_major, [1, 2, 0])
 return image, label

第9行,定義一個(gè)reader,來讀取固定長度的數(shù)據(jù),這個(gè)固定長度是由CIFAR-10數(shù)據(jù)集圖片的存儲(chǔ)格式?jīng)Q定的,1byte的標(biāo)簽加上32 *32 *3長度的圖像,3代表RGB三通道,由于圖片的是按[channel, height, width]的格式存儲(chǔ)的,為了變?yōu)槌S玫腫height, width, channel]維度,需要在17行reshape一次圖像,最終我們提取出了一副完整的圖像與對(duì)應(yīng)的標(biāo)簽

對(duì)圖像進(jìn)行預(yù)處理

我們?nèi)〕龅膇mage與label均為tensor格式,因此預(yù)處理將變得非常簡單

 if not distortion:
  IMAGE_SIZE = 32
 else:
  IMAGE_SIZE = 24
  # 隨機(jī)裁剪為24*24大小
  distorted_image = tf.random_crop(tf.cast(image, tf.float32), [IMAGE_SIZE, IMAGE_SIZE, 3])
  # 隨機(jī)水平翻轉(zhuǎn)
  distorted_image = tf.image.random_flip_left_right(distorted_image)
  # 隨機(jī)調(diào)整亮度
  distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
  # 隨機(jī)調(diào)整對(duì)比度
  distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
  # 對(duì)圖像進(jìn)行白化操作,即像素值轉(zhuǎn)為零均值單位方差
  float_image = tf.image.per_image_standardization(distorted_image)

distortion是定義的一個(gè)輸入布爾型變量,默認(rèn)為True,表示是否對(duì)圖像進(jìn)行處理

填充隊(duì)列與隨機(jī)打亂

調(diào)用tf.train.shuffle_batch或tf.train.batch函數(shù),以tf.train.shuffle_batch為例,函數(shù)的定義如下:

def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
     num_threads=1, seed=None, enqueue_many=False, shapes=None,
     allow_smaller_final_batch=False, shared_name=None, name=None):

tensors表示輸入的張量(tensor),batch_size表示要輸出的batch的大小,capacity表示隊(duì)列的容量,即大小,min_after_dequeue表示出隊(duì)操作后隊(duì)列中的最小元素?cái)?shù)量,這個(gè)值是要小于隊(duì)列的capacity的,通過調(diào)整min_after_dequeue與capacity兩個(gè)變量,可以改變數(shù)據(jù)被隨機(jī)打亂的程度,num_threads表示使用的線程數(shù),只要取大于1的數(shù),隊(duì)列的效率就會(huì)高很多。

通常情況下,我們只需要輸入以上幾個(gè)變量即可,在CIFAR-10_input.py中,谷歌給出的代碼是這樣寫的:

if shuffle:
 images, label_batch = tf.train.shuffle_batch([image, label], batch_size,         min_queue_examples+3*batch_size,
       min_queue_examples, num_preprocess_threads)
else:
 images, label_batch = tf.train.batch([image, label], batch_size,
           num_preprocess_threads, 
           min_queue_examples + 3 * batch_size)

min_queue_examples由以下方式得到:

min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN 
       *min_fraction_of_examples_in_queue)

當(dāng)然,這些值均可以自己隨意設(shè)置,

最終得到的images,labels(label_batch),即為shape=[128, 32, 32, 3]的tensor,其中128為默認(rèn)batch_size。

激活隊(duì)列與處理異常

得到了images和labels兩個(gè)tensor后,我們便可以把這兩個(gè)tensor送入graph中進(jìn)行運(yùn)算了

# input tensor
img_batch, label_batch = cifar10_input.tesnsor_shuffle_input(batch_size)

# build graph that computes the logits predictions from the inference model
logits, predicts = train.inference(img_batch, keep_prob)

# calculate loss
loss = train.loss(logits, label_batch)

定義sess=tf.Session()后,運(yùn)行sess.run(),然而你會(huì)發(fā)現(xiàn)并沒有輸出,程序直接掛起了,仿佛死掉了一樣

原因是這樣的,雖然我們?cè)跀?shù)據(jù)流圖中加入了隊(duì)列,但只有調(diào)用tf.train.start_queue_runners()函數(shù)后,數(shù)據(jù)才會(huì)動(dòng)起來,被負(fù)責(zé)輸入管道的線程填入隊(duì)列,否則隊(duì)列將會(huì)掛起。

OK,我們調(diào)用函數(shù),讓隊(duì)列運(yùn)行起來

with tf.Session(config=run_config) as sess:
 sess.run(init_op) # intialization
 queue_runner = tf.train.start_queue_runners(sess)
 for i in range(10):
  b1, b2 = sess.run([img_batch, label_batch])
  print(b1.shape)

在這里為了測試,我們?nèi)?0次輸出,看看輸出的batch2的維度是否正確

如何利用Tensorflow的隊(duì)列多線程讀取數(shù)據(jù)

10個(gè)batch的維度均為正確的,但是tensorflow卻報(bào)了錯(cuò),錯(cuò)誤的文字內(nèi)容如下:

2017-12-19 16:40:56.429687: W C:\tf_jenkins\home\workspace\rel-win\M\windows-gpu\PY\36\tensorflow\core\kernels\queue_base.cc:295] _ 0 _ input_producer: Skipping cancelled enqueue attempt with queue not closed

簡單地看一下,大致意思是說我們的隊(duì)列里還有數(shù)據(jù),但是程序結(jié)束了,拋出了異常,因此,我們還需要定義一個(gè)Coordinator,也就是協(xié)調(diào)器來處理異常

Coordinator有3個(gè)主要方法:

1. tf.train.Coordinator.should_stop() 如果線程應(yīng)該停止,返回True

2. tf.train.Coordinator.request_stop() 請(qǐng)求停止線程

3. tf.train.Coordinator.join() 等待直到指定線程停止

首先,定義協(xié)調(diào)器

coord = tf.train.Coordinator()

將協(xié)調(diào)器應(yīng)用于QueueRunner

queue_runner = tf.train.start_queue_runners(sess, coord=coord)

結(jié)束數(shù)據(jù)的訓(xùn)練或測試后,關(guān)閉線程

coord.request_stop()
coord.join(queue_runner)

最終的sess代碼段如下:

coord = tf.train.Coordinator()
with tf.Session(config=run_config) as sess:
 sess.run(init_op)
 queue_runner = tf.train.start_queue_runners(sess, coord=coord)
 for i in range(10):
  b1, b2 = sess.run([img_batch, label_batch])
  print(b1.shape)
 coord.request_stop()
 coord.join(queue_runner)

得到的輸出結(jié)果為:

如何利用Tensorflow的隊(duì)列多線程讀取數(shù)據(jù)

感謝你能夠認(rèn)真閱讀完這篇文章,希望小編分享的“如何利用Tensorflow的隊(duì)列多線程讀取數(shù)據(jù)”這篇文章對(duì)大家有幫助,同時(shí)也希望大家多多支持創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司,關(guān)注創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司行業(yè)資訊頻道,更多相關(guān)知識(shí)等著你來學(xué)習(xí)!

另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、網(wǎng)站設(shè)計(jì)器、香港服務(wù)器、美國服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡單易用、服務(wù)可用性高、性價(jià)比高”等特點(diǎn)與優(yōu)勢(shì),專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場景需求。

新聞名稱:如何利用Tensorflow的隊(duì)列多線程讀取數(shù)據(jù)-創(chuàng)新互聯(lián)
網(wǎng)站網(wǎng)址:http://muchs.cn/article18/cosidp.html

成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供營銷型網(wǎng)站建設(shè)、網(wǎng)站設(shè)計(jì)、面包屑導(dǎo)航、關(guān)鍵詞優(yōu)化、企業(yè)建站、定制網(wǎng)站

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場,如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)

微信小程序開發(fā)