当前位置:网站首页>基于tensorflow搭建神经网络
基于tensorflow搭建神经网络
2022-07-28 05:22:00 【积雨辋川】
基于tensorflow搭建神经网络
一、tf.keras搭建神经网络步骤
六步法
- import
- train,test
- model = tf.keras.models.Sequential
- model.compile
- model.fit
- model.summary
models.Sequential()
model = tf.keras.models.Sequential ([ 网络结构 ]) #描述各层网络
网络结构举例:
- 拉直层:
tf.keras.layers.Flatten( )
- 全连接层:
tf.keras.layers.Dense(神经元个数, activation= “激活函数”,kernel_regularizer=哪种正则化)
activation(字符串给出)可选: relu、 softmax、 sigmoid 、 tanh
kernel_regularizer可选: tf.keras.regularizers.l1()、tf.keras.regularizers.l2()
- 卷积层:
tf.keras.layers.Conv2D(filters = 卷积核个数, kernel_size = 卷积核尺寸,strides = 卷积步长, padding = " valid" or “same”)
- LSTM层: tf.keras.layers.LSTM()
model.compile()
model.compile(optimizer = 优化器,
loss = 损失函数
metrics = [“准确率”] )
Optimizer可选:
‘sgd’ or tf.keras.optimizers.SGD (lr=学习率,momentum=动量参数)
‘adagrad’ or tf.keras.optimizers.Adagrad (lr=学习率)
‘adadelta’ or tf.keras.optimizers.Adadelta (lr=学习率)
‘adam’ or tf.keras.optimizers.Adam (lr=学习率, beta_1=0.9, beta_2=0.999)
loss可选:
‘mse’ or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy’ or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
Metrics可选:
‘accuracy’ :y_和y都是数值,如y_=[1] y=[1]
‘categorical_accuracy’ :y_和y都是独热码(概率分布),如y_=[0,1,0] y=[0.256,0.695,0.048]
‘sparse_categorical_accuracy’ :y_是数值,y是独热码(概率分布),如y_=[1] y=[0.256,0.695,0.048]
model.fit()
model.fit (训练集的输入特征, 训练集的标签,
batch_size= , epochs= ,
validation_data=(验证集的输入特征,验证集的标签),
validation_split=从训练集划分多少比例给验证集,
validation_freq = 多少次epoch验证一次)
model.summary()

二、自定义Model
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
定义网络结构块
def call(self, x):
调用网络结构块,实现前向传播
return y
model = MyModel()
例:
class IrisModel(Model):
def __init__(self):
super(IrisModel, self).__init__()
self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
def call(self, x):
y = self.d1(x)
return y
model = IrisModel()
三、tf.keras实现手写数字分类
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
运行结果如下:
60000/60000 [==============================] - 5s 89us/sample - loss: 0.0455 - sparse_categorical_accuracy: 0.9861 - val_loss: 0.0806 - val_sparse_categorical_accuracy: 0.9752
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) multiple 0
_________________________________________________________________
dense (Dense) multiple 100480
_________________________________________________________________
dense_1 (Dense) multiple 1290
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
边栏推荐
- Digital collections "chaos", 100 billion market change is coming?
- Service reliability guarantee -watchdog
- XML parsing entity tool class
- Automatic scheduled backup of remote MySQL scripts
- transformer的理解
- 【1】 Introduction to redis
- tf.keras搭建神经网络功能扩展
- Linux(centOs7) 下安装redis
- 速查表之转MD5
- Deep learning (self supervision: simple Siam) -- Exploring simple Siamese representation learning
猜你喜欢

3: MySQL master-slave replication setup

Service reliability guarantee -watchdog

Applet development

Distributed cluster architecture scenario optimization solution: distributed ID solution

深度学习(自监督:CPC v2)——Data-Efficient Image Recognition with Contrastive Predictive Coding

How much does it cost to make a small program mall? What are the general expenses?

面试官:让你设计一套图片加载框架,你会怎么设计?

高端大气的小程序开发设计有哪些注意点?

微信小程序开发费用制作费用是多少?

深度学习(自监督:MoCo V3):An Empirical Study of Training Self-Supervised Vision Transformers
随机推荐
uView上传组件upload上传auto-upload模式图片压缩
强化学习——价值学习中的SARSA
Record the problems encountered in online capacity expansion server nochange: partition 1 is size 419428319. It cannot be grown
transformer的理解
mysql5.6(根据.ibd,.frm文件)恢复单表数据
【3】 Redis features and functions
Utils commonly used in NLP
Deep learning (self supervision: Moco V2) -- improved bases with momentum contractual learning
Alpine, Debian replacement source
Quick look-up table to MD5
神经网络实现鸢尾花分类
Micro service architecture cognition and service governance Eureka
小程序搭建制作流程是怎样的?
svn incoming内容无法更新下来,且提交报错:svn: E155015: Aborting commit: XXX remains in conflict
Regular verification rules of wechat applet mobile number
Deep learning - patches are all you need
2: Why read write separation
Kotlin语言现在怎么不火了?你怎么看?
uniapp webview监听页面加载后回调
vscode uniapp