当前位置:网站首页>Pytorch---使用Pytorch实现U-Net进行语义分割
Pytorch---使用Pytorch实现U-Net进行语义分割
2022-07-02 04:04:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.10.1
Python==3.8
三、数据集处理代码如下所示
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.long)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
index = 5
images = images[index]
labels = labels[index]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
四、模型的构建代码如下所示
from torch import nn
import torch
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x, is_pool=True):
if is_pool:
x = self.pool(x)
x = self.conv_relu(x)
return x
class UpSample(nn.Module):
def __init__(self, channels):
super(UpSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=2 * channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.up_conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=channels, out_channels=channels // 2, kernel_size=3, stride=2,
output_padding=1, padding=1),
nn.ReLU()
)
def forward(self, x):
x = self.conv_relu(x)
x = self.up_conv(x)
return x
class UnetModel(nn.Module):
def __init__(self):
super(UnetModel, self).__init__()
self.down_1 = DownSample(in_channels=3, out_channels=64)
self.down_2 = DownSample(in_channels=64, out_channels=128)
self.down_3 = DownSample(in_channels=128, out_channels=256)
self.down_4 = DownSample(in_channels=256, out_channels=512)
self.down_5 = DownSample(in_channels=512, out_channels=1024)
self.up = nn.Sequential(
nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, output_padding=1,
padding=1),
nn.ReLU()
)
self.up_1 = UpSample(channels=512)
self.up_2 = UpSample(channels=256)
self.up_3 = UpSample(channels=128)
self.conv_2 = DownSample(in_channels=128, out_channels=64)
self.last = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
def forward(self, x):
down_1 = self.down_1(x, is_pool=False)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_5 = self.up(down_5)
down_5 = torch.cat([down_4, down_5], dim=1)
down_5 = self.up_1(down_5)
down_5 = torch.cat([down_3, down_5], dim=1)
down_5 = self.up_2(down_5)
down_5 = torch.cat([down_2, down_5], dim=1)
down_5 = self.up_3(down_5)
down_5 = torch.cat([down_1, down_5], dim=1)
down_5 = self.conv_2(down_5, is_pool=False)
down_5 = self.last(down_5)
return down_5
五、模型的训练代码如下所示
import torch
from data_loader import load_data
from model_loader import UnetModel
from torch import nn
from torch import optim
import tqdm
import os
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据
train_dl, test_dl = load_data()
# 加载模型
model = UnetModel()
model = model.to(device=device)
# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# 开始进行训练
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item()})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item()})
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import UnetModel
# 数据的加载
train_dl, test_dl = load_data()
# 模型的加载
model = UnetModel()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# 开始进行预测
images, labels = next(iter(test_dl))
index = 1
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示

边栏推荐
- regular expression
- Raspberry pie GPIO pin controls traffic light and buzzer
- Nacos 配置中心整体设计原理分析(持久化,集群,信息同步)
- Demonstration description of integrated base scheme
- JVM知识点
- Dare to go out for an interview without learning some distributed technology?
- The 5th Blue Bridge Cup single chip microcomputer provincial competition
- 5g era is coming in an all-round way, talking about the past and present life of mobile communication
- The first game of the 11th provincial single chip microcomputer competition of the Blue Bridge Cup
- 云服务器的安全设置常识
猜你喜欢

Typescript practice for SAP ui5

SQL Yiwen get window function

Yyds dry inventory compiler and compiler tools

First acquaintance with string+ simple usage (II)

JVM知识点

Force buckle 540 A single element in an ordered array

Basic operations of MySQL database (based on tables)

Learn more about materialapp and common attribute parsing in fluent

Microsoft Research Institute's new book "Fundamentals of data science", 479 Pages pdf

How much is the tuition fee of SCM training class? How long is the study time?
随机推荐
Suggestions on settlement solution of u standard contract position explosion
[source code analysis] NVIDIA hugectr, GPU version parameter server - (1)
Fingertips life Chapter 4 modules and packages
微信小程序 - 实现获取手机验证码倒计时 60 秒(手机号+验证码登录功能)
Blue Bridge Cup SCM digital tube skills
Www 2022 | rethinking the knowledge map completion of graph convolution network
JVM knowledge points
QT designer plug-in implementation of QT plug-in
Which product of anti-cancer insurance is better?
Nacos 配置中心整体设计原理分析(持久化,集群,信息同步)
Document declaration and character encoding
BGP experiment the next day
How to solve the problem that objects cannot be deleted in Editor Mode
Li Kou interview question 02.08 Loop detection
Set vscode. When double clicking, the selected string includes the $symbol - convenient for PHP operation
cookie、session、tooken
一文彻底理解评分卡开发中——Y的确定(Vintage分析、滚动率分析等)
[untitled]
【直播回顾】战码先锋首期8节直播完美落幕,下期敬请期待!
Bitmap principle code record