当前位置:网站首页>4 custom model training
4 custom model training
2022-06-26 15:58:00 【X1996_】
Build the model ( Forward propagation of neural networks ) --> Define the loss function --> Define optimization functions --> Definition tape --> The model gets the predicted value --> Forward propagation gets loss --> Back propagation --> Use the optimization function to update the calculated gradient to the variable
Custom model training No evaluation function
import numpy as np
import tensorflow as tf
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# Define the layers you need
self.dense_1 = tf.keras.layers.Dense(32, activation='relu')
self.dense_2 = tf.keras.layers.Dense(num_classes)
def call(self, inputs):
# Define forward propagation
# Use in (in `__init__`) Defined layer
x = self.dense_1(inputs)
return self.dense_2(x)
model = MyModel(num_classes=10)
# Instantiate an optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = tf.keras.losses.CategoricalCrossentropy()
# Prepare the training dataset.
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# epoch
#batch_size
#tape Find gradient Gradient update
# Training
epochs = 10
for epoch in range(epochs):
#print('Start of epoch %d' % (epoch,))
# Traversing the data set batch_size
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# open GradientTape To record the operations that were run during the forward pass , This will enable automatic discrimination .
with tf.GradientTape() as tape:
# Run the forward propagation of the model . The operation of the model applied to its input will be recorded in GradientTape On .
logits = model(x_batch_train, training=True) # This minibatch The predicted value of
# Compute the minibatch The loss value of
loss_value = loss_fn(y_batch_train, logits)
# Use GradientTape Automatically obtain the gradient of the trainable variable relative to the loss .
grads = tape.gradient(loss_value, model.trainable_weights)
# Minimize the loss by updating the value of the variable , This performs a step of gradient descent .
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Every time 200 batches Print once .
print('Training loss %s epoch: %s' % (epoch, float(loss_value)))
Add evaluation function
import numpy as np
import tensorflow as tf
x_train = np.random.random((1000, 32))
y_train = np.random.random((1000, 10))
x_val = np.random.random((200, 32))
y_val = np.random.random((200, 10))
x_test = np.random.random((200, 32))
y_test = np.random.random((200, 10))
class MyModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# Define the layers you need
self.dense_1 = tf.keras.layers.Dense(32, activation='relu')
self.dense_2 = tf.keras.layers.Dense(num_classes)
def call(self, inputs):
# Define forward propagation
# Use in (in `__init__`) Defined layer
x = self.dense_1(inputs)
return self.dense_2(x)
# Optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# Loss function
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
# Get ready metrics function
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()
# Prepare the training data set
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the test data set
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)
model = MyModel(num_classes=10)
epochs = 10
for epoch in range(epochs):
print('Start of epoch %d' % (epoch,))
# Traversing the data set batch_size
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# One batch
with tf.GradientTape() as tape:
logits = model(x_batch_train)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))####
# Update the of the training set metrics
train_acc_metric(y_batch_train, logits)
# At every epoch Show at the end metrics.
train_acc = train_acc_metric.result()
# print('Training acc over epoch: %s' % (float(train_acc),))
# At every epoch Reset the training indicator at the end
train_acc_metric.reset_states()#!!!!!!!!!!!!!!!
# At every epoch Run a validation set at the end .
for x_batch_val, y_batch_val in val_dataset:
val_logits = model(x_batch_val)
# Update validation set merics
val_acc_metric(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
# print('Validation acc: %s' % (float(val_acc),))
val_acc_metric.reset_states()
print('Training_losses: %s Training_acc: %s Validation_acc: %s' % (float(loss_value), float(train_acc), float(val_acc)))
边栏推荐
- Comprehensive analysis of discord security issues
- How to configure and use the new single line lidar
- OpenSea上如何创建自己的NFT(Polygon)
- 10 tf.data
- Selenium chrome disable JS disable pictures
- PCIe Capabilities List
- Selenium saves elements as pictures
- Summary of students' learning career (2022)
- Solana capacity expansion mechanism analysis (2): an extreme attempt to sacrifice availability for efficiency | catchervc research
- 【leetcode】701. Insert operation in binary search tree
猜你喜欢

SVG大写字母A动画js特效

Particle filter PF - 3D CV target tracking with uniform motion (particle filter vs extended Kalman filter)

Transformation of zero knowledge QAP problem

Svg canvas canvas drag

3 keras版本模型训练

Audio and video learning (III) -- SIP protocol

NFT Platform Security Guide (1)

HW safety response
![[file] VFS four structs: file, dentry, inode and super_ What is a block? difference? Relationship-- Editing](/img/b6/d288065747425863b9af95ec6fd554.png)
[file] VFS four structs: file, dentry, inode and super_ What is a block? difference? Relationship-- Editing

如何辨别合约问题
随机推荐
Audio and video learning (I) -- PTZ control principle
Beijing Fangshan District specialized special new small giant enterprise recognition conditions, with a subsidy of 500000 yuan
nanoPi Duo2连接wifi
我想知道如何通过线上股票开户?在线开户安全么?
JS text scrolling scattered animation JS special effect
[C language practice - printing hollow upper triangle and its deformation]
Seurat to h5ad summary
Selenium saves elements as pictures
Evaluate:huggingface detailed introduction to the evaluation index module
零知识 QAP 问题的转化
NFT 项目的开发、部署、上线的流程(1)
5 模型保存与加载
Development, deployment and online process of NFT project (2)
How to create your own NFT (polygon) on opensea
[graduation season · advanced technology Er] what is a wechat applet, which will help you open the door of the applet
手写数字体识别,用保存的模型跑自己的图片
音视频学习(一)——PTZ控制原理
现在券商的优惠开户政策是什么?现在在线开户安全么?
C语言读取数据
【leetcode】48. Rotate image