2022-07-04 19:42:00 【水哥很水】
import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
class MaskDataset(data.Dataset):
def __init__(self, image_paths, mask_paths, transform):
super(MaskDataset, self).__init__()
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
label_path = self.mask_paths[index]
pil_img = Image.open(image_path)
pil_img = pil_img.convert('RGB')
img_tensor = self.transform(pil_img)
pil_label = Image.open(label_path)
label_tensor = self.transform(pil_label)
label_tensor[label_tensor > 0] = 1
label_tensor = torch.squeeze(input=label_tensor).type(torch.LongTensor)
return img_tensor, label_tensor
def __len__(self):
return len(self.mask_paths)
def load_data():
# DATASET_PATH = r'/home/akita/hk'
DATASET_PATH = r'/Users/leeakita/Desktop/hk'
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')
train_file_names = os.listdir(TRAIN_DATASET_PATH)
test_file_names = os.listdir(TEST_DATASET_PATH)
train_image_names = [name for name in train_file_names if
'matte' in name and name.split('_')[0] + '.png' in train_file_names]
train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]
test_image_names = [name for name in test_file_names if
'matte' in name and name.split('_')[0] + '.png' in test_file_names]
test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]
transform = transforms.Compose([
transforms.Resize((256, 256)),
train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
train_my, test_my = load_data()
images, labels = next(iter(train_my))
indexx = 5
images = images[indexx]
labels = labels[indexx]
labels = torch.unsqueeze(input=labels, dim=0)
result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=labels, dtype=torch.bool),
alpha=0.6, colors=['red'])
plt.imshow(result.permute(1, 2, 0).numpy())
from torch import nn
import torch
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(ConvBlock, self).__init__()
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
def forward(self, x):
return self.conv_bn_relu(x)
class DecodeConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, out_padding=1):
super(DecodeConvBlock, self).__init__()
self.de_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=out_padding)
self.bn = nn.BatchNorm2d(num_features=out_channels)
def forward(self, x, is_act=True):
x = self.de_conv(x)
if is_act:
x = torch.relu(self.bn(x))
return x
class EncodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
self.conv2 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv3 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.conv4 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
self.short_cut = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
def forward(self, x):
out1 = self.conv1(x)
out1 = self.conv2(out1)
short_cut = self.short_cut(x)
out2 = self.conv3(out1 + short_cut)
out2 = self.conv4(out2)
return out1 + out2
class DecodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DecodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=1, padding=0)
self.de_conv = DecodeConvBlock(in_channels=in_channels // 4, out_channels=in_channels // 4)
self.conv3 = ConvBlock(in_channels=in_channels // 4, out_channels=out_channels, kernel_size=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.de_conv(x)
x = self.conv3(x)
return x
class LinkNet(nn.Module):
def __init__(self):
super(LinkNet, self).__init__()
self.init_conv = ConvBlock(in_channels=3, out_channels=64, stride=2, kernel_size=7, padding=3)
self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.encode_1 = EncodeBlock(in_channels=64, out_channels=64)
self.encode_2 = EncodeBlock(in_channels=64, out_channels=128)
self.encode_3 = EncodeBlock(in_channels=128, out_channels=256)
self.encode_4 = EncodeBlock(in_channels=256, out_channels=512)
self.decode_4 = DecodeBlock(in_channels=512, out_channels=256)
self.decode_3 = DecodeBlock(in_channels=256, out_channels=128)
self.decode_2 = DecodeBlock(in_channels=128, out_channels=64)
self.decode_1 = DecodeBlock(in_channels=64, out_channels=64)
self.deconv_out1 = DecodeConvBlock(in_channels=64, out_channels=32)
self.conv_out = ConvBlock(in_channels=32, out_channels=32)
self.deconv_out2 = DecodeConvBlock(in_channels=32, out_channels=2, kernel_size=2, padding=0, out_padding=0)
def forward(self, x):
x = self.init_conv(x)
x = self.init_maxpool(x)
e1 = self.encode_1(x)
e2 = self.encode_2(e1)
e3 = self.encode_3(e2)
e4 = self.encode_4(e3)
d4 = self.decode_4(e4)
d3 = self.decode_3(d4 + e3)
d2 = self.decode_2(d3 + e2)
d1 = self.decode_1(d2 + e1)
f1 = self.deconv_out1(d1)
f2 = self.conv_out(f1)
f3 = self.deconv_out2(f2)
return f3
import torch
from data_loader import load_data
from model_loader import LinkNet
from torch import nn
from torch import optim
import tqdm
import os
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据
train_dl, test_dl = load_data()
# 加载模型
model = LinkNet()
model = model.to(device=device)
# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)
# 开始进行训练
for epoch in range(100):
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
train_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for train_images, train_labels in train_tqdm:
train_images, train_labels = train_images.to(device), train_labels.to(device)
pred = model(train_images)
loss = loss_fn(pred, train_labels)
with torch.no_grad():
intersection = torch.logical_and(input=train_labels, other=torch.argmax(input=pred, dim=1))
union = torch.logical_or(input=train_labels, other=torch.argmax(input=pred, dim=1))
batch_iou = torch.true_divide(torch.sum(intersection), torch.sum(union))
train_iou_sum = torch.cat([train_iou_sum, torch.unsqueeze(input=batch_iou, dim=-1)], dim=-1)
train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
'train loss': train_loss_sum.mean().item(),
'train iou': train_iou_sum.mean().item()
with torch.no_grad():
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
test_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
for test_images, test_labels in test_tqdm:
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_pred = model(test_images)
test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
test_intersection = torch.logical_and(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_union = torch.logical_or(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
test_batch_iou = torch.true_divide(torch.sum(test_intersection), torch.sum(test_union))
test_iou_sum = torch.cat([test_iou_sum, torch.unsqueeze(input=test_batch_iou, dim=-1)], dim=-1)
test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
'test loss': test_loss_sum.mean().item(),
'test iou': test_iou_sum.mean().item()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import LinkNet
# 数据的加载
train_dl, test_dl = load_data()
# 模型的加载
model = LinkNet()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
# 开始进行预测
images, labels = next(iter(test_dl))
index = 2
with torch.no_grad():
pred = model(images)
pred = torch.argmax(input=pred, dim=1)
result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
alpha=0.8, colors=['red'])
plt.figure(figsize=(8, 8), dpi=500)
plt.imshow(result.permute(1, 2, 0))
- 易周金融 | Q1保险行业活跃人数8688.67万人 19家支付机构牌照被注销
- Fleet tutorial 08 introduction to AppBar toolbar Basics (tutorial includes source code)
- QT writing the Internet of things management platform 38- multiple database support
- 电脑页面不能全屏怎么办?Win11页面不能全屏的解决方法
- 阿里测试师用UI自动化测试实现元素定位
- jekins初始化密码没有或找不到
- 强化学习-学习笔记2 | 价值学习
- MySQL statement execution details
- Taishan Office Technology Lecture: about the order of background (shading and highlighting)
- What if the brightness of win11 is locked? Solution to win11 brightness locking
Flet tutorial 04 basic introduction to filledtonalbutton (tutorial includes source code)
How to adapt your games to different sizes of mobile screen
Four traversal methods of binary tree, as well as the creation of binary tree from middle order to post order, pre order to middle order, pre order to post order, and sequence [specially created for t
Aiming at the "amnesia" of deep learning, scientists proposed that based on similarity weighted interleaved learning, they can board PNAS
Leetcode+ 81 - 85 monotone stack topic
NetCore3.1 Json web token 中间件
ICML 2022 | Meta提出鲁棒的多目标贝叶斯优化方法,有效应对输入噪声
NLP, vision, chip What is the development direction of AI? Release of the outlook report of Qingyuan Association [download attached]
Advantages of semantic tags and block level inline elements
acwing 3302. 表达式求值
NetCore3.1 Json web token 中间件
Six stones programming: about code, there are six triumphs
Après l'insertion de l'image dans le mot, il y a une ligne vide au - dessus de l'image, et la disposition est désordonnée après la suppression
ICML 2022 | meta proposes a robust multi-objective Bayesian optimization method to effectively deal with input noise
How does the computer save web pages to the desktop for use
In operation (i.e. included in) usage of SSRs filter
Idea restore default shortcut key
Play the music of youth
Flet tutorial 05 outlinedbutton basic introduction (tutorial includes source code)
Idea configuration standard notes
What if win11u disk refuses access? An effective solution to win11u disk access denial