当前位置:网站首页>DeepLearing4j's deep learning Yolo Tiny realizes target detection
DeepLearing4j's deep learning Yolo Tiny realizes target detection
2022-07-30 07:08:00 【victorkevin】
Yolo Tiny是 Yolo2的简化版,Although a bit outdated, it is still useful for many object detection application scenarios,本示例利用DeepLearing4j构建Yolo算法实现目标检测,The following figure is the network structure of this example:

// parameters matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13;
// number classes (digits) for the SVHN datasets
int nClasses = 5;
// parameters for the Yolo2OutputLayer
double[][] priorBoxes = { { 1.5, 2.2 }, { 1.4, 1.95 }, { 1.8, 3.3 }, { 2.4, 2.9 }, { 1.7, 2.2 } };
double detectionThreshold = 0.8;
// parameters for the training phase
int batchSize = 10;
int nEpochs = 20;
int seed = 123;
Random rng = new Random(seed);
File imageDir = new File("D:\\train");
InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(null, 0.9, 0.1);
InputSplit trainData = data[0];
InputSplit testData = data[1];
ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
new YoloLabelProvider(imageDir.getAbsolutePath()));
recordReaderTrain.initialize(trainData);
ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
new YoloLabelProvider(imageDir.getAbsolutePath()));
recordReaderTest.initialize(testData);
// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));
ComputationGraph model;
String modelFilename = "D:\\model.zip";
if (new File(modelFilename).exists()) {
this.output("Load model...");
model = ComputationGraph.load(new File(modelFilename), true);
} else {
this.output("Build model...");
model = TinyYOLO.builder().numClasses(nClasses).priorBoxes(priorBoxes).build().init();
System.out.println(model.summary(InputType.convolutional(height, width, nChannels)));
this.output("Train model...");
model.setListeners(new ScoreIterationListener(1));
model.fit(train, nEpochs);
ModelSerializer.writeModel(model, modelFilename, true);
}
// visualize results on the test set
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("WatermelonDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model
.getOutputLayer(0);
List<String> labels = train.getLabels();
test.setCollectMetaData(true);
Scalar[] colormap = { RED, BLUE, GREEN, CYAN, YELLOW, MAGENTA, ORANGE, PINK, LIGHTBLUE, VIOLET };
while (test.hasNext() && frame.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = test.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
File file = new File(metadata.getURI());
Mat mat = imageLoader.asMat(features);
Mat convertedMat = new Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = metadata.getOrigW();
int h = metadata.getOrigH();
Mat image = new Mat();
resize(convertedMat, image, new Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
rectangle(image, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()]);
putText(image, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, colormap[obj.getPredictedClass()]);
}
frame.setTitle(new File(metadata.getURI()).getName() + " - WatermelonDetection");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose();
参数讲解
图片的宽高 :int width = 416; int height = 416;是固定的
The number of channels in the image to color 是int nChannels = 3;The gray image isnChannels=1,默认为3
The number of feature extraction boxes of the algorithm,yolo tiny 默认个数为13 不能改变 int gridWidth = 13; int gridHeight = 13;
The number of categories to be detected,My example is5 个 int nClasses = 5
The aspect ratio of the feature extraction prior box double[][] priorBoxes = { { 1.5, 2.2 }, { 1.4, 1.95 }, { 1.8, 3.3 }, { 2.4, 2.9 }, { 1.7, 2.2 } };Yolo2The a priori box is extracted from the need to passKmeans函数,代码如下
YoloLabelProvider svhnLabelProvider = new YoloLabelProvider(trainDir.getAbsolutePath());
DistanceMeasure distanceMeasure = new YoloIOUDistanceMeasure();
KMeansPlusPlusClusterer<ImageObjectWrapper> clusterer = new KMeansPlusPlusClusterer<>(5, 15, distanceMeasure);
File[] pngFiles = trainDir.listFiles(new FilenameFilter() {
private final static String FILENAME_SUFFIX = ".png";
@Override
public boolean accept(File dir, String name) {
return name.endsWith(FILENAME_SUFFIX);
}
});
List<ImageObjectWrapper> clusterInput = Stream.of(pngFiles).flatMap(png -> svhnLabelProvider.getImageObjectsForPath(png.getName()).stream())
.map(imageObject -> new ImageObjectWrapper(imageObject)).filter(imageObjectWraper -> {
double[] point = imageObjectWraper.getPoint();
if (point[0] <= 32d && point[1] <= 32) {//Less than one cell does not count
return false;
}
return true;
}).collect(Collectors.toList());
List<CentroidCluster<ImageObjectWrapper>> clusterResults = clusterer.cluster(clusterInput);
for (int i = 0; i < clusterResults.size(); i++) {
CentroidCluster<ImageObjectWrapper> centroidCluster = clusterResults.get(i);
double[] point = centroidCluster.getCenter().getPoint();
System.out.println(
"width:" + point[0] + " height:" + point[1] + " ratio:" + point[1] / point[0] + " size:" + centroidCluster.getPoints().size());
System.out.println("bbox amount:" + point[0] / 32 + "," + point[1] / 32);
ImageObjectWrapper maxWidthImage = centroidCluster.getPoints().stream()
.collect(Collectors.maxBy(Comparator.comparingDouble(ImageObjectWrapper::getWidth))).get();
ImageObjectWrapper maxHeightImage = centroidCluster.getPoints().stream()
.collect(Collectors.maxBy(Comparator.comparingDouble(ImageObjectWrapper::getHeight))).get();
System.out.println(" width:" + maxWidthImage.getWidth() + " height:" + maxHeightImage.getHeight());
System.out.println("-----------");
}
The above is mainly passedKmeasThe method obtains a representative aspect ratio in the training sample,需要重新Kmeasmethod of distance measurement,改成IOUFor details, please refer toYOLO v2Target detection detailed 2 计算iou - 灰信网(软件开发博客聚合)
detectionThreshold is the confidence threshold for object detection,The higher the value, the smaller the number of detected objects,准确率越高
My training set is passLabelImgproduced and formatted as Yolo,The training samples are as follows,Note that the size of the image depends on the parameters416x416的大小一致



The tag class file is classes.txt ,Includes five categoriesxi ,cake ,dan,ss,bi
The label explains the providing classYoloLabelProvider代码如下,主要作用是把LabelImg制作出来的txtThe data are transformed into algorithms that can be identified
public class YoloLabelProvider implements ImageObjectLabelProvider {
private String baseDirectory;
private List<String> labels;
public YoloLabelProvider(String baseDirectory) {
this.baseDirectory = baseDirectory;
Assert.notNull(baseDirectory, "The label directory cannot be empty");
if (!new File(baseDirectory).exists()) {
throw new IllegalStateException(
"baseDirectory directory does not exist. txt files should be " + "present at Expected location: " + baseDirectory);
}
String classTxtPath = FilenameUtils.concat(this.baseDirectory, "classes.txt");
File classFile = new File(classTxtPath);
Assert.isTrue(classFile.exists(), "classTxtPath does not exist");
try {
labels = Files.readAllLines(classFile.toPath());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public List<ImageObject> getImageObjectsForPath(String path) {
int idx = path.lastIndexOf('/');
idx = Math.max(idx, path.lastIndexOf('\\'));
String filename = path.substring(idx + 1, path.length() - 4); //-4: ".png"
String txtPath = FilenameUtils.concat(this.baseDirectory, filename + ".txt");
String pngPath = FilenameUtils.concat(this.baseDirectory, filename + ".png");
File txtFile = new File(txtPath);
if (!txtFile.exists()) {
throw new IllegalStateException("Could not find TXT file for image " + path + "; expected at " + txtPath);
}
List<String> readAllLines = null;
BufferedImage image = null;
try {
image = ImageIO.read(Paths.get(pngPath).toFile());
readAllLines = Files.readAllLines(txtFile.toPath());
} catch (Exception e) {
throw new RuntimeException(e);
}
int width = image.getWidth();
int height = image.getHeight();
List<ImageObject> imageObjects = readAllLines.stream().map(line -> {
String[] data = StringUtils.split(line, " ");
int centerX = Math.round(Float.valueOf(data[1]) * width);
int centerY = Math.round(Float.valueOf(data[2]) * height);
int bboxWidth = Math.round(Float.valueOf(data[3]) * width);
int bboxHeight = Math.round(Float.valueOf(data[4]) * height);
int xmin = centerX - (bboxWidth / 2);
int ymin = centerY - (bboxHeight / 2);
int xmax = centerX + (bboxWidth / 2);
int ymax = centerY + (bboxHeight / 2);
ImageObject imageObject = new ImageObject(xmin, ymin, xmax, ymax, this.labels.get(Integer.valueOf(data[0])));
return imageObject;
}).collect(Collectors.toList());
return imageObjects;
}
@Override
public List<ImageObject> getImageObjectsForPath(URI uri) {
return getImageObjectsForPath(uri.toString());
}
}First is training4个小时训练300多张图片,结果如下
边栏推荐
- Biome-BGC 生态系统模型与应用
- 原型模式(Prototype):Swift 实现
- 新导则下 防洪评价报告编制方法及洪水建模(HEC-RAS)
- Invalid bound statement (not found)出现的原因和解决方法
- GraphQL (1) Basic introduction and application examples
- MySQL window function
- Receive emails from gmail with pop3
- Jdbc & Mysql timeout analysis
- 边境的悍匪—机器学习实战:第九章 无监督学习任务
- Arthas 命令解析(watch/tt/sc)
猜你喜欢

Student achievement management system (C language version)

AAcell五号文档室——跨平台文件传输的小室一间一间的

TDengine cluster construction

Trust anchor for certification path not found.异常解决方法。

MySQL 5.7 installation tutorial (all steps, nanny tutorials)

Nodejs PM2 monitoring and alarm email (2)

mysql delete duplicate data in the table, (retain only one row)

MySQL achievement method 】 【 5 words, single table SQL queries

抽象工厂模式(Swift 实现)

使用kotlin扩展插件/依赖项简化代码(在最新版本4.0以后,此插件已被弃用,故请选择性学习,以了解为主。)
随机推荐
Jdbc & Mysql timeout analysis
十、Kotlin基础学习:1、延迟加载;2、异常处理;3、使用 throw 主动抛出异常;4、自定义异常;
十五、Kotlin进阶学习:一、子类与子类型;二、协变;三、逆变;
sql中 exists的用法
Redis publish/subscribe
十七、Kotlin进阶学习:1、守护线程;2、线程和协程之间的效率对比;3、取消协程;
CLUE模型构建方法、模型验证及土地利用变化情景预测
Using PyQt5 to add an interface to YoloV5 (1)
目标检测中的知识蒸馏方法
Rsync实现Win系统间的文件夹或数据同步
六、Kotlin基础学习:函数
【十年网络安全工程师整理】—100渗透测试工具使用方法介绍
Arthas 命令解析(jvm/thread/stack/heapdump)
Trust anchor for certification path not found. Exception solution.
学生成绩管理系统(C语言版)
MySQL index optimization and failure scenarios
MySQL window function
Mysql client common exception analysis
The most powerful and most commonly used SQL statements in history
【MySQL功法】第5话 · SQL单表查询