当前位置:网站首页>Pytorch learning notes (VII) ------------------ vision transformer
Pytorch learning notes (VII) ------------------ vision transformer
2022-06-25 02:35:00 【Clear thinking】
Catalog
Two 、Adding classification token
3、 ... and 、Positional encoding
Four 、LN, MSA and Residual Connection
5、 ... and 、LN、MLP and Residual Connection
6、 ... and 、Classification MLP
Preface :vision transformer(vit) since Dosovitskiy Since their introduction , It has been playing a leading role in the field of computer vision , In most cases, it exceeds the traditional convolutional neural network (cnn)Transformer It is actually in naturallanguageprocessing (NLP) field , and vit The whole idea of NLP A big difference is no different , It is to divide a complete picture into several token, And I'll put these token Input into the network , Be similar to NLP The input of the statement in , These are separated token It is equivalent to every little word

This is Vision Transformers for Remote Sensing Image Classification Pictures published in , Let me borrow it
Through this picture , You can see a Be separated from x1-x9 9 A picture , And they are of equal length . These sub images are linearly embedded , These sub images are now just a one-dimensional vector , You can also see these pictures from x1-x9 It is separated from the original picture in order , That's important , after , In these token That is, add position information to the vector , Through these subgraphs, the network can restore the original image
After embedding the location information , these tokens And a for classification token Pass together to transformer encoder in , This is why when data is passed in +1, This 1 It's classification token. In this transformer encoder Contains a layer of normalization (LN), Bulls pay attention to themselves (MSA) And a residual connection (resdiual connection), And then the second one LN, A multilayer perceptron (MLP), A residual . Generally speaking ,encoder The inner block can be repeated many times , Be similar to Resnet. Last , A for classification MLP Block to classify the special classification marks that were originally passed in , It's a sort of thing .
Now look back at the picture above , Do you think your mind is a little more open
One 、Patch and Linear map
The first question is how to change a picture into an English sentence , The author's method is to divide it into several subgraphs , And map to the vector according to the position sequence
for instance , Here's a picture 3*224*224 Pictures of the (3 Number of channels RGB) We can divide it into 14*14 Of patch, every last patch The size is 16*16
(N,C,H,W)→(N, 3, 224, 224)→ (N, pathes, patch_dim) → (N, 14*14, 16*16)
Now enter 3*224*224 The picture of becomes (196, 256), Every patch The dimension of is 16*16, We have patch Each sub picture can be fed back through linear mapping , also , Linear mapping can be mapped to any vector , Call it the hidden dimension , Here again , We can 256 It maps to 8 256→8, Note that the mapped dimensions should be divisible
Two 、Adding classification token
I said before stay tokens Pass in transformer encoder A category should be added to the token, Its purpose is to capture information about other tags , This will in MSA Occur in the . When all images are transferred in , We can just use this one classification token To classify images
Or just 3*224*224 Example , The above said
(N, 196, 256)→(N, 196+1, 256)
This way 1 It's classification token
3、 ... and 、Positional encoding
When the network receives each of these patch Input , How does it know each patch The position in the original image
Vaswani The research of et al , You can do this simply by adding sine and cosine waves
meanwhile , The tag size is (N, 197, 256) Ahead N Will be (197, 256) This location code is repeated N Time
Four 、LN, MSA and Residual Connection
LN: Given an input , Subtract the mean and divide by the standard deviation
MSA: Put each one patch Mapping to 3 Different vectors :q,k and v, After mapping , adopt q And k And then divide by dim The square root of ,softmax These results ( Attention point ), Finally, match each attention cue to v Multiply , Final addition ( It feels boring )
meanwhile , Create a different number for each self - attention header Q,K,V Mapping function
Let's use an example to illustrate
(N, 197, 256)→(N, 197, 16, 16)→ nn.Linear(16, 16) → (N, 197, 256)
The input is (N, 197,256), Through long attention ( This is used here. 16 Head ) Change the vector to (N, 197, 16, 16), At this point, we need another nn.Linear(16, 16) To map it to (N, 197, 256)
Residual Connection: residual
It was said before that in the incoming transformer encoder Will add a classification token, Those token How to get other token The information of , after LN,MSA And residual operation , This classification token There are other things token Information about .
5、 ... and 、LN、MLP and Residual Connection
Previously mentioned in transformer enconder The first step in the block is to add LN, MSA And the residuals , Here is the second step , Join in LN、 MLP and residual
6、 ... and 、Classification MLP
After a series of operations , Our network has many weight indexes and data , stay MLP in , We can N Only classification marks are extracted from three sequences (token), And use token To get the classification
for example , Every one we chose before token yes 16dim Vector , The categories to be classified are 5 class , We can use MLP Create a 16*5 Matrix , And use softmax Function activation
Whole vit The construction of the network has been completed
PY The code is as follows
class MyViT(nn.Module):
def __init__(self, input_shape, n_patches=14, hidden_d=8, n_heads=2, out_d=5, device=None):
super(MyViT, self).__init__()
self.device = device
self.input_shape = input_shape
self.n_patches = n_patches
self.n_heads = n_heads
assert input_shape[1] % n_patches == 0,
assert input_shape[2] % n_patches == 0,
self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
self.hidden_d = hidden_d
# 1) Linear mapper
self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
# 2) Classification token
self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
# 3) Positional embedding
# (In forward method)
# 4a) Layer normalization 1
self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
# 4b) Multi-head Self Attention (MSA) and classification token
self.msa = MyMSA(self.hidden_d, n_heads)
# 5a) Layer normalization 2
self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
# 5b) Encoder MLP
self.enc_mlp = nn.Sequential(
nn.Linear(self.hidden_d, self.hidden_d),
nn.ReLU()
)
# 6) Classification MLP
self.mlp = nn.Sequential(
nn.Linear(self.hidden_d, out_d),
nn.Softmax(dim=-1)
)
def forward(self, images):
n, c, w, h = images.shape
patches = images.reshape(n, self.n_patches ** 2, self.input_d)
tokens = self.linear_mapper(patches)
tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1).to(self.device)
out = tokens + self.msa(self.ln1(tokens))
out = out + self.enc_mlp(self.ln2(out))
out = out[:, 0]
return self.mlp(out)
def get_positional_embeddings(sequence_length, d):
result = torch.ones(sequence_length, d)
for i in range(sequence_length):
for j in range(d):
result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
return result
class MyMSA(nn.Module):
def __init__(self, d, n_heads=2):
super(MyMSA, self).__init__()
self.d = d
self.n_heads = n_heads
assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"
d_head = int(d / n_heads)
self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.d_head = d_head
self.softmax = nn.Softmax(dim=-1)
def forward(self, sequences):
result = []
for sequence in sequences:
seq_result = []
for head in range(self.n_heads):
q_mapping = self.q_mappings[head]
k_mapping = self.k_mappings[head]
v_mapping = self.v_mappings[head]
seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
seq_result.append(attention @ v)
result.append(torch.hstack(seq_result))
return torch.cat([torch.unsqueeze(r, dim=0) for r in result])
You are welcome to correct the shortcomings , Source code can be private letters or comments , See it and reply
边栏推荐
- 【Proteus仿真】Arduino UNO+继电器控制照明设备
- F - Spices(线性基)
- 调用系统函数安全方案
- LINQ query (3)
- File system - basic knowledge of disk and detailed introduction to FAT32 file system
- Unity存档系统——Json格式的文件
- 把 Oracle 数据库从 Windows 系统迁移到 Linux Oracle Rac 集群环境(3)—— 把数据库设置为归档模式
- 会自动化—10K,能做自动化—20K,你搞懂自动化测试没有?
- 都2022年了,你还不了解什么是性能测试?
- Uncaught Error: [About] is not a <Route> component. All component children of <Routes> must be a <Ro
猜你喜欢

Pit entry machine learning: I. Introduction

计网 | 【四 网络层】知识点及例题

QT package the EXE file to solve the problem that "the program input point \u zdapvj cannot be located in the dynamic link library qt5cored.dll"

软件测试人员的7个等级,据说只有1%的人能做到级别7

計網 | 【四 網絡層】知識點及例題

3年测试经验,连简历上真正需要什么都没搞明白,张口就要20k?

File system - basic knowledge of disk and detailed introduction to FAT32 file system

做软件安全测试的作用,如何寻找软件安全测试公司出具报告?

Rod and Schwartz cooperated with ZhongGuanCun pan Lianyuan Institute to carry out 6G technology research and early verification

AI clothing generation helps you complete the last step of clothing design
随机推荐
The ecosystem of the yuan universe
How to quickly familiarize yourself with the code when you join a new company?
把 Oracle 数据库从 Windows 系统迁移到 Linux Oracle Rac 集群环境(4)—— 修改 oracle11g rac 集群的 scanIP
Migrate Oracle database from windows system to Linux Oracle RAC cluster environment (2) -- convert database to cluster mode
Please run IDA with elevated permissons for local debugging.
When they are in private, they have a sense of propriety
E - Average and Median(二分)
MySQL command backup
会自动化—10K,能做自动化—20K,你搞懂自动化测试没有?
[day 26] given the ascending array nums of n elements, find a function to find the subscript of target in nums | learn binary search
Investigation on key threats of cloud computing applications in 2022
psql 列转行
npm包发布详细教程
Rod and Schwartz cooperated with ZhongGuanCun pan Lianyuan Institute to carry out 6G technology research and early verification
高速缓存Cache详解(西电考研向)
internship:svn的使用
记一次beego通过go get命令后找不到bee.exe的坑
Call system function security scheme
3 years of testing experience. I don't even understand what I really need on my resume. I need 20K to open my mouth?
After reciting the eight part essay, I won the hemp in June
https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c