当前位置:网站首页>Adding confidence threshold for demo visualization in detectron2

Adding confidence threshold for demo visualization in detectron2

2022-06-26 09:10:00 G fruit

Let me write it out front

I've been using facebook Developed detectron2 Experiment with deep learning library , When running the visualization of the detection box , I found a lot of boxes in a mess , And many test boxes with low confidence are also drawn , It looks awful , So I thought about setting a confidence threshold (score_threshold) To filter the detection box , Make visualization more intuitive , More flexible .

Initial demo
 Insert picture description here
After increasing the confidence threshold demo

 Insert picture description here

demo.py Complete code

from detectron2.utils.visualizer import  ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from swint.config import add_swint_config
import random
import cv2
from detectron2.config import get_cfg
import os
import pig_dataset

pig_test_metadata = MetadataCatalog.get("pig_coco_test")
dataset_dicts = DatasetCatalog.get("pig_coco_test")

cfg = get_cfg()
add_swint_config(cfg)
cfg.merge_from_file("./configs/SwinT/retinanet_swint_T_FPN_3x.yaml")
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_0019999.pth")
predictor = DefaultPredictor(cfg)

for d in random.sample(dataset_dicts,3):
    im = cv2.imread(d["file_name"])
    output = predictor(im)
    v = Visualizer(im[:,:,::-1],metadata=pig_test_metadata,
                   scale=0.5,instance_mode=ColorMode.IMAGE_BW)
    ''' Frame function ( Add a confidence threshold parameter )!!!'''
    out = v.draw_instance_predictions(output["instances"].to("cpu"),0.5)# threshold =0.5
    cv2.namedWindow("pig",0)
    cv2.resizeWindow("pig",600,400)
    cv2.imshow("pig", out.get_image()[:, :, ::-1])
    #cv2.imwrite("demo-%s"%os.path.basename(d["file_name"]), out.get_image()[:, :, ::-1])

    cv2.waitKey(3000)
    cv2.destroyAllWindows()

Set the confidence threshold (score_thredshold) Part of the code

 if score_threshold != None:
     top_id = np.where(scores.numpy()>score_threshold)[0].tolist()
     scores = torch.tensor(scores.numpy()[top_id])
     boxes.tensor = torch.tensor(boxes.tensor.numpy()[top_id])
     classes = [classes[ii] for ii in top_id]
     labels = [labels[ii] for ii in top_id]

draw_instance_predictions Function complete code ( After modification )

Just copy and paste the above code to if predictions.has("pred_masks"): Before that

The position of this function is in detectron2/utils/visualizer.py Inside (pycharm direct Ctrl+B You can directly access )

    def draw_instance_predictions(self, predictions, score_threshold=None):
        """ Draw instance-level prediction results on an image. Args: score_threshold:  Confidence threshold ( New parameters ) predictions (Instances): the output of an instance detection/segmentation model. Following fields will be used to draw: "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). Returns: output (VisImage): image object with visualizations. """
        boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
        scores = predictions.scores if predictions.has("scores") else None
        classes = predictions.pred_classes if predictions.has("pred_classes") else None
        labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
        keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None

        ''' Some new codes '''
        if score_threshold != None:
            top_id = np.where(scores.numpy()>score_threshold)[0].tolist()
            scores = torch.tensor(scores.numpy()[top_id])
            boxes.tensor = torch.tensor(boxes.tensor.numpy()[top_id])
            classes = [classes[ii] for ii in top_id]
            labels = [labels[ii] for ii in top_id]

        if predictions.has("pred_masks"):
            masks = np.asarray(predictions.pred_masks)
            masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
        else:
            masks = None

        if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
            colors = [
                self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
            ]
            alpha = 0.8
        else:
            colors = None
            alpha = 0.5

        if self._instance_mode == ColorMode.IMAGE_BW:
            self.output.img = self._create_grayscale_image(
                (predictions.pred_masks.any(dim=0) > 0).numpy()
                if predictions.has("pred_masks")
                else None
            )
            alpha = 0.3

        self.overlay_instances(
            masks=masks,
            boxes=boxes,
            labels=labels,
            keypoints=keypoints,
            assigned_colors=colors,
            alpha=alpha,
        )
        return self.output
原网站

版权声明
本文为[G fruit]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202170552524029.html