当前位置:网站首页>Pytoch --- use pytoch to realize u-net semantic segmentation
Pytoch --- use pytoch to realize u-net semantic segmentation
2022-07-02 04:11:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :f1j7
Two 、 Code running environment
Pytorch-gpu==1.10.1
Python==3.8
3、 ... and 、 Data set processing codes are as follows
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.long)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
index = 5
images = images[index]
labels = labels[index]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
Four 、 The construction code of the model is as follows
from torch import nn
import torch
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x, is_pool=True):
if is_pool:
x = self.pool(x)
x = self.conv_relu(x)
return x
class UpSample(nn.Module):
def __init__(self, channels):
super(UpSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=2 * channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.up_conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=channels, out_channels=channels // 2, kernel_size=3, stride=2,
output_padding=1, padding=1),
nn.ReLU()
)
def forward(self, x):
x = self.conv_relu(x)
x = self.up_conv(x)
return x
class UnetModel(nn.Module):
def __init__(self):
super(UnetModel, self).__init__()
self.down_1 = DownSample(in_channels=3, out_channels=64)
self.down_2 = DownSample(in_channels=64, out_channels=128)
self.down_3 = DownSample(in_channels=128, out_channels=256)
self.down_4 = DownSample(in_channels=256, out_channels=512)
self.down_5 = DownSample(in_channels=512, out_channels=1024)
self.up = nn.Sequential(
nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, output_padding=1,
padding=1),
nn.ReLU()
)
self.up_1 = UpSample(channels=512)
self.up_2 = UpSample(channels=256)
self.up_3 = UpSample(channels=128)
self.conv_2 = DownSample(in_channels=128, out_channels=64)
self.last = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
def forward(self, x):
down_1 = self.down_1(x, is_pool=False)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_5 = self.up(down_5)
down_5 = torch.cat([down_4, down_5], dim=1)
down_5 = self.up_1(down_5)
down_5 = torch.cat([down_3, down_5], dim=1)
down_5 = self.up_2(down_5)
down_5 = torch.cat([down_2, down_5], dim=1)
down_5 = self.up_3(down_5)
down_5 = torch.cat([down_1, down_5], dim=1)
down_5 = self.conv_2(down_5, is_pool=False)
down_5 = self.last(down_5)
return down_5
5、 ... and 、 The training code of the model is as follows
import torch
from data_loader import load_data
from model_loader import UnetModel
from torch import nn
from torch import optim
import tqdm
import os
# Configuration of environment variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load data
train_dl, test_dl = load_data()
# Load model
model = UnetModel()
model = model.to(device=device)
# Training related configurations
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# Start training
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item()})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item()})
test_tqdm.close()
# Save model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
6、 ... and 、 The prediction code of the model is as follows
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import UnetModel
# Data loading
train_dl, test_dl = load_data()
# Model loading
model = UnetModel()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# Start Forecasting
images, labels = next(iter(test_dl))
index = 1
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows
边栏推荐
- Sword finger offer II 006 Sort the sum of two numbers in the array
- Play with concurrency: draw a thread state transition diagram
- Bitmap principle code record
- Opencv learning example code 3.2.4 LUT
- 【IBDFE】基于IBDFE的频域均衡matlab仿真
- [wireless image transmission] FPGA based simple wireless image transmission system Verilog development, matlab assisted verification
- 【小技巧】使用matlab GUI以对话框模式读取文件
- Homework in Chapter 3 of slam course of dark blue vision -- derivative application of T6 common functions
- [untitled]
- 整理了一份ECS夏日省钱秘籍,这次@老用户快来领走
猜你喜欢
初识P4语言
Dare to go out for an interview without learning some distributed technology?
Set vscode. When double clicking, the selected string includes the $symbol - convenient for PHP operation
Actual combat | use composite material 3 in application
《西线无战事》我们才刚开始热爱生活,却不得不对一切开炮
Installation and use of blue lake
A summary of common interview questions in 2022, including 25 technology stacks, has helped me successfully get an offer from Tencent
Www 2022 | rethinking the knowledge map completion of graph convolution network
66.qt quick-qml自定义日历组件(支持竖屏和横屏)
FAQ | FAQ for building applications for large screen devices
随机推荐
go 包的使用
Which is better, industrial intelligent gateway or edge computing gateway? How to choose the right one?
office_ Delete the last page of word (the seemingly blank page)
[untitled]
Sorted out an ECS summer money saving secret, this time @ old users come and take it away
Feature Engineering: summary of common feature transformation methods
[ibdfe] matlab simulation of frequency domain equalization based on ibdfe
Fingertips life Chapter 4 modules and packages
cookie、session、tooken
Demonstration description of integrated base scheme
Jetpack之LiveData扩展MediatorLiveData
Spring recruitment of Internet enterprises: Kwai meituan has expanded the most, and the annual salary of technical posts is up to nearly 400000
How to solve the problem that objects cannot be deleted in Editor Mode
Typescript practice for SAP ui5
Wechat applet pull-down loading more waterfall flow loading
文档声明与字符编码
How to solve the code error when storing array data into the database
JVM知识点
[source code analysis] NVIDIA hugectr, GPU version parameter server - (1)
Go language introduction