class image2csv(object): # 分割训练集 验证集 测试集 # 做成对应的txt def __init__(self, data_root, image_dir, label_dir, slice_data, width_input, height_input): self.data_root = data_root self.image_dir = image_dir self.label_dir = label_dir self.slice_train = slice_data[0] self.slice_val = slice_data[1] self.width = width_input self.height = height_input def read_path(self): images = [] labels = [] for i, im in enumerate(os.listdir(self.image_dir)): label_name = im.split('.')[0] + '_GT' + '.bmp' # 由于各图片大小不同,这里进行简单的筛选,只有长宽均大于200px才被选取 if os.path.exists(os.path.join(self.label_dir, label_name)): size_w, size_h = Image.open( os.path.join(self.image_dir, im)).size size_lw, size_lh = Image.open( os.path.join(self.label_dir, label_name)).size if min(size_w, size_lw) > self.width and min(size_h, size_lh) > self.height: images.append(os.path.join(self.image_dir, im)) labels.append(os.path.join(self.label_dir, label_name)) else: continue self.data_length = len(images) # 两个文件夹都有的图片的长度 data_path = { 'image': images, 'label': labels, } return data_path def generate_csv(self): data_path = self.read_path() # 存放了路径 data_path_pd = pd.DataFrame(data_path) train_slice_point = int(self.slice_train*self.data_length) # 0.7*len validation_slice_point = int( (self.slice_train+self.slice_val)*self.data_length) # 0.8*len train_csv = data_path_pd.iloc[:train_slice_point, :] validation_csv = data_path_pd.iloc[train_slice_point:validation_slice_point, :] test_csv = data_path_pd.iloc[validation_slice_point:, :] train_csv.to_csv(os.path.join( self.data_root, 'train.csv'), header=None, index=None) validation_csv.to_csv(os.path.join( self.data_root, 'val.csv'), header=None, index=None) test_csv.to_csv(os.path.join(self.data_root, 'test.csv'), header=False, index=False)
def colormap(n): cmap = np.zeros([n, 3]).astype(np.uint8) for i in np.arange(n): r, g, b = np.zeros(3) for j in np.arange(8): r = r + (1 << (7 - j)) * ((i & (1 << (3 * j))) >> (3 * j)) g = g + (1 << (7 - j)) * ((i & (1 << (3 * j + 1))) >> (3 * j + 1)) b = b + (1 << (7 - j)) * ((i & (1 << (3 * j + 2))) >> (3 * j + 2)) cmap[i, :] = np.array([r, g, b]) return cmap class label2image(): def __init__(self, num_classes=22): self.colormap = colormap(256)[:num_classes].astype('uint8') def __call__(self, label_pred, label_true): pred = self.colormap[label_pred] true = self.colormap[label_true] return pred, trueclass image2label(): def __init__(self, num_classes=22): # 给每一类都来一种颜色 colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [0, 128, 128], [128, 128, 128], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 64, 128], [ 0, 192, 128], [128, 192, 128], [64, 64, 0], [192, 64, 0]] self.colormap = colormap[:num_classes] # 创建256^3 次方空数组,颜色的所有组合 cm2lb = np.zeros(256 ** 3) for i, cm in enumerate(self.colormap): cm2lb[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i # 符合这种组合的标记这一类 self.cm2lb = cm2lb def __call__(self, image): image = np.array(image, dtype=np.int64) idx = (image[:, :, 0] * 256 + image[:, :, 1]) * 256 + image[:, :, 2] label = np.array(self.cm2lb[idx], dtype=np.int64) # 根据颜色条找到这个label的标号 return label
class RandomCrop(object): """ 自定义实现图像与label随机裁剪相同的位置 """ def __init__(self, size): self.size = size @staticmethod def get_params(img, output_size): w, h = img.size th, tw = output_size if w == tw and h == th: return 0, 0, h, w i = random.randint(0, h - th) j = random.randint(0, w - tw) return i, j, th, tw def __call__(self, img, label): i, j, h, w = self.get_params(img, self.size) return img.crop((j, i, j + w, i + h)), label.crop((j, i, j + w, i + h))
class CustomDataset(Dataset): def __init__(self, data_root_csv, input_width, input_height, test=False): # 在子类进行初始化时,也想继承父类的__init__()就通过super()实现 super(CustomDataset, self).__init__() self.data_root_csv = data_root_csv self.data_all = pd.read_csv(self.data_root_csv) self.image_list = list(self.data_all.iloc[:, 0]) self.label_list = list(self.data_all.iloc[:, 1]) self.width = input_width self.height = input_height def __len__(self): return len(self.image_list) def __getitem__(self, index): img = Image.open(self.image_list[index]).convert('RGB') label = Image.open(self.label_list[index]).convert('RGB') img, label = self.train_transform( img, label, crop_size=(self.width, self.height)) # assert(img.size == label.size)s return img, label def train_transform(self, image, label, crop_size=(256, 256)): image, label = RandomCrop(crop_size)( image, label) # 第一个括号是实例话对象,第二个是__call__方法 tfs = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]) ]) image = tfs(image) label = image2label()(label) label = torch.from_numpy(label).long() return image, label
4.Unet 网络结构
class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x)
class Down(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x)
class Up(nn.Module): def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() if bilinear: self.up = nn.Upsample( scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d( in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x)
class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x)
class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=True): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits
# 得到混淆矩阵def _fast_hist(label_true, label_pred, n_class): mask = (label_true >= 0) & (label_true < n_class) hist = np.bincount( n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) return hist# 计算MIOUdef miou_score(label_trues, label_preds, n_class): hist = np.zeros((n_class, n_class)) for lt, lp in zip(label_trues, label_preds): hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) miou = np.nanmean(iou) return miou
GPU_ID = 0INPUT_WIDTH = 200INPUT_HEIGHT = 200BATCH_SIZE = 2NUM_CLASSES = 22LEARNING_RATE = 1e-3epoch = 300net = UNet(3, NUM_CLASSES)# -------------------- 生成csv ------------------DATA_ROOT = './MSRC2/'image = os.path.join(DATA_ROOT, 'Images')label = os.path.join(DATA_ROOT, 'GroundTruth')slice_data = [0.7, 0.1, 0.2] # 训练 验证 测试所占百分比tocsv = image2csv(DATA_ROOT, image, label, slice_data, INPUT_WIDTH, INPUT_HEIGHT)tocsv.generate_csv()# -------------------------------------------model_path = './model_result/best_model_UNet.mdl'train_csv_dir = 'MSRC2/train.csv'val_csv_dir = 'MSRC2/val.csv'train_data = CustomDataset(train_csv_dir, INPUT_WIDTH, INPUT_HEIGHT)train_dataloader = DataLoader( train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)val_data = CustomDataset(val_csv_dir, INPUT_WIDTH, INPUT_HEIGHT)val_dataloader = DataLoader( val_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)net = UNet(3, NUM_CLASSES)use_gpu = torch.cuda.is_available()# 构建网络optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)criterion = nn.CrossEntropyLoss()if use_gpu: torch.cuda.set_device(GPU_ID) net.cuda() criterion = criterion.cuda()if os.path.exists(model_path): net.load_state_dict(torch.load(model_path)) print('successful load weight!')else: print('not successful load weight')# 训练验证# def train():best_score = 0.0for e in range(epoch): net.train() train_loss = 0.0 label_true = torch.LongTensor() label_pred = torch.LongTensor() for i, (batchdata, batchlabel) in enumerate(train_dataloader): if use_gpu: batchdata, batchlabel = batchdata.cuda(), batchlabel.cuda() output = net(batchdata) output = F.log_softmax(output, dim=1) loss = criterion(output, batchlabel) pred = output.argmax(dim=1).squeeze().data.cpu() real = batchlabel.data.cpu() optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.cpu().item() * batchlabel.size(0) label_true = torch.cat((label_true, real), dim=0) label_pred = torch.cat((label_pred, pred), dim=0) train_loss /= len(train_data) miou = miou_score( label_true.numpy(), label_pred.numpy(), NUM_CLASSES) print('\nepoch:{}, train_loss:{:.4f},miou:{:.4f}'.format( e + 1, train_loss, miou)) net.eval() val_loss = 0.0 val_label_true = torch.LongTensor() val_label_pred = torch.LongTensor() with torch.no_grad(): for i, (batchdata, batchlabel) in enumerate(val_dataloader): if use_gpu: batchdata, batchlabel = batchdata.cuda(), batchlabel.cuda() output = net(batchdata) output = F.log_softmax(output, dim=1) loss = criterion(output, batchlabel) pred = output.argmax(dim=1).data.cpu() real = batchlabel.data.cpu() val_loss += loss.cpu().item() * batchlabel.size(0) val_label_true = torch.cat((val_label_true, real), dim=0) val_label_pred = torch.cat((val_label_pred, pred), dim=0) val_loss /= len(val_data) val_miou = miou_score(val_label_true.numpy(), val_label_pred.numpy(), NUM_CLASSES) print('epoch:{}, val_loss:{:.4f}, miou:{:.4f}'.format( e + 1, val_loss, val_miou)) # 通过验证集的val_miou来判断模型效果,保存最好的模型权重 score = val_miou if score > best_score: best_score = score torch.save(net.state_dict(), model_path)
GPU_ID = 0INPUT_WIDTH = 200INPUT_HEIGHT = 200BATCH_SIZE = 2NUM_CLASSES = 22LEARNING_RATE = 1e-3model_path = './model_result/best_model_UNet.mdl'torch.cuda.set_device(0)net = UNet(3, NUM_CLASSES)# 加载网络进行测试test_csv_dir = './MSRC2/train.csv'testset = CustomDataset(test_csv_dir, INPUT_WIDTH, INPUT_HEIGHT)test_dataloader = DataLoader(testset, batch_size=15, shuffle=False)net.load_state_dict(torch.load(model_path, map_location='cuda:0'))test_label_true = torch.LongTensor()test_label_pred = torch.LongTensor()# 这里只提取一个batch来测试,即15张图片for (val_image, val_label) in test_dataloader: net.cuda() out = net(val_image.cuda()) pred = out.argmax(dim=1).squeeze().data.cpu().numpy() label = val_label.data.numpy() output = F.log_softmax(out, dim=1) pred = output.argmax(dim=1).data.cpu() real = val_label.data.cpu() test_label_true = torch.cat((test_label_true, real), dim=0) test_label_pred = torch.cat((test_label_pred, pred), dim=0) test_miou = miou_score(test_label_true.numpy(), test_label_pred.numpy(), NUM_CLASSES) print("测试集上的miou为:" + str(test_miou)) val_pred, val_label = label2image(NUM_CLASSES)(pred, label) for i in range(15): val_imag = val_image[i] val_pre = val_pred[i] val_labe = val_label[i] # 反归一化 mean = [.485, .456, .406] std = [.229, .224, .225] x = val_imag for j in range(3): x[j] = x[j].mul(std[j])+mean[j] img = x.mul(255).byte() img = img.numpy().transpose((1, 2, 0)) # 原图 fig, ax = plt.subplots(1, 3, figsize=(30, 30)) ax[0].imshow(img) ax[1].imshow(val_labe) ax[2].imshow(val_pre) plt.show() plt.savefig('./pic_results/pic_UNet_{}.png'.format(i)) break
