当前位置:网站首页>DeepLearing4j深度学习之Yolo Tiny实现目标检测
DeepLearing4j深度学习之Yolo Tiny实现目标检测
2022-07-30 05:44:00 【victorkevin】
Yolo Tiny是 Yolo2的简化版,虽然有点过时但对于很多物体检测的应用场景还是很管用,本示例利用DeepLearing4j构建Yolo算法实现目标检测,下图是本示例的网络结构:

// parameters matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13;
// number classes (digits) for the SVHN datasets
int nClasses = 5;
// parameters for the Yolo2OutputLayer
double[][] priorBoxes = { { 1.5, 2.2 }, { 1.4, 1.95 }, { 1.8, 3.3 }, { 2.4, 2.9 }, { 1.7, 2.2 } };
double detectionThreshold = 0.8;
// parameters for the training phase
int batchSize = 10;
int nEpochs = 20;
int seed = 123;
Random rng = new Random(seed);
File imageDir = new File("D:\\train");
InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(null, 0.9, 0.1);
InputSplit trainData = data[0];
InputSplit testData = data[1];
ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
new YoloLabelProvider(imageDir.getAbsolutePath()));
recordReaderTrain.initialize(trainData);
ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
new YoloLabelProvider(imageDir.getAbsolutePath()));
recordReaderTest.initialize(testData);
// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));
ComputationGraph model;
String modelFilename = "D:\\model.zip";
if (new File(modelFilename).exists()) {
this.output("Load model...");
model = ComputationGraph.load(new File(modelFilename), true);
} else {
this.output("Build model...");
model = TinyYOLO.builder().numClasses(nClasses).priorBoxes(priorBoxes).build().init();
System.out.println(model.summary(InputType.convolutional(height, width, nChannels)));
this.output("Train model...");
model.setListeners(new ScoreIterationListener(1));
model.fit(train, nEpochs);
ModelSerializer.writeModel(model, modelFilename, true);
}
// visualize results on the test set
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("WatermelonDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model
.getOutputLayer(0);
List<String> labels = train.getLabels();
test.setCollectMetaData(true);
Scalar[] colormap = { RED, BLUE, GREEN, CYAN, YELLOW, MAGENTA, ORANGE, PINK, LIGHTBLUE, VIOLET };
while (test.hasNext() && frame.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = test.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
File file = new File(metadata.getURI());
Mat mat = imageLoader.asMat(features);
Mat convertedMat = new Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = metadata.getOrigW();
int h = metadata.getOrigH();
Mat image = new Mat();
resize(convertedMat, image, new Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
rectangle(image, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()]);
putText(image, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, colormap[obj.getPredictedClass()]);
}
frame.setTitle(new File(metadata.getURI()).getName() + " - WatermelonDetection");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose();
参数讲解
图片的宽高 :int width = 416; int height = 416;是固定的
图片的通道数彩色 是int nChannels = 3;灰图则是nChannels=1,默认为3
算法的特征提取框的个数,yolo tiny 默认个数为13 不能改变 int gridWidth = 13; int gridHeight = 13;
待检测的类别个数,我这示例是5 个 int nClasses = 5
特征提取先验框的横高比 double[][] priorBoxes = { { 1.5, 2.2 }, { 1.4, 1.95 }, { 1.8, 3.3 }, { 2.4, 2.9 }, { 1.7, 2.2 } };Yolo2中提取先验框需通过Kmeans函数,代码如下
YoloLabelProvider svhnLabelProvider = new YoloLabelProvider(trainDir.getAbsolutePath());
DistanceMeasure distanceMeasure = new YoloIOUDistanceMeasure();
KMeansPlusPlusClusterer<ImageObjectWrapper> clusterer = new KMeansPlusPlusClusterer<>(5, 15, distanceMeasure);
File[] pngFiles = trainDir.listFiles(new FilenameFilter() {
private final static String FILENAME_SUFFIX = ".png";
@Override
public boolean accept(File dir, String name) {
return name.endsWith(FILENAME_SUFFIX);
}
});
List<ImageObjectWrapper> clusterInput = Stream.of(pngFiles).flatMap(png -> svhnLabelProvider.getImageObjectsForPath(png.getName()).stream())
.map(imageObject -> new ImageObjectWrapper(imageObject)).filter(imageObjectWraper -> {
double[] point = imageObjectWraper.getPoint();
if (point[0] <= 32d && point[1] <= 32) {//少于一个单元格的不计
return false;
}
return true;
}).collect(Collectors.toList());
List<CentroidCluster<ImageObjectWrapper>> clusterResults = clusterer.cluster(clusterInput);
for (int i = 0; i < clusterResults.size(); i++) {
CentroidCluster<ImageObjectWrapper> centroidCluster = clusterResults.get(i);
double[] point = centroidCluster.getCenter().getPoint();
System.out.println(
"width:" + point[0] + " height:" + point[1] + " ratio:" + point[1] / point[0] + " size:" + centroidCluster.getPoints().size());
System.out.println("bbox amount:" + point[0] / 32 + "," + point[1] / 32);
ImageObjectWrapper maxWidthImage = centroidCluster.getPoints().stream()
.collect(Collectors.maxBy(Comparator.comparingDouble(ImageObjectWrapper::getWidth))).get();
ImageObjectWrapper maxHeightImage = centroidCluster.getPoints().stream()
.collect(Collectors.maxBy(Comparator.comparingDouble(ImageObjectWrapper::getHeight))).get();
System.out.println(" width:" + maxWidthImage.getWidth() + " height:" + maxHeightImage.getHeight());
System.out.println("-----------");
}
上述主要通过Kmeas方法获取训练样本中有代表性的宽高比,需要重新Kmeas的距离测算的方法,改成IOU的形式具体可参照YOLO v2目标检测详解二 计算iou - 灰信网(软件开发博客聚合)
detectionThreshold 是物体检测的置信度阀值,值越高检测出来的物体个数越小,准确率越高
我的训练集是通过LabelImg制作且格式为Yolo,训练样本如下,注意图片的大小要与参数416x416的大小一致



标签类别文件为classes.txt ,包括五个类别xi ,cake ,dan,ss,bi
标签解释提供类YoloLabelProvider代码如下,主要作用是把LabelImg制作出来的txt的数据转化成算法可以识别的
public class YoloLabelProvider implements ImageObjectLabelProvider {
private String baseDirectory;
private List<String> labels;
public YoloLabelProvider(String baseDirectory) {
this.baseDirectory = baseDirectory;
Assert.notNull(baseDirectory, "标签目录不能为空");
if (!new File(baseDirectory).exists()) {
throw new IllegalStateException(
"baseDirectory directory does not exist. txt files should be " + "present at Expected location: " + baseDirectory);
}
String classTxtPath = FilenameUtils.concat(this.baseDirectory, "classes.txt");
File classFile = new File(classTxtPath);
Assert.isTrue(classFile.exists(), "classTxtPath does not exist");
try {
labels = Files.readAllLines(classFile.toPath());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public List<ImageObject> getImageObjectsForPath(String path) {
int idx = path.lastIndexOf('/');
idx = Math.max(idx, path.lastIndexOf('\\'));
String filename = path.substring(idx + 1, path.length() - 4); //-4: ".png"
String txtPath = FilenameUtils.concat(this.baseDirectory, filename + ".txt");
String pngPath = FilenameUtils.concat(this.baseDirectory, filename + ".png");
File txtFile = new File(txtPath);
if (!txtFile.exists()) {
throw new IllegalStateException("Could not find TXT file for image " + path + "; expected at " + txtPath);
}
List<String> readAllLines = null;
BufferedImage image = null;
try {
image = ImageIO.read(Paths.get(pngPath).toFile());
readAllLines = Files.readAllLines(txtFile.toPath());
} catch (Exception e) {
throw new RuntimeException(e);
}
int width = image.getWidth();
int height = image.getHeight();
List<ImageObject> imageObjects = readAllLines.stream().map(line -> {
String[] data = StringUtils.split(line, " ");
int centerX = Math.round(Float.valueOf(data[1]) * width);
int centerY = Math.round(Float.valueOf(data[2]) * height);
int bboxWidth = Math.round(Float.valueOf(data[3]) * width);
int bboxHeight = Math.round(Float.valueOf(data[4]) * height);
int xmin = centerX - (bboxWidth / 2);
int ymin = centerY - (bboxHeight / 2);
int xmax = centerX + (bboxWidth / 2);
int ymax = centerY + (bboxHeight / 2);
ImageObject imageObject = new ImageObject(xmin, ymin, xmax, ymax, this.labels.get(Integer.valueOf(data[0])));
return imageObject;
}).collect(Collectors.toList());
return imageObjects;
}
@Override
public List<ImageObject> getImageObjectsForPath(URI uri) {
return getImageObjectsForPath(uri.toString());
}
}先是训练大概用4个小时训练300多张图片,结果如下
边栏推荐
- Flink-stream/batch/OLAP integrated to get Flink engine
- Jackson 序列化失败问题-oracle数据返回类型找不到对应的Serializer
- C# WPF中监听窗口大小变化事件
- The Request request body is repackaged to solve the problem that the request body can only be obtained once
- MySQL index optimization and failure scenarios
- 《MySQL高级篇》四、索引的存储结构
- 【MySQL功法】第5话 · SQL单表查询
- 二十一、Kotlin进阶学习:实现简单的网络访问封装
- 常用损失函数(二):Dice Loss
- Monstache执行Monstache - f配置。toml出错不存在处理器类型和名称(附件)(= parse_exc类型
猜你喜欢

树莓派OpenCV+OpenCV-contrib
Go简单实现协程池
Remember a traffic analysis practice - Anheng Technology (August ctf)

八、Kotlin基础学习:1、数据类;2、单例;3、伴生对象;4、密封类;

为什么会出现梯度爆炸和梯度消失现象?怎么缓解这种现象的发生?
Bypassing the file upload vulnerability

Use kotlin to extend plugins/dependencies to simplify code (after the latest version 4.0, this plugin has been deprecated, so please choose to learn, mainly to understand.)

SQL Server安装教程

SQL Server 数据库之生成与执行 SQL 脚本

Mycat2.0 build tutorial
随机推荐
单例模式:Swift 实现
The number of warehouse 】 data quality
MySQL 特殊语句及优化器
Use kotlin to extend plugins/dependencies to simplify code (after the latest version 4.0, this plugin has been deprecated, so please choose to learn, mainly to understand.)
Flink-流/批/OLAP一体得到Flink引擎
使用PyQt5为YoloV5添加界面(一)
mysql不是内部或外部命令,也不是可运行的程序或批处理文件解决
【SQL】first_value 应用场景 - 首单 or 复购
GraphQL (1) Basic introduction and application examples
shardingsphere 分库分表及配置示例
树莓派OpenCV+OpenCV-contrib
C语言学习经验
SQL Server安装教程
Online sql editing query tool sql-editor
sql concat()函数
ClickHouse查询语句详解
Oracle数据库SQL优化详解
The most powerful and most commonly used SQL statements in history
FastAPI Quick Start
Jdbc & Mysql timeout analysis