当前位置:网站首页>基于tensorflow搭建神经网络
基于tensorflow搭建神经网络
2022-07-28 05:22:00 【积雨辋川】
基于tensorflow搭建神经网络
一、tf.keras搭建神经网络步骤
六步法
- import
- train,test
- model = tf.keras.models.Sequential
- model.compile
- model.fit
- model.summary
models.Sequential()
model = tf.keras.models.Sequential ([ 网络结构 ]) #描述各层网络
网络结构举例:
- 拉直层:
tf.keras.layers.Flatten( )
- 全连接层:
tf.keras.layers.Dense(神经元个数, activation= “激活函数”,kernel_regularizer=哪种正则化)
activation(字符串给出)可选: relu、 softmax、 sigmoid 、 tanh
kernel_regularizer可选: tf.keras.regularizers.l1()、tf.keras.regularizers.l2()
- 卷积层:
tf.keras.layers.Conv2D(filters = 卷积核个数, kernel_size = 卷积核尺寸,strides = 卷积步长, padding = " valid" or “same”)
- LSTM层: tf.keras.layers.LSTM()
model.compile()
model.compile(optimizer = 优化器,
loss = 损失函数
metrics = [“准确率”] )
Optimizer可选:
‘sgd’ or tf.keras.optimizers.SGD (lr=学习率,momentum=动量参数)
‘adagrad’ or tf.keras.optimizers.Adagrad (lr=学习率)
‘adadelta’ or tf.keras.optimizers.Adadelta (lr=学习率)
‘adam’ or tf.keras.optimizers.Adam (lr=学习率, beta_1=0.9, beta_2=0.999)
loss可选:
‘mse’ or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy’ or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
Metrics可选:
‘accuracy’ :y_和y都是数值,如y_=[1] y=[1]
‘categorical_accuracy’ :y_和y都是独热码(概率分布),如y_=[0,1,0] y=[0.256,0.695,0.048]
‘sparse_categorical_accuracy’ :y_是数值,y是独热码(概率分布),如y_=[1] y=[0.256,0.695,0.048]
model.fit()
model.fit (训练集的输入特征, 训练集的标签,
batch_size= , epochs= ,
validation_data=(验证集的输入特征,验证集的标签),
validation_split=从训练集划分多少比例给验证集,
validation_freq = 多少次epoch验证一次)
model.summary()

二、自定义Model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
定义网络结构块
def call(self, x):
调用网络结构块,实现前向传播
return y
model = MyModel()
例:
class IrisModel(Model):
def __init__(self):
super(IrisModel, self).__init__()
self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
def call(self, x):
y = self.d1(x)
return y
model = IrisModel()
三、tf.keras实现手写数字分类
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
运行结果如下:
60000/60000 [==============================] - 5s 89us/sample - loss: 0.0455 - sparse_categorical_accuracy: 0.9861 - val_loss: 0.0806 - val_sparse_categorical_accuracy: 0.9752
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) multiple 0
_________________________________________________________________
dense (Dense) multiple 100480
_________________________________________________________________
dense_1 (Dense) multiple 1290
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
边栏推荐
- Deep learning (self supervised: Moco V3): An Empirical Study of training self supervised vision transformers
- SQLAlchemy使用相关
- Service reliability guarantee -watchdog
- 深度学习(自监督:CPC v2)——Data-Efficient Image Recognition with Contrastive Predictive Coding
- What are the points for attention in the development and design of high-end atmospheric applets?
- 神经网络实现鸢尾花分类
- Distributed cluster architecture scenario optimization solution: distributed ID solution
- flutter webivew input唤起相机相册
- Automatic scheduled backup of remote MySQL scripts
- What should we pay attention to when making template application of wechat applet?
猜你喜欢

分布式集群架构场景优化解决方案:分布式调度问题

Applet development

微信小程序开发制作注意这几个重点方面

Kotlin语言现在怎么不火了?你怎么看?

The project does not report an error, operates normally, and cannot request services

深度学习(自监督:MoCo v2)——Improved Baselines with Momentum Contrastive Learning

How much does small program development cost? Analysis of two development methods!

深度学习(自监督:SimCLR)——A Simple Framework for Contrastive Learning of Visual Representations

Deep learning (incremental learning) - iccv2022:continuous continuous learning

微信小程序制作模板套用时需要注意什么呢?
随机推荐
What should we pay attention to when making template application of wechat applet?
word2vec和bert的基本使用方法
flutter webivew input唤起相机相册
What are the detailed steps of wechat applet development?
NLP中基于Bert的数据预处理
What is the detail of the applet development process?
The business of digital collections is not so easy to do
深度学习——Patches Are All You Need
Distributed cluster architecture scenario optimization solution: session sharing problem
Installing redis under Linux (centos7)
CertPathValidatorException:validity check failed
Idempotent component
Linux(centOs7) 下安装redis
【5】 Redis master-slave synchronization and redis sentinel (sentinel)
Automatic scheduled backup of remote MySQL scripts
Deep learning (self supervision: Moco V2) -- improved bases with momentum contractual learning
Xshell suddenly failed to connect to the virtual machine
Deep learning (incremental learning) - iccv2022:continuous continuous learning
如何选择小程序开发企业
深度学习(增量学习)——ICCV2022:Contrastive Continual Learning