当前位置:网站首页>Model prediction of semantic segmentation
Model prediction of semantic segmentation
2022-06-13 02:32:00 【Prodigal son's private dishes】
Small size image input
General images do not need to be cropped , You can enter the model , Do end-to-end training . Its prediction process is also relatively simple , Take two categories as an example , The probability graph output from the model is transformed into a binary graph through certain methods . There are two ways to do this , firstly , If the number of categories includes background classes , utilize argmax Output the maximum value at the same position of each dimension , second , If the number of categories does not include the background class , The use of sigmoid Compress its value to 0-1 Between , Using the threshold method , It's usually 0.5, Greater than 0.5 Is a positive class , Less than 0.5 For the background class .
net = torch.load('./model.pth', map_location=lambda storage, loc: storage)["model"]
net = net.to(device)
imglist = os.listdir(input_img_folder)
img = cv2.imread(os.path.join(input_img_folder, imglist[400]))
tensor = img_to_tensor(img)
tensor = Variable(torch.unsqueeze(tensor, dim=0).float(), requires_grad=False)
predict = net(tensor.to(device))[0,0,:,:]
predict = predict.detach().cpu().numpy()
predict[predict <= 0.5] = 0 # Background class
predict[predict > 0.5] = 1 # Just like
Large image input
When the image size is large , Whole input model to train , It's easy to lead to cuda:out of memory. In remote sensing images , It happens all the time . The general solution is to cut the big picture into slices , When the model is predicted, it is spliced . The steps are :
(1) Get all image paths ;
(2) Conduct for loop , Cut each image into slices , Stored in a temporary file ( Delete after the forecast is completed ), And generate a data generator based on this ;
(3) Based on data generator , Make model predictions , All probability graphs are spliced into large probability graphs , Its size is the same as the original drawing ;
(4) Transform the probability graph into a binary graph , And color according to the visual requirements ;
(5) Finally, delete the temporary file , Keep repeating (2)(3)(4).
————————————————
## use model to predict
def predict(model):
result = []
for images in tqdm.tqdm(test_loader):
images = images.to(device)
temp = 0
for keys in model.keys():
model[keys].eval()
outputs = model[keys](images)
temp += outputs
preds = temp/len(model)
# preds = torch.from_numpy(preds)
preds = torch.max(preds,1)[1]
result.append(preds.cpu().numpy())
return result
def input_and_output(pic_path, model, generate_data):
"""
args:
pic_path : the picture you want to predict
model : the model you want to predict
note:
step one : generate some pictures from one picture
step two : predict from the images generated by step one
"""
image_size = args.crop_size
img = cv2.imread(pic_path)
b = args.padding_size
image = cv2.copyMakeBorder(img, b, b, b, b, cv2.BORDER_REFLECT)
h, w = image.shape[0], image.shape[1]
row = img.shape[0]//image_size
col = img.shape[1]//image_size
padding_img = np.zeros((h, w, 3), dtype=np.uint8)
padding_img[0:h, 0:w, :] = image[:, :, :]
padding_img = np.array(padding_img)
# print ('src:',padding_img.shape)
mask_whole = np.zeros((row*image_size, col*image_size), dtype=np.uint8)
if generate_data == False:
result = predict(model)
map_list = [str(i.name) for i in Path('temp_pic').files()]
for i in range(row):
for j in range(col):
if generate_data:
crop = redundancy_crop(padding_img, i, j, image_size)
ch,cw,_ = crop.shape
cv2.imwrite(f'temp_pic/{i}_{j}.png',crop)
else:
temp = result[map_list.index(f'{i}_{j}.png')]
temp = redundancy_crop2(temp)
mask_whole[i*image_size:i*image_size+image_size,j*image_size:j*image_size+image_size] = temp
return mask_whole
def redundancy_crop(img, i, j, targetSize):
temp_img = img[i*targetSize:i*targetSize+targetSize+2*args.padding_size, j*targetSize:j*targetSize+targetSize+2*args.padding_size, :]
return temp_img
def redundancy_crop2(img):
h = img.shape[1]
w = img.shape[2]
temp_img = img[:,args.padding_size:h-args.padding_size,args.padding_size:w-args.padding_size]
return temp_img
def get_dataset_loaders( workers):
batch_size = 1
test_dataset = urban3dDWM(
os.path.join(path), './', test=True
)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=workers)
return test_loader
def get_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (2, 3)
"""
return np.asarray(
[
[0, 0, 0],
[255, 255, 255]
]
)
def decode_segmap(label_mask, n_classes):
"""Decode segmentation class labels into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
label_colours = get_labels()
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r
rgb[:, :, 1] = g
rgb[:, :, 2] = b
return rgb
if __name__ =="__main__":
# def my_predict():
parse = argparse.ArgumentParser()
parse.add_argument("--n_class", type=int, default=2, help="the number of classes")
parse.add_argument("--model_name", type=str, default='UNet', help="UNet,PSPNet,FPN")
parse.add_argument("--n_workers", type=int, default=4, help="the number of workers")
parse.add_argument("--crop_size", type=int, default=256, help="the number of workers")
parse.add_argument("--padding_size", type=int, default=32, help="the number of workers")
args = parse.parse_args()
# model_groups = ["UNet","PSPNet","FPN"]
model_groups = ["UNet"]
# predict on more model
models={}
for index, item in enumerate(model_groups):
models[item] = model = torch.load(f'./results_{item}2/{item}_weights_best.pth')["model_state"]
# model = torch.load(f'./results_{args.model_name}/{args.model_name}_weights_best.pth')["model_state"]
imgList = glob.glob("./valid/*RGB.tif")
num = len(imgList)
save_path = f'./predict_{args.model_name}'
if not os.path.exists(save_path):
os.makedirs(save_path)
for i in tqdm.tqdm(range(num)):
if not os.path.exists('temp_pic'):
os.makedirs('temp_pic')
### predict on one picture
input_and_output(imgList[i], models, generate_data=True)
name = os.path.split(imgList[i])[-1].split(".")[0]
test_loader = get_dataset_loaders(args.n_workers)
mask_result = input_and_output(imgList[i], models, generate_data=False)
# Recursively delete folders
try:
shutil.rmtree('temp_pic')
except:
pass
decoded = decode_segmap(mask_result, args.n_class)
# print(mask_result.shape)
cv2.imwrite(f'{save_path}/{name}.png', decoded)
In order to avoid the gridding effect of the prediction chart , The above code uses redundant prediction .
边栏推荐
- Chapter7-11_ Deep Learning for Question Answering (2/2)
- Port mapping between two computers on different LANs (anydesk)
- CV 06 demonstrates backgroundworker
- Several articles on norms
- Introduction to armv8/armv9 - learning this article is enough
- Opencv 08 demonstrates the effect of opening and closing operations of erode, dilate and morphological function morphologyex.
- [analysis notes] source code analysis of siliconlabs efr32bg22 Bluetooth mesh sensorclient
- Understanding and thinking about multi-core consistency
- Area of basic exercise circle ※
- CCF 201409-1: adjacent number pairs (100 points + problem solving ideas)
猜你喜欢

Mbedtls migration experience

Configuring virtual private network FRR for Huawei equipment

02 优化微信开发者工具默认的结构

Armv8-m learning notes - getting started

0- blog notes guide directory (all)

Barrykay electronics rushes to the scientific innovation board: it is planned to raise 360million yuan. Mr. and Mrs. Wang Binhua are the major shareholders

Basic principle of bilateral filtering

Paper reading - beat tracking by dynamic programming

Understanding and thinking about multi-core consistency

Understand HMM
随机推荐
04 route jump and carry parameters
4.11 introduction to firmware image package
在IDEA使用C3P0連接池連接SQL數據庫後卻不能顯示數據庫內容
Branch and bound method, example sorting
C # illustrated tutorial (Fourth Edition) chapter7-7.2 accessing inherited members
柏瑞凯电子冲刺科创板:拟募资3.6亿 汪斌华夫妇为大股东
[reading papers] dcgan, the combination of generating countermeasure network and deep convolution
Leetcode 450. Delete node in binary search tree [binary search tree]
I didn't expect that the index occupies several times as much space as the data MySQL queries the space occupied by each table in the database, and the space occupied by data and indexes. It is used i
Redis multiple servers share one
Opencv 9 resize size change rotate rotate blur mean (blur)
json,xml,txt
AutoX. JS invitation code
The precision of C language printf output floating point numbers
Thinking back from the eight queens' question
GMM Gaussian mixture model
Mean Value Coordinates
05 tabBar导航栏功能
L1 regularization and its sparsity
C language compressed string is saved to binary file, and the compressed string is read from binary file and decompressed.