当前位置:网站首页>Swin-Unet最强分割网络
Swin-Unet最强分割网络
2022-06-10 08:05:00 【算法之名】
Swin-Unet是基于Swin Transformer为基础(可参考Swin Transformer介绍 ),结合了U-Net网络的特点(可参考Tensorflow深度学习算法整理(三) 中的U-Net)组合而成的新的分割网络

它与Swin Transformer不同的地方在于,在编码器(Encoder)这边虽然跟Swin Transformer一样的4个Stage,但Swin Transformer Block的数量为[2,2,2,1],而不是Swin Transformer的[2,2,6,2]。而在解码器(Decoder)这边,由于是升采样,使用的不再是Patch Embedding和Patch Merging,而使用的是Patch Expanding,它是Patch Merging的逆过程。
我们来看一下Patch Expanding的代码实现
from einops import rearrange
class PatchExpand(nn.Module): """ 块状扩充,尺寸翻倍,通道数减半 """ def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): """ Args: input_resolution: 解码过程的feature map的宽高 dim: frature map通道数 dim_scale: 通道数扩充的倍数 norm_layer: 通道方向归一化 """ super().__init__() self.input_resolution = input_resolution self.dim = dim # 通过全连接层来扩大通道数 self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() self.norm = norm_layer(dim // dim_scale) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution # 先把通道数翻倍 x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # 将各个通道分开,再将所有通道拼成一个feature map # 增大了feature map的尺寸 x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 通道翻倍后再除以4,实际相当于通道数减半 x = x.view(B, -1, C // 4) x = self.norm(x) return x
在编码器这边基本上跟Swin Transformer是一样的,我们重点来看解码器这边。它是使用BasicLayer_up类来对SwinTransformerBlock和Patch Expanding来进行搭配的。
class BasicLayer_up(nn.Module): """ A basic Swin Transformer layer for one stage. 一个BasicLayer_up包含偶数个SwinTransformerBlock和一个upsamele层(即Patch Expanding层) """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): """ Args: dim: feature map通道数 input_resolution: feature map的宽高 depth: 各个Stage中,Swin Transformer Block的数量 num_heads: 多头注意力各个Stage中的头数 window_size: 窗口自注意力机制的窗口中的patch数 mlp_ratio: 层感知机模块中第一个全连接层输出的通道倍数 qkv_bias: 如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置 qk_scale: 窗口自注意力公式常数 drop: dropout rate,默认为0 attn_drop: 用于自注意力机制中的dropout rate,默认为0 drop_path: 在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括 LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0 norm_layer: 通道方向归一化 upsample: 使用Patch Expanding来升采样 use_checkpoint: 是否使用Pytorch中间数据保存机制 """ super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build SwinTransformerBlock self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, # 用于区分是使用W-MSA还是SW-MSA,0为W-MSA,1为SW-MSA shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer # 当stage=4的时候为None if upsample is not None: self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) else: self.upsample = None def forward(self, x): # 通过每一个SwinTransformerBlock for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) # 进行块状扩充(PatchExpanding)上采样 if self.upsample is not None: x = self.upsample(x) return x
SwinTransformerBlock跟SwinTransformer中的代码也是一样的,这里就不重复了。
然后还有一个从编码器到解码器之间的跳连。这里需要看一下Swin-Unet的主类代码
class SwinTransformerSys(nn.Module): """ Swin-UNet网络模型 """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, final_upsample="expand_first", **kwargs): """ Args: img_size: 原始图像尺寸 patch_size: 一个patch中的像素点数 in_chans: 进入网络的图片通道数 num_classes: 分类数量 embed_dim: feature map通道数 depths: 编码器各个Stage中,Swin Transformer Block的数量 depths_decoder: 解码器各个Stage中,Swin Transformer Block的数量 num_heads: 多头注意力各个Stage中的头数 window_size: 窗口自注意力机制的窗口中的patch数 mlp_ratio: 多层感知机模块中第一个全连接层输出的通道倍数 qkv_bias: 如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置 qk_scale: 自注意力公式中的常量 drop_rate: dropout rate,默认为0 attn_drop_rate: 用于自注意力机制中的dropout rate,默认为0 drop_path_rate: 在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括 LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0.1 norm_layer: 通道方向归一化 ape: 是否进行绝对位置嵌入,默认False patch_norm: 如果是True的话,在patch embedding之后加上归一化 use_checkpoint: 是否使用Pytorch中间数据保存机制 final_upsample: 解码器stage4后的Patch Expanding **kwargs: """ super().__init__() print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, depths_decoder, drop_path_rate, num_classes)) self.num_classes = num_classes # stage的数量 self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm # 编码器stage4输出特征的通道数(Swin-Tiny:768) self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) # 解码器stage4输出特征的通道数(192) self.num_features_up = int(embed_dim * 2) self.mlp_ratio = mlp_ratio self.final_upsample = final_upsample # 把图像分割成不重叠的patch self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches # 获取feature map的高宽 patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # 绝对位置嵌入 if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # 不同的stage,舍弃整个直连分支的概率不同,从小到大,最小为0,最大为0.1 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # 创建编码器layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): # layer相当于stage layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, # 只有前3个stage有patchmerging,最后一个没有 downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) # 创建解码器layers self.layers_up = nn.ModuleList() self.concat_back_dim = nn.ModuleList() for i_layer in range(self.num_layers): # layer相当于stage # 每一个stage结束后,通道数减半的全连接层 concat_linear = nn.Linear(2 * int(embed_dim * 2**(self.num_layers - 1 - i_layer)), int(embed_dim * 2**(self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity() if i_layer == 0: # 第一个stage只进行上采样 layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) else: layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), depth=depths[(self.num_layers-1-i_layer)], num_heads=num_heads[(self.num_layers-1-i_layer)], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers - 1 - i_layer) + 1])], norm_layer=norm_layer, # 只有前3个stage有PatchExpand,最后一个没有 upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers_up.append(layer_up) self.concat_back_dim.append(concat_linear) self.norm = norm_layer(self.num_features) self.norm_up = norm_layer(self.embed_dim) # 解码器最后一个stage进行FinalPatchExpand处理 if self.final_upsample == "expand_first": print("---final upsample expand_first---") self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), dim_scale=4, dim=embed_dim) self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) self.apply(self._init_weights)
这里有一个FinalPatchExpand_X4的方法,我们来看一下它的实现
class FinalPatchExpand_X4(nn.Module): """ stage4之后的PatchExpand 尺寸翻倍,通道数不变 """ def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): """ Args: input_resolution: feature map的宽高 dim: feature map通道数 dim_scale: 通道数扩充的倍数 norm_layer: 通道方向归一化 """ super().__init__() self.input_resolution = input_resolution self.dim = dim self.dim_scale = dim_scale # 通过全连接层来扩大通道数 self.expand = nn.Linear(dim, 16 * dim, bias=False) self.output_dim = dim self.norm = norm_layer(self.output_dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution # 先把通道数翻倍 x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # 将各个通道分开,再将所有通道拼成一个feature map # 增大了feature map的尺寸 x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) # 把扩大的通道数转成原来的通道数 x = x.view(B, -1, self.output_dim) x = self.norm(x) return x
回到SwinTransformerSys代码中
def _init_weights(self, m): """ 对全连接层或者通道归一化进行权重以及偏置的初始化 """ if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)@torch.jit.ignoredef no_weight_decay(self): return {'absolute_pos_embed'}@torch.jit.ignoredef no_weight_decay_keywords(self): return {'relative_position_bias_table'}#Encoder and Bottleneckdef forward_features(self, x): """ 编码器过程 """ # 图像分割 x = self.patch_embed(x) # 绝对位置嵌入 if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) # 跳连点 x_downsample = [] # 通过各个编码过程的stage for layer in self.layers: x_downsample.append(x) x = layer(x) x = self.norm(x) # B L C return x, x_downsample#Dencoder and Skip connectiondef forward_up_features(self, x, x_downsample): """ 解码器过程,包含了跳连拼接 """ # 通过各个解码过程的stage for inx, layer_up in enumerate(self.layers_up): if inx == 0: x = layer_up(x) else: # 拼接编码器的跳连部分再进入Swin Transformer Block x = torch.cat([x, x_downsample[3-inx]], -1) x = self.concat_back_dim[inx](x) x = layer_up(x) x = self.norm_up(x) # B L C return xdef up_x4(self, x): """ 完成解码器的最后一个stage后进入 """ H, W = self.patches_resolution B, L, C = x.shape assert L == H * W, "input features has wrong size" if self.final_upsample == "expand_first": x = self.up(x) x = x.view(B, 4 * H, 4 * W, -1) x = x.permute(0, 3, 1, 2) #B,C,H,W x = self.output(x) return xdef forward(self, x): """ 前向运算 """ x, x_downsample = self.forward_features(x) x = self.forward_up_features(x, x_downsample) x = self.up_x4(x) return xdef flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops
接下来就是模型训练了,这里我舍弃了原框架的训练代码,因为完全不符合我的训练要求,包括损失函数这些。
import torchfrom torch import optimimport torch.nn as nnfrom torchvision import transformsfrom torch.utils.data import Dataset, DataLoader, random_splitfrom PIL import Imageimport numpy as npimport matplotlib.pyplot as pltimport osimport timeimport copyfrom torch.utils.tensorboard import SummaryWriterfrom swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSysRUN_NAME = 'swinunetv1'N_CLASSES = 3INPUT_SIZE = 128EPOCHS = 21LEARNING_RATE = 0.002START_FRAME = 16DROP_DATE = 0.5DATA_PATH = '/media/jingzhi/新加卷/'IMAGE_PATH = 'data_dataset_voc/JPEGImagespng/'MASK_PATH = 'data_dataset_voc/SegmentationClassPNG-new/'TEST_IMAGE_PATH = 'test_dataset_voc/JPEGImagespng/'TEST_MASK_PATH = 'test_dataset_voc/SegmentationClassPNG-new/'IMAGE_TYPE = '.png'MASK_TYPE = '.png'LOG_PATH = './runs'SAVE_PATH = './'REAL_HEIGHT = 3000REAL_WIDTH = 4096IMG_HEIGHT = 224IMG_WIDTH = 224RANDOM_SEED = 42VALID_RATIO = 0.2BATCH_SIZE = 4NUM_WORKERS = 1CLASSES = {1: 'line'}class LineDataset(Dataset): def __init__(self, root_dir=DATA_PATH, transform=None): self.root_dir = root_dir listname = [] for imgfile in os.listdir(DATA_PATH + IMAGE_PATH): list = imgfile.split('.') l = len(list) if '.' + list[l - 1] == IMAGE_TYPE: if l > 2: filename = list[0] + '.' + list[1] else: filename = list[0] listname.append(filename) self.ids = listname if transform is None: self.transform1 = transforms.Compose( [transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), transforms.ToTensor()]) self.transform2 = transforms.Compose( [transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()]) # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) def __len__(self): return len(self.ids) def __getitem__(self, index): id = self.ids[index] image = Image.open(self.root_dir + IMAGE_PATH + id + IMAGE_TYPE) mask = Image.open(self.root_dir + MASK_PATH + id + MASK_TYPE) image = self.transform1(image) mask = self.transform2(mask) return image, maskdef get_trainloader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS): train_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers) return train_loaderdef get_dataloader(dataset, batch_size=BATCH_SIZE, random_seed=RANDOM_SEED, valid_ratio=VALID_RATIO, shuffle=True, num_workers=NUM_WORKERS): error_msg = "[!] valid_ratio should be in the range [0, 1]." assert ((valid_ratio >= 0) and (valid_ratio <= 1)), error_msg n = len(dataset) n_valid = int(valid_ratio * n) n_train = n - n_valid torch.manual_seed(random_seed) train_dataset, valid_dataset = random_split(dataset, (n_train, n_valid)) # train_loader = DataLoader(train_dataset, batch_size, shuffle=shuffle, num_workers=num_workers) valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers) return train_loader, valid_loaderdef show_dataset(dataset, n_sample=4): plt.figure(figsize=(30, 15)) for i in range(n_sample): image, mask = dataset[i] image = transforms.ToPILImage()(image) mask = transforms.ToPILImage()(mask) print(i, image.size, mask.size) plt.tight_layout() ax = plt.subplot(n_sample, 1, i + 1) ax.set_title('Sample #{}'.format(i)) ax.axis('off') plt.imshow(image, cmap="Greys") plt.imshow(mask, alpha=0.3, cmap="OrRd") if i == n_sample - 1: plt.show() breakclass Test_LineDataset(Dataset): def __init__(self, root_dir=DATA_PATH, transform=None): self.root_dir = root_dir listname = [] for imgfile in os.listdir(DATA_PATH + TEST_MASK_PATH): if '.' + imgfile.split('.')[1] == MASK_TYPE: filename = imgfile.split('.')[0] listname.append(filename) self.ids = listname if transform is None: self.transform = transforms.Compose([transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()]) # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) def __len__(self): return len(self.ids) def __getitem__(self, index): id = self.ids[index] image = Image.open(self.root_dir + TEST_IMAGE_PATH + id + IMAGE_TYPE) mask = Image.open(self.root_dir + TEST_MASK_PATH + id + MASK_TYPE) image = self.transform(image) mask = self.transform(mask) return image, maskdef get_validloader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS): valid_loader = DataLoader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers) return valid_loaderdef show_test_dataset(dataset, n_sample=2): plt.figure(figsize=(30, 15)) for i in range(n_sample): image = dataset[i][0] image = transforms.ToPILImage()(image) print(i, image.size) plt.tight_layout() ax = plt.subplot(1, n_sample, i + 1) ax.set_title('Sample #{}'.format(i)) ax.axis('off') plt.imshow(image, cmap="Greys") if i == n_sample - 1: plt.show() breakdef labels(): l = {} for i, label in enumerate(CLASSES): l[i] = label return ldef tensor2np(tensor): tensor = tensor.squeeze().cpu() return tensor.detach().numpy()def normtensor(tensor): tensor = torch.where(tensor < 0., torch.zeros(1).cuda(), torch.ones(1).cuda()) return tensordef count_params(model): pytorch_total_params = sum(p.numel() for p in model.parameters()) return pytorch_total_paramsdef cal_iou(outputs, labels, SMOOTH=1e-6): with torch.no_grad(): outputs = outputs.squeeze(1).bool() labels = labels.squeeze(1).bool() intersection = (outputs & labels).float().sum((1, 2)) union = (outputs | labels).float().sum((1, 2)) iou = (intersection + SMOOTH) / (union + SMOOTH) return ioudef get_iou_score(outputs, labels): A = labels.squeeze(1).bool() pred = torch.where(outputs < 0., torch.zeros(1).cuda(), torch.ones(1).cuda()) B = pred.squeeze(1).bool() intersection = (A & B).float().sum((1, 2)) union = (A | B).float().sum((1, 2)) iou = (intersection + 1e-6) / (union + 1e-6) return iou.cpu().detach().numpy()def train(model, device, trainloader, optimizer, loss_function, epoch): model.train() model.is_train = True running_loss = 0 mask_list, iou = [], [] for i, (input, mask) in enumerate(trainloader): input, mask = input.to(device), mask.to(device) predict = model(input) loss = loss_function(predict, mask) iou.append(get_iou_score(predict, mask).mean()) running_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() if ((i + 1) % 10) == 0: pred = normtensor(predict[0]) img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0]) print(f'Epoch: {epoch} | Item: {i} | Train loss: {loss:.5f}') mean_iou = np.mean(iou) total_loss = running_loss / len(trainloader) writer.add_scalar('training loss', total_loss, epoch) return total_loss, mean_ioudef test(model, device, testloader, loss_function, best_iou, epoch): model.eval() model.is_train = False running_loss = 0 mask_list, iou = [], [] with torch.no_grad(): for i, (input, mask) in enumerate(testloader): input, mask = input.to(device), mask.to(device) predict = model(input) loss = loss_function(predict, mask) running_loss += loss.item() iou.append(get_iou_score(predict, mask).mean()) if ((i + 1) % 1) == 0: pred = normtensor(predict[0]) img, pred, mak = tensor2np(input[0]), tensor2np(pred), tensor2np(mask[0]) print(f'Epoch: {epoch} | Item: {i} | test loss: {loss:.5f}') test_loss = running_loss / len(testloader) mean_iou = np.mean(iou) writer.add_scalar('val loss', test_loss, epoch) if mean_iou > best_iou: try: torch.save(model.state_dict(), SAVE_PATH + RUN_NAME + '.pth') except: print('Can export weights') return test_loss, mean_ioudef model_pipeline(prev_model=None): best_model = None model, criterion, optimizer = make_model(prev_model) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) best_iou = -1 for epoch in range(EPOCHS): t0 = time.time() train_loss, train_iou = train(model, device, trainloader, optimizer, criterion, epoch) t1 = time.time() print(f'Epoch: {epoch} | Train loss: {train_loss:.5f} | Train IoU: {train_iou:.3f} | Time: ' f'{(t1 - t0):.1f}s') t0 = time.time() test_loss, test_iou = test(model, device, validloader, criterion, best_iou, epoch) t1 = time.time() print(f'Epoch: {epoch} | Valid loss: {test_loss:.5f} | Valid IoU: {test_iou:.3f} | Time: ' f'{(t1 - t0):.1f}s') scheduler.step() if best_iou < test_iou: best_iou = test_iou best_model = copy.deepcopy(model) return best_modeldef make_model(prev_model=None): if prev_model == None: model = SwinTransformerSys().to(device) else: model = prev_model print("Number of parameter:", count_params(model)) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) return model, criterion, optimizerdef predict(model, test_loader, device): model.eval() predicted_masks = [] back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))]) with torch.no_grad(): for i, (input, _) in enumerate(test_loader): input = input.to(device) predict = model(input) predict = back_transform(predict) predict = (predict > 0).type(torch.float) predicted_masks.append(predict) predicted_masks = torch.cat(predicted_masks) return predicted_masksdef show_sample_test_result(test_dataset, predicted_mask, n_samples=60): plt.rcParams['figure.figsize'] = (30, 15) back_transform = transforms.Compose([transforms.Resize((REAL_HEIGHT, REAL_WIDTH))]) for i in range(n_samples): sample = predicted_mask[i] sample = torch.squeeze(sample, dim=0) sample = transforms.ToPILImage()(sample) X = test_dataset[i][0] X = back_transform(X) X = transforms.ToPILImage()(X) if (i + 1) % 4 != 0: index = (i + 1) % 4 else: index = 4 ax = plt.subplot(2, 2, index) ax.set_title('Sample #{}'.format(i)) ax.axis('off') plt.imshow(X, cmap="Greys") plt.imshow(sample, alpha=0.7, cmap="winter") # if i == n_samples - 1: if i % 3 == 0 and i != 0: plt.show() # breakif __name__ == '__main__': writer = SummaryWriter(LOG_PATH) dataset = LineDataset(DATA_PATH) valid_dataset = Test_LineDataset(DATA_PATH) trainloader = get_trainloader(dataset=dataset) validloader = get_validloader(dataset=valid_dataset) # show_dataset(dataset) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = UNet_ResNet() # model = model.to(device) # model.load_state_dict(torch.load(SAVE_PATH + RUN_NAME + '.pth')) # print(device) model = model_pipeline() writer.close() # predict_mask = predict(model, validloader, device) # show_sample_test_result(valid_dataset, predict_mask)
边栏推荐
- 7-1 intersection of two ordered linked list sequences (20 points)
- Advantages and disadvantages of autojs and ice fox intelligent assistant
- [APIO2022] 火星——构造、状态压缩
- It is said that there are many ways to find prime numbers, but this is the simplest way I have come up with
- 浏览器中如何使用 module export import: Uncaught SyntaxError: Cannot use import statement outside a module 问题
- QT makes simple video calls
- 【软件测试】多家大厂的软件测试面试常见问题合集(BAT、三大流量厂商、知名大厂)
- 10 useful flutter widgets
- Global industry analysis report of UAV detection radar in 2022
- Next generation enterprise IT architecture: cloud native architecture
猜你喜欢

Form1 SLA光固化国产仿打印机用的切片软件PreForm下载

2. ZK's working mechanism
![[C language] C language programming: dynamic address book](/img/bb/b44066e91219a57b653e330198e4a0.png)
[C language] C language programming: dynamic address book
![[APIO2022] 火星——构造、状态压缩](/img/ae/877bd8f20082257551c85abdbb1fd6.png)
[APIO2022] 火星——构造、状态压缩
![[software testing] a collection of frequently asked questions from software testing interviews of several major manufacturers (bat, three major traffic manufacturers, and well-known manufacturers)](/img/d1/920567425c658b5bb9968fdf173555.png)
[software testing] a collection of frequently asked questions from software testing interviews of several major manufacturers (bat, three major traffic manufacturers, and well-known manufacturers)

使用pyQt5 + agora + leanCloud实现基于学生疲劳检测的在线课堂

Huawei device configuration hub and spoke
![[chapter 65 of the flutter problem series] a solution to setting the maximum height of showmodalbottomsheet in the flutter is invalid](/img/a3/7db76a03bfa6c036e576074af84428.png)
[chapter 65 of the flutter problem series] a solution to setting the maximum height of showmodalbottomsheet in the flutter is invalid

A practice of encrypting server 3D model resources

markdown md 文件编辑器测试使用说明
随机推荐
Summary of technical scheme for automatic wool picking
swagger快速使用
618 l'informatique en nuage stimule la diffusion en direct du commerce électronique
[APIO2022] 火星——构造、状态压缩
How to make sql-mode=“NO_ENGINE_SUBSTITUTION” permanent in MySQL my. cnf
How to maintain the length of a solid array and how to delete elements completely
Boxing and UnBoxing
Cloud native network edge -- cilium overview
Vulnerability recurrence_ Cve-2020-0796 eternal black vulnerability_ Pit encounter_ resolved
使用快慢指针实现链表找中点问题
618 大促来袭,浅谈如何做好大促备战
Jenkins-API
Esayexcel quick start
markdown md 文件编辑器测试使用说明
Leetcode 160 Intersecting linked list (2022.06.09)
Notice on the issuance of Shenzhen action plan for cultivating and developing software and information service industry clusters (2022-2025)
Introduction to temporal database incluxdb
Next generation enterprise IT architecture: cloud native architecture
Markdown MD file editor test instructions
Sum of redis data types hash