当前位置:网站首页>Pytorch---使用Pytorch实现U-Net进行语义分割
Pytorch---使用Pytorch实现U-Net进行语义分割
2022-07-02 04:04:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.10.1
Python==3.8
三、数据集处理代码如下所示
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.long)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
index = 5
images = images[index]
labels = labels[index]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
四、模型的构建代码如下所示
from torch import nn
import torch
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x, is_pool=True):
if is_pool:
x = self.pool(x)
x = self.conv_relu(x)
return x
class UpSample(nn.Module):
def __init__(self, channels):
super(UpSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=2 * channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.up_conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=channels, out_channels=channels // 2, kernel_size=3, stride=2,
output_padding=1, padding=1),
nn.ReLU()
)
def forward(self, x):
x = self.conv_relu(x)
x = self.up_conv(x)
return x
class UnetModel(nn.Module):
def __init__(self):
super(UnetModel, self).__init__()
self.down_1 = DownSample(in_channels=3, out_channels=64)
self.down_2 = DownSample(in_channels=64, out_channels=128)
self.down_3 = DownSample(in_channels=128, out_channels=256)
self.down_4 = DownSample(in_channels=256, out_channels=512)
self.down_5 = DownSample(in_channels=512, out_channels=1024)
self.up = nn.Sequential(
nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, output_padding=1,
padding=1),
nn.ReLU()
)
self.up_1 = UpSample(channels=512)
self.up_2 = UpSample(channels=256)
self.up_3 = UpSample(channels=128)
self.conv_2 = DownSample(in_channels=128, out_channels=64)
self.last = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
def forward(self, x):
down_1 = self.down_1(x, is_pool=False)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_5 = self.up(down_5)
down_5 = torch.cat([down_4, down_5], dim=1)
down_5 = self.up_1(down_5)
down_5 = torch.cat([down_3, down_5], dim=1)
down_5 = self.up_2(down_5)
down_5 = torch.cat([down_2, down_5], dim=1)
down_5 = self.up_3(down_5)
down_5 = torch.cat([down_1, down_5], dim=1)
down_5 = self.conv_2(down_5, is_pool=False)
down_5 = self.last(down_5)
return down_5
五、模型的训练代码如下所示
import torch
from data_loader import load_data
from model_loader import UnetModel
from torch import nn
from torch import optim
import tqdm
import os
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据
train_dl, test_dl = load_data()
# 加载模型
model = UnetModel()
model = model.to(device=device)
# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# 开始进行训练
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item()})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item()})
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import UnetModel
# 数据的加载
train_dl, test_dl = load_data()
# 模型的加载
model = UnetModel()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# 开始进行预测
images, labels = next(iter(test_dl))
index = 1
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示
边栏推荐
- Blue Bridge Cup single chip microcomputer sixth temperature recorder
- 【leetcode】74. Search 2D matrix
- Go语言介绍
- Yyds dry inventory compiler and compiler tools
- Force buckle 540 A single element in an ordered array
- [source code analysis] NVIDIA hugectr, GPU version parameter server - (1)
- The 11th Blue Bridge Cup single chip microcomputer provincial competition
- SQL:常用的 SQL 命令
- Qt插件之Qt Designer插件实现
- Sorted out an ECS summer money saving secret, this time @ old users come and take it away
猜你喜欢
WiFi 5GHz frequency
[JS event -- event flow]
藍湖的安裝及使用
集成底座方案演示说明
Suggestions on settlement solution of u standard contract position explosion
The 8th Blue Bridge Cup single chip microcomputer provincial competition
Influence of air resistance on the trajectory of table tennis
《动手学深度学习》(二)-- 多层感知机
【IBDFE】基于IBDFE的频域均衡matlab仿真
WPViewPDF Delphi 和 .NET 的 PDF 查看组件
随机推荐
[Li Kou brush questions] 15 Sum of three numbers (double pointer); 17. Letter combination of phone number (recursive backtracking)
Microsoft Research Institute's new book "Fundamentals of data science", 479 Pages pdf
Welcome the winter vacation multi school league game 2 partial solution (B, C, D, F, G, H)
FAQ | FAQ for building applications for large screen devices
微信小程序 - 实现获取手机验证码倒计时 60 秒(手机号+验证码登录功能)
The original author is out! Faker. JS has been controlled by the community..
5G时代全面到来,浅谈移动通信的前世今生
Dare to go out for an interview without learning some distributed technology?
Influence of air resistance on the trajectory of table tennis
云服务器的安全设置常识
文档声明与字符编码
Finally got byte offer. The 25-year-old inexperienced perception of software testing is written to you who are still confused
The second game of the 11th provincial single chip microcomputer competition of the Blue Bridge Cup
Jetpack之LiveData扩展MediatorLiveData
Which insurance company has a better product of anti-cancer insurance?
How to model noise data? Hong Kong Baptist University's latest review paper on "label noise representation learning" comprehensively expounds the data, objective function and optimization strategy of
Where can I buy cancer insurance? Which product is better?
How should the team choose the feature branch development mode or trunk development mode?
Delete the code you wrote? Sentenced to 10 months!
Lei Jun wrote a blog when he was a programmer. It's awesome