当前位置:网站首页>Pytoch --- use pytoch to realize linknet for semantic segmentation
Pytoch --- use pytoch to realize linknet for semantic segmentation
2022-07-04 23:26:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :f1j7
Two 、 Code running environment
Pytorch-gpu==1.10.1
Python==3.8
3、 ... and 、 Data set processing codes are as follows
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.LongTensor)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
indexx = 5
images = images[indexx]
labels = labels[indexx]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
Four 、 The construction code of the model is as follows
from torch import nn
import torch
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(ConvBlock, self).__init__()
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_bn_relu(x)
class DecodeConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, out_padding=1):
super(DecodeConvBlock, self).__init__()
self.de_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=out_padding)
self.bn = nn.BatchNorm2d(num_features=out_channels)
def forward(self, x, is_act=True):
x = self.de_conv(x)
if is_act:
x = torch.relu(self.bn(x))
return x
class EncodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
self.conv2 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv3 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv4 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.short_cut = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
def forward(self, x):
out1 = self.conv1(x)
out1 = self.conv2(out1)
short_cut = self.short_cut(x)
out2 = self.conv3(out1 + short_cut)
out2 = self.conv4(out2)
return out1 + out2
class DecodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DecodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=1, padding=0)
self.de_conv = DecodeConvBlock(in_channels=in_channels // 4, out_channels=in_channels // 4)
self.conv3 = ConvBlock(in_channels=in_channels // 4, out_channels=out_channels, kernel_size=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.de_conv(x)
x = self.conv3(x)
return x
class LinkNet(nn.Module):
def __init__(self):
super(LinkNet, self).__init__()
self.init_conv = ConvBlock(in_channels=3, out_channels=64, stride=2, kernel_size=7, padding=3)
self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.encode_1 = EncodeBlock(in_channels=64, out_channels=64)
self.encode_2 = EncodeBlock(in_channels=64, out_channels=128)
self.encode_3 = EncodeBlock(in_channels=128, out_channels=256)
self.encode_4 = EncodeBlock(in_channels=256, out_channels=512)
self.decode_4 = DecodeBlock(in_channels=512, out_channels=256)
self.decode_3 = DecodeBlock(in_channels=256, out_channels=128)
self.decode_2 = DecodeBlock(in_channels=128, out_channels=64)
self.decode_1 = DecodeBlock(in_channels=64, out_channels=64)
self.deconv_out1 = DecodeConvBlock(in_channels=64, out_channels=32)
self.conv_out = ConvBlock(in_channels=32, out_channels=32)
self.deconv_out2 = DecodeConvBlock(in_channels=32, out_channels=2, kernel_size=2, padding=0, out_padding=0)
def forward(self, x):
x = self.init_conv(x)
x = self.init_maxpool(x)
e1 = self.encode_1(x)
e2 = self.encode_2(e1)
e3 = self.encode_3(e2)
e4 = self.encode_4(e3)
d4 = self.decode_4(e4)
d3 = self.decode_3(d4 + e3)
d2 = self.decode_2(d3 + e2)
d1 = self.decode_1(d2 + e1)
f1 = self.deconv_out1(d1)
f2 = self.conv_out(f1)
f3 = self.deconv_out2(f2)
return f3
5、 ... and 、 The training code of the model is as follows
import torch
from data_loader import load_data
from model_loader import LinkNet
from torch import nn
from torch import optim
import tqdm
import os
# Configuration of environment variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load data
train_dl, test_dl = load_data()
# Load model
model = LinkNet()
model = model.to(device=device)
# Training related configurations
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# Start training
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
train_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
intersection = torch.logical_and(input=train_labels, other=torch.argmax(input=pred, dim=1))
union = torch.logical_or(input=train_labels, other=torch.argmax(input=pred, dim=1))
batch_iou = torch.true_divide(torch.sum(intersection), torch.sum(union))
train_iou_sum = torch.cat([train_iou_sum, torch.unsqueeze(input=batch_iou, dim=-1)], dim=-1)
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item(),
'train iou': train_iou_sum.mean().item()
})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
test_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_intersection = torch.logical_and(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_union = torch.logical_or(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_batch_iou = torch.true_divide(torch.sum(test_intersection), torch.sum(test_union))
test_iou_sum = torch.cat([test_iou_sum, torch.unsqueeze(input=test_batch_iou, dim=-1)], dim=-1)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item(),
'test iou': test_iou_sum.mean().item()
})
test_tqdm.close()
# Save model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
6、 ... and 、 The prediction code of the model is as follows
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import LinkNet
# Data loading
train_dl, test_dl = load_data()
# Model loading
model = LinkNet()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# Start Forecasting
images, labels = next(iter(test_dl))
index = 2
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.8, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows

边栏推荐
- Explanation of bitwise operators
- 解决无法通过ssh服务远程连接虚拟机
- Object detection based on OpenCV haarcascades
- phpcms付费阅读功能支付宝支付
- The initial arrangement of particles in SPH (solved by two pictures)
- [JS] - [dynamic planning] - Notes
- Wechat official account solves the cache problem of entering from the customized menu
- 如何在外地外网电脑远程公司项目?
- 头文件重复定义问题解决“C1014错误“
- Basic knowledge of database
猜你喜欢
![[sword finger offer] questions 1-5](/img/54/b70d5290978e842939db99645c6ada.png)
[sword finger offer] questions 1-5

取得PMP证书需要多长时间?

ICML 2022 | 3dlinker: e (3) equal variation self encoder for molecular link design

Editplus-- usage -- shortcut key / configuration / background color / font size

D3.js+Three. JS data visualization 3D Earth JS special effect

CTF竞赛题解之stm32逆向入门

Redis:Redis消息的发布与订阅(了解)

45岁教授,她投出2个超级独角兽

PMP证书续证流程

Why does infographic help your SEO
随机推荐
法国学者:最优传输理论下对抗攻击可解释性探讨
Stm32 Reverse Introduction to CTF Competition Interpretation
Recommended collection: build a cross cloud data warehouse environment, which is particularly dry!
ETCD数据库源码分析——处理Entry记录简要流程
Basic use and upgrade of Android native database
Network namespace
时间 (计算)总工具类 例子: 今年开始时间和今年结束时间等
Combien de temps faut - il pour obtenir un certificat PMP?
Is the account opening link of Huatai Securities with low commission safe?
debug和release的区别
JS 3D explosive fragment image switching JS special effect
【监控】zabbix
ScriptableObject
初试为锐捷交换机跨设备型号升级版本(以RG-S2952G-E为例)
LabVIEW中比较两个VI
CTF competition problem solution STM32 reverse introduction
字体设计符号组合多功能微信小程序源码
QT addition calculator (simple case)
PS style JS webpage graffiti board plug-in
PaddleOCR教程