当前位置:网站首页>Hands on deep learning (33) -- style transfer
Hands on deep learning (33) -- style transfer
2022-07-04 09:36:00 【Stay a little star】
List of articles
One 、 be based on CNN Style migration
1. Intuitively understand style migration
If you're a photographer , You may touch the filter . It can change the color and style of photos , So as to make the scenery more sharp or make people look more white . But a filter usually changes only one aspect of the picture . If you want the picture to achieve the desired style , You may need to try a lot of different combinations . This process is no less complex than model tuning .
In this section , We will show you how to use convolutional neural networks , Automatically apply styles from one image to another , namely Style migration (style transfer). Here we need two input images : One is Content image , The other is Style image ( Be careful : Both the content diagram and the style diagram are our input , It is obviously different from the previous input image ). We will use neural network to modify the content image , Make it close to the style image in style .
for example , The image in the following figure is a landscape photo taken by Mu Mu Shen Li in Mount Rainier National Park in the suburbs of Seattle , The style image is an oil painting with the theme of autumn oak . The final output of the composite image applies the oil painting strokes of the style image to make the overall color more bright , At the same time, the shape of the object body in the content image is preserved .
2. be based on CNN Style migration method
The following figure illustrates the style migration method based on convolutional neural network with a simple example .
- First , We initialize the composite image , For example, initialize it as a content image ( It doesn't matter what it initializes , You can try different initialization ). The composite image is the only variable that needs to be updated during style migration , That is, the iterative model parameters required for style migration .
- then , We choose a pre trained convolutional neural network to extract the features of the image , The model parameters need not be updated in training . This deep convolution neural network extracts the features of the image step by step with multiple layers , We can select the output of some of these layers as content features or style features .( The closer to the bottom , The more global the extracted features are )
The following is an example , The pre trained neural network selected here contains 3 Convolution layers , The second layer outputs content features , The first and third layers output style features .
Next , We propagate forward ( Solid arrow direction ) Calculate the loss function of style migration , And through back propagation ( Dashed arrow direction ) Iterative model parameters , That is, constantly updating the composite image .
The loss function commonly used in style migration is determined by 3 Part of it is made up of :
- (i) Content loss Make the composite image close to the content image in content features ;
- (ii) Style loss Make the composite image close to the style image in style characteristics ;
- (iii) Total variation loss It helps to reduce the noise in the composite image .
Last , When the model training is over , We output the model parameters for style migration , That is, the final composite image .
Hereunder , We will learn more about the technical details of style migration through the code .
Two 、 Style migration implementation
1. Read content and style images
First , We read content and style images .
It can be seen from the coordinate axis of the printed image , They are not the same size .
%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import os
import matplotlib.pyplot as plt
d2l.set_figsize()
content_img = d2l.Image.open('./course_file/pytorch/img/rainier.jpg')
d2l.plt.imshow(content_img)
style_img = d2l.Image.open('./course_file/pytorch/img/autumn-oak.jpg')
d2l.plt.imshow(style_img);
2. Pretreatment and post-processing
""" below , Define image pre-processing function and post-processing function . Preprocessing functions `preprocess` The input image is displayed in RGB The three channels are standardized respectively , The result is transformed into the input format accepted by convolutional neural network . Post processing functions `postprocess` Then the pixel value in the output image is restored to the value before standardization . Because the image printing function requires the floating-point value of each pixel to be in 0 To 1 Between , We are less than 0 And greater than 1 The values of are respectively 0 and 1. """
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])
def preprocess(img, image_shape):
""" The picture becomes tensor"""
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
return transforms(img).unsqueeze(0)
def postprocess(img):
"""tensor Become a picture """
img = img[0].to(rgb_std.device)
img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
3. Extract image features
# We use based on ImageNet Data set pre training VGG-19 Model to extract image features .
pretrained_net = torchvision.models.vgg19(pretrained=True)
pretrained_net
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace=True)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace=True)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace=True)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace=True)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace=True)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace=True)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
In order to extract the content features and style features of the image , We can choose VGG Output of some layers in the network .
Generally speaking :
- The closer to the input layer , The easier it is to extract the details of the image ;
- conversely , The easier it is to extract the global information of the image .
In order to avoid too many composite images, keep the details of the content image , We choose VGG Layers closer to the output , namely Content of the layer , To output the content features of the image . We also VGG Select the output of different layers to match the local and global styles , These layers are also called Style, . just as vgg Introduced in the network ,VGG The Internet uses 5 Convolution blocks . In the experiments , We select the last convolution layer of the fourth convolution block as the content layer , Select the first convolution layer of each convolution block as the style layer . These layers can be indexed by typing pretrained_net
Instance acquisition .
style_layers, content_layers = [0, 5, 10, 19, 28], [25] # Smaller, closer to input ( Local style ), The larger, the closer to the output ( Global style )
Use VGG When extracting features from layers , We only need to use all layers from the input layer to the content layer or style layer closest to the output layer . Let's build a new network net
, It only keeps what it needs VGG All layers of .
# Discard the layer after the largest layer
net = nn.Sequential(*[
pretrained_net.features[i]
for i in range(max(content_layers + style_layers) + 1)])
A given input X
, If we simply call forward computation net(X)
, Only the output of the last layer can be obtained . Because we also need the output of the middle layer , So here we calculate layer by layer , And preserve the output of the content layer and the style layer .
def extract_features(X, content_layers, style_layers):
""" Input X, Get the output of content layer and style layer """
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
# Because there is no need to change the pre training VGG Parameters of , So we can extract content features and style features before training
# Because the composite image is a model parameter that needs to be iterated during style migration , We can only call during training extract_features Function to extract the content features and style features of the composite image
def get_contents(image_shape, device):
""" Extract content features from content images """
content_X = preprocess(content_img, image_shape).to(device)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, contents_Y
def get_styles(image_shape, device):
""" Extract style features from style images """
style_X = preprocess(style_img, image_shape).to(device)
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
4. Define the loss function
Let's describe the loss function of style migration . It is caused by the loss of content 、 Pattern loss and total variation loss 3 Part of it is made up of .
4.1 Content loss
Similar to the loss function in linear regression , Content loss measures the difference in content features between synthetic image and content image by square error function . Both inputs of the square error function are extract_features
Function to calculate the output of the content layer .
def content_loss(Y_hat, Y):
""" Differences in content features between synthetic images and content images """
# We separate the target from the tree of dynamically calculating the gradient :
# This is a specified value , Not a variable .
return torch.square(Y_hat - Y.detach()).mean()
4.2 Style loss
For styles , We regard it as the statistical distribution of pixels in each channel , For example, to match the colors of two pictures , One of our ways is to match these two pictures in RGB Histogram on three channels .
Style loss is similar to content loss , It also measures the difference in style between the synthetic image and the style image through the square error function . In order to express the style output from the style layer , Let's go through extract_features
The function evaluates the output of the style layer . Suppose the number of samples for this output is 1, The number of channels is c c c, The height and width are h h h and w w w, We can convert this output into a matrix X \mathbf{X} X, Its have c c c Row sum h w hw hw Column . This matrix can be seen as composed of c c c A length of h w hw hw Vector x 1 , … , x c \mathbf{x}_1, \ldots, \mathbf{x}_c x1,…,xc Combined . Where vector x i \mathbf{x}_i xi Represents the channel i i i Style features on .
In these vectors Gram matrix X X ⊤ ∈ R c × c \mathbf{X}\mathbf{X}^\top \in \mathbb{R}^{c \times c} XX⊤∈Rc×c in , i i i That's ok j j j The elements of the column x i j x_{ij} xij The vector x i \mathbf{x}_i xi and x j \mathbf{x}_j xj Inner product . It expresses the channel i i i And channel j j j Correlation of upper style features . We use this Gram matrix to express the output style of the style layer . It should be noted that , When h w hw hw When the value of is large , The elements in the Gram matrix are prone to large values . Besides , The height and width of Gram matrix are the number of channels c c c. In order to make the style loss independent of the size of these values , As defined below gram
The function divides the Gram matrix by the number of elements in the matrix , namely c h w chw chw
def gram(X):
num_channels, n = X.shape[1], X.numel() // X.shape[1] # channel Is the number of channels ,n Is the product of height and width
X = X.reshape((num_channels, n))
return torch.matmul(X, X.T) / (num_channels * n)
Naturally , The two Gram matrix inputs of the square error function of the style loss are based on the style layer output of the composite image and the style image respectively . Here, assume that the Gram matrix based on the style image gram_Y
It has been calculated in advance .
def style_loss(Y_hat, gram_Y):
return torch.square(gram(Y_hat) - gram_Y.detach()).mean()
4.3 Total variation loss
occasionally , The composite image we learned has a lot of high-frequency noise , That is, there are particularly bright or dark particle pixels . A common noise reduction method is Total variation noise reduction : hypothesis x i , j x_{i, j} xi,j Representation coordinates ( i , j ) (i, j) (i,j) The pixel value at , Reduce total variation loss
∑ i , j ∣ x i , j − x i + 1 , j ∣ + ∣ x i , j − x i , j + 1 ∣ \sum_{i, j} \left|x_{i, j} - x_{i+1, j}\right| + \left|x_{i, j} - x_{i, j+1}\right| i,j∑∣xi,j−xi+1,j∣+∣xi,j−xi,j+1∣
Make the adjacent pixel values as similar as possible .
def tv_loss(Y_hat):
"""TV Noise reduction , Make adjacent pixel values similar """
return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
4.4 Total loss weighting
The loss function of style transfer is content loss 、 Weighted sum of style loss and total change loss
By adjusting these weights, the super parameters , We can weigh synthetic images in preserving content 、 The relative importance of migration style and noise reduction .
content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
# Calculate the content loss separately 、 Pattern loss and total variation loss
contents_l = [
content_loss(Y_hat, Y) * content_weight
for Y_hat, Y in zip(contents_Y_hat, contents_Y)]
styles_l = [
style_loss(Y_hat, Y) * style_weight
for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]
tv_l = tv_loss(X) * tv_weight
# Sum all losses
l = sum(10 * styles_l + contents_l + [tv_l])
return contents_l, styles_l, tv_l, l
5. Initialize the composite image
In style migration , The synthesized image is the only variable that needs to be updated during training . therefore , We can define a simple model SynthesizedImage
, The synthesized image is regarded as a model parameter . The forward calculation of the model only needs to return the model parameters .
class SynthesizedImage(nn.Module):
def __init__(self, img_shape, **kwargs):
super(SynthesizedImage, self).__init__(**kwargs)
self.weight = nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
def get_inits(X, device, lr, styles_Y):
""" This function creates a model instance of the composite image , And initialize it as an image `X` . The style image is in the Gram matrix of each style layer `styles_Y_gram` Will be pre calculated before training . """
gen_img = SynthesizedImage(X.shape).to(device)
gen_img.weight.data.copy_(X.data)
trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, trainer
6. model training
When the training model performs style transfer , We constantly extract the content features and style features of synthetic images , Then calculate the loss function . The training cycle is defined below . The training process is different from the traditional neural network training in :
- The loss function is more complex
- We only update the input ( It means you need to input X Pre assigned gradient )
- We may replace layers that match content and style , Adjust the weight between them , To get different styles of output .
- Still use simple random gradient descent , But every n One iteration reduces the learning rate
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8) # Reduce the learning rate in turn
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[10, num_epochs],
legend=['content', 'style',
'TV'], ncols=2, figsize=(7, 2.5))
for epoch in range(num_epochs):
trainer.zero_grad()
contents_Y_hat, styles_Y_hat = extract_features(
X, content_layers, style_layers)
contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat,
styles_Y_hat, contents_Y,
styles_Y_gram)
l.backward()
trainer.step()
scheduler.step()
if (epoch + 1) % 10 == 0:
animator.axes[1].imshow(postprocess(X))
animator.add(
epoch + 1,
[float(sum(contents_l)),
float(sum(styles_l)),
float(tv_l)])
return X
# Now we [** Training models **]: First, adjust the height and width of the content image and the style image to 300 and 450 Pixels , Initialize the composite image with the content image .
device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.1, 500, 200)
plt.imshow(postprocess(output))
plt.savefig("test.png")
3、 ... and 、 summary
- The loss of style migration consists of three parts
- Content loss : Make the composite image close to the content image in content characteristics
- Style loss : Make the composite image close to the style image in style characteristics
- Total variation loss : Reduce noise points in the composite image
- The convolution neural network of the pre training model can be used to extract the features of the image , By minimizing the loss function, the synthetic image is continuously updated as the model parameters
- Use the Lagrange matrix to express the style of the style layer
边栏推荐
- QTreeView+自定义Model实现示例
- Some points needing attention in PMP learning
- Global and Chinese market of sampler 2022-2028: Research Report on technology, participants, trends, market size and share
- Implementation principle of redis string and sorted set
- About the for range traversal operation in channel in golang
- SSM online examination system source code, database using mysql, online examination system, fully functional, randomly generated question bank, supporting a variety of question types, students, teache
- Tkinter Huarong Road 4x4 tutorial II
- Golang defer
- 2022-2028 global tensile strain sensor industry research and trend analysis report
- DR6018-CP01-wifi6-Qualcomm-IPQ6010-IPQ6018-FAMILY-2T2R-2.5G-ETH-port-CP01-802-11AX-MU-MIMO-OFDMA
猜你喜欢
Implementation principle of redis string and sorted set
How does idea withdraw code from remote push
mmclassification 标注文件生成
2022-2028 global industry research and trend analysis report on anterior segment and fundus OTC detectors
Nurse level JDEC addition, deletion, modification and inspection exercise
Sword finger offer 30 contains the stack of Min function
How to ensure the uniqueness of ID in distributed environment
How do microservices aggregate API documents? This wave of show~
165 webmaster online toolbox website source code / hare online tool system v2.2.7 Chinese version
Solve the problem of "Chinese garbled MySQL fields"
随机推荐
MySQL transaction mvcc principle
26. Delete duplicates in the ordered array (fast and slow pointer de duplication)
The child container margin top acts on the parent container
Write a jison parser from scratch (6/10): parse, not define syntax
Write a jison parser from scratch (1/10):jison, not JSON
Reading notes of how the network is connected - understanding the basic concepts of the network (I)
If you can quickly generate a dictionary from two lists
C # use gdi+ to add text with center rotation (arbitrary angle)
HMS core helps baby bus show high-quality children's digital content to global developers
Function comparison between cs5261 and ag9310 demoboard test board | cost advantage of cs5261 replacing ange ag9310
What is permission? What is a role? What are users?
Dynamic analysis and development prospect prediction report of high purity manganese dioxide in the world and China Ⓡ 2022 ~ 2027
Nurse level JDEC addition, deletion, modification and inspection exercise
Solution to null JSON after serialization in golang
Jianzhi offer 09 realizes queue with two stacks
QTreeView+自定义Model实现示例
2022-2028 global probiotics industry research and trend analysis report
libmysqlclient.so.20: cannot open shared object file: No such file or directory
2022-2028 global tensile strain sensor industry research and trend analysis report
Reading notes on how to connect the network - hubs, routers and routers (III)