当前位置:网站首页>Vit slimming -- joint structure search and patch selection

Vit slimming -- joint structure search and patch selection

2022-06-09 00:18:00 Law-Yao

Paper Address :https://arxiv.org/abs/2201.00814

GitHub link :https://github.com/Arnav0400/ViT-Slim

Methods

ViT Slimming Through structure search and Patch selection The combination of , On the one hand, it realizes multi-dimensional 、 Multiscale structure compression , On the other hand, it's reduced Patch or Token Length redundancy of , Thus, the amount of parameters and calculation can be effectively reduced . To be specific , by ViT Flowing in a structure Tensor The corresponding Soft mask, Multiply the two in the calculation , And in Loss function Introduction in Soft mask Of L1 Regular constraints :

among A series of Mask A collection of vectors , Corresponding to the intermediate tensor .

  • Structure search : First, in the MHSA in Introduce differentiable Soft mask, stay Attention head Dimension implementation Feature size Of L1 Sparsity ( It can be constructed similarly Head number Thinning of ):

among It means the first one l Layer of the first h individual Head Of Soft mask. Secondly, in FFN in Introduce differentiable Soft mask, stay FFN Dimension implementation Intermediate size Of L1 Sparsity :

among It means corresponding Soft mask.

  • Patch selection: For each Transformer layer Input or output of Tensor, It's all defined Soft mask To eliminate low importance Patches, And Mask value Go first Tanh To prevent numerical expansion . in addition , The shallow layer is eliminated Patch, It also needs to be eliminated in the deep layer , To avoid calculation exceptions .
class SparseAttention(Attention):
    def __init__(self, attn_module, head_search=False, uniform_search=False):
        super().__init__(attn_module.qkv.in_features, attn_module.num_heads, True, attn_module.scale, attn_module.attn_drop.p, attn_module.proj_drop.p)
        self.is_searched = False
        self.num_gates = attn_module.qkv.in_features // self.num_heads
        if head_search:
            self.zeta = nn.Parameter(torch.ones(1, 1, self.num_heads, 1, 1))
        elif uniform_search:
            self.zeta = nn.Parameter(torch.ones(1, 1, 1, 1, self.num_gates))
        else:
            self.zeta = nn.Parameter(torch.ones(1, 1, self.num_heads, 1, self.num_gates))
        self.searched_zeta = torch.ones_like(self.zeta)
        self.patch_zeta = nn.Parameter(torch.ones(1, self.num_patches, 1)*3)
        self.searched_patch_zeta = torch.ones_like(self.patch_zeta)
        self.patch_activation = nn.Tanh()
    
    def forward(self, x):
        z_patch = self.searched_patch_zeta if self.is_searched else self.patch_activation(self.patch_zeta)
        x *= z_patch
        B, N, C = x.shape
        z = self.searched_zeta if self.is_searched else self.zeta
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # 3, B, H, N, d(C/H)
        qkv *= z
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple) # B, H, N, d

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def compress(self, threshold_attn):
        self.is_searched = True
        self.searched_zeta = (self.zeta>=threshold_attn).float()
        self.zeta.requires_grad = False
        
    def compress_patch(self, threshold_patch=None, zetas=None):
        self.is_searched = True
        zetas = torch.from_numpy(zetas).reshape_as(self.patch_zeta)
        self.searched_patch_zeta = (zetas).float().to(self.zeta.device)
        self.patch_zeta.requires_grad = False

Differentiable Soft mask And the training of network weights is carried out jointly , And the network weights are initialized with pre training parameters , Therefore, the overall time cost of search training is relatively low . After the search workout , Press Soft mask Sort the size of the values , Eliminate unimportant network weights or Patches, So as to realize structure search and Patch selection. After extracting a specific reduced structure , Need extra Re-training Restore model accuracy .

experimental result

of Transformer Model compression and optimization acceleration More discussion of , Refer to the following article :

Bert/Transformer Model compression and optimization acceleration _Law-Yao The blog of -CSDN Blog _transformer The model of compression

原网站

版权声明
本文为[Law-Yao]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/159/202206082322161464.html