当前位置:网站首页>Pytoch --- use pytoch to realize u-net semantic segmentation
Pytoch --- use pytoch to realize u-net semantic segmentation
2022-07-02 04:11:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :f1j7
Two 、 Code running environment
Pytorch-gpu==1.10.1
Python==3.8
3、 ... and 、 Data set processing codes are as follows
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.long)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_image_names]
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
BATCH_SIZE = 8
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
index = 5
images = images[index]
labels = labels[index]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
Four 、 The construction code of the model is as follows
from torch import nn
import torch
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x, is_pool=True):
if is_pool:
x = self.pool(x)
x = self.conv_relu(x)
return x
class UpSample(nn.Module):
def __init__(self, channels):
super(UpSample, self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels=2 * channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.up_conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=channels, out_channels=channels // 2, kernel_size=3, stride=2,
output_padding=1, padding=1),
nn.ReLU()
)
def forward(self, x):
x = self.conv_relu(x)
x = self.up_conv(x)
return x
class UnetModel(nn.Module):
def __init__(self):
super(UnetModel, self).__init__()
self.down_1 = DownSample(in_channels=3, out_channels=64)
self.down_2 = DownSample(in_channels=64, out_channels=128)
self.down_3 = DownSample(in_channels=128, out_channels=256)
self.down_4 = DownSample(in_channels=256, out_channels=512)
self.down_5 = DownSample(in_channels=512, out_channels=1024)
self.up = nn.Sequential(
nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, output_padding=1,
padding=1),
nn.ReLU()
)
self.up_1 = UpSample(channels=512)
self.up_2 = UpSample(channels=256)
self.up_3 = UpSample(channels=128)
self.conv_2 = DownSample(in_channels=128, out_channels=64)
self.last = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
def forward(self, x):
down_1 = self.down_1(x, is_pool=False)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_5 = self.up(down_5)
down_5 = torch.cat([down_4, down_5], dim=1)
down_5 = self.up_1(down_5)
down_5 = torch.cat([down_3, down_5], dim=1)
down_5 = self.up_2(down_5)
down_5 = torch.cat([down_2, down_5], dim=1)
down_5 = self.up_3(down_5)
down_5 = torch.cat([down_1, down_5], dim=1)
down_5 = self.conv_2(down_5, is_pool=False)
down_5 = self.last(down_5)
return down_5
5、 ... and 、 The training code of the model is as follows
import torch
from data_loader import load_data
from model_loader import UnetModel
from torch import nn
from torch import optim
import tqdm
import os
# Configuration of environment variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load data
train_dl, test_dl = load_data()
# Load model
model = UnetModel()
model = model.to(device=device)
# Training related configurations
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# Start training
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
train_tqdm.set_postfix({
'train loss': train_loss_sum.mean().item()})
train_tqdm.close()
lr_scheduler.step()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
test_tqdm.set_postfix({
'test loss': test_loss_sum.mean().item()})
test_tqdm.close()
# Save model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
6、 ... and 、 The prediction code of the model is as follows
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import UnetModel
# Data loading
train_dl, test_dl = load_data()
# Model loading
model = UnetModel()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
# Start Forecasting
images, labels = next(iter(test_dl))
index = 1
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows

边栏推荐
- Set vscode. When double clicking, the selected string includes the $symbol - convenient for PHP operation
- Sorted out an ECS summer money saving secret, this time @ old users come and take it away
- Uni app - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
- The first practical project of software tester: web side (video tutorial + document + use case library)
- What is 5g industrial wireless gateway? What functions can 5g industrial wireless gateway achieve?
- [JS event -- event flow]
- Wechat applet JWT login issue token
- Vite: configure IP access
- [untitled]
- 10 minutes to understand CMS garbage collector in JVM
猜你喜欢

BGP experiment the next day

Recently, the weather has been extremely hot, so collect the weather data of Beijing, Shanghai, Guangzhou and Shenzhen last year, and make a visual map

Visual slam Lecture 3 -- Lie groups and Lie Algebras

MySQL advanced SQL statement 2

Dare to go out for an interview without learning some distributed technology?
![[Li Kou brush questions] 15 Sum of three numbers (double pointer); 17. Letter combination of phone number (recursive backtracking)](/img/5e/81e613370c808c63665c14298f9a39.png)
[Li Kou brush questions] 15 Sum of three numbers (double pointer); 17. Letter combination of phone number (recursive backtracking)

Target free or target specific: a simple and effective zero sample position detection comparative learning method

How much can a job hopping increase? Today, I saw the ceiling of job hopping.

Déchirure à la main - tri

Nacos 配置中心整体设计原理分析(持久化,集群,信息同步)
随机推荐
Monkey test
【leetcode】74. Search 2D matrix
Common sense of cloud server security settings
XSS prevention
《西线无战事》我们才刚开始热爱生活,却不得不对一切开炮
60后关机程序
[live broadcast review] the first 8 live broadcasts of battle code Pioneer have come to a perfect end. Please look forward to the next one!
【c语言】动态规划---入门到起立
向数据库中存入数组数据,代码出错怎么解决
【小技巧】使用matlab GUI以对话框模式读取文件
SQL Yiwen get window function
Learn more about materialapp and common attribute parsing in fluent
regular expression
10 minutes to understand CMS garbage collector in JVM
Finally got byte offer. The 25-year-old inexperienced perception of software testing is written to you who are still confused
Www2022 | know your way back: self training method of graph neural network under distribution and migration
【leetcode】34. Find the first and last positions of elements in a sorted array
Wechat applet - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
2022-07-01:某公司年会上,大家要玩一食发奖金游戏,一共有n个员工, 每个员工都有建设积分和捣乱积分, 他们需要排成一队,在队伍最前面的一定是老板,老板也有建设积分和捣乱积分, 排好队后,所有
Target free or target specific: a simple and effective zero sample position detection comparative learning method