当前位置:网站首页>TensorFlow2训练数据集的两种方式
TensorFlow2训练数据集的两种方式
2022-06-12 17:14:00 【老油条666】
方式一:
def pre_process(x, y):
x = 2. * tf.cast(x, dtype=tf.float32) / 255. - 1.
y = tf.cast(y, dtype=tf.int32)
return x, y
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
x_train, y_train = pre_process(x_train, y_train)
x_test, y_test = pre_process(x_test, y_test)
print(x_train.shape, y_train.shape)
history = net.fit(x_train, y_train,
batch_size=512,
epochs=100,
validation_split=0.2)
test_scores = net.evaluate(x_test, y_test, verbose=2)训练方式二:
def pre_process(x, y):
# [0,255] => [-1,1] ,[-1,1]可能是一个最适合神经网络计算的范围
x = 2. * tf.cast(x, dtype=tf.float32) / 255. - 1
y = tf.squeeze(y) # 从张量形状中移除大小为1的维度.
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
batch_size = 128
(x, y), (x_val, y_val) = datasets.cifar10.load_data()
print('datasets:', x.shape, y.shape, x.min(), y.min())
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(pre_process).shuffle(1000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(pre_process).shuffle(1000).batch(batch_size)
sample = next(iter(train_db))
print('batch:', sample[0].shape, sample[1].shape)
network = MyNetwork() # MYNetwork是Keras.Model的一个子类
network.compile(
optimizer=optimizers.Adam(learning_rate=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(train_db, epochs=50, validation_data=test_db, validation_freq=1)
network.evaluate(test_db)
network.save_weights('./ckpt/cifar10_weights.ckpt') # b将模型保存到磁盘文件参考链接:
1.李沐大神《动手深度学习》TensorFlow实现,GitHub链接:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0,参考了其中的CNN5.9GoogleNet部分代码
2龙良曲.深度学习与TensorFlow入门实战,项目GitHub链接:https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book,摘自其中的Lesson40--CIFAR与VGG实战
边栏推荐
- Operating with idle funds
- The significance of writing technology blog
- DRM 驱动 mmap 详解:(一)预备知识
- Gerrit triggers Jenkins sonarqube scan
- Exclusive interview with oppo find X5 Product Manager: deeply cultivate self-developed chips to create the ultimate flagship experience with the highest standards
- [BSP video tutorial] stm32h7 video tutorial Issue 8: the last issue of the MDK theme, the new generation of debugging technologies event recorder and RTT, and using stm32cubemx to generate project tem
- R语言使用ggplot2可视化dataframe数据中特定数据列的密度图(曲线)、并使用xlim参数指定X轴的范围
- 5、Embedding
- Cicada mother talks to rainbow couple: 1.3 billion goods a year, from e-commerce beginners to super goods anchor
- Saturated! Can't future programmers work anymore?
猜你喜欢

2080 virtual machine login command

Hangzhou AI developer meetup registration opens!

多种Qt的开发方式,你选择哪种?

LCD参数解释及计算

Volcano engine held a video cloud technology force summit and released a new experience oriented video cloud product matrix

Microsoft Office MSDT Code Execution Vulnerability (cve-2022-30190) vulnerability recurrence

JVM内存模型与本地内存

Introduction to several common functions of fiddler packet capturing (stop packet capturing, clear session window contents, filter requests, decode, set breakpoints...)

I heard that distributed IDS cannot be incremented globally?

Qiushengchang: Practice of oppo commercial data system construction
随机推荐
JS using RSA encryption and decryption
Quick start sweep crawler framework
office应用程序无法正常启动0xc0000142
Volcano engine held a video cloud technology force summit and released a new experience oriented video cloud product matrix
Exclusive interview with oppo find X5 Product Manager: deeply cultivate self-developed chips to create the ultimate flagship experience with the highest standards
Qiushengchang: Practice of oppo commercial data system construction
DRM 驱动 mmap 详解:(一)预备知识
LCD参数解释及计算
性能优化之编译优化
Feedback compilation
STL -- function object
龙芯处理器内核中断讲解
记录使用yolov5进行旋转目标的检测
Introduction to several common functions of fiddler packet capturing (stop packet capturing, clear session window contents, filter requests, decode, set breakpoints...)
Learn the mitmproxy packet capturing tool from scratch
布局管理中的sizePolicy的策略问题
C# 业务流水号规则生成组件
Dongfeng Yueda Kia, Tencent advertising and hero League mobile game professional league cooperate to build a new E-sports ecology
redis. clients. jedis. exceptions. JedisConnectionException: Could not get a resource from the pool
写技术博客的意义