当前位置:网站首页>Model prediction of semantic segmentation

Model prediction of semantic segmentation

2022-06-13 02:32:00 Prodigal son's private dishes

Small size image input

General images do not need to be cropped , You can enter the model , Do end-to-end training . Its prediction process is also relatively simple , Take two categories as an example , The probability graph output from the model is transformed into a binary graph through certain methods . There are two ways to do this , firstly , If the number of categories includes background classes , utilize argmax Output the maximum value at the same position of each dimension , second , If the number of categories does not include the background class , The use of sigmoid Compress its value to 0-1 Between , Using the threshold method , It's usually 0.5, Greater than 0.5 Is a positive class , Less than 0.5 For the background class .

net = torch.load('./model.pth', map_location=lambda storage, loc: storage)["model"]
net = net.to(device)

imglist = os.listdir(input_img_folder)
img = cv2.imread(os.path.join(input_img_folder, imglist[400]))
tensor = img_to_tensor(img)
tensor = Variable(torch.unsqueeze(tensor, dim=0).float(), requires_grad=False)
predict = net(tensor.to(device))[0,0,:,:]
predict = predict.detach().cpu().numpy()
predict[predict <= 0.5] = 0   # Background class 
predict[predict > 0.5] = 1   # Just like 

Large image input

When the image size is large , Whole input model to train , It's easy to lead to cuda:out of memory. In remote sensing images , It happens all the time . The general solution is to cut the big picture into slices , When the model is predicted, it is spliced . The steps are :

(1) Get all image paths ;
(2) Conduct for loop , Cut each image into slices , Stored in a temporary file ( Delete after the forecast is completed ), And generate a data generator based on this ;
(3) Based on data generator , Make model predictions , All probability graphs are spliced into large probability graphs , Its size is the same as the original drawing ;
(4) Transform the probability graph into a binary graph , And color according to the visual requirements ;
(5) Finally, delete the temporary file , Keep repeating (2)(3)(4).
————————————————

## use model to predict
def predict(model):
    result = []
    for images in tqdm.tqdm(test_loader):
        images = images.to(device)
        temp = 0
        for keys in model.keys():
            model[keys].eval()
            outputs = model[keys](images)
            temp += outputs
        preds = temp/len(model)
        # preds = torch.from_numpy(preds)
        preds = torch.max(preds,1)[1]
        result.append(preds.cpu().numpy())
    return result


def input_and_output(pic_path, model, generate_data):
    """
    args:
        pic_path : the picture you want to predict
        model    : the model you want to predict
    note:
        step one : generate some pictures from one picture
        step two : predict from the images generated by step one 
    """
    image_size = args.crop_size

    img = cv2.imread(pic_path)
    b = args.padding_size
    image = cv2.copyMakeBorder(img, b, b, b, b, cv2.BORDER_REFLECT)
    h, w = image.shape[0], image.shape[1]
    row = img.shape[0]//image_size
    col = img.shape[1]//image_size
    padding_img = np.zeros((h, w, 3), dtype=np.uint8)
    padding_img[0:h, 0:w, :] = image[:, :, :]

    padding_img = np.array(padding_img)
#     print ('src:',padding_img.shape)
    mask_whole = np.zeros((row*image_size, col*image_size), dtype=np.uint8)
    if generate_data == False:
        result = predict(model)
        map_list = [str(i.name) for i in Path('temp_pic').files()]
    for i in range(row):
        for j in range(col):
            if generate_data:
                crop = redundancy_crop(padding_img, i, j, image_size)
                ch,cw,_ = crop.shape
                cv2.imwrite(f'temp_pic/{i}_{j}.png',crop)
            else:
                temp = result[map_list.index(f'{i}_{j}.png')]
                temp = redundancy_crop2(temp)
                mask_whole[i*image_size:i*image_size+image_size,j*image_size:j*image_size+image_size] = temp
    return mask_whole


def redundancy_crop(img, i, j, targetSize):
    temp_img = img[i*targetSize:i*targetSize+targetSize+2*args.padding_size, j*targetSize:j*targetSize+targetSize+2*args.padding_size, :]
    return temp_img


def redundancy_crop2(img):
    h = img.shape[1]
    w = img.shape[2]
    temp_img = img[:,args.padding_size:h-args.padding_size,args.padding_size:w-args.padding_size]
    return temp_img


def get_dataset_loaders( workers):
    batch_size = 1

    test_dataset = urban3dDWM(
        os.path.join(path), './',  test=True
    )

    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=workers)
    return test_loader


def get_labels():
    """Load the mapping that associates pascal classes with label colors

    Returns:
        np.ndarray with dimensions (2, 3)
    """
    return np.asarray(
        [
            [0, 0, 0],
            [255, 255, 255]
        ]
    )


def decode_segmap(label_mask, n_classes):
    """Decode segmentation class labels into a color image

    Args:
        label_mask (np.ndarray): an (M,N) array of integer values denoting
          the class label at each spatial location.
        plot (bool, optional): whether to show the resulting color image
          in a figure.

    Returns:
        (np.ndarray, optional): the resulting decoded color image.
    """
    label_colours = get_labels()
    r = label_mask.copy()
    g = label_mask.copy()
    b = label_mask.copy()
    for ll in range(0, n_classes):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    rgb[:, :, 0] = r
    rgb[:, :, 1] = g
    rgb[:, :, 2] = b
    return rgb


if __name__ =="__main__":
# def my_predict():
    parse = argparse.ArgumentParser()
    parse.add_argument("--n_class", type=int, default=2, help="the number of classes")
    parse.add_argument("--model_name", type=str, default='UNet', help="UNet,PSPNet,FPN")

    parse.add_argument("--n_workers", type=int, default=4, help="the number of workers")
    parse.add_argument("--crop_size", type=int, default=256, help="the number of workers")
    parse.add_argument("--padding_size", type=int, default=32, help="the number of workers")

    args = parse.parse_args()
    # model_groups = ["UNet","PSPNet","FPN"]
    model_groups = ["UNet"]

# predict on more model
    models={}
    for index, item in enumerate(model_groups):
        models[item] = model = torch.load(f'./results_{item}2/{item}_weights_best.pth')["model_state"]

    # model = torch.load(f'./results_{args.model_name}/{args.model_name}_weights_best.pth')["model_state"]

    imgList = glob.glob("./valid/*RGB.tif")
    num = len(imgList)

    save_path = f'./predict_{args.model_name}'
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    for i in tqdm.tqdm(range(num)):
        if not os.path.exists('temp_pic'):
            os.makedirs('temp_pic')
        ### predict on one picture
        input_and_output(imgList[i], models, generate_data=True)
        name = os.path.split(imgList[i])[-1].split(".")[0]
        test_loader = get_dataset_loaders(args.n_workers)
        mask_result = input_and_output(imgList[i], models, generate_data=False)
        #  Recursively delete folders 
        try:
            shutil.rmtree('temp_pic')
        except:
            pass

        decoded = decode_segmap(mask_result, args.n_class)

        # print(mask_result.shape)
        cv2.imwrite(f'{save_path}/{name}.png', decoded)

In order to avoid the gridding effect of the prediction chart , The above code uses redundant prediction .

原网站

版权声明
本文为[Prodigal son's private dishes]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202280540496826.html