当前位置:网站首页>Pytoch --- use pytoch to realize linknet for semantic segmentation

Pytoch --- use pytoch to realize linknet for semantic segmentation

2022-07-04 23:26:00 Brother Shui is very water

One 、 The datasets in the code can be obtained through the following link

Baidu online disk extraction code :f1j7

Two 、 Code running environment

Pytorch-gpu==1.10.1
Python==3.8

3、 ... and 、 Data set processing codes are as follows

import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks


class MaskDataset(data.Dataset):
    def __init__(self, image_paths, mask_paths, transform):
        super(MaskDataset, self).__init__()
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label_path = self.mask_paths[index]

        pil_img = Image.open(image_path)
        pil_img = pil_img.convert('RGB')
        img_tensor = self.transform(pil_img)

        pil_label = Image.open(label_path)
        label_tensor = self.transform(pil_label)
        label_tensor[label_tensor > 0] = 1
        label_tensor = torch.squeeze(input=label_tensor).type(torch.LongTensor)

        return img_tensor, label_tensor

    def __len__(self):
        return len(self.mask_paths)


def load_data():
    # DATASET_PATH = r'/home/akita/hk'
    DATASET_PATH = r'/Users/leeakita/Desktop/hk'
    TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
    TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')

    train_file_names = os.listdir(TRAIN_DATASET_PATH)
    test_file_names = os.listdir(TEST_DATASET_PATH)

    train_image_names = [name for name in train_file_names if
                         'matte' in name and name.split('_')[0] + '.png' in train_file_names]
    train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
                         train_image_names]
    train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]

    test_image_names = [name for name in test_file_names if
                        'matte' in name and name.split('_')[0] + '.png' in test_file_names]
    test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
    test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    BATCH_SIZE = 8

    train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
    test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

    return train_dl, test_dl


if __name__ == '__main__':
    train_my, test_my = load_data()
    images, labels = next(iter(train_my))
    indexx = 5
    images = images[indexx]
    labels = labels[indexx]
    labels = torch.unsqueeze(input=labels, dim=0)

    result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
                                     masks=torch.as_tensor(data=labels, dtype=torch.bool),
                                     alpha=0.6, colors=['red'])
    plt.imshow(result.permute(1, 2, 0).numpy())
    plt.show()

Four 、 The construction code of the model is as follows

from torch import nn
import torch


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      padding=padding),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv_bn_relu(x)


class DecodeConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, out_padding=1):
        super(DecodeConvBlock, self).__init__()
        self.de_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                          stride=stride, padding=padding, output_padding=out_padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)

    def forward(self, x, is_act=True):
        x = self.de_conv(x)
        if is_act:
            x = torch.relu(self.bn(x))
        return x


class EncodeBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncodeBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
        self.conv2 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
        self.conv3 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
        self.conv4 = ConvBlock(in_channels=out_channels, out_channels=out_channels)

        self.short_cut = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)

    def forward(self, x):
        out1 = self.conv1(x)
        out1 = self.conv2(out1)

        short_cut = self.short_cut(x)

        out2 = self.conv3(out1 + short_cut)
        out2 = self.conv4(out2)

        return out1 + out2


class DecodeBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecodeBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=1, padding=0)
        self.de_conv = DecodeConvBlock(in_channels=in_channels // 4, out_channels=in_channels // 4)
        self.conv3 = ConvBlock(in_channels=in_channels // 4, out_channels=out_channels, kernel_size=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.de_conv(x)
        x = self.conv3(x)
        return x


class LinkNet(nn.Module):
    def __init__(self):
        super(LinkNet, self).__init__()
        self.init_conv = ConvBlock(in_channels=3, out_channels=64, stride=2, kernel_size=7, padding=3)
        self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))

        self.encode_1 = EncodeBlock(in_channels=64, out_channels=64)
        self.encode_2 = EncodeBlock(in_channels=64, out_channels=128)
        self.encode_3 = EncodeBlock(in_channels=128, out_channels=256)
        self.encode_4 = EncodeBlock(in_channels=256, out_channels=512)

        self.decode_4 = DecodeBlock(in_channels=512, out_channels=256)
        self.decode_3 = DecodeBlock(in_channels=256, out_channels=128)
        self.decode_2 = DecodeBlock(in_channels=128, out_channels=64)
        self.decode_1 = DecodeBlock(in_channels=64, out_channels=64)

        self.deconv_out1 = DecodeConvBlock(in_channels=64, out_channels=32)
        self.conv_out = ConvBlock(in_channels=32, out_channels=32)
        self.deconv_out2 = DecodeConvBlock(in_channels=32, out_channels=2, kernel_size=2, padding=0, out_padding=0)

    def forward(self, x):
        x = self.init_conv(x)
        x = self.init_maxpool(x)

        e1 = self.encode_1(x)
        e2 = self.encode_2(e1)
        e3 = self.encode_3(e2)
        e4 = self.encode_4(e3)

        d4 = self.decode_4(e4)
        d3 = self.decode_3(d4 + e3)
        d2 = self.decode_2(d3 + e2)
        d1 = self.decode_1(d2 + e1)

        f1 = self.deconv_out1(d1)
        f2 = self.conv_out(f1)
        f3 = self.deconv_out2(f2)
        return f3

5、 ... and 、 The training code of the model is as follows

import torch
from data_loader import load_data
from model_loader import LinkNet
from torch import nn
from torch import optim
import tqdm
import os

#  Configuration of environment variables 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#  Load data 
train_dl, test_dl = load_data()

#  Load model 
model = LinkNet()
model = model.to(device=device)

#  Training related configurations 
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)

#  Start training 
for epoch in range(100):
    train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
    train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
    train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
    train_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
    for train_images, train_labels in train_tqdm:
        train_images, train_labels = train_images.to(device), train_labels.to(device)
        pred = model(train_images)
        loss = loss_fn(pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            intersection = torch.logical_and(input=train_labels, other=torch.argmax(input=pred, dim=1))
            union = torch.logical_or(input=train_labels, other=torch.argmax(input=pred, dim=1))
            batch_iou = torch.true_divide(torch.sum(intersection), torch.sum(union))

            train_iou_sum = torch.cat([train_iou_sum, torch.unsqueeze(input=batch_iou, dim=-1)], dim=-1)
            train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
            train_tqdm.set_postfix({
    
                'train loss': train_loss_sum.mean().item(),
                'train iou': train_iou_sum.mean().item()
            })
    train_tqdm.close()

    lr_scheduler.step()

    with torch.no_grad():
        test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
        test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
        test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
        test_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
        for test_images, test_labels in test_tqdm:
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            test_pred = model(test_images)
            test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)

            test_intersection = torch.logical_and(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
            test_union = torch.logical_or(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
            test_batch_iou = torch.true_divide(torch.sum(test_intersection), torch.sum(test_union))

            test_iou_sum = torch.cat([test_iou_sum, torch.unsqueeze(input=test_batch_iou, dim=-1)], dim=-1)
            test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
            test_tqdm.set_postfix({
    
                'test loss': test_loss_sum.mean().item(),
                'test iou': test_iou_sum.mean().item()
            })
        test_tqdm.close()

#  Save model 
if not os.path.exists(os.path.join('model_data')):
    os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))

6、 ... and 、 The prediction code of the model is as follows

import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import LinkNet

#  Data loading 
train_dl, test_dl = load_data()

#  Model loading 
model = LinkNet()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)

#  Start Forecasting 
images, labels = next(iter(test_dl))
index = 2
with torch.no_grad():
    pred = model(images)
    pred = torch.argmax(input=pred, dim=1)
    result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
                                     masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
                                     alpha=0.8, colors=['red'])
    plt.figure(figsize=(8, 8), dpi=500)
    plt.axis('off')
    plt.imshow(result.permute(1, 2, 0))
    plt.savefig('result.png')
    plt.show()

7、 ... and 、 The running result of the code is as follows

 Insert picture description here

原网站

版权声明
本文为[Brother Shui is very water]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/185/202207041941572972.html