当前位置:网站首页>yolov1
yolov1
2022-07-26 21:23:00 【通宵睡一宿】
论文: https://arxiv.org/abs/1506.02640
https://arxiv.org/abs/1506.02640target的格式:7 * 7 * 30 前20个是类别,然后是
[box1_confidence , x,y,w,h,box2_confidence,x,y,w,h]
记得对输入图片进行resize处理
backbone
import torch
import torch.nn as nn
architecture_config = [
(7, 64, 2, 3),
"M",
(3, 192, 1, 1),
"M",
(1, 128, 1, 0),
(3, 256, 1, 1),
(1, 256, 1, 0),
(3, 512, 1, 1),
"M",
[(1, 256, 1, 0), (3, 512, 1, 1), 4],
(1, 512, 1, 0),
(3, 1024, 1, 1),
"M",
[(1, 512, 1, 0), (3, 1024, 1, 1), 2],
(3, 1024, 1, 1),
(3, 1024, 2, 1),
(3, 1024, 1, 1),
(3, 1024, 1, 1),
]
class CNN_BLOCK(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(CNN_BLOCK, self).__init__()
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.LeakyReLU(0.1)
def forward(self, x):
return self.relu(self.bn(self.cnn(x)))
class Yolo_v1(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(Yolo_v1, self).__init__()
self.S = S
self.B = B
self.C = C
self.cnn = self._create_cnn()
self.fc = self._create_fc()
def _create_cnn(self):
in_channels = 3
layers = []
for layer in architecture_config:
if type(layer) == str:
layers += [
nn.MaxPool2d(kernel_size=2, stride=2)
]
elif type(layer) == tuple:
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=layer[1], kernel_size=layer[0], stride=layer[2],
padding=layer[3])
]
in_channels = layer[1]
else:
conv1 = layer[0]
conv2 = layer[1]
for _ in range(layer[-1]):
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=conv1[1], kernel_size=conv1[0], stride=conv1[2],
padding=conv1[3])
]
in_channels = conv1[1]
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=conv2[1], kernel_size=conv2[0], stride=conv2[2],
padding=conv2[3])
]
in_channels = conv2[1]
return nn.Sequential(*layers)
def _create_fc(self):
return nn.Sequential(
nn.Flatten(),
nn.Linear(1024 * self.S * self.S, 4096),
nn.LeakyReLU(0.1),
nn.Linear(4096, self.S * self.S * (self.C + self.B * 5))
)
def forward(self, x):
x = self.cnn(x)
x = self.fc(x)
return x
输入batch * 3 * 448 * 448 输出 batch * 7 * 7 * 30
loss 按照论文
import torch
import torch.nn as nn
from utils import intersection_over_union
class Yolo_v1_loss(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(Yolo_v1_loss, self).__init__()
self.mse = nn.MSELoss(reduction='sum')
self.S = S
self.B = B
self.C = C
self.coord = 5
self.noobj = 0.5
def forward(self, predict, target):
# predict: N,(S * S * (B * 5 + C))
# target: N,S,S,(B * 5 + C)
predict = predict.reshape(-1, self.S, self.S, self.B * 5 + self.C)
"""计算每个anchor预测的第一个box和实际box的iou值"""
iou_b1 = intersection_over_union(predict[..., 21:25], target[..., 21:25])
"""计算每个anchor预测的第二个box和实际box的iou值"""
iou_b2 = intersection_over_union(predict[..., 26:30], target[..., 21:25])
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
iou_max, best_box_idx = torch.max(ious, dim=0)
exists_box = target[..., 20:21]
# ============= #
# 开始计算loss #
# ============= #
"""bounding_box损失"""
"""先得到最优iou对应的box"""
box_predict = exists_box * (best_box_idx * predict[..., 26:30] + (1 - best_box_idx) * predict[..., 21:25])
box_target = exists_box * target[..., 21:25]
"""按照论文给的coord损失,首先将w和h开方"""
box_predict[..., 2:4] = torch.sign(box_predict[..., 2:4]) * torch.sqrt(torch.abs(box_predict[..., 2:4] + 1e-6))
box_target[..., 2:4] = torch.sqrt(box_target[..., 2:4])
"""得到box_loss"""
box_loss = self.mse(
# N * 4
torch.flatten(box_predict, end_dim=-2),
torch.flatten(box_target, end_dim=-2)
)
"""confidence损失"""
"""先得到最优iou对应的confidence"""
confidence_predict = best_box_idx * predict[..., 25:26] + (1 - best_box_idx) * predict[..., 20:21]
confidence_target = target[..., 20:21]
"""这里只极大化目标位置的confidence,其他位置的损失由于太多了会影响,所以给个权重参数为noobj来弱化其他位置的损失"""
confidence_loss = self.mse(
torch.flatten(exists_box * confidence_predict),
torch.flatten(exists_box * confidence_target)
)
no_confidence_loss = self.mse(
torch.flatten((1 - exists_box) * predict[..., 20:21]),
torch.flatten((1 - exists_box) * confidence_target)
)
no_confidence_loss += self.mse(
torch.flatten((1 - exists_box) * predict[..., 25:26]),
torch.flatten((1 - exists_box) * confidence_target)
)
"""计算classes损失"""
class_loss = self.mse(
# N * S * S * 20
torch.flatten(exists_box * predict[..., :20], end_dim=-2),
torch.flatten(exists_box * target[..., :20], end_dim=-2)
)
"""相加"""
loss = self.coord * box_loss + confidence_loss + self.noobj * no_confidence_loss + class_loss
return loss
utils
import xml.etree.ElementTree as ET
import os
import os.path
import numpy as np
import torch
import matplotlib.pyplot as plt # 导入绘图包
import cv2 as cv
class_dict = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
class_dict = {name: i for i, name in enumerate(class_dict)}
class_list = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
def parse_xml():
xml_path = '../../VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/Annotations/'
xml_file = os.listdir(xml_path)
if not os.path.exists('labels'):
os.makedirs('labels')
for file in xml_file:
with open('labels/' + file.replace('.xml', '.txt'), 'w') as f:
root = ET.parse(xml_path + file).getroot()
width = float(root.find('size/width').text)
height = float(root.find('size/height').text)
for child in root.findall('object'):
"""类别"""
c = child.find('name').text
c = class_dict[c]
xmin = float(child.find('bndbox').find('xmin').text)
ymin = float(child.find('bndbox').find('ymin').text)
xmax = float(child.find('bndbox').find('xmax').text)
ymax = float(child.find('bndbox').find('ymax').text)
x_center = (xmin + xmax) / (2 * width)
y_center = (ymin + ymax) / (2 * height)
w = (xmax - xmin) / width
h = (ymax - ymin) / height
f.write(' '.join([str(c), str(x_center), str(y_center), str(w), str(h)]) + '\n')
def intersection_over_union(box1, box2, mode='center'):
if mode == 'center':
"""x_center,y_center,w,h"""
"""xmin,ymin,xmax,ymax"""
box1_x1 = box1[..., 0:1] - box1[..., 2:3] / 2
box1_y1 = box1[..., 1:2] - box1[..., 3:4] / 2
box1_x2 = box1[..., 0:1] + box1[..., 2:3] / 2
box1_y2 = box1[..., 1:2] + box1[..., 3:4] / 2
box2_x1 = box2[..., 0:1] - box2[..., 2:3] / 2
box2_y1 = box2[..., 1:2] - box2[..., 3:4] / 2
box2_x2 = box2[..., 0:1] + box2[..., 2:3] / 2
box2_y2 = box2[..., 1:2] + box2[..., 3:4] / 2
else:
box1_x1 = box1[..., 0:1]
box1_y1 = box1[..., 1:2]
box1_x2 = box1[..., 0:1]
box1_y2 = box1[..., 1:2]
box2_x1 = box2[..., 0:1]
box2_y1 = box2[..., 1:2]
box2_x2 = box2[..., 0:1]
box2_y2 = box2[..., 1:2]
"""计算交集面积"""
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
intersection_area = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
"""计算并集面积"""
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
return intersection_area / (box1_area + box2_area - intersection_area + 1e-6)
def non_max_suppression(bboxes, iou_threshold=0.5, threshold=0.4):
# bboxes: [[class,confidence,x,y,w,h],...]
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
"""类别不一样或者iou小于某一个阈值说明俩个box不是预测同一个物体"""
bboxes = [
box for box in bboxes
if box[0] != chosen_box[0]
or intersection_over_union(torch.tensor(chosen_box[2:6]), torch.tensor(box[2:6]) < iou_threshold)
]
bboxes_nms.append(chosen_box)
return bboxes_nms
def plot_box(boxes, img):
H,W = img.shape[:2]
plt.imshow(img)
current_axis = plt.gca()
for bbox in boxes:
classes = bbox[0]
confidence = round(bbox[1].item(),2)
x = bbox[2]
y = bbox[3]
w = bbox[4]
h = bbox[5]
xmin = (x - w / 2) * W
xmax = (x + w / 2) * W
ymin = (y - h / 2) * H
ymax = (y + h / 2) * H
current_axis.add_patch(
plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color='green', fill=False, linewidth=2))
current_axis.text(xmin, ymin, class_list[int(classes)] + ': {}'.format(confidence),
color='white', bbox={'facecolor': 'green', 'alpha': 1.0})
plt.show()
def get_boxes(pre,S = 7):
# pre.shape == 1 * 7 * 7 * 30
'''[[class confidence,x,y,w,h],...]'''
cell_indices = torch.arange(7).repeat(1, 7, 1).unsqueeze(-1)
pre[...,21:22] = (pre[...,21:22] + cell_indices) / S
pre[...,26:27] = (pre[...,26:27] + cell_indices) / S
pre[..., 22:23] = (pre[..., 22:23] + cell_indices.permute(0, 2, 1, 3)) / S
pre[..., 27:28] = (pre[..., 27:28] + cell_indices.permute(0, 2, 1, 3)) / S
pre[...,23:25] = pre[...,23:25] / S
pre[..., 28:30] = pre[..., 28:30] / S
pre = pre.reshape(7,7,30)
classes = torch.max(pre[..., :20],dim=-1).indices.unsqueeze(-1)
box1 = pre[...,21:25]
box2 = pre[...,26:30]
confidence1 = pre[...,20:21]
confidence2 = pre[...,25:26]
new_box = torch.zeros((7*7*2,6))
new_box[:49,2:6] = torch.flatten(box1,end_dim=-2)
new_box[49:,2:6] = torch.flatten(box2,end_dim=-2)
new_box[:49,0:1] = torch.flatten(classes,end_dim=-2)
new_box[49:,0:1] = torch.flatten(classes,end_dim=-2)
new_box[:49,1:2] = torch.flatten(confidence1,end_dim=-2)
new_box[49:,1:2] = torch.flatten(confidence2,end_dim=-2)
return non_max_suppression(new_box)
不知道要train多久,租了个服务器一直在跑
边栏推荐
- Shrimp Shope gets the product details API according to the ID
- FreeRTOS personal notes - Software Timer
- Isilon 的OneFs常见操作命令(一)
- Pytoch -- used by visdom
- JS verify complex password
- Afnetworking understand
- Isilon's onefs common operation commands (I)
- Thorough load balancing
- VB.net Chart1的处理
- Basic operation of (C language) files
猜你喜欢

If you do not add waitkey() function after imshow() function, it will not be displayed

I successfully landed the automatic testing post, with a maximum monthly salary of 15.4k. I'm great~

matlab 激励模型 三角波频谱

Try new functions | decrypt Doris complex data type array

Basic operation of (C language) files
![[mysql]substr usage - query the value of specific digits of a field in the table](/img/d5/68658ff15f204dc97abfe7c9e6b354.png)
[mysql]substr usage - query the value of specific digits of a field in the table

Understanding and practice of the trend of Bank of London foreign exchange

Isilon 的OneFs常见操作命令(一)

Development to testing: a six-year road to automation from scratch

梦里的一碗面
随机推荐
[mysql]substr usage - query the value of specific digits of a field in the table
matlab 短时自相关实现
Finding a new direction for the development of digital retail is the key to ensure that digital retail can enter a new stage of development
If you do not add waitkey() function after imshow() function, it will not be displayed
仅需一个依赖给Swagger换上新皮肤,既简单又炫酷
Pytoch uses RNN model to build person name classifier
Can you use redis? Then come and learn about redis protocol
Just one dependency to give swagger a new skin, which is simple and cool~
js验证复杂密码
Use of cmake
Afnetworking understand
Thorough load balancing
第15章 mysql用户管理
NPM, NPM Chinese documents, NPM learning and using
day07-
Matlab draws short-term average amplitude spectrum
梦里的一碗面
easyui的combobox默认选中第一个选项
深入源码剖析String类为什么不可变?(还不明白就来打我)
Method overloading and method rewriting