当前位置:网站首页>MMSeg——Mutli-view时序数据检查与可视化

MMSeg——Mutli-view时序数据检查与可视化

2022-07-05 13:17:00 Irving.Gao

可视化功能函数

首先我们需要构造一个能够用于可视化的功能函数,我们将其写在tools/visualize_helper.py下,方便在其他函数中调用:

visualize_helper.py

import os
import cv2
import torch
import numpy as np

# visualize
label_colors = np.array([
        [255, 255, 255],
        [0, 0, 255],
        [0, 255, 0],
        [255, 0, 0],
        [0, 255, 255],
        ])

def decode_segmap(mask):
    rgb_mask_list = [mask.copy() for i in range(3)]
    rgb = np.ones((mask.shape[0], mask.shape[1], 3)) # create an empty rgb image to save clorized masks
    for idx, single_mask in enumerate(rgb_mask_list):
        for idx_c, color in enumerate(label_colors):
            rgb_mask_list[idx][single_mask == idx_c] = color[idx] # colorize pixels if the value is equal to the class num
        rgb[:, :, idx] = rgb_mask_list[idx] # rgb = [r, g, b]
    return rgb.astype(np.int)
    
def post_process(mask):
    ''' mask: [W, H] gt_mask: [W, H, 3] '''
    gt_mask = decode_segmap(mask)
    return gt_mask
    
    

def get_multi_view_imgs(multi_view_file_list, dataset_dir):
    ''' multi_view_imgs (Tensor): input images, shape is [B,6,3,H,W] '''
    img_list = []
    for i, img_path in enumerate(multi_view_file_list):
        img = cv2.imread(os.path.join(dataset_dir, img_path))
        img_list.append(img.astype('float32'))

    num_temporal = int(len(img_list)/6)
    
    surr_img_list = []
    for idx in range(num_temporal):
        surr_img_top = cv2.hconcat(img_list[0+idx*6:3+idx*6])    # 水平拼接
        surr_img_btm = cv2.hconcat(img_list[3+idx*6:6+idx*6])    # 水平拼接
        surr_img = cv2.vconcat([surr_img_top, surr_img_btm])
        surr_img_list.append(surr_img)
    
    surr_imgs = cv2.hconcat(surr_img_list)
        
    return surr_imgs, num_temporal
    
    
def get_gt_imgs(gt):
    ''' gt (Tensor): input gt, shape is [B,H,W] '''
    gt = gt.squeeze(0).squeeze(0)
    gt = gt.cpu().numpy()
    gt_img = post_process(gt)
    return gt_img.astype('float32')
    
def save_train_imgs(img_metas, gt, save_name="multi_view_imgs"):
    ''' img_metas (list[dict]): List of image info dict gt (Tensor): input gt, shape is [B,H,W] '''
    for idx, img_meta in enumerate(img_metas):
        dataset_dir = "/".join(img_meta['filename'].split("/")[:-3])
        file_list = img_meta['ori_filename']
        surr_img, num_temporal = get_multi_view_imgs(file_list, dataset_dir)
        gt_img = get_gt_imgs(gt)
        
        surr_img = cv2.resize(surr_img ,(1600*num_temporal, 600))
        gt_img = cv2.resize(gt_img, (300, 600))
        
        print(surr_img.shape)
        print(gt_img.shape)
        all_img = cv2.hconcat([surr_img, gt_img])    # 水平拼接
        cv2.imwrite(f"{
      save_name}_bs{
      idx}.jpg", all_img)

在训练过程中进行可视化

forward_train函数中加入函数即可实现功能:

	def forward_train(self, img, img_metas, gt_semantic_seg):
		#####################
		# 训练过程中数据检查与可视化
		from tools.visualize_helper import save_train_imgs
        save_train_imgs(img_metas, gt_semantic_seg)
		###################
        x = self.extract_feat(img)

        losses = dict()
        loss_decode = self._decode_head_forward_train(x, img_metas,
                                                      gt_semantic_seg)
        losses.update(loss_decode)
        if self.with_auxiliary_head:
            loss_aux = self._auxiliary_head_forward_train(
                x, img_metas, gt_semantic_seg)
            losses.update(loss_aux)
        return losses

可视化效果

  • 左边为前一时刻的6张环视图像,右边为当前时刻的6张环视图像,最右边为BEV的GT。
    在这里插入图片描述
原网站

版权声明
本文为[Irving.Gao]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_45779334/article/details/125604249