当前位置:网站首页>Pytorch---使用Pytorch实现LinkNet进行语义分割
Pytorch---使用Pytorch实现LinkNet进行语义分割
2022-07-04 19:42:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.10.1
Python==3.8
三、数据集处理代码如下所示
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.LongTensor)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
indexx = 5
images = images[indexx]
labels = labels[indexx]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
四、模型的构建代码如下所示
from torch import nn
import torch
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(ConvBlock, self).__init__()
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_bn_relu(x)
class DecodeConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, out_padding=1):
super(DecodeConvBlock, self).__init__()
self.de_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=out_padding)
self.bn = nn.BatchNorm2d(num_features=out_channels)
def forward(self, x, is_act=True):
x = self.de_conv(x)
if is_act:
x = torch.relu(self.bn(x))
return x
class EncodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
self.conv2 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv3 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv4 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.short_cut = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
def forward(self, x):
out1 = self.conv1(x)
out1 = self.conv2(out1)
short_cut = self.short_cut(x)
out2 = self.conv3(out1 + short_cut)
out2 = self.conv4(out2)
return out1 + out2
class DecodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DecodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=1, padding=0)
self.de_conv = DecodeConvBlock(in_channels=in_channels // 4, out_channels=in_channels // 4)
self.conv3 = ConvBlock(in_channels=in_channels // 4, out_channels=out_channels, kernel_size=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.de_conv(x)
x = self.conv3(x)
return x
class LinkNet(nn.Module):
def __init__(self):
super(LinkNet, self).__init__()
self.init_conv = ConvBlock(in_channels=3, out_channels=64, stride=2, kernel_size=7, padding=3)
self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.encode_1 = EncodeBlock(in_channels=64, out_channels=64)
self.encode_2 = EncodeBlock(in_channels=64, out_channels=128)
self.encode_3 = EncodeBlock(in_channels=128, out_channels=256)
self.encode_4 = EncodeBlock(in_channels=256, out_channels=512)
self.decode_4 = DecodeBlock(in_channels=512, out_channels=256)
self.decode_3 = DecodeBlock(in_channels=256, out_channels=128)
self.decode_2 = DecodeBlock(in_channels=128, out_channels=64)
self.decode_1 = DecodeBlock(in_channels=64, out_channels=64)
self.deconv_out1 = DecodeConvBlock(in_channels=64, out_channels=32)
self.conv_out = ConvBlock(in_channels=32, out_channels=32)
self.deconv_out2 = DecodeConvBlock(in_channels=32, out_channels=2, kernel_size=2, padding=0, out_padding=0)
def forward(self, x):
x = self.init_conv(x)
x = self.init_maxpool(x)
e1 = self.encode_1(x)
e2 = self.encode_2(e1)
e3 = self.encode_3(e2)
e4 = self.encode_4(e3)
d4 = self.decode_4(e4)
d3 = self.decode_3(d4 + e3)
d2 = self.decode_2(d3 + e2)
d1 = self.decode_1(d2 + e1)
f1 = self.deconv_out1(d1)
f2 = self.conv_out(f1)
f3 = self.deconv_out2(f2)
return f3
五、模型的训练代码如下所示
import torch
from data_loader import load_data
from model_loader import LinkNet
from torch import nn
from torch import optim
import tqdm
import os
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据
train_dl, test_dl = load_data()
# 加载模型
model = LinkNet()
model = model.to(device=device)
# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# 开始进行训练
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
train_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
intersection = torch.logical_and(input=train_labels, other=torch.argmax(input=pred, dim=1))
union = torch.logical_or(input=train_labels, other=torch.argmax(input=pred, dim=1))
batch_iou = torch.true_divide(torch.sum(intersection), torch.sum(union))
train_iou_sum = torch.cat([train_iou_sum, torch.unsqueeze(input=batch_iou, dim=-1)], dim=-1)
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item(),
'train iou': train_iou_sum.mean().item()
})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
test_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_intersection = torch.logical_and(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_union = torch.logical_or(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_batch_iou = torch.true_divide(torch.sum(test_intersection), torch.sum(test_union))
test_iou_sum = torch.cat([test_iou_sum, torch.unsqueeze(input=test_batch_iou, dim=-1)], dim=-1)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item(),
'test iou': test_iou_sum.mean().item()
})
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import LinkNet
# 数据的加载
train_dl, test_dl = load_data()
# 模型的加载
model = LinkNet()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# 开始进行预测
images, labels = next(iter(test_dl))
index = 2
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.8, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示
边栏推荐
- GVM使用
- 最长的可整合子数组的长度
- GVM use
- Win11怎么搜索无线显示器?Win11查找无线显示器设备的方法
- 记录线上bug解决list(未完待续7/4)
- Hands on deep learning (III) -- convolutional neural network CNN
- Four traversal methods of binary tree, as well as the creation of binary tree from middle order to post order, pre order to middle order, pre order to post order, and sequence [specially created for t
- mysql语句执行详解
- Flet tutorial 04 basic introduction to filledtonalbutton (tutorial includes source code)
- What is the development of block hash quiz game system? Hash quiz game system development (case mature)
猜你喜欢
Sword finger offer II 80-100 (continuous update)
工厂从自动化到数字孪生,图扑能干什么?
《动手学深度学习》(三) -- 卷积神经网络 CNN
Understand the reading, writing and creation of files in go language
Jiuqi ny8b062d MCU specification /datasheet
MySQL --- 数据库查询 - 聚合函数的使用、聚合查询、分组查询
So this is the BGP agreement
FS8B711S14电动红酒开瓶器单片机IC方案开发专用集成IC
Hands on deep learning (III) -- convolutional neural network CNN
What if win11u disk refuses access? An effective solution to win11u disk access denial
随机推荐
Quelques suggestions pour la conception de l'interface
MySQL --- 数据库查询 - 聚合函数的使用、聚合查询、分组查询
GVM use
Common verification rules of form components -1 (continuously updating ~)
Idea restore default shortcut key
最长的可整合子数组的长度
hash 表的概念及应用
Hands on deep learning (III) -- convolutional neural network CNN
托管式服务网络:云原生时代的应用体系架构进化
Idea plug-in
Practice examples to understand JS strong cache negotiation cache
AP8022开关电源小家电ACDC芯片离线式开关电源IC
Understand the reading, writing and creation of files in go language
NLP、视觉、芯片...AI重点方向发展几何?青源会展望报告发布[附下载]
企业数字化转型最佳实践案例:基于云的数字化平台系统安全措施简介与参考
记一次重复造轮子(Obsidian 插件设置说明汉化)
奏响青春的乐章
九齐NY8B062D MCU规格书/datasheet
Aiming at the "amnesia" of deep learning, scientists proposed that based on similarity weighted interleaved learning, they can board PNAS
Play the music of youth