当前位置:网站首页>Tensorflow2.0 深度学习运行代码简单教程
Tensorflow2.0 深度学习运行代码简单教程
2022-07-26 22:24:00 【赫凯】
之前写过Pytorch,现在工作需求用了Tensorflow2.0的框架,经过这段时间的学习,来总结下。Tensorflow2.0 提供了两种训练模式可以上手。
简单模式
Tensorflow深度学习简单模式,它的基本思想就是将所有的操作进行封装,能简单就简单,所以只要知道输入的数据格式和它的如何接收数据就好啦。
整理数据
# 引入包
import tensorflow as tf
# 做数据,tensorflow有个特别好的机制就是将常见的数据包封装起来
mnist = tf.keras.datasets.mnist
# 这里的x_train, y_train和x_test, y_test具体数据就是普通array类型
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
模型
Tensorflow 的语法总是封装好的,好多API都是封装好的,不用自己写,就是找到有点麻烦
# 这里的模型建立和pytorch的`Sequential`一样简单罗列,代码执行就是顺序执行
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# 这里就是设置优化器,损失函数,以及指标参数
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
跑起来
跑起来就简单多了
# 训练
model.fit(x_train, y_train, epochs=5)
# 验证
model.evaluate(x_test, y_test, verbose=2)
当然还可以自定义一些回调函数什么的,去控制早停,模型存储,学习率优化等等。
专家模式
专家模式有了好多我们自己可以设置的来一步步看看吧
整理数据
要注意的是Tensorflow 不像Pytorch一样数据是<x, y>一对一对地输出,它是一列x和一列y分开的。
# 引入包
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
# 做个最简单的数据
x = [i for i in range(1, 10000)]
y = [(i+1)*(i+2) for i in x]
# 在转变为Tensorflow的Tensor格式之前需要转换为numpy格式
x = np.array(x)
y = np.array(y)
# 这个就类似于Pytorch的Dataset
# batch就是batchsize
# shuffle是打乱数据顺序的程度
# map是可以对提出的数据进行二次加工
# 这里的输入要和from_tensor_slices((x, y))保持一致,输出随自己
def fun(x, y):
return x, x+2, y
# 这里的ds就是Tensorflow存放数据的地方了,类似于Pytorch的DataLoader,可以用for循环打印
ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(1000).shuffle(5000).map(fun)
模型
这个模型也是和Pytorch的差不多,不过需要是重写继承方法call,就是前向传播。
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.d = Dense(32, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(32, activation='relu')
self.d2 = Dense(10)
# 在上面的map中出来的是两个x,这里也用两个参数进行接收
def call(self, x, x1):
x = self.d(x)
x = self.flatten(x)
x1 = self.d1(x1)
return self.d2((x+x1))
# 创建一个模型实例
model = MyModel()
# 再定义优化器和损失函数
loss_object = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam()
# 损失存放
train_loss = tf.keras.metrics.Mean(name='train_loss')
跑起来
一般会定义一个训练方法,在方法前加@tf.function这句话,可以加快训练速度,但是不能打印数据流动的信息了。
@tf.function
def train_step(x, x1, y):
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
predictions = model(x, x1, training=True)
loss = loss_object(y, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
# 走几个轮回
EPOCHS = 500
for epoch in range(EPOCHS):
# Reset the metrics at the start of the next epoch
train_loss.reset_states()
for x, x1, y in ds:
train_step(x, x1, y)
print(
f'Epoch {
epoch + 1}, '
f'Loss: {
train_loss.result()}, '
)
基本就是这样,Tensorflow的API特别丰富,熟悉了好多函数都不用自己写了,还是蛮方便的,但需要了解。
边栏推荐
- 杭州银行面试题【杭州多测师】【杭州多测师_王sir】
- HCIA-R&S自用笔记(18)园区网架构基础、交换机工作原理、VLAN原理
- Three person management of system design
- [flask advanced] analyze the thread isolation mechanism in flask in combination with the source code
- How to recover the original data when the U disk is damaged, and how to recover the damaged data when the U disk is damaged
- 什么是 Base64 ?
- 30、 Modern storage system (management database and distributed storage system)
- My SQL is OK. Why is it still so slow? MySQL locking rules
- Dao:op token and non transferable NFT are committed to building a new digital democracy
- Why am I still writing articles on CSDN? A journey of accompanying learning.
猜你喜欢

Part II - C language improvement_ 7. Structure

华裔科学家Ashe教授对涉嫌造假的Nature论文的正面回应

json格式化小工具--pyqt5实例

SQL Basics

程序员成长第二十九篇:如何激励员工?

黑马瑞吉外卖之新增员工

Kt6368a Bluetooth chip development precautions and problem collection - long term update

Ribbon负载均衡

Counter attack dark horse: devdbops training, give you the best courses!

Recruit | PostgreSQL database R & D engineers every week, with an annual salary of 60+, high salary for famous enterprises, and challenge yourself!
随机推荐
Openstack virtual machine network card is renamed cirename0
Esmfold: a new breakthrough in protein structure prediction after alphafold2
HCIA-R&S自用笔记(21)STP技术背景、STP基础和数据包结构、STP选举规则及案例
About statefulwidget, you have to know the principle and main points!
C language dynamic memory management
第二部分—C语言提高篇_8. 文件操作
Product principles of non-financial decentralized application
Concept of functional interface & definition and use of functional interface
Lesson 2 of Silicon Valley classroom - building project environment and developing lecturer management interface
Huawei atlas900 reveals the secret: it integrates thousands of shengteng 910 chips, and its computing power is comparable to 500000 PCs!
Part II - C language improvement_ 7. Structure
第二部分—C语言提高篇_7. 结构体
30、 Modern storage system (management database and distributed storage system)
Go uses flag package to parse command line parameters
[shader realizes swaying effect _shader effect Chapter 4]
Problems and solutions encountered in using nextline(), nextint() and next() in scanner
Dynamic memory management and related topics
Vit:vision transformer super detailed with code
Arduino experiment I: two color lamp experiment
JSON formatting gadget -- pyqt5 instance