当前位置:网站首页>加速訓練之並行化 tf.data.Dataset 生成器

加速訓練之並行化 tf.data.Dataset 生成器

2022-06-12 04:40:00 u012804784

優質資源分享

學習路線指引(點擊解鎖)知識定比特人群定比特
🧡 Python實戰微信訂餐小程序 🧡進階級本課程是python flask+微信小程序的完美結合,從項目搭建到騰訊雲部署上線,打造一個全棧訂餐系統。
Python量化交易實戰入門級手把手帶你打造一個易擴展、更安全、效率更高的量化交易系統

在處理大規模數據時,數據無法全部載入內存,我們通常用兩個選項

  • 使用tfrecords
  • 使用 tf.data.Dataset.from_generator()

tfrecords的並行化使用前文已經有過介紹,這裏不再贅述。如果我們不想生成tfrecord中間文件,那麼生成器就是你所需要的。

本文主要記錄針對 from_generator()的並行化方法,在 tf.data 中,並行化主要通過 mapnum_parallel_calls 實現,但是對一些場景,我們的generator()中有一些處理邏輯,是無法直接並行化的,最簡單的方法就是將generator()中的邏輯抽出來,使用map實現。

tf.data.Dataset generator 並行

generator()中的複雜邏輯,我們對其進行簡化,即僅在生成器中做一些下標取值的類型操作,將generator()中處理部分使用py_function 包裹(wrapped) ,然後調用map處理。

def func(i):
    i = i.numpy() # Decoding from the EagerTensor object
    x, y = your_processing_function(training_set[i])
    return x, y

z = list(range(len(training_set))) # The index generator

dataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)

dataset = dataset.map(lambda i: tf.py_function(func=func, 
                                               inp=[i], 
                                               Tout=[tf.uint8,
                                                     tf.float32]
                                               ), 
                      num_parallel_calls=tf.data.AUTOTUNE)

由於隱式推斷的原因,有時tensor的輸出shape是未知的,需要額外處理

dataset = dataset.batch(8)
def \_fixup\_shape(x, y):
    x.set_shape([None, None, None, nb_channels]) # n, h, w, c
    y.set_shape([None, nb_classes]) # n, nb\_classes
    return x, y
dataset = dataset.map(_fixup_shape)

tf.Tensor與tf.EagerTensor

為什麼需要 tf.py_function,先來看下tf.Tensortf.EagerTensor

EagerTensor是實時的,可以在任何時候獲取到它的值,即通過numpy獲取

Tensor是非實時的,它是靜態圖中的組件,只有當喂入數據、運算完成才能獲得該Tensor的值,

map中映射的函數運算,而僅僅是告訴dataset,你每一次拿出來的樣本時要先進行一遍function運算之後才使用的,所以function的調用是在每次迭代dataset的時候才調用的,屬於靜態圖邏輯

tensorflow.python.framework.ops.EagerTensor
tensorflow.python.framework.ops.Tensor

tf.py_function在這裏起了什麼作用?

Wraps a python function into a TensorFlow op that executes it eagerly.

剛才說到map數據靜態圖邏輯,默認參數都是Tensor。而 使用tf.py_function()包裝後,參數就變成了EagerTensor。

references

【1】https://medium.com/@acordier/tf-data-dataset-generators-with-parallelization-the-easy-way-b5c5f7d2a18

【2】https://blog.csdn.net/qq_27825451/article/details/105247211

【3】https://www.tensorflow.org/guide/data_performance#parallelizing_data_extraction

原网站

版权声明
本文为[u012804784]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/163/202206120434005136.html