当前位置:网站首页>Key point detection data preparation and model design based on u-net Network -- detection model of four key points of industrial components
Key point detection data preparation and model design based on u-net Network -- detection model of four key points of industrial components
2022-06-13 00:55:00 【Your name is Yu yuezheng】
Key point detection data preparation and based on U-net Network model design
Entire project code
Already in GitHub The open source , Only one effect display picture has been uploaded , You can make improvements according to the actual projects you encounter
Project address https://github.com/ExileSaber/Industry-Keypoint-Detection
The effect of the model is as follows
The green points are the key points of the dimension , The red dot is the key point predicted by the model 
Brief introduction
Key point detection of industrial image based on self annotation , Each picture is marked with 4 A key point , Adopted U-net The Internet
primary coverage
- This part is mainly the first time for an individual to do the task of target detection , For practicing and understanding the Internet
- The network uses U-net
- Tag construction uses Coordinate Method , The loss function only uses the sum of the squares of the distances between the real coordinate points and the predicted coordinate points
The subsequent exploration process
- Tag build attempt Heatmap and Heatmap + Offsets
- Network structure improvement
Data preparation
First, you should mark the position of the key points you need to detect on the picture , The annotation software used by the author is labelme, After the picture annotation is completed, you can get a json file
Match the marked picture with the corresponding json The file is saved in the corresponding folder , In this project, the marked data is divided into training set and test set , The saved path is as follows 
Only one image demonstration is given in the path , If the picture format is not jpg perhaps json Data preprocessing problems such as the inconsistency between the key name of the key point coordinates stored in the file and the author , By modifying the data_pre.py Problem solvable
Code instructions
For each python The function of the document is briefly explained , See for details GitHub: https://github.com/ExileSaber/Industry-Keypoint-Detection
config.py
Network model parameters 、 Training path 、 Test path and other parameters
import torch
config = {
# Network training
'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
'batch_size': 1,
'epochs': 1000,
'save_epoch': 100,
# Network evaluation section
'test_batch_size': 1,
'test_threshold': 0.5,
# Set the path section
'train_date': '07_23_2',
'train_way': 'train',
'test_date': '07_23_2',
'test_way': 'test',
}
data_pre.py
Read the corresponding... According to the picture json File and get the marked N There are three coordinate point data , Turn into a N×2 Two dimensions of ndarray data type
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from config import config as cfg
import cv2
# json Become Gauss np
def json_to_numpy(dataset_path):
# Saved path
imgs_path = os.path.join(dataset_path, 'imgs')
labels_path = os.path.join(dataset_path, 'labels')
# Start to deal with
for name in os.listdir(imgs_path):
# Read in label
with open(os.path.join(os.path.join(labels_path),
name.split('.')[0] + '.json'), 'r', encoding='utf8')as fp:
json_data = json.load(fp)
points = json_data['shapes']
landmarks = []
for point in points:
for p in point['points'][0]:
landmarks.append(p)
landmarks = np.array(landmarks)
return landmarks
determine_rotation_angle.py
Calculate the rotation angle of the object ( Part of the project , Calculate the rotation angle of the object in a certain direction based on the detected key points )
models.py
structure 4 Key point detection U-net A network model , The last two layers of the network are the full connection layer , The convoluted three-dimensional data is transformed into 4 Coordinate data of key points
from torchsummaryX import summary
from net_util import *
# Unet Down sampling module , Double convolution
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, channel_reduce=False): # Just define the methods that need to be used in the network
super(DoubleConv, self).__init__()
# Coefficient of channel reduction
coefficient = 2 if channel_reduce else 1
self.down = nn.Sequential(
nn.Conv2d(in_channels, coefficient * out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(coefficient * out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(coefficient * out_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.down(x)
# On the sampling ( Transpose convolution plus residual link )
class Up(nn.Module):
# Be sure to input ,in_channels Is to be fed into quadratic convolution channel,out_channels Is after quadratic convolution channel
def __init__(self, in_channels, out_channels):
super().__init__()
# First, sample the characteristic graph
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=4, stride=2, padding=1)
self.conv = DoubleConv(in_channels, out_channels, channel_reduce=True)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x1, x2], dim=1)
x = self.conv(x)
return x
# simple U-net Model
class U_net(nn.Module):
def __init__(self): # Just define the methods that need to be used in the network
super(U_net, self).__init__()
# Down sampling
self.double_conv1 = DoubleConv(3, 32)
self.double_conv2 = DoubleConv(32, 64)
self.double_conv3 = DoubleConv(64, 128)
self.double_conv4 = DoubleConv(128, 256)
self.double_conv5 = DoubleConv(256, 256)
# On the sampling
self.up1 = Up(512, 128)
self.up2 = Up(256, 64)
self.up3 = Up(128, 32)
self.up4 = Up(64, 16)
# The last layer
self.conv = nn.Conv2d(16, 1, kernel_size=(1, 1), padding=0)
self.fc1 = nn.Linear(180224, 1024)
self.fc2 = nn.Linear(1024, 8)
def forward(self, x):
# down
# print(x.shape)
c1 = self.double_conv1(x) # (,32,512,512)
p1 = nn.MaxPool2d(2)(c1) # (,32,256,256)
c2 = self.double_conv2(p1) # (,64,256,256)
p2 = nn.MaxPool2d(2)(c2) # (,64,128,128)
c3 = self.double_conv3(p2) # (,128,128,128)
p3 = nn.MaxPool2d(2)(c3) # (,128,64,64)
c4 = self.double_conv4(p3) # (,256,64,64)
p4 = nn.MaxPool2d(2)(c4) # (,256,32,32)
c5 = self.double_conv5(p4) # (,256,32,32)
# The last convolution will not be pooled
# up
u1 = self.up1(c5, c4) # (,128,64,64)
u2 = self.up2(u1, c3) # (,64,128,128)
u3 = self.up3(u2, c2) # (,32,256,256)
u4 = self.up4(u3, c1) # (,16,512,512)
# The last layer , Insinuate to 3 A feature map
x1 = self.conv(u4)
# print(x1.shape)
x1 = x1.view(x1.size(0), -1)
# print(x1.shape)
x = self.fc1(x1)
out = self.fc2(x)
return out
def summary(self, net):
x = torch.rand(cfg['batch_size'], 3, 352, 512) # 352*512
# Feeding equipment
x = x.to(cfg['device'])
# Output y Of shape
# print(net(x).shape)
# Show network structure
summary(net, x)
net_util.py
Read the picture data and its corresponding
import torch
import os
import numpy as np
from torch import nn
import torchvision
from config import config as cfg
import torch.utils.data
from torchvision import datasets, transforms, models
import cv2
from data_pre import json_to_numpy
# box_3D Data warehouse
class Dataset(torch.utils.data.Dataset):
# initialization
def __init__(self, dataset_path):
self.dataset_path = dataset_path
self.img_name_list = os.listdir(os.path.join(dataset_path, 'imgs'))
# according to index Returns the image of the location and label
def __getitem__(self, index):
# First processing img
img = cv2.imread(os.path.join(self.dataset_path, 'imgs', self.img_name_list[index]))
img = cv2.resize(img, (512, 352))
img = transforms.ToTensor()(img)
# Read in labels
mask = json_to_numpy(self.dataset_path)
# mask = np.load(os.path.join(self.dataset_path, 'masks', self.img_name_list[index].split('.')[0] + '.npy'))
mask = torch.tensor(mask, dtype=torch.float32)
return img, mask
# The size of the dataset
def __len__(self):
return len(self.img_name_list)
test_main.py
In the test set ( With marked key points json) Test the effect of the model
train_main.py
Training models on training sets
Model effect
The green points are the key points of the dimension , The red dot is the key point predicted by the model 
边栏推荐
- [JS component] create a custom horizontal and vertical scroll bar following the steam style
- How many steps are appropriate for each cycle of deep learning?
- Download nail live playback through packet capturing
- Druid reports an error connection holder is null
- Et5.0 simply transform referencecollectorieditor
- Aunt learning code sequel: ability to sling a large number of programmers
- @Disallowcurrentexecution prevents quartz scheduled tasks from executing in parallel
- MySQL lpad() and rpad() concatenate string functions with specified length
- 三角波与三角波卷积
- [virtual machine] notes on virtual machine environment problems
猜你喜欢

Arduino controls tb6600 driver +42 stepper motor

ROS从入门到精通(零) 教程导读

.net core 抛异常对性能影响的求证之路

Maybe we can figure out the essence of the Internet after the dust falls

Arduino control soil moisture sensor

深度学习模型剪枝

Kotlin collaboration, the life cycle of a job

Canvas airplane game

How to solve the duplication problem when MySQL inserts data in batches?

kotlin 协程withContext切换线程
随机推荐
Arduino controls tb6600 driver +42 stepper motor
Go simple read database
Common skills of quantitative investment - index part 2: detailed explanation of BOL (Bollinger line) index, its code implementation and drawing
生物解锁--指纹录入流程
Canvas game lower level 100
Stack overflow learning summary
Android Weather
深度学习每周期的步数多少合适?
Three column simple Typecho theme lanstar/ Blue Star Typecho theme
ImportError: cannot import name 'get_ora_doc' from partially initialized module
市值破万亿,连续三个月销量破10万,比亚迪会成为最强国产品牌?
Binary tree - right view
The scope builder coroutinescope, runblocking and supervisorscope of kotlin collaboration processes run synchronously. How can other collaboration processes not be suspended when the collaboration pro
Build your own PE manually from winpe of ADK
Oceanbase is the leader in the magic quadrant of China's database in 2021
Zhouchuankai, Bank of Tianjin: from 0 to 1, my experience in implementing distributed databases
Introduction to ROS from introduction to mastery (zero) tutorial
(01).NET MAUI实战 建项目
Comparison of disk partition modes (MBR and GPT)
[server data recovery] successful cases of data loss recovery during data migration between storage servers