当前位置:网站首页>Model prediction of semantic segmentation
Model prediction of semantic segmentation
2022-06-13 02:32:00 【Prodigal son's private dishes】
Small size image input
General images do not need to be cropped , You can enter the model , Do end-to-end training . Its prediction process is also relatively simple , Take two categories as an example , The probability graph output from the model is transformed into a binary graph through certain methods . There are two ways to do this , firstly , If the number of categories includes background classes , utilize argmax Output the maximum value at the same position of each dimension , second , If the number of categories does not include the background class , The use of sigmoid Compress its value to 0-1 Between , Using the threshold method , It's usually 0.5, Greater than 0.5 Is a positive class , Less than 0.5 For the background class .
net = torch.load('./model.pth', map_location=lambda storage, loc: storage)["model"]
net = net.to(device)
imglist = os.listdir(input_img_folder)
img = cv2.imread(os.path.join(input_img_folder, imglist[400]))
tensor = img_to_tensor(img)
tensor = Variable(torch.unsqueeze(tensor, dim=0).float(), requires_grad=False)
predict = net(tensor.to(device))[0,0,:,:]
predict = predict.detach().cpu().numpy()
predict[predict <= 0.5] = 0 # Background class
predict[predict > 0.5] = 1 # Just like
Large image input
When the image size is large , Whole input model to train , It's easy to lead to cuda:out of memory. In remote sensing images , It happens all the time . The general solution is to cut the big picture into slices , When the model is predicted, it is spliced . The steps are :
(1) Get all image paths ;
(2) Conduct for loop , Cut each image into slices , Stored in a temporary file ( Delete after the forecast is completed ), And generate a data generator based on this ;
(3) Based on data generator , Make model predictions , All probability graphs are spliced into large probability graphs , Its size is the same as the original drawing ;
(4) Transform the probability graph into a binary graph , And color according to the visual requirements ;
(5) Finally, delete the temporary file , Keep repeating (2)(3)(4).
————————————————
## use model to predict
def predict(model):
result = []
for images in tqdm.tqdm(test_loader):
images = images.to(device)
temp = 0
for keys in model.keys():
model[keys].eval()
outputs = model[keys](images)
temp += outputs
preds = temp/len(model)
# preds = torch.from_numpy(preds)
preds = torch.max(preds,1)[1]
result.append(preds.cpu().numpy())
return result
def input_and_output(pic_path, model, generate_data):
"""
args:
pic_path : the picture you want to predict
model : the model you want to predict
note:
step one : generate some pictures from one picture
step two : predict from the images generated by step one
"""
image_size = args.crop_size
img = cv2.imread(pic_path)
b = args.padding_size
image = cv2.copyMakeBorder(img, b, b, b, b, cv2.BORDER_REFLECT)
h, w = image.shape[0], image.shape[1]
row = img.shape[0]//image_size
col = img.shape[1]//image_size
padding_img = np.zeros((h, w, 3), dtype=np.uint8)
padding_img[0:h, 0:w, :] = image[:, :, :]
padding_img = np.array(padding_img)
# print ('src:',padding_img.shape)
mask_whole = np.zeros((row*image_size, col*image_size), dtype=np.uint8)
if generate_data == False:
result = predict(model)
map_list = [str(i.name) for i in Path('temp_pic').files()]
for i in range(row):
for j in range(col):
if generate_data:
crop = redundancy_crop(padding_img, i, j, image_size)
ch,cw,_ = crop.shape
cv2.imwrite(f'temp_pic/{i}_{j}.png',crop)
else:
temp = result[map_list.index(f'{i}_{j}.png')]
temp = redundancy_crop2(temp)
mask_whole[i*image_size:i*image_size+image_size,j*image_size:j*image_size+image_size] = temp
return mask_whole
def redundancy_crop(img, i, j, targetSize):
temp_img = img[i*targetSize:i*targetSize+targetSize+2*args.padding_size, j*targetSize:j*targetSize+targetSize+2*args.padding_size, :]
return temp_img
def redundancy_crop2(img):
h = img.shape[1]
w = img.shape[2]
temp_img = img[:,args.padding_size:h-args.padding_size,args.padding_size:w-args.padding_size]
return temp_img
def get_dataset_loaders( workers):
batch_size = 1
test_dataset = urban3dDWM(
os.path.join(path), './', test=True
)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=workers)
return test_loader
def get_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (2, 3)
"""
return np.asarray(
[
[0, 0, 0],
[255, 255, 255]
]
)
def decode_segmap(label_mask, n_classes):
"""Decode segmentation class labels into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
label_colours = get_labels()
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r
rgb[:, :, 1] = g
rgb[:, :, 2] = b
return rgb
if __name__ =="__main__":
# def my_predict():
parse = argparse.ArgumentParser()
parse.add_argument("--n_class", type=int, default=2, help="the number of classes")
parse.add_argument("--model_name", type=str, default='UNet', help="UNet,PSPNet,FPN")
parse.add_argument("--n_workers", type=int, default=4, help="the number of workers")
parse.add_argument("--crop_size", type=int, default=256, help="the number of workers")
parse.add_argument("--padding_size", type=int, default=32, help="the number of workers")
args = parse.parse_args()
# model_groups = ["UNet","PSPNet","FPN"]
model_groups = ["UNet"]
# predict on more model
models={}
for index, item in enumerate(model_groups):
models[item] = model = torch.load(f'./results_{item}2/{item}_weights_best.pth')["model_state"]
# model = torch.load(f'./results_{args.model_name}/{args.model_name}_weights_best.pth')["model_state"]
imgList = glob.glob("./valid/*RGB.tif")
num = len(imgList)
save_path = f'./predict_{args.model_name}'
if not os.path.exists(save_path):
os.makedirs(save_path)
for i in tqdm.tqdm(range(num)):
if not os.path.exists('temp_pic'):
os.makedirs('temp_pic')
### predict on one picture
input_and_output(imgList[i], models, generate_data=True)
name = os.path.split(imgList[i])[-1].split(".")[0]
test_loader = get_dataset_loaders(args.n_workers)
mask_result = input_and_output(imgList[i], models, generate_data=False)
# Recursively delete folders
try:
shutil.rmtree('temp_pic')
except:
pass
decoded = decode_segmap(mask_result, args.n_class)
# print(mask_result.shape)
cv2.imwrite(f'{save_path}/{name}.png', decoded)
In order to avoid the gridding effect of the prediction chart , The above code uses redundant prediction .
边栏推荐
- 柏瑞凯电子冲刺科创板:拟募资3.6亿 汪斌华夫妇为大股东
- Mean Value Coordinates
- [work with notes] MFC solves the problem that pressing ESC and enter will automatically exit
- Solution of depth learning for 3D anisotropic images
- 01 initial knowledge of wechat applet
- json,xml,txt
- Common web page status return code crawler
- Graph theory, tree based concept
- Leetcode 93 recovery IP address
- AutoX. JS invitation code
猜你喜欢

02 optimize the default structure of wechat developer tools

Classification and summary of system registers in aarch64 architecture of armv8/arnv9

After idea uses c3p0 connection pool to connect to SQL database, database content cannot be displayed

Understanding and thinking about multi-core consistency

An image is word 16x16 words: transformers for image recognition at scale
![[learning notes] xr872 GUI littlevgl 8.0 migration (file system)](/img/9b/0bf88354e8cfdbcc1ea91311c9a823.jpg)
[learning notes] xr872 GUI littlevgl 8.0 migration (file system)

05 tabBar导航栏功能

Huawei equipment is configured with CE dual attribution

Laravel 权限导出

speech production model
随机推荐
Armv8-m learning notes - getting started
Number of special palindromes in basic exercise of test questions
Chapter7-11_ Deep Learning for Question Answering (2/2)
Mbedtls migration experience
Port mapping between two computers on different LANs (anydesk)
Huawei equipment is configured with IP and virtual private network hybrid FRR
Is space time attention all you need for video understanding?
Image table solid line and dashed line detection
ROS learning-8 pit for custom action programming
Opencv 9 resize size change rotate rotate blur mean (blur)
1000 fans ~
[pytorch] kaggle image classification competition arcface + bounding box code learning
[reading point paper] deeplobv3 rethinking atlas revolution for semantic image segmentation ASPP
Flow chart of interrupt process
4.11 introduction to firmware image package
C language compressed string is saved to binary file, and the compressed string is read from binary file and decompressed.
How to solve the problem of obtaining the time through new date() and writing out the difference of 8 hours between the database and the current time [valid through personal test]
The precision of C language printf output floating point numbers
Leetcode 926. 将字符串翻转到单调递增 [前缀和]
ROS learning-7 error in custom message or service reference header file