当前位置:网站首页>Adding confidence threshold for demo visualization in detectron2
Adding confidence threshold for demo visualization in detectron2
2022-06-26 09:10:00 【G fruit】
Let me write it out front
I've been using facebook Developed detectron2 Experiment with deep learning library , When running the visualization of the detection box , I found a lot of boxes in a mess , And many test boxes with low confidence are also drawn , It looks awful , So I thought about setting a confidence threshold (score_threshold) To filter the detection box , Make visualization more intuitive , More flexible .
Initial demo
After increasing the confidence threshold demo

demo.py Complete code
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from swint.config import add_swint_config
import random
import cv2
from detectron2.config import get_cfg
import os
import pig_dataset
pig_test_metadata = MetadataCatalog.get("pig_coco_test")
dataset_dicts = DatasetCatalog.get("pig_coco_test")
cfg = get_cfg()
add_swint_config(cfg)
cfg.merge_from_file("./configs/SwinT/retinanet_swint_T_FPN_3x.yaml")
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_0019999.pth")
predictor = DefaultPredictor(cfg)
for d in random.sample(dataset_dicts,3):
im = cv2.imread(d["file_name"])
output = predictor(im)
v = Visualizer(im[:,:,::-1],metadata=pig_test_metadata,
scale=0.5,instance_mode=ColorMode.IMAGE_BW)
''' Frame function ( Add a confidence threshold parameter )!!!'''
out = v.draw_instance_predictions(output["instances"].to("cpu"),0.5)# threshold =0.5
cv2.namedWindow("pig",0)
cv2.resizeWindow("pig",600,400)
cv2.imshow("pig", out.get_image()[:, :, ::-1])
#cv2.imwrite("demo-%s"%os.path.basename(d["file_name"]), out.get_image()[:, :, ::-1])
cv2.waitKey(3000)
cv2.destroyAllWindows()
Set the confidence threshold (score_thredshold) Part of the code
if score_threshold != None:
top_id = np.where(scores.numpy()>score_threshold)[0].tolist()
scores = torch.tensor(scores.numpy()[top_id])
boxes.tensor = torch.tensor(boxes.tensor.numpy()[top_id])
classes = [classes[ii] for ii in top_id]
labels = [labels[ii] for ii in top_id]
draw_instance_predictions Function complete code ( After modification )
Just copy and paste the above code to if predictions.has("pred_masks"): Before that
The position of this function is in detectron2/utils/visualizer.py Inside (pycharm direct Ctrl+B You can directly access )
def draw_instance_predictions(self, predictions, score_threshold=None):
""" Draw instance-level prediction results on an image. Args: score_threshold: Confidence threshold ( New parameters ) predictions (Instances): the output of an instance detection/segmentation model. Following fields will be used to draw: "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). Returns: output (VisImage): image object with visualizations. """
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
scores = predictions.scores if predictions.has("scores") else None
classes = predictions.pred_classes if predictions.has("pred_classes") else None
labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
''' Some new codes '''
if score_threshold != None:
top_id = np.where(scores.numpy()>score_threshold)[0].tolist()
scores = torch.tensor(scores.numpy()[top_id])
boxes.tensor = torch.tensor(boxes.tensor.numpy()[top_id])
classes = [classes[ii] for ii in top_id]
labels = [labels[ii] for ii in top_id]
if predictions.has("pred_masks"):
masks = np.asarray(predictions.pred_masks)
masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
else:
masks = None
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
colors = [
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
]
alpha = 0.8
else:
colors = None
alpha = 0.5
if self._instance_mode == ColorMode.IMAGE_BW:
self.output.img = self._create_grayscale_image(
(predictions.pred_masks.any(dim=0) > 0).numpy()
if predictions.has("pred_masks")
else None
)
alpha = 0.3
self.overlay_instances(
masks=masks,
boxes=boxes,
labels=labels,
keypoints=keypoints,
assigned_colors=colors,
alpha=alpha,
)
return self.output
边栏推荐
- Load other related resources or configurations (promise application of the applet) before loading the homepage of the applet
- Phpcms V9 adds the reading amount field in the background, and the reading amount can be modified at will
- 外部排序和大小堆相关知识
- 【开源】使用PhenoCV-WeedCam进行更智能、更精确的杂草管理
- 编程训练7-日期转换问题
- How to use the least money to quickly open the Taobao traffic portal?
- Self taught neural network series - 1 Basic programming knowledge
- PD快充磁吸移动电源方案
- PD fast magnetization mobile power supply scheme
- 《单片机原理及应用》——概述
猜你喜欢

20220623 Adobe Illustrator入门

SRv6----IS-IS扩展

Vipshop work practice: Jason's deserialization application

Data warehouse (1) what is data warehouse and what are the characteristics of data warehouse

浅谈一下Type-C接口发展历程
![[300+ continuous sharing of selected interview questions from large manufacturers] column on interview questions of big data operation and maintenance (I)](/img/cf/44b3983dd5d5f7b92d90d918215908.png)
[300+ continuous sharing of selected interview questions from large manufacturers] column on interview questions of big data operation and maintenance (I)

行为树的基本概念及进阶

phpcms v9后台文章列表增加一键推送到百度功能

XSS cross site scripting attack

Yolov5进阶之二安装labelImg
随机推荐
Mongodb分片环境搭建和验证(redis期末大作业)
phpcms v9手机访问电脑站一对一跳转对应手机站页面插件
How to convert wechat applet into Baidu applet
Notes on setting qccheckbox style
Nacos注册表结构和海量服务注册与并发读写原理 源码分析
Self taught neural network series - 4 learning of neural network
关于小程序tabbar不支持传参的处理办法
Phpcms V9 background article list adds one click push to Baidu function
Yolov5 advanced level 2 installation of labelimg
【开源】使用PhenoCV-WeedCam进行更智能、更精确的杂草管理
Implementation code of interceptor and filter
Machine learning (Part 2)
微信小程序如何转换成百度小程序
Course paper: Copula modeling code of portfolio risk VaR
Self taught neural network series - 3. First knowledge of neural network
Phpcms applet interface new universal interface get_ diy. php
How to use the least money to quickly open the Taobao traffic portal?
基于SSM的电脑商城
phpcms v9商城模块(修复自带支付宝接口bug)
[300+ continuous sharing of selected interview questions from large manufacturers] column on interview questions of big data operation and maintenance (I)