当前位置:网站首页>使用SwinUnet训练自己的数据集

使用SwinUnet训练自己的数据集

2022-07-07 05:19:00 我是一个小稻米

参考博文: https://blog.csdn.net/qq_37652891/article/details/123932772

数据集准备

遥感图像多类别语义分割,总共分为7类(包括背景)
在这里插入图片描述
image:
在这里插入图片描述
label_rgb
在这里插入图片描述
label(这里并不是全黑,其中的类别取值为0,1,2,3,4,5,6),此后的训练使用的也是这样的数据
在这里插入图片描述

数据地址
百度云:https://pan.baidu.com/s/1zZHnZfBgVWxs6TJW4yjeeQ

提取码:2022

SwinUNet代码地址

数据集处理

数据集的imagelabel,这个数据集应该提供了rgb格式标签和包含0,1,2,3,4,5,6值的标签,SwinUNet使用的是包含0,1,2,3,4,5,6的标签图像;

1. 数据集

数据集存放在SwinUNet根目录下,image中是原图像,label中是标签图像(共7类,其标签取值为0,1,2,3,4,5,6,7);
如果使用其他数据集,要注意标签的取值。比如如果是二分类。即标签0255,需要换成01

—SwinUNet
---------configs
---------img_datas
---------------train
--------------------image
--------------------label
---------------test
--------------------image
--------------------label

2. 在SwinUnet根目录下创建npz.py文件,运行npz.py文件

import glob
import cv2
import numpy as np
import os

def npz(im, la, s):
    images_path = im
    labels_path = la
    path2 = s
    images = os.listdir(images_path)
    for s in images:
        image_path = os.path.join(images_path, s)
        label_path = os.path.join(labels_path, s)

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
		# 标签由三通道转换为单通道
        label = cv2.imread(label_path, flags=0)
        # 保存npz文件 
        np.savez(path2+s[:-4]+".npz",image=image,label=label)

npz('./img_datas/train/image/', './img_datas/train/label/', './data/Synapse/train_npz')

npz('./img_datas/test/image/', './img_datas/test/label/', './data/Synapse/test_vol_h5')

3. 在SwinUnet根目录下创建txt.py文件,运行txt.py文件

目的是生成./list/list_Synapse/train.txt./list/list_Synapse/test_vol.txt文件

import os
def write_name(np, tx):
    #npz文件路径
    files = os.listdir(np)
    #txt文件路径
    f = open(tx, 'w')
    for i in files:
        #name = i.split('\\')[-1]
        name = i[:-4]+'\n'
        f.write(name)
        
write_name('./data/Synapse/train_npz', './lists/lists_Synapse/train.txt')
write_name('./data/Synapse/test_vol_h5', './lists/lists_Synapse/test_vol.txt')

4. 下载预训练权重,放在SwinUnet目录下的pretrained_ckpt文件夹下

链接:https://pan.baidu.com/s/1-hYwJRlr95Fv08e9AEARww
提取码:2022

在这里插入图片描述

修改网络

1. 修改train.py文件

在这里插入图片描述
比较重要的是类别数量,其他视情况而定
在这里插入图片描述

2. 修改./datasets/dataset_synapse.py文件

在这里插入图片描述

3. 修改trainer.py文件

此处不知道为什么
在这里插入图片描述

4. 运行代码

这些信息可以作为超参传入,如果不能,那么可以使用default=的方式写入默认值
在这里插入图片描述
如果设置好啦默认值,那么运行python train.py就可以啦
在这里插入图片描述

原网站

版权声明
本文为[我是一个小稻米]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_44669966/article/details/125623961