当前位置:网站首页>Pytoch --- use pytoch to realize u-net semantic segmentation
Pytoch --- use pytoch to realize u-net semantic segmentation
2022-07-02 04:11:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :f1j7
Two 、 Code running environment
Pytorch-gpu==1.10.1
Python==3.8
3、 ... and 、 Data set processing codes are as follows
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.long)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
index = 5
images = images[index]
labels = labels[index]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
Four 、 The construction code of the model is as follows
from torch import nn
import torch
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x, is_pool=True):
if is_pool:
x = self.pool(x)
x = self.conv_relu(x)
return x
class UpSample(nn.Module):
def __init__(self, channels):
super(UpSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=2 * channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.up_conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=channels, out_channels=channels // 2, kernel_size=3, stride=2,
output_padding=1, padding=1),
nn.ReLU()
)
def forward(self, x):
x = self.conv_relu(x)
x = self.up_conv(x)
return x
class UnetModel(nn.Module):
def __init__(self):
super(UnetModel, self).__init__()
self.down_1 = DownSample(in_channels=3, out_channels=64)
self.down_2 = DownSample(in_channels=64, out_channels=128)
self.down_3 = DownSample(in_channels=128, out_channels=256)
self.down_4 = DownSample(in_channels=256, out_channels=512)
self.down_5 = DownSample(in_channels=512, out_channels=1024)
self.up = nn.Sequential(
nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, output_padding=1,
padding=1),
nn.ReLU()
)
self.up_1 = UpSample(channels=512)
self.up_2 = UpSample(channels=256)
self.up_3 = UpSample(channels=128)
self.conv_2 = DownSample(in_channels=128, out_channels=64)
self.last = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
def forward(self, x):
down_1 = self.down_1(x, is_pool=False)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_5 = self.up(down_5)
down_5 = torch.cat([down_4, down_5], dim=1)
down_5 = self.up_1(down_5)
down_5 = torch.cat([down_3, down_5], dim=1)
down_5 = self.up_2(down_5)
down_5 = torch.cat([down_2, down_5], dim=1)
down_5 = self.up_3(down_5)
down_5 = torch.cat([down_1, down_5], dim=1)
down_5 = self.conv_2(down_5, is_pool=False)
down_5 = self.last(down_5)
return down_5
5、 ... and 、 The training code of the model is as follows
import torch
from data_loader import load_data
from model_loader import UnetModel
from torch import nn
from torch import optim
import tqdm
import os
# Configuration of environment variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load data
train_dl, test_dl = load_data()
# Load model
model = UnetModel()
model = model.to(device=device)
# Training related configurations
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# Start training
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item()})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item()})
test_tqdm.close()
# Save model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
6、 ... and 、 The prediction code of the model is as follows
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import UnetModel
# Data loading
train_dl, test_dl = load_data()
# Model loading
model = UnetModel()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# Start Forecasting
images, labels = next(iter(test_dl))
index = 1
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows

边栏推荐
- Pandora IOT development board learning (RT thread) - Experiment 1 LED flashing experiment (learning notes)
- MySQL error: expression 1 of select list is not in group by claim and contains nonaggre
- 66.qt quick-qml自定义日历组件(支持竖屏和横屏)
- WiFi 5GHz frequency
- Go language introduction
- Homework in Chapter 3 of slam course of dark blue vision -- derivative application of T6 common functions
- Opencv learning example code 3.2.4 LUT
- BiShe cinema ticket purchasing system based on SSM
- How to model noise data? Hong Kong Baptist University's latest review paper on "label noise representation learning" comprehensively expounds the data, objective function and optimization strategy of
- Delete the code you wrote? Sentenced to 10 months!
猜你喜欢

BGP experiment the next day

Www 2022 | rethinking the knowledge map completion of graph convolution network

The confusion I encountered when learning stm32

【leetcode】34. Find the first and last positions of elements in a sorted array

手撕——排序

Lost a few hairs, and finally learned - graph traversal -dfs and BFS

Pandora IOT development board learning (RT thread) - Experiment 1 LED flashing experiment (learning notes)

Lei Jun wrote a blog when he was a programmer. It's awesome

MySQL advanced SQL statement 2

Play with concurrency: what's the use of interruptedexception?
随机推荐
Actual combat | use composite material 3 in application
XSS prevention
Okcc why is cloud call center better than traditional call center?
Spring recruitment of Internet enterprises: Kwai meituan has expanded the most, and the annual salary of technical posts is up to nearly 400000
【leetcode】81. Search rotation sort array II
FAQ | FAQ for building applications for large screen devices
Yyds dry goods inventory kubernetes introduction foundation pod concept and related operations
go 分支与循环
office_ Delete the last page of word (the seemingly blank page)
A thorough understanding of the development of scorecards - the determination of Y (Vintage analysis, rolling rate analysis, etc.)
Bitmap principle code record
Basic operations of MySQL database (based on tables)
Common sense of cloud server security settings
BGP experiment the next day
Use of go package
【c语言】动态规划---入门到起立
66.qt quick-qml自定义日历组件(支持竖屏和横屏)
How much can a job hopping increase? Today, I saw the ceiling of job hopping.
Wechat applet pull-down loading more waterfall flow loading
How to solve the problem that objects cannot be deleted in Editor Mode