当前位置:网站首页>[leave some code] Apply transformer to target detection, and understand the model operation process of the model through debug
[leave some code] Apply transformer to target detection, and understand the model operation process of the model through debug
2022-07-26 15:08:00 【Shameful child】
End-to-End Object Detection with Transformers
Facebook Based on Transformer End to end target detection network , Published in ECCV2020, The code is open source :facebookresearch/detr: End-to-End Object Detection with Transformers (github.com), What remains here is the application of a simplified version detr The trained weight file is a model for image reasoning and detection .
detr Model constructor
def __init__(self, num_classes, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6):
super().__init__()
# create ResNet-50 backbone, from torchvision.models In the import resnet50 As a feature extractor , Therefore, there is no need for full connection for image classification calculation
self.backbone = resnet50()
del self.backbone.fc
# create conversion layer, Through convolution operation 2048 Dimensional data is implemented to 256 Projection of dimension , Use 256 individual 1*1*2048 Convolution kernel
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# from torch import nn Introduction in transformer Model , And pass the set super parameter to transformer In the model
self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1) # Make category predictions , Output confidence for each category , There is another unknown category
self.linear_bbox = nn.Linear(hidden_dim, 4) # Make the prediction of the confidence box , The output is a 1*4 The vectors of are 【x1,y1,x2,y2】
# output positional encodings (object queries)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # Will generate 100 A confidence box , The premise is set to “ No more than 100 Objects need to be detected ”
# spatial positional encodings , Spatial location coding
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
# note that in baseline DETR we use sine positional encodings, Be careful , At baseline DETR in , We use sine position coding
- detr Forward propagation function
def forward(self, inputs):
# propagate inputs through ResNet-50 up to avg-pool layer
x = self.backbone.conv1(inputs) # nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x) # nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
# convert from 2048 to 256 feature planes for the transformer, take transformer Of 2048 Feature planes are converted to 256 Feature planes
h = self.conv(x)
# construct positional encodings, Construction location code
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
# propagate through the transformer adopt transformer spread
h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1)).transpose(0, 1)
# finally project transformer outputs to class labels and bounding boxes
return {
'pred_logits': self.linear_class(h),
'pred_boxes': self.linear_bbox(h).sigmoid()}

About detr You can refer to learning :https://www.bilibili.com/video/BV1GB4y1X72R in 30:00 Start
This program detr The model is based on Resnet50 As a feature extractor . Resnet Of BasicBlock Forward conduction :
def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return outLoad the training weight file
state_dict = torch.hub.load_state_dict_from_url( # What is loaded is the weight file ,load_dict, The first time is to download , If you don't delete it later, you can load it directly on the local hard disk url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth', map_location='cpu', check_hash=True) detr.load_state_dict(state_dict)
You can further understand the operation of the model by reading the dimension changes in the weight file , And the inclusion of some attributes in each layer , change
Detection function
def detect(im, model, transform): # mean-std normalize the input image (batch-size: 1) img = transform(im).unsqueeze(0) # By default, the demonstration model only supports aspect ratio in 0.5 To 2 Between the images # If you want to use an image with an aspect ratio outside this range # Rescale the image , Make the maximum size 1333, For best results assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side' # propagate through the model, Transfer the preprocessed image into the model for calculation outputs = model(img) # keep only predictions with 0.7+ confidence, Keep only 0.7+ Prediction of confidence , When the confidence is lowered, it may be checked such as more boxes probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] keep = probas.max(-1).values > 0.7 # Change the box from [0;1] Convert to image scale bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) return probas[keep], bboxes_scaled # Returns the confidence tensor , Execute the annotation information of the box in the original drawing
Complete code colab, It can be changed 162-165 Line of code reasoning detects its own image .
from PIL import Image
# This program can download data online ( A few reasoning pictures ), Model file , Import request The package is easy to download
import requests
import matplotlib.pyplot as plt
# Use %config InlineBackend.figure_format = 'retina' stay %matplotlib inline After that, a higher resolution image will be presented .
# %config InlineBackend.figure_format = 'retina'
import torch
from torch import nn
# backbone by resnet50, The extracted features
from torchvision.models import resnet50
import torchvision.transforms as T
# reasoning process , There is no need to design for gradient changes , Set to False You can save memory , Keep the source model file
torch.set_grad_enabled(False)
class DETRdemo(nn.Module):
""" Demo DETR implementation. Demo implementation of DETR in minimal number of lines, with the following differences wrt DETR in the paper: * learned positional encoding (instead of sine) * positional encoding is passed at input (instead of attention) * fc bbox predictor (instead of MLP) The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100. Only batch size 1 supported. """
# detr Constructor for , The default number of codecs is 6 individual , Bulls pay attention to 8, image coco Data sets num_classes There is 91 Classes
def __init__(self, num_classes, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6):
super().__init__()
# create ResNet-50 backbone, from torchvision.models In the import resnet50 As a feature extractor , Therefore, there is no need for full connection for image classification calculation
self.backbone = resnet50()
del self.backbone.fc
# create conversion layer, Through convolution operation 2048 Dimensional data is implemented to 256 Projection of dimension , Use 256 individual 1*1*2048 Convolution kernel
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# create a default PyTorch transformer, from torch import nn Introduction in transformer Model , And pass the set super parameter to transformer In the model
self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
# prediction heads, one extra class for predicting non-empty slots
# note that in baseline DETR linear_bbox layer is 3-layer MLP
self.linear_class = nn.Linear(hidden_dim, num_classes + 1) # Make category predictions , Output confidence for each category , There is another unknown category
self.linear_bbox = nn.Linear(hidden_dim, 4) # Make the prediction of the confidence box , The output is a 1*4 The vectors of are 【x1,y1,x2,y2】
# output positional encodings (object queries)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # Will generate 100 A confidence box
# spatial positional encodings , Spatial location coding
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
# note that in baseline DETR we use sine positional encodings, Be careful , At baseline DETR in , We use sine position coding
def forward(self, inputs):
# propagate inputs through ResNet-50 up to avg-pool layer
x = self.backbone.conv1(inputs) # nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x) # nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
# convert from 2048 to 256 feature planes for the transformer, take transformer Of 2048 Feature planes are converted to 256 Feature planes
h = self.conv(x)
# construct positional encodings, Construction location code
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
# propagate through the transformer adopt transformer spread
h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1)).transpose(0, 1)
# finally project transformer outputs to class labels and bounding boxes
return {
'pred_logits': self.linear_class(h),
'pred_boxes': self.linear_bbox(h).sigmoid()}
detr = DETRdemo(num_classes=91) # coco Data sets ,80 Target categories (object categories: Pedestrians 、 automobile 、 Elephants, etc ),91 Material category (stuff categories: The grass 、 wall 、 The sky, etc )
state_dict = torch.hub.load_state_dict_from_url( # What is loaded is the weight file ,load_dict
url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
map_location='cpu', check_hash=True)
detr.load_state_dict(state_dict)
# Is not enabled BatchNormalization and Dropout, Guarantee BN and dropout No change ,pytorch The frame will automatically put BN and Dropout Hold on , No averaging , It's about training , otherwise , once test Of batch_size Too small , It's easy to be BN Layers influence the results .
detr.eval()
# train() Enable BatchNormalization and Dropout
# Using the original .pth Before the forward reasoning of the model , Be sure to do it first model.eval() operation , Is not enabled BatchNormalization and Dropout.
# COCO classes, Entity words are 80 individual ,N/A Yes 11 individual
CLASSES = [
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
# standard PyTorch mean-std input image normalization
transform = T.Compose([
T.Resize(800), # The picture will be entered resize become 800*800
T.ToTensor(), # take img Format into tensor
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Standardization , The average and variance values of the picture come from imagenet
])
# for output bounding box post-processing, Used for post-processing of output bounding box
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1) # After removing the specified dimension , Returns a tuple , Contains slices along the specified dimension .
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
def detect(im, model, transform):
# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)
# demo model only support by default images with aspect ratio between 0.5 and 2, By default, the demonstration model only supports aspect ratio in 0.5 To 2 Between the images
# if you want to use images with an aspect ratio outside this range, If you want to use an image with an aspect ratio outside this range
# rescale your image so that the maximum size is at most 1333 for best results, Rescale the image , Make the maximum size 1333, For best results
assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'
# propagate through the model
outputs = model(img)
# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.7
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
return probas[keep], bboxes_scaled
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# im = Image.open(requests.get(url, stream=True).raw)
test_img = "./test.jpg"
im = Image.open(test_img)
scores, boxes = detect(im, detr, transform)
# scores Of shape by 【 Number of frames detected ,91】,boxes Of shape by 【 Number of frames detected ,4】
def plot_results(pil_img, prob, boxes):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
cl = p.argmax()
text = f'{
CLASSES[cl]}: {
p[cl]:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('on')
plt.savefig("wsg.jpg")
plt.show()
plot_results(im, scores, boxes)
- test.jpg
- wsg.jpg, Detection target cell phone error
边栏推荐
- 食品制造企业想要实现智能协同的供应商管理,选择SRM供应商系统就够了
- R语言使用lm函数构建带交互项的多元回归模型、使用step函数构建逐步回归模型筛选预测变量的最佳子集(step regression)
- 采购实用技巧,5个瓶颈物料的采购方法
- Yifang biological fell 16% on the first day of listing: the company's market value was 8.8 billion, and Hillhouse and Lilly were shareholders
- 怎样在nature上查文献?
- JMeter distributed
- Pokemon card evolution jsjs special effect
- 9、学习MySQL DELETE 语句
- postman 环境变量设置代码存放
- Qt开发高级进阶:如何在显示时适合视窗宽度和高度(fitWidth+fitHeight)
猜你喜欢

【华为联机对战服务】客户端退出重连或中途进入游戏,新玩家如何补帧?

哪里有写毕业论文需要的外文文献?

CVE-2022-33891 Apache spark shell 命令注入漏洞复现

jmeter分布式

The leader took credit for it. I changed the variable name and laid him off

Summary of target tracking related knowledge

Simulation of character function and string function

OpenGL学习日记2——着色器

How to search literature on nature?

Usage of nn.conv2d and nn.convtranspose2d functions in pytorch
随机推荐
Siamrpn++: evolution of deep network connected visual tracking
双屏协作效率翻倍 灵耀X双屏Pro引领双屏科技新潮流
Parallel d-pipeline: a cuckoo hashing implementation for increased throughput
Environment regulation system based on Internet of things (esp32-c3+onenet+ wechat applet)
李宏毅《机器学习》丨3. Gradient Descent(梯度下降)
driver开发环境
The most detailed patent application tutorial, teaching you how to apply for a patent
oss删除当前时间前两天的所有文件
Stacked noise reducing auto encoder (sdae)
带你熟悉云网络的“电话簿”:DNS
2023餐饮业展,中国餐饮供应链展,江西餐饮食材展2月举办
RPN: region proposal networks
Deep Packet Inspection Using Quotient Filter论文总结
固态硬盘对游戏运行的帮助有多少
CVE-2022-33891漏洞复现
Where is the foreign literature needed to write the graduation thesis?
JMeter distributed
BSN IPFs (interstellar file system) private network introduction, functions, architecture and characteristics, access instructions
PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法
如何查询外文文献?

