当前位置:网站首页>【YOLOv3 SPP 数据集准备】YOLOv3 SPP数据集准备代码理解
【YOLOv3 SPP 数据集准备】YOLOv3 SPP数据集准备代码理解
2022-08-03 05:27:00 【寻找永不遗憾】
本文主要参考 霹雳吧啦Wz 的代码,如果不想看文章,欢迎去看他的视频,视频链接在感谢链接中!
1 VOC数据集准备
1.1 数据集介绍
用Pascal Voc2007+2012做训练,在Pascal Voc2007上做测试,数据情况如下:
训练数据:16551张图像,共40058个目标
测试(验证)数据:4952张图像,共12032个目标
VOC数据集格式中的Annotations是.xml文件,需要使用脚本将voc数据格式(.xml)转成yolo数据格式(.txt),也就是大家常见的生成train.txt和val.txt。
提供Pascal Voc2007+2012数据集链接:
链接:https://pan.baidu.com/s/1dzoU8_kCqCXHT7smYlJ3oQ
提取码:uzei
1.2 生成需要的文件1
最终的目标如下:
├── my_yolo_dataset 自定义数据集根目录
│ ├── train 训练集目录
│ │ ├── images 训练集图像目录
│ │ └── labels 训练集标签目录
│ └── val 验证集目录
│ ├── images 验证集图像目录
│ └── labels 验证集标签目录
│
├── data
│ ├── my_data_label.names 数据集类别标签名称
其中,labels文件夹下的一个txt文件表示一张图片的目标框信息,举例内容如下:
[class_index, xcenter, ycenter, w, h]:第一个参数是类别id,后面四个参数是目标的相对位置。
代码如下:
""" 本脚本有两个功能: 1.将voc数据集标注信息(.xml)转为yolo标注格式(.txt),并将图像文件复制到相应文件夹 2.根据json标签文件,生成对应names标签(my_data_label.names) """
import os
from tqdm import tqdm # 用于进度条显示
from lxml import etree
import json
import shutil
# 原voc数据集根目录以及版本
voc_root = "D:\DeepLearning\dataset\VOCdevkit"
voc_version = "VOC2007_12" # 和VOCdevkit下的文件夹名称一致
# 原VOC训练集以及验证集对应txt文件
train_txt = "train.txt"
val_txt = "val.txt"
# 转换后的文件保存目录
save_file_root = "./my_yolo_dataset"
# label标签对应json文件,字典
label_json_path = './data/pascal_voc_classes.json'
# 拼接出voc的images目录,xml目录,txt目录
voc_images_path = os.path.join(voc_root, voc_version, "JPEGImages")
voc_xml_path = os.path.join(voc_root, voc_version, "Annotations")
train_txt_path = os.path.join(voc_root, voc_version, "ImageSets", "Main", train_txt)
val_txt_path = os.path.join(voc_root, voc_version, "ImageSets", "Main", val_txt)
# 检查文件/文件夹都是否存在
assert os.path.exists(voc_images_path), "VOC images path not exist..."
assert os.path.exists(voc_xml_path), "VOC xml path not exist..."
assert os.path.exists(train_txt_path), "VOC train txt file not exist..."
assert os.path.exists(val_txt_path), "VOC val txt file not exist..."
assert os.path.exists(label_json_path), "label_json_path does not exist..."
if os.path.exists(save_file_root) is False:
os.makedirs(save_file_root)
def parse_xml_to_dict(xml):
""" 将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict Args: xml: xml tree obtained by parsing XML file contents using lxml.etree Returns: Python dictionary holding XML contents. """
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {
xml.tag: xml.text}
result = {
}
for child in xml:
child_result = parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {
xml.tag: result}
def translate_info(file_names: list, save_root: str, class_dict: dict, train_val='train'):
""" 将对应xml文件信息转为yolo中使用的txt文件信息 :param file_names: :param save_root: :param class_dict: :param train_val: :return: """
save_txt_path = os.path.join(save_root, train_val, "labels")
if os.path.exists(save_txt_path) is False:
os.makedirs(save_txt_path)
save_images_path = os.path.join(save_root, train_val, "images")
if os.path.exists(save_images_path) is False:
os.makedirs(save_images_path)
# 进度条用法:第一个参数是可迭代对象;desc参数是进度条前的说明信息
for file in tqdm(file_names, desc="translate {} file...".format(train_val)):
# 检查下图像文件是否存在
img_path = os.path.join(voc_images_path, file + ".jpg")
assert os.path.exists(img_path), "file:{} not exist...".format(img_path)
# 检查xml文件是否存在
xml_path = os.path.join(voc_xml_path, file + ".xml")
assert os.path.exists(xml_path), "file:{} not exist...".format(xml_path)
# read xml
with open(xml_path) as fid:
xml_str = fid.read() # xml_str里放着xml文件中的所有字符信息
# 此时xml里存放着乱七八糟的东西
xml = etree.fromstring(xml_str)
# data里存放字典,包括xml文件中的信息
data = parse_xml_to_dict(xml)["annotation"]
img_height = int(data["size"]["height"])
img_width = int(data["size"]["width"])
# write object info into txt
assert "object" in data.keys(), "file: '{}' lack of object key.".format(xml_path)
if len(data["object"]) == 0:
# 如果xml文件中没有目标就直接忽略该样本
print("Warning: in '{}' xml, there are no objects.".format(xml_path))
continue
with open(os.path.join(save_txt_path, file + ".txt"), "w") as f:
for index, obj in enumerate(data["object"]): # data["object"]是列表里装多个字典元素
# 获取每个object的box信息
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
class_name = obj["name"]
class_index = class_dict[class_name] - 1 # 目标id从0开始
# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
continue
# 将box信息转换到yolo格式
xcenter = xmin + (xmax - xmin) / 2
ycenter = ymin + (ymax - ymin) / 2
w = xmax - xmin
h = ymax - ymin
# 绝对坐标转相对坐标,保存6位小数
xcenter = round(xcenter / img_width, 6)
ycenter = round(ycenter / img_height, 6)
w = round(w / img_width, 6)
h = round(h / img_height, 6)
info = [str(i) for i in [class_index, xcenter, ycenter, w, h]]
if index == 0:
f.write(" ".join(info))
else:
f.write("\n" + " ".join(info))
# copy image into save_images_path
# os.sep:文件的路径分隔符,保证在linux和windows上都能用
path_copy_to = os.path.join(save_images_path, img_path.split(os.sep)[-1])
if os.path.exists(path_copy_to) is False:
shutil.copyfile(img_path, path_copy_to) # 图片copy过去
# 生成my_data_label.names文件,里面存放所有数据集中的所有类别
def create_class_names(class_dict: dict):
keys = class_dict.keys()
with open("./data/my_data_label.names", "w") as w:
for index, k in enumerate(keys):
if index + 1 == len(keys):
w.write(k)
else:
w.write(k + "\n")
def main():
# read class_indict
json_file = open(label_json_path, 'r')
class_dict = json.load(json_file)
# 读取train.txt中的所有行信息,删除空行
with open(train_txt_path, "r") as r:
train_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
# voc信息转yolo,并将图像文件复制到相应文件夹
translate_info(train_file_names, save_file_root, class_dict, "train")
# 读取val.txt中的所有行信息,删除空行
with open(val_txt_path, "r") as r:
val_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
# voc信息转yolo,并将图像文件复制到相应文件夹
translate_info(val_file_names, save_file_root, class_dict, "val")
# 创建my_data_label.names文件
create_class_names(class_dict)
if __name__ == "__main__":
main()
1.2 生成需要的文件2
使用calculate_dataset.py
脚本生成my_train_data.txt
文件、my_val_data.txt
文件以及my_data.data
文件,并生成新的my_yolov3.cfg
文件。
my_train_data.txt
和my_val_data.txt
内容类似于:里面连空行都要正确,有点搞,目前不太喜欢这种构建网络的方式。
./my_yolo_dataset/train/images\000005.jpg
./my_yolo_dataset/train/images\000007.jpg
./my_yolo_dataset/train/images\000009.jpg
...
my_data.data
内容如下:
classes=20
train=data/my_train_data.txt
valid=data/my_val_data.txt
names=data/my_data_label.names
my_yolov3.cfg
用于构建网络结构,部分内容如下:
[net]
# Testing
# batch=1
# subdivisions=1
# Training
batch=64
subdivisions=16
...
[yolo]
mask = 6,7,8
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
...
代码如下:
""" 该脚本有3个功能: 1.统计训练集和验证集的数据并生成相应.txt文件 2.创建data.data文件,记录classes个数, train以及val数据集文件(.txt)路径和label.names文件路径 3.根据yolov3-spp.cfg创建my_yolov3.cfg文件修改其中的predictor filters以及yolo classes参数(这两个参数是根据类别数改变的) """
import os
train_annotation_dir = "./my_yolo_dataset/train/labels"
val_annotation_dir = "./my_yolo_dataset/val/labels"
classes_label = "./data/my_data_label.names"
cfg_path = "./cfg/yolov3-spp.cfg"
assert os.path.exists(train_annotation_dir), "train_annotation_dir not exist!"
assert os.path.exists(val_annotation_dir), "val_annotation_dir not exist!"
assert os.path.exists(classes_label), "classes_label not exist!"
assert os.path.exists(cfg_path), "cfg_path not exist!"
def calculate_data_txt(txt_path, dataset_dir):
# create my_data.txt file that record image list
with open(txt_path, "w") as w:
for file_name in os.listdir(dataset_dir):
if file_name == "classes.txt":
continue
img_path = os.path.join(dataset_dir.replace("labels", "images"),
file_name.split(".")[0]) + ".jpg"
line = img_path + "\n"
assert os.path.exists(img_path), "file:{} not exist!".format(img_path)
w.write(line)
def create_data_data(create_data_path, label_path, train_path, val_path, classes_info):
# create my_data.data file that record classes, train, valid and names info.
# shutil.copyfile(label_path, "./data/my_data_label.names")
with open(create_data_path, "w") as w:
w.write("classes={}".format(len(classes_info)) + "\n") # 记录类别个数
w.write("train={}".format(train_path) + "\n") # 记录训练集对应txt文件路径
w.write("valid={}".format(val_path) + "\n") # 记录验证集对应txt文件路径
w.write("names=data/my_data_label.names" + "\n") # 记录label.names文件路径
def change_and_create_cfg_file(classes_info, save_cfg_path="./cfg/my_yolov3.cfg"):
# create my_yolov3.cfg file changed predictor filters and yolo classes param.
# this operation only deal with yolov3-spp.cfg
filters_lines = [636, 722, 809]
classes_lines = [643, 729, 816]
# cfg_lines:列表里放着每一行的str(内容)
cfg_lines = open(cfg_path, "r").readlines()
for i in filters_lines:
assert "filters" in cfg_lines[i-1], "filters param is not in line:{}".format(i-1)
output_num = (5 + len(classes_info)) * 3
cfg_lines[i-1] = "filters={}\n".format(output_num)
for i in classes_lines:
assert "classes" in cfg_lines[i-1], "classes param is not in line:{}".format(i-1)
cfg_lines[i-1] = "classes={}\n".format(len(classes_info))
with open(save_cfg_path, "w") as w:
w.writelines(cfg_lines)
def main():
# 统计训练集和验证集的数据并生成相应txt文件
train_txt_path = "data/my_train_data.txt"
val_txt_path = "data/my_val_data.txt"
# 把训练与测试的图片路径写到对应的txt文件里
calculate_data_txt(train_txt_path, train_annotation_dir)
calculate_data_txt(val_txt_path, val_annotation_dir)
classes_info = [line.strip() for line in open(classes_label, "r").readlines() if len(line.strip()) > 0]
# 创建data.data文件,记录classes个数, train以及val数据集文件(.txt)路径和label.names文件路径
create_data_data("./data/my_data.data", classes_label, train_txt_path, val_txt_path, classes_info)
# 根据yolov3-spp.cfg创建my_yolov3.cfg文件修改其中的predictor filters以及yolo classes参数(这两个参数是根据类别数改变的)
change_and_create_cfg_file(classes_info)
if __name__ == '__main__':
main()
2 感谢链接
https://www.bilibili.com/video/BV1t54y1C7ra/?spm_id_from=trigger_reload
边栏推荐
猜你喜欢
随机推荐
C#操作FTP上传文件后检查上传正确性
【随笔】平常心
域名怎么管理,域名管理注意事项有哪些?
C#通过WebBrowser对网页截图
NIO知识汇总 收藏这一篇就够了!!!
Oracle 数据库集群常用巡检命令
在OracleLinux8.6的Zabbix6.0中监控Oracle11gR2
快速理解JVM+GC
使用Powershell批量导入Task
C#使用Oracle.ManagedDataAccess连接C#数据库
VI和VIM编辑指令
数据库OracleRAC节点宕机处理流程
JS--正则表达式
2. What is the difference between Exception and Error?
DNS常见资源记录类型详解
BurpSuite 进阶玩法
二分查找4 - 搜索旋转排序数组
3D建模:做什么副业在家就能月入1W?
C#切换输入法
域名管理常见问题:IP、域名和DNS之间的区别和关系