当前位置:网站首页>Pytorch---使用Pytorch实现LinkNet进行语义分割
Pytorch---使用Pytorch实现LinkNet进行语义分割
2022-07-04 19:42:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.10.1
Python==3.8
三、数据集处理代码如下所示
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.LongTensor)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
indexx = 5
images = images[indexx]
labels = labels[indexx]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
四、模型的构建代码如下所示
from torch import nn
import torch
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(ConvBlock, self).__init__()
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_bn_relu(x)
class DecodeConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, out_padding=1):
super(DecodeConvBlock, self).__init__()
self.de_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=out_padding)
self.bn = nn.BatchNorm2d(num_features=out_channels)
def forward(self, x, is_act=True):
x = self.de_conv(x)
if is_act:
x = torch.relu(self.bn(x))
return x
class EncodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
self.conv2 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv3 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv4 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.short_cut = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
def forward(self, x):
out1 = self.conv1(x)
out1 = self.conv2(out1)
short_cut = self.short_cut(x)
out2 = self.conv3(out1 + short_cut)
out2 = self.conv4(out2)
return out1 + out2
class DecodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DecodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=1, padding=0)
self.de_conv = DecodeConvBlock(in_channels=in_channels // 4, out_channels=in_channels // 4)
self.conv3 = ConvBlock(in_channels=in_channels // 4, out_channels=out_channels, kernel_size=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.de_conv(x)
x = self.conv3(x)
return x
class LinkNet(nn.Module):
def __init__(self):
super(LinkNet, self).__init__()
self.init_conv = ConvBlock(in_channels=3, out_channels=64, stride=2, kernel_size=7, padding=3)
self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.encode_1 = EncodeBlock(in_channels=64, out_channels=64)
self.encode_2 = EncodeBlock(in_channels=64, out_channels=128)
self.encode_3 = EncodeBlock(in_channels=128, out_channels=256)
self.encode_4 = EncodeBlock(in_channels=256, out_channels=512)
self.decode_4 = DecodeBlock(in_channels=512, out_channels=256)
self.decode_3 = DecodeBlock(in_channels=256, out_channels=128)
self.decode_2 = DecodeBlock(in_channels=128, out_channels=64)
self.decode_1 = DecodeBlock(in_channels=64, out_channels=64)
self.deconv_out1 = DecodeConvBlock(in_channels=64, out_channels=32)
self.conv_out = ConvBlock(in_channels=32, out_channels=32)
self.deconv_out2 = DecodeConvBlock(in_channels=32, out_channels=2, kernel_size=2, padding=0, out_padding=0)
def forward(self, x):
x = self.init_conv(x)
x = self.init_maxpool(x)
e1 = self.encode_1(x)
e2 = self.encode_2(e1)
e3 = self.encode_3(e2)
e4 = self.encode_4(e3)
d4 = self.decode_4(e4)
d3 = self.decode_3(d4 + e3)
d2 = self.decode_2(d3 + e2)
d1 = self.decode_1(d2 + e1)
f1 = self.deconv_out1(d1)
f2 = self.conv_out(f1)
f3 = self.deconv_out2(f2)
return f3
五、模型的训练代码如下所示
import torch
from data_loader import load_data
from model_loader import LinkNet
from torch import nn
from torch import optim
import tqdm
import os
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据
train_dl, test_dl = load_data()
# 加载模型
model = LinkNet()
model = model.to(device=device)
# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# 开始进行训练
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
train_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
intersection = torch.logical_and(input=train_labels, other=torch.argmax(input=pred, dim=1))
union = torch.logical_or(input=train_labels, other=torch.argmax(input=pred, dim=1))
batch_iou = torch.true_divide(torch.sum(intersection), torch.sum(union))
train_iou_sum = torch.cat([train_iou_sum, torch.unsqueeze(input=batch_iou, dim=-1)], dim=-1)
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item(),
'train iou': train_iou_sum.mean().item()
})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
test_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_intersection = torch.logical_and(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_union = torch.logical_or(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_batch_iou = torch.true_divide(torch.sum(test_intersection), torch.sum(test_union))
test_iou_sum = torch.cat([test_iou_sum, torch.unsqueeze(input=test_batch_iou, dim=-1)], dim=-1)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item(),
'test iou': test_iou_sum.mean().item()
})
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import LinkNet
# 数据的加载
train_dl, test_dl = load_data()
# 模型的加载
model = LinkNet()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# 开始进行预测
images, labels = next(iter(test_dl))
index = 2
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.8, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示
边栏推荐
- Talking about cookies of client storage technology
- Idea plug-in
- 二叉树的四种遍历方式以及中序后序、前序中序、前序后序、层序创建二叉树【专为力扣刷题而打造】
- Flet教程之 08 AppBar工具栏基础入门(教程含源码)
- 企业数字化转型最佳实践案例:基于云的数字化平台系统安全措施简介与参考
- Browser render page pass
- Win11亮度被锁定怎么办?Win11亮度被锁定的解决方法
- Related concepts of federal learning and motivation (1)
- 2022 version of stronger jsonpath compatibility and performance test (snack3, fastjson2, jayway.jsonpath)
- Quelques suggestions pour la conception de l'interface
猜你喜欢
QT writing the Internet of things management platform 38- multiple database support
Quelques suggestions pour la conception de l'interface
idea配置标准注释
What if win11u disk refuses access? An effective solution to win11u disk access denial
面对同样复杂的测试任务为什么大老很快能梳理解决方案,阿里十年测试工程师道出其中的技巧
What if the win11 shared file cannot be opened? The solution of win11 shared file cannot be opened
Qt五子棋人机对战画棋子之QPainter的使用误区总结
How does the computer save web pages to the desktop for use
【观察】联想:3X(1+N)智慧办公解决方案,释放办公生产力“乘数效应”
Form组件常用校验规则-1(持续更新中~)
随机推荐
Flet教程之 05 OutlinedButton基础入门(教程含源码)
ICML 2022 | meta proposes a robust multi-objective Bayesian optimization method to effectively deal with input noise
hash 表的概念及应用
强化学习-学习笔记2 | 价值学习
Fleet tutorial 08 introduction to AppBar toolbar Basics (tutorial includes source code)
Four traversal methods of binary tree, as well as the creation of binary tree from middle order to post order, pre order to middle order, pre order to post order, and sequence [specially created for t
word中插入图片后,图片上方有一空行,且删除后布局变乱
Hands on deep learning (III) -- convolutional neural network CNN
电脑共享打印机拒绝访问要怎么办
语义化标签的优势和块级行内元素
Qt五子棋人机对战画棋子之QPainter的使用误区总结
E-week finance | Q1 the number of active people in the insurance industry was 86.8867 million, and the licenses of 19 Payment institutions were cancelled
WinCC7.5 SP1如何通过交叉索引来寻找变量及其位置?
Flet教程之 08 AppBar工具栏基础入门(教程含源码)
【观察】联想:3X(1+N)智慧办公解决方案,释放办公生产力“乘数效应”
What ppt writing skills does the classic "pyramid principle" teach us?
2022 version of stronger jsonpath compatibility and performance test (snack3, fastjson2, jayway.jsonpath)
Every time I look at the interface documents of my colleagues, I get confused and have a lot of problems...
RFID仓库管理系统解决方案有哪些功能模块
Win11U盘拒绝访问怎么办?Win11U盘拒绝访问的有效解决方法