当前位置:网站首页>Vit:vision transformer super detailed with code
Vit:vision transformer super detailed with code
2022-07-26 23:14:00 【Thinking and acting recklessly】
Original paper :An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
1.VIT Model architecture

Specific steps :
1.1 Split the picture into patch
1.2 patch Turn into embedding
Due to a patch It's a square , Not directly as TRM The input of , You need to put this one patch Into a fixed dimension embedding, And then use embedding As TRM The input of . Method 1: hold patch Flatten , Two dimensional to one-dimensional (eg. original 16x16 Turn into 256); Method 2: Map the flattened dimension to a vector length specified by myself .
notes : There are two experimental methods in this process , It's used here Linear Projection It's a linear transformation , Another kind is petch=16*16, You can use a 16*16, In steps of 16 To manipulate this , Convolution kernel is set to 768, The output channel is 768, That is to say 768 Transformed into TRM Encoder Dimensions .
1.3 Location embedding and token sembedding Add up
First generate CLS The symbol of token emb, In the figure *, Then the position codes of all sequences are generated , In the figure 1,2,3..., Add the pink and purple to get the input embadding.
Why join a CLS Symbol ?
Prove after the paper ,CLS Not much , Its function is to reduce the impact on the original TRM Model changes ,BERT Use in CLS Is due to ,BERT There are two pre training tasks ,NSP( Two classification ) Mission : Predict the next sentence ;MLM: Predict the current word . If both tasks use pooling for loss , In some places tokens Repeat on , Use CLS To some extent, the two tasks remain relatively independent . however VIT Not involved MLM This form of task , There will only be one multi category task , therefore CLS Symbols are not necessary .
Location code
In order to keep the input image patch Spatial location information between , You also need to add a position coding vector to the image block embedding , In the above formula Epos Shown ,ViT The location code of does not use the updated 2D Location embedding method , But directly use one-dimensional learnable position to embed variables , It was originally when the author of the paper found that it was actually used 2D Did not show more than 1D Better results .
1.4 Input to TRM Model
Input followed by a Normalization layer , Entering the self attention layer , Output and input make a residual , In the input to Normalization, Input to feedforward neural network , After a residual error , There are several Encoder Just a few times , Every one you finally get token Will generate an output .
1.5 CLS Output to do multiple classification tasks
Take the first one. CLS Take out the output to do multi classification tasks
2. Code
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# Multi head attention mechanism
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) # hold dim Dimension mapping to inner_dim * 3 This dimension
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1) # Yes tensor Tensor blocking x :1*197*1024 qkv Finally, there is a tuple ,tuple, The length is 3, Each element shape :1 197 1024
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v) # Multiply by the corresponding v matrix
out = rearrange(out, 'b h n d -> b n (h d)') # Make a change of shape
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
# The multiple encoder Stack together
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), # The long attention mechanism
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) # Feedforward neural networks
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
# The overall architecture
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size) ## 224*224
patch_height, patch_width = pair(patch_size)## 16 * 16
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width) # How many pictures are divided patch
patch_dim = channels * patch_height * patch_width # Flatten :patch Multiply the width and height of by the number of channels
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
# Image flattening mapping to encoder In our own model
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # Generate all location codes
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # Generate CLS token Is the initialization parameter of
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) # After the input is solved , Put it in TRM in
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img) # img betch:1 passageway 3 224 224 Shape of the output x : 1*196*1024
b, n, _ = x.shape ##
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) # Copy b Share , every last betchsize You have to add one CLS Symbol
x = torch.cat((cls_tokens, x), dim=1) # hold CLS Of tokens Embedding and Patch Embedding Splicing
x += self.pos_embedding[:, :(n + 1)] # Add up
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
v = ViT(
image_size = 224, # Enter the image size
patch_size = 16, # The size of each slice
num_classes = 1000, # Last CLS How many dimensions are mapped to , Category
dim = 1024,
depth = 6, # encoder The layer number
heads = 16, # Long attention mechanism parameters
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 224, 224)
preds = v(img) # (1, 1000)3. To sum up
I want to summarize , It's hard to summarize , This thing is very simple , It doesn't feel like much , But I don't quite understand , Let's go first , There will be opportunities to change in the future .
边栏推荐
- Siliwei's counterattack: huiding's under screen optical fingerprint patent involved in the case was declared invalid
- SQL Basics
- 逆袭黑马:数据库全栈工程师(DevDBOps)培训,把最好的课程送给您!
- Restful interface specification
- 华裔科学家Ashe教授对涉嫌造假的Nature论文的正面回应
- P5469 [noi2019] robot (Lagrange interpolation, interval DP)
- 比海豹便宜,造型炸裂空间大,20万左右真没对手?长安全新“王炸”这样选才划算
- Interview questions of Bank of Hangzhou [Hangzhou multi tester] [Hangzhou multi tester _ Wang Sir]
- [hcip] OSPF route calculation
- Reduce power consumption and upgrade functions! Qiyingtailun released the second generation voice AI chip: the module price is as low as 14.99 yuan!
猜你喜欢

Counter attack dark horse: devdbops training, give you the best courses!

杰理下载器强制下载工具的使用介绍_AC695N696NAD14AD15全系列支持
![[MySQL] - index principle and use](/img/e1/af74ee20ebe0c6e6f5e453330cc13b.png)
[MySQL] - index principle and use

面试:你印象最深的BUG,举个例子

Hcia-r & s self use notes (18) campus network architecture foundation, switch working principle, VLAN principle

Cheaper than seals, with a large space for shape explosion. Is there really no match for 200000 or so? Chang'an's new "King fried" is cost-effective

JSON formatting gadget -- pyqt5 instance

Hcia-r & s self use notes (19) VLAN configuration and experiment, routing between VLANs

Do you know the common core types of magnetic ring inductors?

The interviewer asked: this point of JS
随机推荐
2019 biometric forum successfully ended: these ten highlights should not be missed!
Luo Xu talks with Siemens wanghaibin: advanced manufacturing requires benefits from Digitalization
Shardingsphere JDBC keyword problem
Docker uses mysql:5.6 and owncloud image to build a personal network disk, install and build a private warehouse harbor
What if redis memory is full? This is the right way to deal with it
Plato farm is expected to further expand its ecosystem through elephant swap
[MySQL] CentOS 7.9 installation and use mysql-5.7.39 binary version
Practical project: boost search engine
Parameter analysis and stone jumping board
HCIA-R&S自用笔记(20)VLAN综合实验、GVRP
Counter attack dark horse: devdbops training, give you the best courses!
【flask高级】结合源码分析flask中的线程隔离机制
After closing the Suzhou plant, Omron Dongguan plant announced its dissolution, and more than 2000 people are facing unemployment!
你知道磁环电感的常见磁芯类型有哪些吗?
How can enterprises mitigate the security risks of Internet of things and industrial Internet of things
Apifox--比 Postman 还好用的 API 测试工具
The secret weapon of apple iphone11 series: U1 chip may usher in the era of ultra wideband
KT6368A蓝牙芯片开发注意事项以及问题集锦--长期更新
杭州银行面试题【杭州多测师】【杭州多测师_王sir】
ZTE: more than 50000 5g base stations have been shipped worldwide!