当前位置:网站首页>yolov1
yolov1
2022-07-26 21:23:00 【通宵睡一宿】
论文: https://arxiv.org/abs/1506.02640
https://arxiv.org/abs/1506.02640target的格式:7 * 7 * 30 前20个是类别,然后是
[box1_confidence , x,y,w,h,box2_confidence,x,y,w,h]
记得对输入图片进行resize处理
backbone
import torch
import torch.nn as nn
architecture_config = [
(7, 64, 2, 3),
"M",
(3, 192, 1, 1),
"M",
(1, 128, 1, 0),
(3, 256, 1, 1),
(1, 256, 1, 0),
(3, 512, 1, 1),
"M",
[(1, 256, 1, 0), (3, 512, 1, 1), 4],
(1, 512, 1, 0),
(3, 1024, 1, 1),
"M",
[(1, 512, 1, 0), (3, 1024, 1, 1), 2],
(3, 1024, 1, 1),
(3, 1024, 2, 1),
(3, 1024, 1, 1),
(3, 1024, 1, 1),
]
class CNN_BLOCK(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(CNN_BLOCK, self).__init__()
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.LeakyReLU(0.1)
def forward(self, x):
return self.relu(self.bn(self.cnn(x)))
class Yolo_v1(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(Yolo_v1, self).__init__()
self.S = S
self.B = B
self.C = C
self.cnn = self._create_cnn()
self.fc = self._create_fc()
def _create_cnn(self):
in_channels = 3
layers = []
for layer in architecture_config:
if type(layer) == str:
layers += [
nn.MaxPool2d(kernel_size=2, stride=2)
]
elif type(layer) == tuple:
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=layer[1], kernel_size=layer[0], stride=layer[2],
padding=layer[3])
]
in_channels = layer[1]
else:
conv1 = layer[0]
conv2 = layer[1]
for _ in range(layer[-1]):
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=conv1[1], kernel_size=conv1[0], stride=conv1[2],
padding=conv1[3])
]
in_channels = conv1[1]
layers += [
CNN_BLOCK(in_channels=in_channels, out_channels=conv2[1], kernel_size=conv2[0], stride=conv2[2],
padding=conv2[3])
]
in_channels = conv2[1]
return nn.Sequential(*layers)
def _create_fc(self):
return nn.Sequential(
nn.Flatten(),
nn.Linear(1024 * self.S * self.S, 4096),
nn.LeakyReLU(0.1),
nn.Linear(4096, self.S * self.S * (self.C + self.B * 5))
)
def forward(self, x):
x = self.cnn(x)
x = self.fc(x)
return x
输入batch * 3 * 448 * 448 输出 batch * 7 * 7 * 30
loss 按照论文
import torch
import torch.nn as nn
from utils import intersection_over_union
class Yolo_v1_loss(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(Yolo_v1_loss, self).__init__()
self.mse = nn.MSELoss(reduction='sum')
self.S = S
self.B = B
self.C = C
self.coord = 5
self.noobj = 0.5
def forward(self, predict, target):
# predict: N,(S * S * (B * 5 + C))
# target: N,S,S,(B * 5 + C)
predict = predict.reshape(-1, self.S, self.S, self.B * 5 + self.C)
"""计算每个anchor预测的第一个box和实际box的iou值"""
iou_b1 = intersection_over_union(predict[..., 21:25], target[..., 21:25])
"""计算每个anchor预测的第二个box和实际box的iou值"""
iou_b2 = intersection_over_union(predict[..., 26:30], target[..., 21:25])
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
iou_max, best_box_idx = torch.max(ious, dim=0)
exists_box = target[..., 20:21]
# ============= #
# 开始计算loss #
# ============= #
"""bounding_box损失"""
"""先得到最优iou对应的box"""
box_predict = exists_box * (best_box_idx * predict[..., 26:30] + (1 - best_box_idx) * predict[..., 21:25])
box_target = exists_box * target[..., 21:25]
"""按照论文给的coord损失,首先将w和h开方"""
box_predict[..., 2:4] = torch.sign(box_predict[..., 2:4]) * torch.sqrt(torch.abs(box_predict[..., 2:4] + 1e-6))
box_target[..., 2:4] = torch.sqrt(box_target[..., 2:4])
"""得到box_loss"""
box_loss = self.mse(
# N * 4
torch.flatten(box_predict, end_dim=-2),
torch.flatten(box_target, end_dim=-2)
)
"""confidence损失"""
"""先得到最优iou对应的confidence"""
confidence_predict = best_box_idx * predict[..., 25:26] + (1 - best_box_idx) * predict[..., 20:21]
confidence_target = target[..., 20:21]
"""这里只极大化目标位置的confidence,其他位置的损失由于太多了会影响,所以给个权重参数为noobj来弱化其他位置的损失"""
confidence_loss = self.mse(
torch.flatten(exists_box * confidence_predict),
torch.flatten(exists_box * confidence_target)
)
no_confidence_loss = self.mse(
torch.flatten((1 - exists_box) * predict[..., 20:21]),
torch.flatten((1 - exists_box) * confidence_target)
)
no_confidence_loss += self.mse(
torch.flatten((1 - exists_box) * predict[..., 25:26]),
torch.flatten((1 - exists_box) * confidence_target)
)
"""计算classes损失"""
class_loss = self.mse(
# N * S * S * 20
torch.flatten(exists_box * predict[..., :20], end_dim=-2),
torch.flatten(exists_box * target[..., :20], end_dim=-2)
)
"""相加"""
loss = self.coord * box_loss + confidence_loss + self.noobj * no_confidence_loss + class_loss
return loss
utils
import xml.etree.ElementTree as ET
import os
import os.path
import numpy as np
import torch
import matplotlib.pyplot as plt # 导入绘图包
import cv2 as cv
class_dict = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
class_dict = {name: i for i, name in enumerate(class_dict)}
class_list = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
def parse_xml():
xml_path = '../../VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/Annotations/'
xml_file = os.listdir(xml_path)
if not os.path.exists('labels'):
os.makedirs('labels')
for file in xml_file:
with open('labels/' + file.replace('.xml', '.txt'), 'w') as f:
root = ET.parse(xml_path + file).getroot()
width = float(root.find('size/width').text)
height = float(root.find('size/height').text)
for child in root.findall('object'):
"""类别"""
c = child.find('name').text
c = class_dict[c]
xmin = float(child.find('bndbox').find('xmin').text)
ymin = float(child.find('bndbox').find('ymin').text)
xmax = float(child.find('bndbox').find('xmax').text)
ymax = float(child.find('bndbox').find('ymax').text)
x_center = (xmin + xmax) / (2 * width)
y_center = (ymin + ymax) / (2 * height)
w = (xmax - xmin) / width
h = (ymax - ymin) / height
f.write(' '.join([str(c), str(x_center), str(y_center), str(w), str(h)]) + '\n')
def intersection_over_union(box1, box2, mode='center'):
if mode == 'center':
"""x_center,y_center,w,h"""
"""xmin,ymin,xmax,ymax"""
box1_x1 = box1[..., 0:1] - box1[..., 2:3] / 2
box1_y1 = box1[..., 1:2] - box1[..., 3:4] / 2
box1_x2 = box1[..., 0:1] + box1[..., 2:3] / 2
box1_y2 = box1[..., 1:2] + box1[..., 3:4] / 2
box2_x1 = box2[..., 0:1] - box2[..., 2:3] / 2
box2_y1 = box2[..., 1:2] - box2[..., 3:4] / 2
box2_x2 = box2[..., 0:1] + box2[..., 2:3] / 2
box2_y2 = box2[..., 1:2] + box2[..., 3:4] / 2
else:
box1_x1 = box1[..., 0:1]
box1_y1 = box1[..., 1:2]
box1_x2 = box1[..., 0:1]
box1_y2 = box1[..., 1:2]
box2_x1 = box2[..., 0:1]
box2_y1 = box2[..., 1:2]
box2_x2 = box2[..., 0:1]
box2_y2 = box2[..., 1:2]
"""计算交集面积"""
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
intersection_area = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
"""计算并集面积"""
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
return intersection_area / (box1_area + box2_area - intersection_area + 1e-6)
def non_max_suppression(bboxes, iou_threshold=0.5, threshold=0.4):
# bboxes: [[class,confidence,x,y,w,h],...]
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
"""类别不一样或者iou小于某一个阈值说明俩个box不是预测同一个物体"""
bboxes = [
box for box in bboxes
if box[0] != chosen_box[0]
or intersection_over_union(torch.tensor(chosen_box[2:6]), torch.tensor(box[2:6]) < iou_threshold)
]
bboxes_nms.append(chosen_box)
return bboxes_nms
def plot_box(boxes, img):
H,W = img.shape[:2]
plt.imshow(img)
current_axis = plt.gca()
for bbox in boxes:
classes = bbox[0]
confidence = round(bbox[1].item(),2)
x = bbox[2]
y = bbox[3]
w = bbox[4]
h = bbox[5]
xmin = (x - w / 2) * W
xmax = (x + w / 2) * W
ymin = (y - h / 2) * H
ymax = (y + h / 2) * H
current_axis.add_patch(
plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color='green', fill=False, linewidth=2))
current_axis.text(xmin, ymin, class_list[int(classes)] + ': {}'.format(confidence),
color='white', bbox={'facecolor': 'green', 'alpha': 1.0})
plt.show()
def get_boxes(pre,S = 7):
# pre.shape == 1 * 7 * 7 * 30
'''[[class confidence,x,y,w,h],...]'''
cell_indices = torch.arange(7).repeat(1, 7, 1).unsqueeze(-1)
pre[...,21:22] = (pre[...,21:22] + cell_indices) / S
pre[...,26:27] = (pre[...,26:27] + cell_indices) / S
pre[..., 22:23] = (pre[..., 22:23] + cell_indices.permute(0, 2, 1, 3)) / S
pre[..., 27:28] = (pre[..., 27:28] + cell_indices.permute(0, 2, 1, 3)) / S
pre[...,23:25] = pre[...,23:25] / S
pre[..., 28:30] = pre[..., 28:30] / S
pre = pre.reshape(7,7,30)
classes = torch.max(pre[..., :20],dim=-1).indices.unsqueeze(-1)
box1 = pre[...,21:25]
box2 = pre[...,26:30]
confidence1 = pre[...,20:21]
confidence2 = pre[...,25:26]
new_box = torch.zeros((7*7*2,6))
new_box[:49,2:6] = torch.flatten(box1,end_dim=-2)
new_box[49:,2:6] = torch.flatten(box2,end_dim=-2)
new_box[:49,0:1] = torch.flatten(classes,end_dim=-2)
new_box[49:,0:1] = torch.flatten(classes,end_dim=-2)
new_box[:49,1:2] = torch.flatten(confidence1,end_dim=-2)
new_box[49:,1:2] = torch.flatten(confidence2,end_dim=-2)
return non_max_suppression(new_box)
不知道要train多久,租了个服务器一直在跑
边栏推荐
- Oppo self-developed large-scale knowledge map and its application in digital intelligence engineering
- Get network time by unity
- Vb.net chart1 processing
- : could not determine a constructor for the tag !RootAdmin
- 现货黄金操作指南与建议(上)
- Thoroughly understand the principle and implementation of service discovery
- 进程的概念和分类
- ansible安装及使用
- Method overloading and method rewriting
- 基于CAShapeLayer和贝塞尔曲线的圆形进度条动画
猜你喜欢

Matlab pitch period estimation post-processing

仅需一个依赖给Swagger换上新皮肤,既简单又炫酷~

Basic operation of (C language) files

Thorough load balancing
![[mysql]substr usage - query the value of specific digits of a field in the table](/img/d5/68658ff15f204dc97abfe7c9e6b354.png)
[mysql]substr usage - query the value of specific digits of a field in the table

Isilon's onefs common operation commands (I)

TASK04|分类分析

Kalibr calibration realsensed435i -- multi camera calibration

A new technical director asked me to do an IP territorial function~

小白学习MySQL - Derived Table
随机推荐
45. Instance segmented labelme dataset to coco dataset and coco dataset to labelme dataset
OPPO 自研大规模知识图谱及其在数智工程中的应用
Use of cmake
Ansible installation and use
Shrimp Shope gets the product details API according to the ID
也谈数据治理
C# 数据类型_摘自菜鸟教程
Leetcode exercise - Sword finger offer II 005. maximum product of word length
寻找数字零售的发展新方向,才是保证数字零售可以进入到全新发展阶段的关键
第15章 mysql用户管理
Altium designer 22 Chinese character garbled
FreeRTOS personal notes - Software Timer
【C语言基础】17 链表初探
Go----Go语言中的变量使用方法
MPLS基础知识概述
Supplement - nonlinear programming
iptables防止nmap扫描以及binlog实现增量备份
Just one dependency to give swagger a new skin, which is simple and cool
仅需一个依赖给Swagger换上新皮肤,既简单又炫酷~
正规方程法(Normal Equation)原理以及与梯度下降法的区别