当前位置:网站首页>4 custom model training
4 custom model training
2022-06-26 15:58:00 【X1996_】
Build the model ( Forward propagation of neural networks ) --> Define the loss function --> Define optimization functions --> Definition tape --> The model gets the predicted value --> Forward propagation gets loss --> Back propagation --> Use the optimization function to update the calculated gradient to the variable
Custom model training No evaluation function
import numpy as np
import tensorflow as tf
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# Define the layers you need
self.dense_1 = tf.keras.layers.Dense(32, activation='relu')
self.dense_2 = tf.keras.layers.Dense(num_classes)
def call(self, inputs):
# Define forward propagation
# Use in (in `__init__`) Defined layer
x = self.dense_1(inputs)
return self.dense_2(x)
model = MyModel(num_classes=10)
# Instantiate an optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = tf.keras.losses.CategoricalCrossentropy()
# Prepare the training dataset.
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# epoch
#batch_size
#tape Find gradient Gradient update
# Training
epochs = 10
for epoch in range(epochs):
#print('Start of epoch %d' % (epoch,))
# Traversing the data set batch_size
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# open GradientTape To record the operations that were run during the forward pass , This will enable automatic discrimination .
with tf.GradientTape() as tape:
# Run the forward propagation of the model . The operation of the model applied to its input will be recorded in GradientTape On .
logits = model(x_batch_train, training=True) # This minibatch The predicted value of
# Compute the minibatch The loss value of
loss_value = loss_fn(y_batch_train, logits)
# Use GradientTape Automatically obtain the gradient of the trainable variable relative to the loss .
grads = tape.gradient(loss_value, model.trainable_weights)
# Minimize the loss by updating the value of the variable , This performs a step of gradient descent .
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Every time 200 batches Print once .
print('Training loss %s epoch: %s' % (epoch, float(loss_value)))
Add evaluation function
import numpy as np
import tensorflow as tf
x_train = np.random.random((1000, 32))
y_train = np.random.random((1000, 10))
x_val = np.random.random((200, 32))
y_val = np.random.random((200, 10))
x_test = np.random.random((200, 32))
y_test = np.random.random((200, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# Define the layers you need
self.dense_1 = tf.keras.layers.Dense(32, activation='relu')
self.dense_2 = tf.keras.layers.Dense(num_classes)
def call(self, inputs):
# Define forward propagation
# Use in (in `__init__`) Defined layer
x = self.dense_1(inputs)
return self.dense_2(x)
# Optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# Loss function
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
# Get ready metrics function
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()
# Prepare the training data set
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the test data set
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)
model = MyModel(num_classes=10)
epochs = 10
for epoch in range(epochs):
print('Start of epoch %d' % (epoch,))
# Traversing the data set batch_size
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# One batch
with tf.GradientTape() as tape:
logits = model(x_batch_train)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))####
# Update the of the training set metrics
train_acc_metric(y_batch_train, logits)
# At every epoch Show at the end metrics.
train_acc = train_acc_metric.result()
# print('Training acc over epoch: %s' % (float(train_acc),))
# At every epoch Reset the training indicator at the end
train_acc_metric.reset_states()#!!!!!!!!!!!!!!!
# At every epoch Run a validation set at the end .
for x_batch_val, y_batch_val in val_dataset:
val_logits = model(x_batch_val)
# Update validation set merics
val_acc_metric(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
# print('Validation acc: %s' % (float(val_acc),))
val_acc_metric.reset_states()
print('Training_losses: %s Training_acc: %s Validation_acc: %s' % (float(loss_value), float(train_acc), float(val_acc)))
边栏推荐
- NFT 项目的开发、部署、上线的流程(1)
- 如何辨别合约问题
- 【leetcode】48. Rotate image
- [file] VFS four structs: file, dentry, inode and super_ What is a block? difference? Relationship-- Editing
- NFT 平台安全指南(1)
- Development, deployment and online process of NFT project (1)
- js创意图标导航菜单切换背景色
- 「干货」NFT 上中下游产业链全景分析
- Summary of students' learning career (2022)
- selenium chrome 禁用js 禁用图片
猜你喜欢

9 use of tensorboard

NFT Platform Security Guide (1)

10 tf.data

Canvas three dot flashing animation

3 keras版本模型训练

5000 word analysis: the way of container security attack and defense in actual combat scenarios

OpenSea上如何创建自己的NFT(Polygon)

Evaluation - TOPSIS

Audio and video learning (I) -- PTZ control principle

【问题解决】新版webots纹理等资源文件加载/下载时间过长
随机推荐
(一)keras手写数字体识别并识别自己写的数字
AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
我想知道如何通过线上股票开户?在线开户安全么?
Summer camp is coming!!! Chongchongchong
[file] VFS four structs: file, dentry, inode and super_ What is a block? difference? Relationship-- Editing
/etc/profile、/etc/bashrc、~/.bashrc的区别
El dialog drag and drop, the boundary problem is completely corrected, and the bug of the online version is fixed
音视频学习(一)——PTZ控制原理
JVM笔记
【leetcode】112. 路径总和 - 113. 路径总和 II
Why are encoder and decoder structures often used in image segmentation tasks?
Binding method of multiple sub control signal slots under QT
el-dialog拖拽,边界问题完全修正,网上版本的bug修复
Nanopi duo2 connection WiFi
如何配置使用新的单线激光雷达
Learning memory barrier
7 自定义损失函数
Stepn débutant et avancé
selenium chrome 禁用js 禁用图片
学习内存屏障