当前位置:网站首页>Dataloader source code_ DataLoader
Dataloader source code_ DataLoader
2022-06-30 23:11:00 【Full stack programmer webmaster】
Hello everyone , I meet you again , I'm your friend, Quan Jun .
import paddle.fluid as fluid
import numpy as np
BATCH_NUM = 10
BATCH_SIZE = 16
EPOCH_NUM = 4
CLASS_NUM = 10
ITERABLE = True # whether the created DataLoader object is iterable
USE_GPU = False # whether to use GPU
DATA_FORMAT = ‘batch_generator’ # data format of data source user provides
def simple_net(image, label):
fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)
cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)
loss = fluid.layers.reduce_mean(cross_entropy)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
return loss
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype(‘float32’)
label = np.random.random(size=label_shape).astype(‘int64’)
return image, label
# If the data generator yields one sample each time,
# use DataLoader.set_sample_generator to set the data source.
def sample_generator_creator():
def __reader__():
for _ in range(BATCH_NUM * BATCH_SIZE):
image, label = get_random_images_and_labels([784], [1])
yield image, label
return __reader__
# If the data generator yield list of samples each time,
# use DataLoader.set_sample_list_generator to set the data source.
def sample_list_generator_creator():
def __reader__():
for _ in range(BATCH_NUM):
sample_list = []
for _ in range(BATCH_SIZE):
image, label = get_random_images_and_labels([784], [1])
sample_list.append([image, label])
yield sample_list
return __reader__
# If the data generator yields a batch each time,
# use DataLoader.set_batch_generator to set the data source.
def batch_generator_creator():
def __reader__():
for _ in range(BATCH_NUM):
batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])
yield batch_image, batch_label
return __reader__
# If DataLoader is iterable, use for loop to train the network
def train_iterable(exe, prog, loss, loader):
for _ in range(EPOCH_NUM):
for data in loader():
exe.run(prog, feed=data, fetch_list=[loss])
# If DataLoader is not iterable, use start() and reset() method to control the process
def train_non_iterable(exe, prog, loss, loader):
for _ in range(EPOCH_NUM):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
exe.run(prog, fetch_list=[loss])
except fluid.core.EOFException:
loader.reset() # call DataLoader.reset() after catching EOFException
def set_data_source(loader, places):
if DATA_FORMAT == ‘sample_generator’:
loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)
elif DATA_FORMAT == ‘sample_list_generator’:
loader.set_sample_list_generator(sample_list_generator_creator(), places=places)
elif DATA_FORMAT == ‘batch_generator’:
loader.set_batch_generator(batch_generator_creator(), places=places)
else:
raise ValueError(‘Unsupported data format’)
image = fluid.layers.data(name=’image’, shape=[784], dtype=’float32′)
label = fluid.layers.data(name=’label’, shape=[1], dtype=’int64′)
# Define DataLoader
loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)
# Define network
loss = simple_net(image, label)
# Set data source of DataLoader
#
# If DataLoader is iterable, places must be given and the number of places must be the same with device number.
# – If you are using GPU, call `fluid.cuda_places()` to get all GPU places.
# – If you are using CPU, call `fluid.cpu_places()` to get all CPU places.
#
# If DataLoader is not iterable, places can be None.
places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()
set_data_source(loader, places)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
if loader.iterable:
train_iterable(exe, prog, loss, loader)
else:
train_non_iterable(exe, prog, loss, loader)
”’
Users can use return_list = True in dygraph mode.
”’
with fluid.dygraph.guard(places[0]):
loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)
set_data_source(loader, places[0])
for image, label in loader():
relu = fluid.layers.relu(image)
assert image.shape == [BATCH_SIZE, 784]
assert label.shape == [BATCH_SIZE, 1]
assert relu.shape == [BATCH_SIZE, 784]
Publisher : Full stack programmer stack length , Reprint please indicate the source :https://javaforall.cn/132199.html Link to the original text :https://javaforall.cn
边栏推荐
- Introduction to machine learning compilation course learning notes lesson 2 tensor program abstraction
- 微信支付WxPayPubHelper v3版 回调xml为空的原因
- How to distinguish between platform security and online hype? What are the stop loss techniques for online speculation?
- [fundamentals of wireless communication-13]: illustrated mobile communication technology and application development-1-overview
- 76页智慧物流园区综合解决方案2022(附下载)
- Meet the streamnational | yangzike: what made me give up Dachang offer
- MIT doctoral dissertation optimization theory and machine learning practice
- What does the &?
- 如何使用 DataAnt 监控 Apache APISIX
- 对于产业互联网的粗浅认识,最终将产业互联网的发展带入到了消费互联网的怪圈之中
猜你喜欢

Doker's container data volume

shell 同时执行多任务下载视频

Ideal interface automation project

HP 惠普笔记本电脑 禁用触摸板 在插入鼠标后

2022-06-30: what does the following golang code output? A:0; B:2; C: Running error. package main import “fmt“ func main() { ints := make

Architecture of IM integrated messaging system sharing 100000 TPS

Redis的事务和锁机制

Jmeter跨线程参数关联无需脚本

Introduction to machine learning compilation course learning notes lesson 2 tensor program abstraction

项目管理到底管的是什么?
随机推荐
Shell multitasking to download video at the same time
Doker's container data volume
【Android,Kotlin,TFLite】移动设备集成深度学习轻模型TFlite(图像分类篇)
图纸加密如何保障我们的核心图纸安全
leetcode:104. Maximum depth of binary tree
10 airbags are equipped as standard, and Chery arizer 8 has no dead corner for safety protection
Doker的容器数据卷
基金销售行为规范及信息管理
Dell r720 server installation network card Broadcom 5720 driver
Fund managers' corporate governance and risk management
CNN经典网络模型详解-LeNet-5(pytorch实现)
Redis的事务和锁机制
New trends of China's national tide development in 2022
Zero sample and small sample learning
HP notebook disable touchpad after mouse is inserted
Sm2246en+ SanDisk 15131
ESP8266 成为客户端和服务器
Swift 5.0 - creation and use of swift framework
D compile time count
[无线通信基础-13]:图解移动通信技术与应用发展-1-概述