当前位置:网站首页>Machine learning notes - spatial transformer network using tensorflow
Machine learning notes - spatial transformer network using tensorflow
2022-06-11 09:08:00 【Sit and watch the clouds rise】
One 、 summary
For machine learning models , The full pixel orientation will be considered , Not just numbers . After all , The machine learning model is a matrix . therefore , The values of these weight matrices are formed according to the complete pixel direction of the input data .
For example, the following two images , For numbers without rotation ,CNN The prediction is correct , The prediction may fail after image rotation , explain CNN It is not rotation invariant , This may not be a big problem in simple tests or experiments , However, the uncertainty of the real world may lead to many problems when it is applied in reality .
Fast researchers in the field of deep learning have come up with a solution , namely Spatial Transformer Networks.


Spatial Transformer The main purpose behind the network module is to help our model Select the most relevant image ROI. Once the model successfully calculates the relevant pixels , The spatial transformation module will help the model decide what transformation is required for the image to become a standard format .
The model must find a transformation , So that it can predict the correct label of the image . This may not be a transformation we can understand , But as long as the loss function can be reduced , It applies to the model .
Created the entire module , So that the model can access various transformations according to the situation : translation 、 tailoring 、 Isotropic and inclined . You will learn more about it in the next section .
Think of the space transformer module as an add-on to your model . It applies a specific spatial transformation to the feature map during the forward transfer according to the input requirements . So for a particular input , We have an output characteristic graph .
It will help us model decisions like chart 2 Medium 7 What kind of conversion is required for such input . For multichannel input ( for example ,RGB Images ), The same changes apply to all three channels ( To maintain spatial consistency ). most important of all , This module will learn along with other model weights ( It is differentiable ).
Shows the space converter module , It's divided into three parts : Location networks 、 Grid generator and sampler .

Location networks : The network accepts a width of W、 The height is H And channel C Input characteristic diagram of U. Its job is to output
, The transformation parameters to be applied to the characteristic graph . A location network can be anything : Fully connected network or convolutional network .
Parametric sampling grid : We have the parameters of the transformation
. Suppose our input characteristic graph is U, Pictured 4 Shown . We can see U It's the number. 9 Rotated version of . Our output characteristic graph V It's a square grid . So we already know its index ( namely , A normal rectangular grid ).

Sampler : Now we have coordinates , Output characteristic map V Values will be estimated using only our input pixel values . in other words , We will use the input pixels to perform a linear... On the output pixels / Bilinear interpolation . Bilinear interpolation uses the nearest pixel value , These pixel values are located in the diagonal direction of a given position , In order to find the appropriate color intensity value for the pixel .
Two 、 create profile
# import the necessary packages
from tensorflow.data import AUTOTUNE
import os
# define AUTOTUNE
AUTO = AUTOTUNE
# define the image height, width and channel size
IMAGE_HEIGHT = 28
IMAGE_WIDTH = 28
CHANNEL = 1
# define the dataset path, dataset name, and the batch size
DATASET_PATH = "dataset"
DATASET_NAME = "emnist"
BATCH_SIZE = 1024
# define the number of epochs
EPOCHS = 100
# define the conv filters
FILTERS = 256
# define an output directory
OUTPUT_PATH = "output"
# define the loss function and optimizer
LOSS_FN = "sparse_categorical_crossentropy"
OPTIMIZER = "adam"
# define the name of the gif
GIF_NAME = "stn.gif"
# define the number of classes for classification
CLASSES = 62
# define the stn layer name
STN_LAYER_NAME = "stn"3、 ... and 、 For export GIF Create callback function
This function will help us build our GIF. It also allows us to track the progress of the model during training .
# import the necessary packages
from tensorflow.keras.callbacks import Callback
from tensorflow.keras import Model
import matplotlib.pyplot as plt
def get_train_monitor(testDs, outputPath, stnLayerName):
# iterate over the test dataset and take a batch of test images
(testImg, _) = next(iter(testDs))
# define a training monitor
class TrainMonitor(Callback):
def on_epoch_end(self, epoch, logs=None):
model = Model(self.model.input,
self.model.get_layer(stnLayerName).output)
testPred = model(testImg)
# plot the image and the transformed image
_, axes = plt.subplots(nrows=5, ncols=2, figsize=(5, 10))
for ax, im, t_im in zip(axes, testImg[:5], testPred[:5]):
ax[0].imshow(im[..., 0], cmap="gray")
ax[0].set_title(epoch)
ax[0].axis("off")
ax[1].imshow(t_im[..., 0], cmap="gray")
ax[1].set_title(epoch)
ax[1].axis("off")
# save the figures
plt.savefig(f"{outputPath}/{epoch:03d}")
plt.close()
# instantiate the training monitor callback
trainMonitor = TrainMonitor()
# return the training monitor object
return trainMonitorFour 、Spatial Transformer modular
To attach the spatial converter module to our main model , We create a separate script , It contains all the necessary auxiliary functions and the main layer .
- The input image will provide us with the required conversion parameters .
- Mapping input feature map from output feature map .
- Bilinear interpolation is applied to estimate the pixel value of the output feature image .
# import the necessary packages
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Layer
import tensorflow as tf
def get_pixel_value(B, H, W, featureMap, x, y):
# create batch indices and reshape it
batchIdx = tf.range(0, B)
batchIdx = tf.reshape(batchIdx, (B, 1, 1))
# create the indices matrix which will be used to sample the
# feature map
b = tf.tile(batchIdx, (1, H, W))
indices = tf.stack([b, y, x], 3)
# gather the feature map values for the corresponding indices
gatheredPixelValue = tf.gather_nd(featureMap, indices)
# return the gather pixel values
return gatheredPixelValue
def affine_grid_generator(B, H, W, theta):
# create normalized 2D grid
x = tf.linspace(-1.0, 1.0, H)
y = tf.linspace(-1.0, 1.0, W)
(xT, yT) = tf.meshgrid(x, y)
# flatten the meshgrid
xTFlat = tf.reshape(xT, [-1])
yTFlat = tf.reshape(yT, [-1])
# reshape the meshgrid and concatenate ones to convert it to
# homogeneous form
ones = tf.ones_like(xTFlat)
samplingGrid = tf.stack([xTFlat, yTFlat, ones])
# repeat grid batch size times
samplingGrid = tf.broadcast_to(samplingGrid, (B, 3, H * W))
# cast the affine parameters and sampling grid to float32
# required for matmul
theta = tf.cast(theta, "float32")
samplingGrid = tf.cast(samplingGrid, "float32")
# transform the sampling grid with the affine parameter
batchGrids = tf.matmul(theta, samplingGrid)
# reshape the sampling grid to (B, H, W, 2)
batchGrids = tf.reshape(batchGrids, [B, 2, H, W])
# return the transformed grid
return batchGrids
def bilinear_sampler(B, H, W, featureMap, x, y):
# define the bounds of the image
maxY = tf.cast(H - 1, "int32")
maxX = tf.cast(W - 1, "int32")
zero = tf.zeros([], dtype="int32")
# rescale x and y to feature spatial dimensions
x = tf.cast(x, "float32")
y = tf.cast(y, "float32")
x = 0.5 * ((x + 1.0) * tf.cast(maxX-1, "float32"))
y = 0.5 * ((y + 1.0) * tf.cast(maxY-1, "float32"))
# grab 4 nearest corner points for each (x, y)
x0 = tf.cast(tf.floor(x), "int32")
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), "int32")
y1 = y0 + 1
# clip to range to not violate feature map boundaries
x0 = tf.clip_by_value(x0, zero, maxX)
x1 = tf.clip_by_value(x1, zero, maxX)
y0 = tf.clip_by_value(y0, zero, maxY)
y1 = tf.clip_by_value(y1, zero, maxY)
# get pixel value at corner coords
Ia = get_pixel_value(B, H, W, featureMap, x0, y0)
Ib = get_pixel_value(B, H, W, featureMap, x0, y1)
Ic = get_pixel_value(B, H, W, featureMap, x1, y0)
Id = get_pixel_value(B, H, W, featureMap, x1, y1)
# recast as float for delta calculation
x0 = tf.cast(x0, "float32")
x1 = tf.cast(x1, "float32")
y0 = tf.cast(y0, "float32")
y1 = tf.cast(y1, "float32")
# calculate deltas
wa = (x1-x) * (y1-y)
wb = (x1-x) * (y-y0)
wc = (x-x0) * (y1-y)
wd = (x-x0) * (y-y0)
# add dimension for addition
wa = tf.expand_dims(wa, axis=3)
wb = tf.expand_dims(wb, axis=3)
wc = tf.expand_dims(wc, axis=3)
wd = tf.expand_dims(wd, axis=3)
# compute transformed feature map
transformedFeatureMap = tf.add_n(
[wa * Ia, wb * Ib, wc * Ic, wd * Id])
# return the transformed feature map
return transformedFeatureMap
class STN(Layer):
def __init__(self, name, filter):
# initialize the layer
super().__init__(name=name)
self.B = None
self.H = None
self.W = None
self.C = None
# create the constant bias initializer
self.output_bias = tf.keras.initializers.Constant(
[1.0, 0.0, 0.0,
0.0, 1.0, 0.0]
)
# define the filter size
self.filter = filter
def build(self, input_shape):
# get the batch size, height, width and channel size of the
# input
(self.B, self.H, self.W, self.C) = input_shape
# define the localization network
self.localizationNet = Sequential([
Conv2D(filters=self.filter // 4, kernel_size=3,
input_shape=(self.H, self.W, self.C),
activation="relu", kernel_initializer="he_normal"),
MaxPool2D(),
Conv2D(filters=self.filter // 2, kernel_size=3,
activation="relu", kernel_initializer="he_normal"),
MaxPool2D(),
Conv2D(filters=self.filter, kernel_size=3,
activation="relu", kernel_initializer="he_normal"),
MaxPool2D(),
GlobalAveragePooling2D()
])
# define the regressor network
self.regressorNet = tf.keras.Sequential([
Dense(units = self.filter, activation="relu",
kernel_initializer="he_normal"),
Dense(units = self.filter // 2, activation="relu",
kernel_initializer="he_normal"),
Dense(units = 3 * 2, kernel_initializer="zeros",
bias_initializer=self.output_bias),
Reshape(target_shape=(2, 3))
])
def call(self, x):
# get the localization feature map
localFeatureMap = self.localizationNet(x)
# get the regressed parameters
theta = self.regressorNet(localFeatureMap)
# get the transformed meshgrid
grid = affine_grid_generator(self.B, self.H, self.W, theta)
# get the x and y coordinates from the transformed meshgrid
xS = grid[:, 0, :, :]
yS = grid[:, 1, :, :]
# get the transformed feature map
x = bilinear_sampler(self.B, self.H, self.W, x, xS, yS)
# return the transformed feature map
return x5、 ... and 、 Create a classification model
# import the necessary packages
from tensorflow.keras import Input
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
import tensorflow as tf
def get_training_model(batchSize, height, width, channel, stnLayer,
numClasses, filter):
# define the input layer and pass the input through the STN
# layer
inputs = Input((height, width, channel), batch_size=batchSize)
x = Lambda(lambda image: tf.cast(image, "float32")/255.0)(inputs)
x = stnLayer(x)
# apply a series of conv and maxpool layers
x = Conv2D(filter // 4, 3, activation="relu",
kernel_initializer="he_normal")(x)
x = MaxPool2D()(x)
x = Conv2D(filter // 2, 3, activation="relu",
kernel_initializer="he_normal")(x)
x = MaxPool2D()(x)
x = Conv2D(filter, 3, activation="relu",
kernel_initializer="he_normal")(x)
x = MaxPool2D()(x)
# global average pool the output of the previous layer
x = GlobalAveragePooling2D()(x)
# pass the flattened output through a couple of dense layers
x = Dense(filter, activation="relu",
kernel_initializer="he_normal")(x)
x = Dense(filter // 2, activation="relu",
kernel_initializer="he_normal")(x)
# apply dropout for better regularization
x = Dropout(0.5)(x)
# apply softmax to the output for a multi-classification task
outputs = Dense(numClasses, activation="softmax")(x)
# return the model
return Model(inputs, outputs)6、 ... and 、 Training models
# USAGE
# python train.py
# setting seed for reproducibility
import tensorflow as tf
tf.random.set_seed(42)
# import the necessary packages
from pyimagesearch.stn import STN
from pyimagesearch.classification_model import get_training_model
from pyimagesearch.callback import get_train_monitor
from pyimagesearch import config
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow_datasets as tfds
import os
# load the train and test dataset
print("[INFO] loading the train and test dataset...")
trainingDs = tfds.load(name=config.DATASET_NAME,
data_dir=config.DATASET_PATH, split="train", shuffle_files=True,
as_supervised=True)
testingDs = tfds.load(name=config.DATASET_NAME,
data_dir=config.DATASET_PATH, split="test", as_supervised=True)
# preprocess the train and test dataset
print("[INFO] preprocessing the train and test dataset...")
trainDs = (
trainingDs
.shuffle(config.BATCH_SIZE*100)
.batch(config.BATCH_SIZE, drop_remainder=True)
.prefetch(config.AUTO)
)
testDs = (
testingDs
.batch(config.BATCH_SIZE, drop_remainder=True)
.prefetch(config.AUTO)
)
# initialize the stn layer
print("[INFO] initializing the stn layer...")
stnLayer = STN(name=config.STN_LAYER_NAME, filter=config.FILTERS)
# get the classification model for cifar10
print("[INFO] grabbing the multiclass classification model...")
model = get_training_model(batchSize=config.BATCH_SIZE,
height=config.IMAGE_HEIGHT, width=config.IMAGE_WIDTH,
channel=config.CHANNEL, stnLayer=stnLayer,
numClasses=config.CLASSES, filter=config.FILTERS)
# print the model summary
print("[INFO] the model summary...")
print(model.summary())
# create an output images directory if it not already exists
if not os.path.exists(config.OUTPUT_PATH):
os.makedirs(config.OUTPUT_PATH)
# get the training monitor
trainMonitor = get_train_monitor(testDs=testDs,
outputPath=config.OUTPUT_PATH, stnLayerName=config.STN_LAYER_NAME)
# compile the model
print("[INFO] compiling the model...")
model.compile(loss=config.LOSS_FN, optimizer=config.OPTIMIZER,
metrics=["accuracy"])
# define an early stopping callback
esCallback = EarlyStopping(patience=5, restore_best_weights=True)
# train the model
print("[INFO] training the model...")
model.fit(trainDs, epochs=config.EPOCHS,
callbacks=[trainMonitor, esCallback], validation_data=testDs)7、 ... and 、 Train the model and visualize
[INFO] compiling the model...
[INFO] training the model...
Epoch 1/100
681/681 [==============================] - 104s 121ms/step - loss: 0.9146 - accuracy: 0.7350 - val_loss: 0.4381 - val_accuracy: 0.8421
Epoch 2/100
681/681 [==============================] - 84s 118ms/step - loss: 0.4705 - accuracy: 0.8392 - val_loss: 0.4064 - val_accuracy: 0.8526
Epoch 3/100
681/681 [==============================] - 85s 119ms/step - loss: 0.4258 - ...
Epoch 16/100
681/681 [==============================] - 85s 119ms/step - loss: 0.3192 - accuracy: 0.8794 - val_loss: 0.3483 - val_accuracy: 0.8725
Epoch 17/100
681/681 [==============================] - 85s 119ms/step - loss: 0.3151 - accuracy: 0.8803 - val_loss: 0.3487 - val_accuracy: 0.8736
Epoch 18/100
681/681 [==============================] - 85s 118ms/step - loss: 0.3113 - accuracy: 0.8814 - val_loss: 0.3503 - val_accuracy: 0.8719We can see , here we are 18 epoch, Stopped taking effect in advance and stopped model training . result , The final training and verification accuracy rates are 88.14% and 87.19%.

stay chart 6 in , We can see in every epoch A number of numbers have been converted in . You might notice , These transformations are not easy for the human eye to understand . Switching doesn't always happen based on our brain's perception of numbers . It depends on which conversion can reduce the loss . So it may not be something we can understand very well , But if it applies to the loss function , It is good enough for the model .
边栏推荐
- 682. baseball game
- 2161. 根据给定数字划分数组
- 2130. maximum twin sum of linked list
- PHP uploading large files for more than 40 seconds server 500
- M1 chip guide: M1, M1 pro, M1 Max and M1 ultra
- SQL基本查询
- Livedata and stateflow, which should I use?
- [C language - Advanced pointer] mining deeper knowledge of pointer
- Leveldb simple use example
- 【服装ERP】施行在项目中的重要性
猜你喜欢

【新手上路常见问答】关于数据可视化

Typescript high level feature 1 - merge type (&)

Vagrant mounting pit
![[C language - data storage] how is data stored in memory?](/img/cb/2d0cc83fd77de7179a9c45655c1a2d.png)
[C language - data storage] how is data stored in memory?

Screening frog log file analyzer Chinese version installation tutorial

Port occupancy problem, 10000 ports

实现边充边OTG的PD芯片GA670-10

Android 面试笔录(精心整理篇)

Notes on MySQL core points

Exclusive interview with PMC member Liu Yu: female leadership in Apache pulsar community
随机推荐
【分享】企业如何进行施行规划?
剑指 Offer 06. 从尾到头打印链表
【服装ERP】施行在项目中的重要性
远程办公最佳实践及策略
Strength and appearance Coexist -- an exclusive interview with Liu Yu, a member of Apache pulsar PMC
1400. 构造 K 个回文字符串
shell脚本之sed详解 (sed命令 , sed -e , sed s/ new / old / ... )
【237. 删除链表中的节点】
1854. 人口最多的年份
【C语言-指针进阶】挖掘指针更深一层的知识
Pulsar job Plaza | Tencent, Huawei cloud, shrimp skin, Zhong'an insurance, streamnational and other hot jobs
Complexity analysis of matrix inversion operation (complexity analysis of inverse matrix)
682. 棒球比赛
multiplication table
445. adding two numbers II
Leveldb simple use example
83. 删除排序链表中的重复元素
844. compare strings with backspace
小型制氧机解决方案PCBA电路板开发
Type-C蓝牙音箱单口可充可OTG方案