当前位置:网站首页>6 custom layer
6 custom layer
2022-06-26 15:58:00 【X1996_】
The custom layer name should not be the same as the self-contained layer name
from sklearn import datasets
import tensorflow as tf
import numpy as np
iris = datasets.load_iris()
data = iris.data
labels = iris.target
# Define a full connectivity layer
class MyDense(tf.keras.layers.Layer):
def __init__(self, units=32, **kwargs):
self.units = units
super(MyDense, self).__init__(**kwargs)
# build Methods are generally defined as Layer Parameters that need to be trained
# trainable=True Get involved in training False Don't take part in training
# name Need to name , Otherwise, an error will occur when saving the model
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True,
name='w')
self.b = self.add_weight(shape=(self.units,),
initializer='random_normal',
trainable=True,
name='b')
super(MyDense,self).build(input_shape) # It's equivalent to setting self.built = True
#call Methods generally define forward propagation operation logic ,__call__ Method calls it .
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
# If you want a custom Layer adopt Functional API When combined into a model, you can serialize , You need to customize it get_config Method .
# Model cannot be saved without definition
def get_config(self):
config = super(MyDense, self).get_config()
config.update({
'units': self.units})
return config
# Functional programming
inputs = tf.keras.Input(shape=(4,))
x = MyDense(units=16)(inputs) # The number of neurons is set to 16
x = tf.nn.tanh(x) # The full connection layer is followed by an activation function
x = tf.keras.layers.Dense(8)(x)
x = tf.nn.relu(x)
x = MyDense(units=3)(x) # Three categories
predictions = tf.nn.softmax(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# Upset
data = np.concatenate((data,labels.reshape(150,1)),axis=-1)
np.random.shuffle(data)
labels = data[:,-1]
data = data[:,:4]
# Optimizer Adam
# Loss function Cross entropy loss function
# Evaluation function #acc
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
# Training
model.fit(data, labels, batch_size=32, epochs=100,shuffle=True)
Show network structure
model.summary()
Save the model
model.save('keras_model_tf_version.h5')
Load model predictions
# Add the custom layer name to the dictionary before loading the model
# Need to put MyDense The network can be defined only when it is written
_custom_objects = {
"MyDense" : MyDense
}
new_model = tf.keras.models.load_model("keras_model_tf_version.h5",custom_objects=_custom_objects)
y_pred = new_model.predict(data)
np.argmax(y_pred,axis=1)
边栏推荐
- golang 1.18 go work 使用
- Beijing Fangshan District specialized special new small giant enterprise recognition conditions, with a subsidy of 500000 yuan
- NFT Platform Security Guide (1)
- AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy
- 8 自定义评估函数
- Solana扩容机制分析(1):牺牲可用性换取高效率的极端尝试 | CatcherVC Research
- Comprehensive analysis of discord security issues
- Super double efficiency! Pycharm ten tips
- feil_ The working directory on the left of uvission4 disappears
- Unable to download Plug-in after idea local agent
猜你喜欢
随机推荐
Handwritten numeral recognition, run your own picture with the saved model
Anaconda3 installation tensorflow version 2.0 CPU and GPU installation, win10 system
4 custom model training
nanoPi Duo2连接wifi
Analyse panoramique de la chaîne industrielle en amont, en aval et en aval de la NFT « Dry goods»
NFT 项目的开发、部署、上线的流程(1)
3. Keras version model training
反射修改final
How to identify contractual issues
NFT 平台安全指南(2)
Nanopi duo2 connection WiFi
2022 Beijing Shijingshan District specializes in the application process for special new small and medium-sized enterprises, with a subsidy of 100000-200000 yuan
SVG大写字母A动画js特效
手机上怎么开户?在线开户安全么?
NFT transaction principle analysis (2)
svg上升的彩色气泡动画
7 自定义损失函数
如何配置使用新的单线激光雷达
有Cmake的工程交叉编译到链接时报错找不到.so动态库文件
svg野人动画代码









