当前位置:网站首页>Swin-Unet最强分割网络

Swin-Unet最强分割网络

2022-06-10 08:05:00 算法之名

Swin-Unet是基于Swin Transformer为基础(可参考Swin Transformer介绍 ),结合了U-Net网络的特点(可参考Tensorflow深度学习算法整理(三) 中的U-Net)组合而成的新的分割网络

它与Swin Transformer不同的地方在于,在编码器(Encoder)这边虽然跟Swin Transformer一样的4个Stage,但Swin Transformer Block的数量为[2,2,2,1],而不是Swin Transformer的[2,2,6,2]。而在解码器(Decoder)这边,由于是升采样,使用的不再是Patch Embedding和Patch Merging,而使用的是Patch Expanding,它是Patch Merging的逆过程。

我们来看一下Patch Expanding的代码实现

from einops import rearrange
class PatchExpand(nn.Module):    """    块状扩充,尺寸翻倍,通道数减半    """    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):        """        Args:            input_resolution: 解码过程的feature map的宽高            dim: frature map通道数            dim_scale: 通道数扩充的倍数            norm_layer: 通道方向归一化        """        super().__init__()        self.input_resolution = input_resolution        self.dim = dim        # 通过全连接层来扩大通道数        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()        self.norm = norm_layer(dim // dim_scale)    def forward(self, x):        """        x: B, H*W, C        """        H, W = self.input_resolution        # 先把通道数翻倍        x = self.expand(x)        B, L, C = x.shape        assert L == H * W, "input feature has wrong size"        x = x.view(B, H, W, C)        # 将各个通道分开,再将所有通道拼成一个feature map        # 增大了feature map的尺寸        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)        # 通道翻倍后再除以4,实际相当于通道数减半        x = x.view(B, -1, C // 4)        x = self.norm(x)        return x

在编码器这边基本上跟Swin Transformer是一样的,我们重点来看解码器这边。它是使用BasicLayer_up类来对SwinTransformerBlock和Patch Expanding来进行搭配的。

class BasicLayer_up(nn.Module):    """ A basic Swin Transformer layer for one stage.    一个BasicLayer_up包含偶数个SwinTransformerBlock和一个upsamele层(即Patch Expanding层)    """    def __init__(self, dim, input_resolution, depth, num_heads, window_size,                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):        """        Args:            dim: feature map通道数            input_resolution: feature map的宽高            depth: 各个Stage中,Swin Transformer Block的数量            num_heads: 多头注意力各个Stage中的头数            window_size: 窗口自注意力机制的窗口中的patch数            mlp_ratio: 层感知机模块中第一个全连接层输出的通道倍数            qkv_bias: 如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置            qk_scale: 窗口自注意力公式常数            drop: dropout rate,默认为0            attn_drop: 用于自注意力机制中的dropout rate,默认为0            drop_path: 在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括                       LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0            norm_layer: 通道方向归一化            upsample: 使用Patch Expanding来升采样            use_checkpoint: 是否使用Pytorch中间数据保存机制        """        super().__init__()        self.dim = dim        self.input_resolution = input_resolution        self.depth = depth        self.use_checkpoint = use_checkpoint        # build SwinTransformerBlock        self.blocks = nn.ModuleList([            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,                                 num_heads=num_heads, window_size=window_size,                                 # 用于区分是使用W-MSA还是SW-MSA,0为W-MSA,1为SW-MSA                                 shift_size=0 if (i % 2 == 0) else window_size // 2,                                 mlp_ratio=mlp_ratio,                                 qkv_bias=qkv_bias, qk_scale=qk_scale,                                 drop=drop, attn_drop=attn_drop,                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,                                 norm_layer=norm_layer)            for i in range(depth)])        # patch merging layer        # 当stage=4的时候为None        if upsample is not None:            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)        else:            self.upsample = None    def forward(self, x):        # 通过每一个SwinTransformerBlock        for blk in self.blocks:            if self.use_checkpoint:                x = checkpoint.checkpoint(blk, x)            else:                x = blk(x)        # 进行块状扩充(PatchExpanding)上采样        if self.upsample is not None:            x = self.upsample(x)        return x

SwinTransformerBlock跟SwinTransformer中的代码也是一样的,这里就不重复了。

然后还有一个从编码器到解码器之间的跳连。这里需要看一下Swin-Unet的主类代码

class SwinTransformerSys(nn.Module):    """ Swin-UNet网络模型    """    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,                 embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,                 use_checkpoint=False, final_upsample="expand_first", **kwargs):        """        Args:            img_size: 原始图像尺寸            patch_size: 一个patch中的像素点数            in_chans: 进入网络的图片通道数            num_classes: 分类数量            embed_dim: feature map通道数            depths: 编码器各个Stage中,Swin Transformer Block的数量            depths_decoder: 解码器各个Stage中,Swin Transformer Block的数量            num_heads: 多头注意力各个Stage中的头数            window_size: 窗口自注意力机制的窗口中的patch数            mlp_ratio: 多层感知机模块中第一个全连接层输出的通道倍数            qkv_bias: 如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置            qk_scale: 自注意力公式中的常量            drop_rate: dropout rate,默认为0            attn_drop_rate: 用于自注意力机制中的dropout rate,默认为0            drop_path_rate: 在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括                            LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0.1            norm_layer: 通道方向归一化            ape: 是否进行绝对位置嵌入,默认False            patch_norm: 如果是True的话,在patch embedding之后加上归一化            use_checkpoint: 是否使用Pytorch中间数据保存机制            final_upsample: 解码器stage4后的Patch Expanding            **kwargs:        """        super().__init__()        print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,              depths_decoder, drop_path_rate, num_classes))        self.num_classes = num_classes        # stage的数量        self.num_layers = len(depths)        self.embed_dim = embed_dim        self.ape = ape        self.patch_norm = patch_norm        # 编码器stage4输出特征的通道数(Swin-Tiny:768)        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))        # 解码器stage4输出特征的通道数(192)        self.num_features_up = int(embed_dim * 2)        self.mlp_ratio = mlp_ratio        self.final_upsample = final_upsample        # 把图像分割成不重叠的patch        self.patch_embed = PatchEmbed(            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,            norm_layer=norm_layer if self.patch_norm else None)        num_patches = self.patch_embed.num_patches        # 获取feature map的高宽        patches_resolution = self.patch_embed.patches_resolution        self.patches_resolution = patches_resolution        # 绝对位置嵌入        if self.ape:            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))            trunc_normal_(self.absolute_pos_embed, std=.02)        self.pos_drop = nn.Dropout(p=drop_rate)        # 不同的stage,舍弃整个直连分支的概率不同,从小到大,最小为0,最大为0.1        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule        # 创建编码器layers        self.layers = nn.ModuleList()        for i_layer in range(self.num_layers):  # layer相当于stage            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),                               input_resolution=(patches_resolution[0] // (2 ** i_layer),                                                 patches_resolution[1] // (2 ** i_layer)),                               depth=depths[i_layer],                               num_heads=num_heads[i_layer],                               window_size=window_size,                               mlp_ratio=self.mlp_ratio,                               qkv_bias=qkv_bias, qk_scale=qk_scale,                               drop=drop_rate, attn_drop=attn_drop_rate,                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],                               norm_layer=norm_layer,                               # 只有前3个stage有patchmerging,最后一个没有                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,                               use_checkpoint=use_checkpoint)            self.layers.append(layer)                # 创建解码器layers        self.layers_up = nn.ModuleList()        self.concat_back_dim = nn.ModuleList()        for i_layer in range(self.num_layers):  # layer相当于stage            # 每一个stage结束后,通道数减半的全连接层            concat_linear = nn.Linear(2 * int(embed_dim * 2**(self.num_layers - 1 - i_layer)),                                      int(embed_dim * 2**(self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()            if i_layer == 0:  # 第一个stage只进行上采样                layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),                                       patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)            else:                layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),                                         input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),                                                           patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),                                         depth=depths[(self.num_layers-1-i_layer)],                                         num_heads=num_heads[(self.num_layers-1-i_layer)],                                         window_size=window_size,                                         mlp_ratio=self.mlp_ratio,                                         qkv_bias=qkv_bias, qk_scale=qk_scale,                                         drop=drop_rate, attn_drop=attn_drop_rate,                                         drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers - 1 - i_layer) + 1])],                                         norm_layer=norm_layer,                                         # 只有前3个stage有PatchExpand,最后一个没有                                         upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,                                         use_checkpoint=use_checkpoint)            self.layers_up.append(layer_up)            self.concat_back_dim.append(concat_linear)        self.norm = norm_layer(self.num_features)        self.norm_up = norm_layer(self.embed_dim)        # 解码器最后一个stage进行FinalPatchExpand处理        if self.final_upsample == "expand_first":            print("---final upsample expand_first---")            self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), dim_scale=4, dim=embed_dim)            self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)        self.apply(self._init_weights)

这里有一个FinalPatchExpand_X4的方法,我们来看一下它的实现

class FinalPatchExpand_X4(nn.Module):    """    stage4之后的PatchExpand    尺寸翻倍,通道数不变    """    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):        """        Args:            input_resolution: feature map的宽高            dim: feature map通道数            dim_scale: 通道数扩充的倍数            norm_layer: 通道方向归一化        """        super().__init__()        self.input_resolution = input_resolution        self.dim = dim        self.dim_scale = dim_scale        # 通过全连接层来扩大通道数        self.expand = nn.Linear(dim, 16 * dim, bias=False)        self.output_dim = dim         self.norm = norm_layer(self.output_dim)    def forward(self, x):        """        x: B, H*W, C        """        H, W = self.input_resolution        # 先把通道数翻倍        x = self.expand(x)        B, L, C = x.shape        assert L == H * W, "input feature has wrong size"        x = x.view(B, H, W, C)        # 将各个通道分开,再将所有通道拼成一个feature map        # 增大了feature map的尺寸        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))        # 把扩大的通道数转成原来的通道数        x = x.view(B, -1, self.output_dim)        x = self.norm(x)        return x

回到SwinTransformerSys代码中

def _init_weights(self, m):    """    对全连接层或者通道归一化进行权重以及偏置的初始化    """    if isinstance(m, nn.Linear):        trunc_normal_(m.weight, std=.02)        if isinstance(m, nn.Linear) and m.bias is not None:            nn.init.constant_(m.bias, 0)    elif isinstance(m, nn.LayerNorm):        nn.init.constant_(m.bias, 0)        nn.init.constant_(m.weight, 1.0)@torch.jit.ignoredef no_weight_decay(self):    return {'absolute_pos_embed'}@torch.jit.ignoredef no_weight_decay_keywords(self):    return {'relative_position_bias_table'}#Encoder and Bottleneckdef forward_features(self, x):    """    编码器过程    """    # 图像分割    x = self.patch_embed(x)    # 绝对位置嵌入    if self.ape:        x = x + self.absolute_pos_embed    x = self.pos_drop(x)    # 跳连点    x_downsample = []    # 通过各个编码过程的stage    for layer in self.layers:        x_downsample.append(x)        x = layer(x)    x = self.norm(x)  # B L C    return x, x_downsample#Dencoder and Skip connectiondef forward_up_features(self, x, x_downsample):    """    解码器过程,包含了跳连拼接    """    # 通过各个解码过程的stage    for inx, layer_up in enumerate(self.layers_up):        if inx == 0:            x = layer_up(x)        else:            # 拼接编码器的跳连部分再进入Swin Transformer Block            x = torch.cat([x, x_downsample[3-inx]], -1)            x = self.concat_back_dim[inx](x)            x = layer_up(x)    x = self.norm_up(x)  # B L C    return xdef up_x4(self, x):    """    完成解码器的最后一个stage后进入    """    H, W = self.patches_resolution    B, L, C = x.shape    assert L == H * W, "input features has wrong size"    if self.final_upsample == "expand_first":        x = self.up(x)        x = x.view(B, 4 * H, 4 * W, -1)        x = x.permute(0, 3, 1, 2) #B,C,H,W        x = self.output(x)            return xdef forward(self, x):    """    前向运算    """    x, x_downsample = self.forward_features(x)    x = self.forward_up_features(x, x_downsample)    x = self.up_x4(x)    return xdef flops(self):    flops = 0    flops += self.patch_embed.flops()    for i, layer in enumerate(self.layers):        flops += layer.flops()    flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)    flops += self.num_features * self.num_classes    return flops

接下来就是模型训练了,这里我舍弃了原框架的训练代码,因为完全不符合我的训练要求,包括损失函数这些。

import torchfrom torch import optimimport torch.nn as nnfrom torchvision import transformsfrom torch.utils.data import Dataset, DataLoader, random_splitfrom PIL import Imageimport numpy as npimport matplotlib.pyplot as pltimport osimport timeimport copyfrom torch.utils.tensorboard import SummaryWriterfrom swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSysRUN_NAME = 'swinunetv1'N_CLASSES = 3INPUT_SIZE = 128EPOCHS = 21LEARNING_RATE = 0.002START_FRAME = 16DROP_DATE = 0.5DATA_PATH = '/media/jingzhi/新加卷/'IMAGE_PATH = 'data_dataset_voc/JPEGImagespng/'MASK_PATH = 'data_dataset_voc/SegmentationClassPNG-new/'TEST_IMAGE_PATH = 'test_dataset_voc/JPEGImagespng/'TEST_MASK_PATH = 'test_dataset_voc/SegmentationClassPNG-new/'IMAGE_TYPE = '.png'MASK_TYPE = '.png'LOG_PATH = './runs'SAVE_PATH = './'REAL_HEIGHT = 3000REAL_WIDTH = 4096IMG_HEIGHT = 224IMG_WIDTH = 224RANDOM_SEED = 42VALID_RATIO = 0.2BATCH_SIZE = 4NUM_WORKERS = 1CLASSES = {1: 'line'}class LineDataset(Dataset):    def __init__(self, root_dir=DATA_PATH, transform=None):        self.root_dir = root_dir        listname = []        for imgfile in os.listdir(DATA_PATH + IMAGE_PATH):            list = imgfile.split('.')            l = len(list)            if '.' + list[l - 1] == IMAGE_TYPE:                if l > 2:                    filename = list[0] + '.' + list[1]                else:                    filename = list[0]                listname.append(filename)        self.ids = listname        if transform is None:            self.transform1 = transforms.Compose(                [transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),                 transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),                 transforms.ToTensor()])            self.transform2 = transforms.Compose(                [transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),                 transforms.ToTensor()])                                                 # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])    def __len__(self):        return len(self.ids)    def __getitem__(self, index):        id = self.ids[index]        image = Image.open(self.root_dir + IMAGE_PATH + id + IMAGE_TYPE)        mask = Image.open(self.root_dir + MASK_PATH + id + MASK_TYPE)        image = self.transform1(image)        mask = self.transform2(mask)        return image, maskdef get_trainloader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS):    train_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)    return train_loaderdef get_dataloader(dataset, batch_size=BATCH_SIZE, random_seed=RANDOM_SEED,                   valid_ratio=VALID_RATIO, shuffle=True, num_workers=NUM_WORKERS):    error_msg = "[!] valid_ratio should be in the range [0, 1]."    assert ((valid_ratio >= 0) and (valid_ratio <= 1)), error_msg    n = len(dataset)    n_valid = int(valid_ratio * n)    n_train = n - n_valid    torch.manual_seed(random_seed)    train_dataset, valid_dataset = random_split(dataset, (n_train, n_valid))    #    train_loader = DataLoader(train_dataset, batch_size, shuffle=shuffle, num_workers=num_workers)    valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers)    return train_loader, valid_loaderdef show_dataset(dataset, n_sample=4):    plt.figure(figsize=(30, 15))    for i in range(n_sample):        image, mask = dataset[i]        image = transforms.ToPILImage()(image)        mask = transforms.ToPILImage()(mask)        print(i, image.size, mask.size)        plt.tight_layout()        ax = plt.subplot(n_sample, 1, i + 1)        ax.set_title('Sample #{}'.format(i))        ax.axis('off')        plt.imshow(image, cmap="Greys")        plt.imshow(mask, alpha=0.3, cmap="OrRd")        if i == n_sample - 1:            plt.show()            breakclass Test_LineDataset(Dataset):    def __init__(self, root_dir=DATA_PATH, transform=None):        self.root_dir = root_dir        listname = []        for imgfile in os.listdir(DATA_PATH + TEST_MASK_PATH):            if '.' + imgfile.split('.')[1] == MASK_TYPE:                filename = imgfile.split('.')[0]                listname.append(filename)        self.ids = listname        if transform is None:            self.transform = transforms.Compose([transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST),                                                 transforms.ToTensor()])                                                 # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])    def __len__(self):        return len(self.ids)    def __getitem__(self, index):        id = self.ids[index]        image = Image.open(self.root_dir + TEST_IMAGE_PATH + id + IMAGE_TYPE)        mask = Image.open(self.root_dir + TEST_MASK_PATH + id + MASK_TYPE)        image = self.transform(image)        mask = self.transform(mask)        return image, maskdef get_validloader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS):    valid_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)    return valid_loaderdef show_test_dataset(dataset, n_sample=2):    plt.figure(figsize=(30, 15))    for i in range(n_sample):        image = dataset[i][0]        image = transforms.ToPILImage()(image)        print(i, image.size)        plt.tight_layout()        ax = plt.subplot(1, n_sample, i + 1)        ax.set_title('Sample #{}'.format(i))        ax.axis('off')        plt.imshow(image, cmap="Greys")        if i == n_sample - 1:            plt.show()            breakdef labels():    l = {}    for i, label in enumerate(CLASSES):        l[i] = label    return ldef tensor2np(tensor):    tensor = tensor.squeeze().cpu()    return tensor.detach().numpy()def normtensor(tensor):    tensor = torch.where(tensor < 0., torch.zeros(1).cuda(), torch.ones(1).cuda())    return tensordef count_params(model):    pytorch_total_params = sum(p.numel() for p in model.parameters())    return pytorch_total_paramsdef cal_iou(outputs, labels, SMOOTH=1e-6):    with torch.no_grad():        outputs = outputs.squeeze(1).bool()        labels = labels.squeeze(1).bool()        intersection = (outputs & labels).float().sum((1, 2))        union = (outputs | labels).float().sum((1, 2))        iou = (intersection + SMOOTH) / (union + SMOOTH)    return ioudef get_iou_score(outputs, labels):    A = labels.squeeze(1).bool()    pred = torch.where(outputs < 0., torch.zeros(1).cuda(), torch.ones(1).cuda())    B = pred.squeeze(1).bool()    intersection = (A & B).float().sum((1, 2))    union = (A | B).float().sum((1, 2))    iou = (intersection + 1e-6) / (union + 1e-6)    return iou.cpu().detach().numpy()def train(model, device, trainloader, optimizer, loss_function, epoch):    model.train()    model.is_train = True    running_loss = 0    mask_list, iou = [], []    for i, (input, mask) in enumerate(trainloader):        input, mask = input.to(device), mask.to(device)        predict = model(input)        loss = loss_function(predict, mask)        iou.append(get_iou_score(predict, mask).mean())        running_loss += loss.item()        optimizer.zero_grad()        loss.backward()        optimizer.step()        if ((i + 1) % 10) == 0:            pred = normtensor(predict[0])            img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0])            print(f'Epoch: {epoch} | Item: {i} | Train loss: {loss:.5f}')    mean_iou = np.mean(iou)    total_loss = running_loss / len(trainloader)    writer.add_scalar('training loss', total_loss, epoch)    return total_loss, mean_ioudef test(model, device, testloader, loss_function, best_iou, epoch):    model.eval()    model.is_train = False    running_loss = 0    mask_list, iou = [], []    with torch.no_grad():        for i, (input, mask) in enumerate(testloader):            input, mask = input.to(device), mask.to(device)            predict = model(input)            loss = loss_function(predict, mask)            running_loss += loss.item()            iou.append(get_iou_score(predict, mask).mean())            if ((i + 1) % 1) == 0:                pred = normtensor(predict[0])                img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0])                print(f'Epoch: {epoch} | Item: {i} | test loss: {loss:.5f}')    test_loss = running_loss / len(testloader)    mean_iou = np.mean(iou)    writer.add_scalar('val loss', test_loss, epoch)    if mean_iou > best_iou:        try:            torch.save(model.state_dict(), SAVE_PATH + RUN_NAME + '.pth')        except:            print('Can export weights')    return test_loss, mean_ioudef model_pipeline(prev_model=None):    best_model = None    model, criterion, optimizer = make_model(prev_model)    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)    best_iou = -1    for epoch in range(EPOCHS):        t0 = time.time()        train_loss, train_iou = train(model, device, trainloader, optimizer, criterion, epoch)        t1 = time.time()        print(f'Epoch: {epoch} | Train loss: {train_loss:.5f} | Train IoU: {train_iou:.3f} | Time: '              f'{(t1 - t0):.1f}s')        t0 = time.time()        test_loss, test_iou = test(model, device, validloader, criterion, best_iou, epoch)        t1 = time.time()        print(f'Epoch: {epoch} | Valid loss: {test_loss:.5f} | Valid IoU: {test_iou:.3f} | Time: '              f'{(t1 - t0):.1f}s')        scheduler.step()        if best_iou < test_iou:            best_iou = test_iou            best_model = copy.deepcopy(model)    return best_modeldef make_model(prev_model=None):    if prev_model == None:        model = SwinTransformerSys().to(device)    else:        model = prev_model    print("Number of parameter:", count_params(model))    criterion = nn.BCEWithLogitsLoss()    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)    return model, criterion, optimizerdef predict(model, test_loader, device):    model.eval()    predicted_masks = []    back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))])    with torch.no_grad():        for i, (input, _) in enumerate(test_loader):            input = input.to(device)            predict = model(input)            predict = back_transform(predict)            predict = (predict > 0).type(torch.float)            predicted_masks.append(predict)    predicted_masks = torch.cat(predicted_masks)    return predicted_masksdef show_sample_test_result(test_dataset, predicted_mask, n_samples=60):    plt.rcParams['figure.figsize'] = (30, 15)    back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))])    for i in range(n_samples):        sample = predicted_mask[i]        sample = torch.squeeze(sample, dim=0)        sample = transforms.ToPILImage()(sample)        X = test_dataset[i][0]        X = back_transform(X)        X = transforms.ToPILImage()(X)        if (i + 1) % 4 != 0:            index = (i + 1) % 4        else:            index = 4        ax = plt.subplot(2, 2, index)        ax.set_title('Sample #{}'.format(i))        ax.axis('off')        plt.imshow(X, cmap="Greys")        plt.imshow(sample, alpha=0.7, cmap="winter")        # if i == n_samples - 1:        if i % 3 == 0 and i != 0:            plt.show()            # breakif __name__ == '__main__':    writer = SummaryWriter(LOG_PATH)    dataset = LineDataset(DATA_PATH)    valid_dataset = Test_LineDataset(DATA_PATH)    trainloader = get_trainloader(dataset=dataset)    validloader = get_validloader(dataset=valid_dataset)    # show_dataset(dataset)    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    # model = UNet_ResNet()    # model = model.to(device)    # model.load_state_dict(torch.load(SAVE_PATH + RUN_NAME + '.pth'))    # print(device)    model = model_pipeline()    writer.close()    # predict_mask = predict(model, validloader, device)    # show_sample_test_result(valid_dataset, predict_mask)
原网站

版权声明
本文为[算法之名]所创,转载请带上原文链接,感谢
https://my.oschina.net/u/3768341/blog/5535941