Implementing yolox from scratch: dataset class
This series refers to bloggers Bubbliiiing Blog and code , Link to :https://blog.csdn.net/weixin_44791964/article/details/120476949
Before recurrence , It is necessary to know YOLOX Principle , as well as pytorch Use of framework , This is the most basic part .
Let's start the journey of recurrence .
1 Data set and its division
(1) Data set file organization
Establish a directory structure as shown in the figure to store data
among Annotations For storing label files ( namely xml file ),ImageSets It is used to store... After data set division txt file ,JPEGImages For storing pictures , The image should correspond to the name of the label file .
Copy pictures and label files to Annotations and JPEGImages in , After the copy :
Annotations It's like this
JPEGImages It's like this
(2) The partition of data sets
This dataset 10506 A picture , Now it is planned to follow 7:1:2 The way to divide the training set 、 Validation set and test set , And store the corresponding file name in train.txt、val.txt and test.txt In file , So it can be yolox_from_scratch The next new one is called split_voc.py The program , The new directory structure is as follows :
split_voc.py The contents are as follows :
import os
import random
trainval_percent = 0.8 # Training set + Verify the total proportion
train_percent = 0.875 # Training set in trainval_percent Inside train Proportion ,0.875*0.8=0.7, Therefore, the training set accounts for the proportion of the total sample 70%
VOCdevkit_path = 'VOCdevkit' # Dataset file path
random.seed(0) # Seed setting , Enable the program to reproduce
print("Generate txt in ImageSets.")
xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations') # Tag file path
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main') # Training set 、 Verification set 、 Test set txt The path of the file
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
num = len(total_xml) # Get the total number of data set samples
list = range(num) # Get the index of the dataset sample
tv = int(num * trainval_percent) # Verification set + The total number of training set samples
tr = int(tv * train_percent) # Number of training set samples
trainval = random.sample(list, tv) # Training set + The list formed by the sample index of the validation set
train = random.sample(trainval, tr) # List of training set sample indexes
# random.sample(list, tv) From list Generate a length of tv New list , The elements in the new list are from list Obtained by sampling
# and list It's a range object , Represents the index of the dataset
print("train and val size", tv)
print("train size", tr)
ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
for i in list:
name = total_xml[i][:-4] + '\n' # total_xml[i][:-4] The reason why we only arrive -4, Because in the end 4 Is it .xml, We don't need this for the time being
if i in trainval:
if i in train:
print("Generate txt in ImageSets done.")
Generate txt in ImageSets.
train and val size 8404
train size 7353
Generate txt in ImageSets done.
here VOCdevkit/VOC2007/ImageSets/Main
Several files ,VOCdevkit The structure of is shown in the figure below :
four txt In the file is the file name of the sample after removing the suffix , for example train.txt As shown in the figure below
(3) from xml Extract the target information from the file ( Borders and categories )
The data set is divided , But the border and classification of the target are still xml In file , Now let's extract it .
We are yolox_from_scratch Create a new one called model_data Folder , Used to store the required classification information , After new construction , The project structure is as follows :
stay yolox_from_scratch The next new one is called annotations_convert.py The program , The contents are as follows :
import os
import xml.etree.ElementTree as ET
VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')] # Data sets
VOCdevkit_path = 'VOCdevkit' # Dataset file path
classes = ['D00', 'D10', 'D20', 'D40'] # Class name
def convert_annotation(year, image_id, list_file):
in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml' % (year, image_id)), encoding='utf-8')
tree = ET.parse(in_file) # analysis xml file
root = tree.getroot() # Get the root directory
for obj in root.iter('object'):
difficult = 0
if obj.find('difficult') != None:
difficult = obj.find('difficult').text
cls = obj.find('name').text # Get the class name of the target
if cls not in classes: # Not all targets need to be detected
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)),
int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
# ",".join([str(a) for a in b]) Generate a new string , This string uses “,” Separate the list
if __name__ == '__main__':
print("Generate 2007_train.txt and 2007_val.txt for train.")
for year, image_set in VOCdevkit_sets:
image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt' % (year, image_set)),
# os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt' % (year, image_set))
# return VOCdevkit/VOC2007/ImageSets/Main/train.txt or VOCdevkit/VOC2007/ImageSets/Main/test.txt
# read() Is to read all at once , It returns a string , and readlines() It returns a list , Each element of the list is a row
# strip() Is to remove the empty characters at the beginning and end
# split() Enable it to press \n Symbol Division , because read() What is returned is a string composed of all lines , Line breaks are also included
list_file = open('%s_%s.txt' % (year, image_set), 'w', encoding='utf-8') # open 2007_train.txt perhaps 2007_val.txt
for image_id in image_ids:
list_file.write('%s/VOC%s/JPEGImages/%s.jpg' % (VOCdevkit_path, year, image_id)) # Write the picture file name
convert_annotation(year, image_id, list_file)
print("Generate 2007_train.txt and 2007_val.txt for train done.")
with open('model_data/voc_classes.txt', 'w+') as f:
The reason why the above program is so complex , Because it was copied from somewhere else , Pressed for time , There is no time to simplify
After the program runs , The directory structure becomes as follows :
stay yolox_from_scratch There are two more txt file , We turn on 2007_train.txt, The contents are as follows :
This txt The file puts the picture name and the corresponding target tag information on the same line ,2007_val.txt The content of is similar . There may be multiple targets in a picture ( Such as Japan_00000.jpg), There may also be no goal ( Such as Japan_00005.jpg).
stay model_data There is another one named voc_classes.txt Folder , The contents are as follows :
2 Dataset class
stay yolox_from_scratch Next, create a new package , be known as utils, Create a new one named dataloader.py The file of , The new structure is shown in the following figure :
And __len__
Again dataloader.py in , First import the package you want to use
from random import sample, shuffle
import cv2
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
In this py Define a dataset class in the file , Inheritance of such kind torch.utils.data Medium Dataset class , The self-made dataset class must implement three functions : __init__
and __getitem__
, They are initialization classes , Find the length len(obj), Get a single sample and its label through index .
First write __init__
These two functions :
import cv2
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
class YoloDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes,
is_train, mosaic=False, mixup=False, mosaic_prob=0.5, mixup_prob=0.5):
""" Args: annotation_lines: This is the label file ( for example 2007_train.txt) A list made up of each line in , adopt open after readlines() get input_shape: The image size input to the model num_classes: The number of classes to be detected is_train: Whether the corresponding model is in training state , This has an impact on whether to carry out ordinary data enhancement In training , Whether or not to use mosaic and mix_up Data to enhance , Must use ordinary data enhancement Ordinary data enhancement includes randomly adjusting the aspect ratio 、 Random image 、 Gamut distortion, etc If you are not in training ( namely eval state ), Then any form of data enhancement is not used mosaic: Whether to use mosaic data to enhance mixup: Whether to use mix_up Data to enhance mosaic_prob: When mosaic=True when , The probability of image mosaic data enhancement mixup_prob: When mixup=True when , The picture goes on mixup Probability of data enhancement """
super(YoloDataset, self).__init__()
self.annotation_lines = annotation_lines
self.length = len(self.annotation_lines) # Label length , It's actually the number of pictures
self.input_shape = input_shape # The image size input to the model
self.num_classes = num_classes # The number of categories to be detected
self.is_train = is_train # Whether the corresponding model is in training state
self.mosaic = mosaic # Whether to use mosaic data to enhance
self.mixup = mixup # Whether to use mix_up Data to enhance
self.mosaic_prob = mosaic_prob # When mosaic=True when , The probability of image mosaic data enhancement
self.mixup_prob = mixup_prob # When mixup=True when , The picture goes on mixup Probability of data enhancement
self.step_now = -1 # Used to count the number of pictures read
def __len__(self):
return self.length
Next is __getitem__
, Generally speaking , In the dataset class defined by yourself , This function is the most complex , Because in this function , Handle the label , Convert it to standard format , If data enhancement is involved , It is also handled in this function .( Generally in use torch Completing computer vision tasks , There are two hardest places to write , One is from here __getitem__
function , The other is to calculate the loss function )
def __getitem__(self, index):
index = index % self.length # Adjust the index to 0-self.length, Prevent indexes from crossing boundaries
self.step_now += 1 # Read picture count +1
# ---------------------------------------------------#
# Random enhancement of data during training
# Random enhancement of data is not carried out during verification
# ---------------------------------------------------#
if self.is_train:
if self.mosaic:
# I see the original yolox In the code ,mosaic and mixup Not independent , Only when mosaic by True when , Will discuss mixup Is it True
# However, the mosaic data enhancement code has not been fully understood , So here first pass
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, rand=True)
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, rand=False)
# First press the picture ImageNet The mean and variance of , Then adjust the channel index to the front
from utils.utils import preprocess_input
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
# Specify the data type , After data enhancement ,box The type of np.int32, Here it's transformed into np.float32
box = np.array(box, dtype=np.float32)
# If the current picture has no target , that box Will be an empty array , There is no type , The above command can also specify the type of empty array
# take box The coordinates of the upper and lower corners of are transformed into x,y,w,h
if len(box) != 0:
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
return image, box
In the program above , Called self.get_random_data
and preprocess_input
Two methods , Let's start with self.get_random_data
If it doesn't involve mosaic Data to enhance , So it's all self.get_random_data
Intermediate processing , If the model is in training , Then carry out traditional data enhancement ( Such as random scaling, etc ), If the model is in the evaluation state , Then do not do data enhancement .
Here are the functions get_random_data
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, rand=True):
""" Traditional data enhancement strategies , Including random scaling 、 Height and width twist 、 Random image 、 Gamut distortion About gamut (HSV Color model ), You can read this article :https://www.cnblogs.com/lfri/p/10426113.html Args: annotation_line:self.annotation_lines A row in the , There is the path of the picture 、box Tag information input_shape: Model input image size , in other words , Here we need to convert the image into this size jitter: Used to generate a zoom factor of width and height , for example jitter yes 0.3 When , The scaling factor is from (1-0.3,1+0.3) Randomly generate a hue: tonal sat: saturation val: Brightness rand: Whether random data enhancement is needed , Because data enhancement is required only when the model is in the training state , So here True、False Represents whether the model is in the training state Returns: """
We can start from annotation_line
Get images and box, These general information, whether training status or evaluation status , Can use
""" Divide the picture and annotation information """
line = annotation_line.split()
""" Read the image and convert it into RGB Images """
from utils.utils import cvtColor
image = Image.open(line[0])
image = cvtColor(image)
""" Get the height and width of the image and the input height and width of the model """
iw, ih = image.size # The width and height of the original image
h, w = input_shape # Input dimensions of the model , Enter the size of the model , Is high in front
""" Get the target box , And into numpy Array """
box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]) # line It's a list
# If there is no target in the picture , that line There is only one element in this list , That is, the path string of the image
# however line[1:] No mistake. , This will return an empty list , but line[1] Will report a mistake
# in other words , For list index out of bounds , If a single element is taken, an error will be reported , But if you take slices, you won't report an error
Here we call cvtColor
function , We are utils The new one in is called utils.py
The file of , The established project structure is :
stay utils.py
Write the following function in :
import numpy as np
# ---------------------------------------------------------#
# Convert the image to RGB Images , Prevent gray image from making mistakes in prediction .
# The code only supports RGB Image prediction , All other types of images are converted to RGB
# ---------------------------------------------------------#
def cvtColor(image):
"""image yes PIL.Image.open The return value of , The meaning of this function is to convert the image into RGB Three channels """
if len(np.shape(image)) == 3 and np.shape(image)[-2] == 3: # Check image Is it 3 Channels
return image
image = image.convert('RGB')
return image
go back to get_random_data
in , Let's deal with the situation when the model is in the evaluation state :
""" If not in training , There is no need for data enhancement , Go straight ahead letter_box conversion """
if not rand:
scale = min(w/iw, h/ih) # Scale according to the larger side of the height and width of the original image
# Because the input size of the model is square , therefore iw=ih, therefore w/iw and h/ih The denominator of is the same
# If the width of the original picture is smaller , that w/iw The larger ,min(w/iw, h/ih) Namely h/ih
# in other words ,scale The scale is determined according to the larger side of the height and width of the original image
nw = int(iw*scale) # The new width
nh = int(ih*scale) # New high
dx = (w-nw)//2 # letter_box The width of the left and right gray bars in the algorithm
dy = (h-nh)//2 # The width of the upper and lower gray bars
# Because horizontal and vertical , There are gray bars in only one direction , therefore dx and dy One of them must be 0,
# If the original picture is square , that dx and dy All two are 0
# Scale the image according to the new width and height
image = image.resize((nw,nh), Image.BICUBIC)
# Generate a grayscale image with a specified width and height as the canvas , The three color channels are 128
new_image = Image.new('RGB', (w, h), (128, 128, 128)) # (128, 128, 128) Is the three channel pixel value of the gray bar
# Paste the zoomed image into the center of the canvas
new_image.paste(image, (dx, dy)) # Paste the zoomed image into new_image The designated location of
image_data = np.array(new_image, np.float32) # Convert to specified format
# Adjust the real box
if len(box) > 0:
box[:, [0, 2]] = box[:, [0, 2]]*nw/iw + dx # Transform the abscissa of the upper and lower corners into letter_box Abscissa after
box[:, [1, 3]] = box[:, [1, 3]]*nh/ih + dy # Convert the vertical coordinates of the upper and lower corners into letter_box The ordinate of the back
box[:, 0:2][box[:, 0:2] < 0] = 0 # Negative value check ( Why does abscissa have negative value check , The ordinate does not ?)
box[:, 2][box[:, 2] > w] = w # Cross border inspection
box[:, 3][box[:, 3] > h] = h
# Are the above three really necessary ?
box_w = box[:, 2] - box[:, 0] #
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box Make the width and height greater than 1 Filter out the border of
return image_data, box
If the model is in training , Then the above... Will not be executed if
sentence , It's about data enhancement , The data enhancement here is divided into 4 Parts of : Random scaling and height width distortion 、 Random image 、 Gamut distortion .
The following procedure is Random scaling and height width distortion
""" Zoom and distort the image """
new_ar = w/h * self.rand(1-jitter, 1+jitter) / self.rand(1-jitter, 1+jitter) # Randomly generate a new aspect ratio
scale = self.rand(.25, 2) # Randomly generate a scaling factor
# Which is bigger, height or width ( According to new_ar To obtain a ), Just zoom in and out , The other is obtained according to the aspect ratio
if new_ar < 1: #
nh = int(scale*h) # Now press the zoom factor to zoom
nw = int(nh*new_ar) # According to the new height and the new aspect ratio , Get a new width
nw = int(scale*w)
nh = int(nw/new_ar)
# According to the new width and height , Zoom image
image = image.resize((nw, nh), Image.BICUBIC)
""" Add gray bars to the redundant parts of the image , Around here ( Or up and down ) Gray bar , Not necessarily the same thickness """
dx = int(self.rand(0, w-nw))
dy = int(self.rand(0, h-nh))
# above dx and dy It could be negative , because scales It may be greater than 1, that nh and nw It may be greater than h and w
new_image = Image.new('RGB', (w, h), (128, 128, 128)) # Generate canvas with specified width and height
new_image.paste(image, (dx, dy)) # Paste the scaled image to the specified position of the canvas
# If dx Greater than 0, It means that w>nw, Then the whole process is equivalent to shrinking horizontally , Then fill the left and right sides with gray strips
# If dx Less than 0, It means that w<nw, So it's equivalent to magnifying horizontally , Then cut the left and right sides
# dy It's the same thing , All in all , After the above command ,new_image The width and height of (w, h) 了
image = new_image
Here we call self.rand
function , This is a YoloDataset
A member method of class , If no parameter is specified , Then generate a 0-1 Random number between , If you specify a and b, Then it generates a a-b Random number between
def rand(self, a=0, b=1):
""" Generate a a-b Random number between , For example, to generate a 0-100 The random number , Then you can. a=0, b=100"""
return np.random.rand()*(b-a) + a
go back to get_random_data
in , Next is gamut distortion :
""" Gamut distortion """
hue = self.rand(-hue, hue) # New hue scale
sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat) # New saturation
val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val) # New brightness
x = cv2.cvtColor(np.array(image, np.float32)/255, cv2.COLOR_RGB2HSV) # take RGB turn HSV, Get new graphics (numpy Array )
# Adjust tone
x[..., 0] += hue * 360
x[..., 0][x[..., 0] > 360] -= 360 # Adjust the hue to a reasonable range according to the cycle
x[..., 0][x[..., 0] < 0] += 360 # Adjust the tone to a reasonable range
# x[..., 0] Back to a shape by (nw, nh) Of numpy Array ,
# x[..., 0]>360 Back to a shape by (nw, nh) Boolean array of
# x[..., 0][x[..., 0] > 360] and x[..., 0][x[..., 0] < 360] It's a Boolean index
# because x[..., 0] += hue after ,hue It may be greater than 360, It may also be less than 0, Here is to adjust it to 0-360 In this range
# Adjust saturation and brightness
x[..., 1] *= sat
x[..., 2] *= val
# Saturate 、 Adjust the brightness to 0-1 Between
x[:, :, 1:][x[:, :, 1:] > 1] = 1
x[:, :, 1:][x[:, :, 1:] < 0] = 0
# take HSV Turn it back RGB
image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 # take HSV Turn it back RGB
In the program above , take RGB Turn into HSV when , The image data is normalized , This translates into HSV after , Saturation and brightness are normalized , But there is no color gamut , After conversion, the color gamut is still 0~360.
Finally, adjust the target box according to the data enhancement , And return to the enhanced image and target box
""" Adjust the target box """
if len(box) > 0:
# Determine the new according to the image zoom ratio and gray bars box The location of
box[:, [0, 2]] = box[:, [0, 2]]*nw/iw + dx
box[:, [1, 3]] = box[:, [1, 3]]*nh/ih + dy
# It depends on whether the mirror operation is performed , Yes box The abscissa of
if flip:
box[:, [0, 2]] = w - box[:, [2, 0]]
# box Outliers check
box[:, 0:2][box[:, 0:2] < 0] = 0
box[:, 2][box[:, 2] > w] = w
box[:, 3][box[:, 3] > h] = h
# The width and height will be qualified box Sift it out
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w > 1, box_h > 1)]
""" Return image data (numpy Array ) And borders ( The same is numpy Array )"""
return image_data, box
Next, write the above mentioned preprocess_input
stay utils.py
Add the following function to
def preprocess_input(image):
""" Before entering the model , Standardize the picture first ( Press imagenet) Mean and variance of """
image /= 255.0
image -= np.array([0.485, 0.456, 0.406]) # imagenet The average of # TODO The mean and variance here , Whether you need to modify it into your own data set ?
image /= np.array([0.229, 0.224, 0.225]) # imagenet Standard deviation
return image
(3)dataset The test script
well , Now we have finished the dataset class , Next, write a test script .
stay yolox_from_scratch
The next new one is called dataloader_test.py
The file of , The contents are as follows :
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from utils.dataloader import YoloDataset
import cv2
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
if __name__ == '__main__':
""" Set seeds """
""" Get the relevant initialization parameters of the dataset class """
train_annotation_path = '2007_train.txt'
with open(train_annotation_path) as f:
train_lines = f.readlines() # train_lines Will be a list
input_shape = [640, 640]
num_classes = 4
mosaic = False
mixup = False
""" Create a dataset class object """
train_dataset = YoloDataset(train_lines, input_shape, num_classes, is_train=True, mosaic=mosaic, mixup=mixup)
""" Obtain enhanced images and labels through indexing """
img, boxes = train_dataset[2]
img = np.transpose(img, (1, 2, 0)) # Adjust the channel to the last
print("boxes info after data_augmentation (center_x, center_y, w, h):")
# mapping
ax1 = plt.subplot(1, 2, 1)
for box in boxes:
# center_x, center_y, w, h, _ = tuple(map(int, value) for value in box)
center_x, center_y, w, h, _ = box[0], box[1], box[2], box[3], box[4]
ax1.add_patch(patches.Rectangle((center_x-w//2, center_y-h//2), w, h, facecolor="red", alpha=0.3))
# Rectangle The first parameter of is closest to 0 The coordinates of the points ( This is the upper left corner ), The back is wide and high , Then color and transparency
""" Original pictures and labels """
orig_info = train_lines[2]
line = orig_info.split()
img_dir = line[0] # Picture path
boxes = line[1:] # Target box information
boxes = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])
print("original boxes:")
# mapping
ax2 = plt.subplot(1, 2, 2)
img_orig = cv2.imread(img_dir)
img_orig = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)
for box in boxes:
top_left_x, top_left_y, low_right_x, low_right_y, _ = box[0], box[1], box[2], box[3], box[4]
w = (top_left_x + low_right_x)//2
h = (top_left_y + low_right_y)//2
ax2.add_patch(patches.Rectangle((top_left_x, top_left_y), w, h, facecolor="red", alpha=0.3))
The terminal output is :
boxes info after data_augmentation (center_x, center_y, w, h):
[[383. 576. 514. 128. 2.]
[ 1. 622. 2. 36. 2.]]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
original boxes:
[[151 427 581 600 2]
[ 2 493 53 581 2]]
Pay attention here , The second one after data enhancement box, Its width is only 2 Pixel , Usually it can't be so small , Only random clipping is possible , Make the target frame be cut off , Combining graphics , We can see the situation of the target box in the enhanced image and the original image :
stay 2007_train.txt In file , The fifth line only has the image path , No border information , We change the index to 4, Come on debug Let's look at the program , Look at the borderless ,__getitem__
Back to box What is it? , And track __getitem__
in box The type of change .
The procedure is as follows , Because it's hard to show here debug The process , So we can run it directly here , When you knock yourself , best debug
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from utils.dataloader import YoloDataset
import cv2
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
if __name__ == '__main__':
""" Set seeds """
""" Get the relevant initialization parameters of the dataset class """
train_annotation_path = '2007_train.txt'
with open(train_annotation_path) as f:
train_lines = f.readlines() # train_lines Will be a list
input_shape = [640, 640]
num_classes = 4
mosaic = False
mixup = False
""" Create a dataset class object """
train_dataset = YoloDataset(train_lines, input_shape, num_classes, is_train=True, mosaic=mosaic, mixup=mixup)
""" Obtain enhanced images and labels through indexing """
img, boxes = train_dataset[4] # The index for 4, The corresponding picture name is VOCdevkit/VOC2007/JPEGImages/Japan_000005.jpg
img = np.transpose(img, (1, 2, 0)) # Adjust the channel to the last
print("boxes info after data_augmentation (center_x, center_y, w, h):")
print(type(boxes)) # Add another line to print boxes The type of
# mapping
ax1 = plt.subplot(1, 2, 1)
for box in boxes:
# center_x, center_y, w, h, _ = tuple(map(int, value) for value in box)
center_x, center_y, w, h, _ = box[0], box[1], box[2], box[3], box[4]
ax1.add_patch(patches.Rectangle((center_x-w//2, center_y-h//2), w, h, facecolor="red", alpha=0.3))
# Rectangle The first parameter of is closest to 0 The coordinates of the points ( This is the upper left corner ), The back is wide and high , Then color and transparency
""" Original pictures and labels """
orig_info = train_lines[4] # The index for 4, The corresponding picture name is VOCdevkit/VOC2007/JPEGImages/Japan_000005.jpg
line = orig_info.split()
img_dir = line[0] # Picture path
boxes = line[1:] # Target box information
boxes = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])
print("original boxes:")
# mapping
ax2 = plt.subplot(1, 2, 2)
img_orig = cv2.imread(img_dir)
img_orig = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)
for box in boxes:
top_left_x, top_left_y, low_right_x, low_right_y, _ = box[0], box[1], box[2], box[3], box[4]
w = (top_left_x + low_right_x)//2
h = (top_left_y + low_right_y)//2
ax2.add_patch(patches.Rectangle((top_left_x, top_left_y), w, h, facecolor="red", alpha=0.3))
Output here
boxes info after data_augmentation (center_x, center_y, w, h):
<class 'numpy.ndarray'>
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
original boxes:
The image displayed is :
3 collate_fn And test script
stay dataloader.py
Add a function to , So that we can pass through DataLoader Import multiple pictures and their labels at one time ( That is, a batch Of data and targets)
# DataLoader in collate_fn Use
def yolo_dataset_collate(batch):
images = []
bboxes = []
for img, box in batch:
images = np.array(images)
return images, bboxes
The function above , Will the whole batch All the pictures of are integrated into a tensor (numpy Array ), And each picture corresponds to box It turned out to be a two-dimensional numpy Array , But the above function will be a batch All in box All in the same list .
Let's write two test scripts
The first script tests the return value type :
The code is as follows :
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torch.utils.data import DataLoader
from utils.dataloader import YoloDataset, yolo_dataset_collate
if __name__ == '__main__':
""" Set seeds """
""" Get the relevant initialization parameters of the dataset class """
train_annotation_path = '2007_train.txt'
with open(train_annotation_path) as f:
train_lines = f.readlines() # train_lines Will be a list
input_shape = [640, 640]
num_classes = 4
mosaic = False
mixup = False
""" Create a dataset class object """
train_dataset = YoloDataset(train_lines, input_shape, num_classes, is_train=True, mosaic=mosaic, mixup=mixup)
batch_size = 4
num_workers = 4
""" Create importer object """
gen = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
for iteration, batch in enumerate(gen):
images, targets = batch[0], batch[1]
print("images type", type(images))
print("images shape", images.shape)
print("targets type", type(targets))
if iteration == 1:
images type <class 'numpy.ndarray'>
images shape (4, 3, 640, 640)
targets type <class 'list'>
[array([[570. , 541. , 54. , 78. , 0. ],
[505.5, 575. , 177. , 32. , 1. ],
[535. , 574.5, 210. , 107. , 2. ],
[535.5, 474.5, 205. , 87. , 3. ]], dtype=float32), array([[ 59., 515., 118., 30., 1.],
[380., 514., 520., 66., 1.]], dtype=float32), array([[492.5, 373. , 39. , 50. , 2. ],
[237. , 360. , 324. , 98. , 2. ]], dtype=float32), array([[322., 412., 636., 174., 2.]], dtype=float32)]
images type <class 'numpy.ndarray'>
images shape (4, 3, 640, 640)
targets type <class 'list'>
[array([], dtype=float32), array([[348.5, 255.5, 155. , 299. , 2. ],
[244. , 230.5, 50. , 311. , 0. ],
[540. , 517. , 80. , 32. , 1. ]], dtype=float32), array([], dtype=float32), array([[187.5, 411. , 39. , 58. , 2. ]], dtype=float32)]
The second script is for drawing , The code is as follows :
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torch.utils.data import DataLoader
from utils.dataloader import YoloDataset, yolo_dataset_collate
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
if __name__ == '__main__':
""" Set seeds """
""" Get the relevant initialization parameters of the dataset class """
train_annotation_path = '2007_train.txt'
with open(train_annotation_path) as f:
train_lines = f.readlines() # train_lines Will be a list
input_shape = [640, 640]
num_classes = 4
mosaic = False
mixup = False
""" Create a dataset class object """
train_dataset = YoloDataset(train_lines, input_shape, num_classes, is_train=True, mosaic=mosaic, mixup=mixup)
batch_size = 4
num_workers = 4
""" Create importer object """
gen = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
for iteration, batch in enumerate(gen):
images, targets = batch[0], batch[1]
images = np.transpose(images, (0, 2, 3, 1)) # Adjust the channel to the last , Easy to draw
ax = [0, 0, 0, 0]
for index in range(4):
ax[index] = plt.subplot(2, 2, index+1)
for box in targets[index]:
# center_x, center_y, w, h, _ = tuple(map(int, value) for value in box)
center_x, center_y, w, h, _ = box[0], box[1], box[2], box[3], box[4]
ax[index].add_patch(patches.Rectangle((center_x - w // 2, center_y - h // 2), w, h, facecolor="red", alpha=0.3))
# Rectangle The first parameter of is closest to 0 The coordinates of the points ( This is the upper left corner ), The back is wide and high , Then color and transparency
The image displayed :
thus , Dataset classes and supporting collate_fn End of explanation , In the next section, let's build yolox Network structure .
