当前位置:网站首页>Machine learning notes - spatial transformer network using tensorflow

Machine learning notes - spatial transformer network using tensorflow

2022-06-11 09:08:00 Sit and watch the clouds rise

One 、 summary

         For machine learning models , The full pixel orientation will be considered , Not just numbers . After all , The machine learning model is a matrix . therefore , The values of these weight matrices are formed according to the complete pixel direction of the input data .

        For example, the following two images , For numbers without rotation ,CNN The prediction is correct , The prediction may fail after image rotation , explain CNN It is not rotation invariant , This may not be a big problem in simple tests or experiments , However, the uncertainty of the real world may lead to many problems when it is applied in reality .

         Fast researchers in the field of deep learning have come up with a solution , namely Spatial Transformer Networks.

chart 1: standard CNN Digital digital predictor

chart 2: standard CNN Digital predictor failed

        Spatial Transformer The main purpose behind the network module is to help our model Select the most relevant image ROI. Once the model successfully calculates the relevant pixels , The spatial transformation module will help the model decide what transformation is required for the image to become a standard format .

         The model must find a transformation , So that it can predict the correct label of the image . This may not be a transformation we can understand , But as long as the loss function can be reduced , It applies to the model .

         Created the entire module , So that the model can access various transformations according to the situation : translation 、 tailoring 、 Isotropic and inclined . You will learn more about it in the next section .

         Think of the space transformer module as an add-on to your model . It applies a specific spatial transformation to the feature map during the forward transfer according to the input requirements . So for a particular input , We have an output characteristic graph .

         It will help us model decisions like chart 2 Medium 7 What kind of conversion is required for such input . For multichannel input ( for example ,RGB Images ), The same changes apply to all three channels ( To maintain spatial consistency ). most important of all , This module will learn along with other model weights ( It is differentiable ).

         Shows the space converter module , It's divided into three parts : Location networks 、 Grid generator and sampler .

chart 3: Space transformer

         Location networks : The network accepts a width of W、 The height is H And channel C Input characteristic diagram of U. Its job is to output \theta, The transformation parameters to be applied to the characteristic graph . A location network can be anything : Fully connected network or convolutional network .

         Parametric sampling grid : We have the parameters of the transformation \theta.  Suppose our input characteristic graph is U, Pictured 4 Shown . We can see U It's the number. 9 Rotated version of . Our output characteristic graph V It's a square grid . So we already know its index ( namely , A normal rectangular grid ).

         Sampler : Now we have coordinates , Output characteristic map V Values will be estimated using only our input pixel values . in other words , We will use the input pixels to perform a linear... On the output pixels / Bilinear interpolation . Bilinear interpolation uses the nearest pixel value , These pixel values are located in the diagonal direction of a given position , In order to find the appropriate color intensity value for the pixel .

Two 、 create profile

# import the necessary packages
from tensorflow.data import AUTOTUNE
import os
# define AUTOTUNE
AUTO = AUTOTUNE
# define the image height, width and channel size
IMAGE_HEIGHT = 28
IMAGE_WIDTH = 28
CHANNEL = 1
# define the dataset path, dataset name, and the batch size
DATASET_PATH = "dataset"
DATASET_NAME = "emnist"
BATCH_SIZE = 1024
# define the number of epochs
EPOCHS = 100
# define the conv filters
FILTERS = 256
# define an output directory
OUTPUT_PATH = "output"
# define the loss function and optimizer
LOSS_FN = "sparse_categorical_crossentropy"
OPTIMIZER = "adam"
# define the name of the gif
GIF_NAME = "stn.gif"
# define the number of classes for classification
CLASSES = 62
# define the stn layer name
STN_LAYER_NAME = "stn"

3、 ... and 、 For export GIF Create callback function

         This function will help us build our GIF. It also allows us to track the progress of the model during training .

# import the necessary packages
from tensorflow.keras.callbacks import Callback
from tensorflow.keras import Model
import matplotlib.pyplot as plt
def get_train_monitor(testDs, outputPath, stnLayerName):
	# iterate over the test dataset and take a batch of test images
	(testImg, _) = next(iter(testDs))
	# define a training monitor
	class TrainMonitor(Callback):
		def on_epoch_end(self, epoch, logs=None):
			model = Model(self.model.input,
				self.model.get_layer(stnLayerName).output)
			testPred = model(testImg)
			# plot the image and the transformed image
			_, axes = plt.subplots(nrows=5, ncols=2, figsize=(5, 10))
			for ax, im, t_im  in zip(axes, testImg[:5], testPred[:5]):
				ax[0].imshow(im[..., 0], cmap="gray")
				ax[0].set_title(epoch)
				ax[0].axis("off")
				ax[1].imshow(t_im[..., 0], cmap="gray")
				ax[1].set_title(epoch)
				ax[1].axis("off")
			
			# save the figures
			plt.savefig(f"{outputPath}/{epoch:03d}")
			plt.close()
	
	# instantiate the training monitor callback
	trainMonitor = TrainMonitor()
	# return the training monitor object
	return trainMonitor

Four 、Spatial Transformer modular

         To attach the spatial converter module to our main model , We create a separate script , It contains all the necessary auxiliary functions and the main layer .

  • The input image will provide us with the required conversion parameters .
  • Mapping input feature map from output feature map .
  • Bilinear interpolation is applied to estimate the pixel value of the output feature image .
# import the necessary packages
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Layer
import tensorflow as tf
def get_pixel_value(B, H, W, featureMap, x, y):
	# create batch indices and reshape it
	batchIdx = tf.range(0, B)
	batchIdx = tf.reshape(batchIdx, (B, 1, 1))
	# create the indices matrix which will be used to sample the 
	# feature map
	b = tf.tile(batchIdx, (1, H, W))
	indices = tf.stack([b, y, x], 3)
	# gather the feature map values for the corresponding indices
	gatheredPixelValue = tf.gather_nd(featureMap, indices)
	# return the gather pixel values
	return gatheredPixelValue

def affine_grid_generator(B, H, W, theta):
	# create normalized 2D grid
	x = tf.linspace(-1.0, 1.0, H)
	y = tf.linspace(-1.0, 1.0, W)
	(xT, yT) = tf.meshgrid(x, y)
	# flatten the meshgrid
	xTFlat = tf.reshape(xT, [-1])
	yTFlat = tf.reshape(yT, [-1])
	# reshape the meshgrid and concatenate ones to convert it to 
	# homogeneous form
	ones = tf.ones_like(xTFlat)
	samplingGrid = tf.stack([xTFlat, yTFlat, ones])
	# repeat grid batch size times
	samplingGrid = tf.broadcast_to(samplingGrid, (B, 3, H * W))
	# cast the affine parameters and sampling grid to float32 
	# required for matmul
	theta = tf.cast(theta, "float32")
	samplingGrid = tf.cast(samplingGrid, "float32")
	# transform the sampling grid with the affine parameter
	batchGrids = tf.matmul(theta, samplingGrid)
	# reshape the sampling grid to (B, H, W, 2)
	batchGrids = tf.reshape(batchGrids, [B, 2, H, W])
	# return the transformed grid
	return batchGrids

def bilinear_sampler(B, H, W, featureMap, x, y):
	# define the bounds of the image
	maxY = tf.cast(H - 1, "int32")
	maxX = tf.cast(W - 1, "int32")
	zero = tf.zeros([], dtype="int32")
	# rescale x and y to feature spatial dimensions
	x = tf.cast(x, "float32")
	y = tf.cast(y, "float32")
	x = 0.5 * ((x + 1.0) * tf.cast(maxX-1, "float32"))
	y = 0.5 * ((y + 1.0) * tf.cast(maxY-1, "float32"))
	# grab 4 nearest corner points for each (x, y)
	x0 = tf.cast(tf.floor(x), "int32")
	x1 = x0 + 1
	y0 = tf.cast(tf.floor(y), "int32")
	y1 = y0 + 1
	# clip to range to not violate feature map boundaries
	x0 = tf.clip_by_value(x0, zero, maxX)
	x1 = tf.clip_by_value(x1, zero, maxX)
	y0 = tf.clip_by_value(y0, zero, maxY)
	y1 = tf.clip_by_value(y1, zero, maxY)
	# get pixel value at corner coords
	Ia = get_pixel_value(B, H, W, featureMap, x0, y0)
	Ib = get_pixel_value(B, H, W, featureMap, x0, y1)
	Ic = get_pixel_value(B, H, W, featureMap, x1, y0)
	Id = get_pixel_value(B, H, W, featureMap, x1, y1)
	# recast as float for delta calculation
	x0 = tf.cast(x0, "float32")
	x1 = tf.cast(x1, "float32")
	y0 = tf.cast(y0, "float32")
	y1 = tf.cast(y1, "float32")
	# calculate deltas
	wa = (x1-x) * (y1-y)
	wb = (x1-x) * (y-y0)
	wc = (x-x0) * (y1-y)
	wd = (x-x0) * (y-y0)
	# add dimension for addition
	wa = tf.expand_dims(wa, axis=3)
	wb = tf.expand_dims(wb, axis=3)
	wc = tf.expand_dims(wc, axis=3)
	wd = tf.expand_dims(wd, axis=3)
	# compute transformed feature map
	transformedFeatureMap = tf.add_n(
		[wa * Ia, wb * Ib, wc * Ic, wd * Id])
	# return the transformed feature map
	return transformedFeatureMap

class STN(Layer):
	def __init__(self, name, filter):
		# initialize the layer
		super().__init__(name=name)
		self.B = None
		self.H = None
		self.W = None
		self.C = None
		# create the constant bias initializer
		self.output_bias = tf.keras.initializers.Constant(
			[1.0, 0.0, 0.0,
			0.0, 1.0, 0.0]
		)
		# define the filter size
		self.filter = filter

	def build(self, input_shape):
		# get the batch size, height, width and channel size of the
		# input
		(self.B, self.H, self.W, self.C) = input_shape
		# define the localization network
		self.localizationNet = Sequential([
			Conv2D(filters=self.filter // 4, kernel_size=3,
				input_shape=(self.H, self.W, self.C), 
				activation="relu", kernel_initializer="he_normal"),
			MaxPool2D(),
			Conv2D(filters=self.filter // 2, kernel_size=3,
				activation="relu", kernel_initializer="he_normal"),
			MaxPool2D(),
			Conv2D(filters=self.filter, kernel_size=3,
				activation="relu", kernel_initializer="he_normal"),
			MaxPool2D(),
			GlobalAveragePooling2D()
		])
		# define the regressor network
		self.regressorNet = tf.keras.Sequential([
			Dense(units = self.filter, activation="relu",
				kernel_initializer="he_normal"),
			Dense(units = self.filter // 2, activation="relu",
				kernel_initializer="he_normal"),
			Dense(units = 3 * 2, kernel_initializer="zeros",
				bias_initializer=self.output_bias),
			Reshape(target_shape=(2, 3))
		])

	def call(self, x):
		# get the localization feature map
		localFeatureMap = self.localizationNet(x)
		# get the regressed parameters
		theta = self.regressorNet(localFeatureMap)
		# get the transformed meshgrid
		grid = affine_grid_generator(self.B, self.H, self.W, theta)
		# get the x and y coordinates from the transformed meshgrid
		xS = grid[:, 0, :, :]
		yS = grid[:, 1, :, :]
		# get the transformed feature map
		x = bilinear_sampler(self.B, self.H, self.W, x, xS, yS)
		# return the transformed feature map
		return x

  5、 ... and 、 Create a classification model

# import the necessary packages
from tensorflow.keras import Input
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
import tensorflow as tf
def get_training_model(batchSize, height, width, channel, stnLayer,
	numClasses, filter):
	# define the input layer and pass the input through the STN
	# layer
	inputs = Input((height, width, channel), batch_size=batchSize)
	x = Lambda(lambda image: tf.cast(image, "float32")/255.0)(inputs)
	x = stnLayer(x) 

	# apply a series of conv and maxpool layers
	x = Conv2D(filter // 4, 3, activation="relu", 
		kernel_initializer="he_normal")(x)
	x = MaxPool2D()(x)
	x = Conv2D(filter // 2, 3, activation="relu",
		kernel_initializer="he_normal")(x)
	x = MaxPool2D()(x)
	x = Conv2D(filter, 3, activation="relu",
		kernel_initializer="he_normal")(x)
	x = MaxPool2D()(x)
	# global average pool the output of the previous layer
	x = GlobalAveragePooling2D()(x)
	# pass the flattened output through a couple of dense layers
	x = Dense(filter, activation="relu",
		kernel_initializer="he_normal")(x)
	x = Dense(filter // 2, activation="relu",
		kernel_initializer="he_normal")(x)
	# apply dropout for better regularization
	x = Dropout(0.5)(x)
	# apply softmax to the output for a multi-classification task
	outputs = Dense(numClasses, activation="softmax")(x)
	# return the model
	return Model(inputs, outputs)

6、 ... and 、 Training models

# USAGE
# python train.py
# setting seed for reproducibility
import tensorflow as tf
tf.random.set_seed(42)
# import the necessary packages
from pyimagesearch.stn import STN
from pyimagesearch.classification_model import get_training_model
from pyimagesearch.callback import get_train_monitor
from pyimagesearch import config
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow_datasets as tfds
import os
# load the train and test dataset
print("[INFO] loading the train and test dataset...")
trainingDs = tfds.load(name=config.DATASET_NAME,
	data_dir=config.DATASET_PATH, split="train", shuffle_files=True,
	as_supervised=True)
testingDs = tfds.load(name=config.DATASET_NAME,
	data_dir=config.DATASET_PATH, split="test", as_supervised=True)
 # preprocess the train and test dataset
print("[INFO] preprocessing the train and test dataset...")
trainDs = (
	trainingDs
	.shuffle(config.BATCH_SIZE*100)
	.batch(config.BATCH_SIZE, drop_remainder=True)
	.prefetch(config.AUTO)
)
testDs = (
	testingDs
	.batch(config.BATCH_SIZE, drop_remainder=True)
	.prefetch(config.AUTO)
)
# initialize the stn layer
print("[INFO] initializing the stn layer...")
stnLayer = STN(name=config.STN_LAYER_NAME, filter=config.FILTERS)
# get the classification model for cifar10
print("[INFO] grabbing the multiclass classification model...")
model = get_training_model(batchSize=config.BATCH_SIZE,
	height=config.IMAGE_HEIGHT, width=config.IMAGE_WIDTH,
	channel=config.CHANNEL, stnLayer=stnLayer,
	numClasses=config.CLASSES, filter=config.FILTERS)
# print the model summary
print("[INFO] the model summary...")
print(model.summary())
# create an output images directory if it not already exists
if not os.path.exists(config.OUTPUT_PATH):
	os.makedirs(config.OUTPUT_PATH)
# get the training monitor
trainMonitor = get_train_monitor(testDs=testDs,
	outputPath=config.OUTPUT_PATH, stnLayerName=config.STN_LAYER_NAME)
# compile the model
print("[INFO] compiling the model...")
model.compile(loss=config.LOSS_FN, optimizer=config.OPTIMIZER, 
	metrics=["accuracy"])
# define an early stopping callback
esCallback = EarlyStopping(patience=5, restore_best_weights=True)
# train the model
print("[INFO] training the model...")
model.fit(trainDs, epochs=config.EPOCHS, 
	callbacks=[trainMonitor, esCallback], validation_data=testDs)

7、 ... and 、 Train the model and visualize

[INFO] compiling the model...
[INFO] training the model...
Epoch 1/100
681/681 [==============================] - 104s 121ms/step - loss: 0.9146 - accuracy: 0.7350 - val_loss: 0.4381 - val_accuracy: 0.8421
Epoch 2/100
681/681 [==============================] - 84s 118ms/step - loss: 0.4705 - accuracy: 0.8392 - val_loss: 0.4064 - val_accuracy: 0.8526
Epoch 3/100
681/681 [==============================] - 85s 119ms/step - loss: 0.4258 - ... 
Epoch 16/100
681/681 [==============================] - 85s 119ms/step - loss: 0.3192 - accuracy: 0.8794 - val_loss: 0.3483 - val_accuracy: 0.8725
Epoch 17/100
681/681 [==============================] - 85s 119ms/step - loss: 0.3151 - accuracy: 0.8803 - val_loss: 0.3487 - val_accuracy: 0.8736
Epoch 18/100
681/681 [==============================] - 85s 118ms/step - loss: 0.3113 - accuracy: 0.8814 - val_loss: 0.3503 - val_accuracy: 0.8719

         We can see , here we are 18 epoch, Stopped taking effect in advance and stopped model training . result , The final training and verification accuracy rates are 88.14% and 87.19%.

         stay chart 6 in , We can see in every epoch A number of numbers have been converted in . You might notice , These transformations are not easy for the human eye to understand . Switching doesn't always happen based on our brain's perception of numbers . It depends on which conversion can reduce the loss . So it may not be something we can understand very well , But if it applies to the loss function , It is good enough for the model .

原网站

版权声明
本文为[Sit and watch the clouds rise]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/162/202206110859382745.html