当前位置:网站首页>dataloader 源码_DataLoader
dataloader 源码_DataLoader
2022-06-30 18:58:00 【全栈程序员站长】
大家好,又见面了,我是你们的朋友全栈君。
import paddle.fluid as fluid
import numpy as np
BATCH_NUM = 10
BATCH_SIZE = 16
EPOCH_NUM = 4
CLASS_NUM = 10
ITERABLE = True # whether the created DataLoader object is iterable
USE_GPU = False # whether to use GPU
DATA_FORMAT = ‘batch_generator’ # data format of data source user provides
def simple_net(image, label):
fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)
cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)
loss = fluid.layers.reduce_mean(cross_entropy)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
return loss
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype(‘float32’)
label = np.random.random(size=label_shape).astype(‘int64’)
return image, label
# If the data generator yields one sample each time,
# use DataLoader.set_sample_generator to set the data source.
def sample_generator_creator():
def __reader__():
for _ in range(BATCH_NUM * BATCH_SIZE):
image, label = get_random_images_and_labels([784], [1])
yield image, label
return __reader__
# If the data generator yield list of samples each time,
# use DataLoader.set_sample_list_generator to set the data source.
def sample_list_generator_creator():
def __reader__():
for _ in range(BATCH_NUM):
sample_list = []
for _ in range(BATCH_SIZE):
image, label = get_random_images_and_labels([784], [1])
sample_list.append([image, label])
yield sample_list
return __reader__
# If the data generator yields a batch each time,
# use DataLoader.set_batch_generator to set the data source.
def batch_generator_creator():
def __reader__():
for _ in range(BATCH_NUM):
batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])
yield batch_image, batch_label
return __reader__
# If DataLoader is iterable, use for loop to train the network
def train_iterable(exe, prog, loss, loader):
for _ in range(EPOCH_NUM):
for data in loader():
exe.run(prog, feed=data, fetch_list=[loss])
# If DataLoader is not iterable, use start() and reset() method to control the process
def train_non_iterable(exe, prog, loss, loader):
for _ in range(EPOCH_NUM):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
exe.run(prog, fetch_list=[loss])
except fluid.core.EOFException:
loader.reset() # call DataLoader.reset() after catching EOFException
def set_data_source(loader, places):
if DATA_FORMAT == ‘sample_generator’:
loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)
elif DATA_FORMAT == ‘sample_list_generator’:
loader.set_sample_list_generator(sample_list_generator_creator(), places=places)
elif DATA_FORMAT == ‘batch_generator’:
loader.set_batch_generator(batch_generator_creator(), places=places)
else:
raise ValueError(‘Unsupported data format’)
image = fluid.layers.data(name=’image’, shape=[784], dtype=’float32′)
label = fluid.layers.data(name=’label’, shape=[1], dtype=’int64′)
# Define DataLoader
loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)
# Define network
loss = simple_net(image, label)
# Set data source of DataLoader
#
# If DataLoader is iterable, places must be given and the number of places must be the same with device number.
# – If you are using GPU, call `fluid.cuda_places()` to get all GPU places.
# – If you are using CPU, call `fluid.cpu_places()` to get all CPU places.
#
# If DataLoader is not iterable, places can be None.
places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()
set_data_source(loader, places)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
if loader.iterable:
train_iterable(exe, prog, loss, loader)
else:
train_non_iterable(exe, prog, loss, loader)
”’
Users can use return_list = True in dygraph mode.
”’
with fluid.dygraph.guard(places[0]):
loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)
set_data_source(loader, places[0])
for image, label in loader():
relu = fluid.layers.relu(image)
assert image.shape == [BATCH_SIZE, 784]
assert label.shape == [BATCH_SIZE, 1]
assert relu.shape == [BATCH_SIZE, 784]
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/132199.html原文链接:https://javaforall.cn
边栏推荐
- Task04:集合运算-表的加减法和join等--天池龙珠计划SQL训练营学习笔记
- Code shoe set - mt3111 · assignment
- Browser window switch activation event visibilitychange
- Sqlserver SQL Server Management Studio and transact SQL create accounts and create read-only users to access the specified database
- 码蹄集 - MT3435 · 赋值 - 二分图问题 - 图文讲解
- Kubernetes为什么会赢,容器圈的风云变幻!
- Promise from recognition to use
- VR全景拍摄为什么要加盟?巧借资源实现共赢
- 科大讯飞活跃竞赛汇总!(12个)
- 一文详解|Go 分布式链路追踪实现原理
猜你喜欢
随机推荐
Entropy - conditional entropy - joint entropy - mutual information - cross entropy
将 EMQX Cloud 数据通过公网桥接到 AWS IoT
Tupu software has passed CMMI5 certification| High authority and high-level certification in the international software field
What securities dealers recommend? In addition, is it safe to open a mobile account?
Why do more and more people choose cloud rendering?
The project is configured with eslint. When the editor does not close the eslint function, the eslint does not take effect
mysql统计账单信息(上):mysql安装及客户端DBeaver连接使用
Makefile笔记(一文学会Makefile)
Go语言学习教程(十三)
Abaqus 2022软件安装包和安装教程
Friends in Guangzhou can join us if they have the opportunity
Pyth-Solana链上联通现实的桥梁
SQL continuous login problem
[solved] how does Tiktok cancel paying attention to the cancelled account
MySQL数据库查询优化
Safe holidays without holidays, VR traffic makes children travel safely | Guangzhou Sinovel viewpoint
1. 爬虫之Beautifulsoup解析库&在线解析图片验证码
力扣------统计包含给定前缀的字符串
手机炒股开户安全嘛!?
阿里天池SQL训练营学习笔记5




![[solved] how does Tiktok cancel paying attention to the cancelled account](/img/1f/7b0bd2c0f69f7f3d1c25c426cc5771.png)



