当前位置:网站首页>【深度学习】数据准备-pytorch自定义图像分割类数据集加载
【深度学习】数据准备-pytorch自定义图像分割类数据集加载
2022-07-29 07:25:00 【鱼与钰遇雨】
一、参考资料
宏观方法框架,参考我的这一篇文章
其他细节资料:
1. pytorch加载自己的图片数据集的2种方法详解
2 .PIL将png的RGBA四通道改为jpg的RGB三通道方法
3. 将cv2+plt+Image读取的图片转换为tensor格式
二、问题定义
假设我现在有20张图片和使用labelme打好的30张标签图像。分别放在train和train_label这两个文件夹中,并且一一对应进行命名。
如何将这些图片加载到深度学习的流程中,作为样本和标签项目匹配的训练集和测试集呢?
三、实现过程
注意,代码的写法并不唯一,这里知识提供一个思路,如果有更好的,大家也可以在评论区下方留言。
其中,最合心的部分就是继承Dataset类,并重写其中的__len__方法和__getitem__方法。__len__方法用于返回数据的
第一步:生成文件索引序列
在train和train_label文件的这一层目录下,创建一个python文件用于给所有的图片的地址,生成索引txt
import os
path = os.getcwd() #获取当前文件的绝对鹿筋
def make_txt(root, file_name):
path = os.path.join(root, file_name)
data = os.listdir(path)
f = open(root+'/'+file_name+'.txt', 'w')
for line in data:
f.write(line+'\n')
f.close()
print('success')
#调用函数生成两个文件夹下的txt文件
make_txt(path, file_name='train')
make_txt(path, file_name='train_label')
output:
success
success
得到train.txt和train_label.txt文件索引
第二步:继承Dataset类
根据文件索引,加载图片和标签,并继承Dataset类。
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as trans
trans1 = trans.ToTensor()
class MyDataset(Dataset):
def __init__(self, path, transform=None):
super(MyDataset, self).__init__()
self.image_path = path+'/'+'train.txt'
self.label_path = path+'/'+'train_label.txt'
self.path = path
f = open(self.image_path, 'r')
data_image = f.readlines()
imgs = []
for line in data_image:
line = line.rstrip()
imgs.append(os.path.join(self.path+'/train', line))
f.close()
f2 = open(self.label_path, 'r')
data_label = f2.readlines()
labels = []
for line in data_label:
line = line.rstrip()
labels.append(os.path.join(self.path+'/train_label', line))
f2.close()
self.img = imgs
self.label = labels
self.transform = transform
def __len__(self):
return len(self.label)
def __getitem__(self, item):
img = self.img[item]
label = self.label[item]
img = Image.open(img).convert('1')
img = trans1(img)
#此时img是PIL.Image类型 label是str类型
if transforms is not None:
img = self.transform(img)
label = Image.open(label).convert('1')
label = trans1(label)
# label = torch.from_numpy(label)
# label = label.to(torch.float32)
return img, label
实例化myDataset,并对其进行数据划分,得到训练集和测试集
path = os.getcwd()
data = MyDataset(path, transform=None)
train_size = int(len(data) * 0.7)
test_size = len(data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
第三步:利用DataLoader生成用于训练的样本迭代器
train_loader = DataLoader(train_dataset,batch_size=4,shuffle=True,drop_last=False,num_workers=0)
test_loader = DataLoader(test_dataset,batch_size=4,shuffle=False,drop_last=False,num_workers=0)
从迭代器读取后,查看一下效果
如何用plt同时绘制两张图呢?
from matplotlib import pyplot as plt
for i, (images,GT) in enumerate (test_loader):
print(i)
print(GT.shape)
plt.subplot(121)
plt.imshow(images.reshape(256,256),cmap='gray')
plt.subplot(122)
plt.imshow(GT.reshape(256,256),cmap='gray')
plt.show()

边栏推荐
- 计算程序运行时间 demo
- 一篇长文---深入理解synchronized
- JS day 4 process control (if statement and switch statement)
- Latest 10 billion quantitative private placement list
- 使用自定义注解校验list的大小
- 信用卡购物积分
- Practice of online problem feedback module (XVII): realize the online download function of excel template
- Scala higher order (10): exception handling in Scala
- logback中RollingFileAppender属性简介说明
- Does Flink support sqlserver databases? Get the changes of SQLSERVER database
猜你喜欢

QT连接两个qslite数据库报错QSqlQuery::exec: database not open

How to establish EDI connection with Scania in Scania?

分析25个主要DeFi协议的路线图 预见DeFi未来的七大趋势

WPF interface layout must know basis
![[OpenGL] use of shaders](/img/73/1322afec8add6462ca4b82cb8112d1.png)
[OpenGL] use of shaders
![[MySQL] - [subquery]](/img/81/0880f798f0f41724fd485ae82d142d.png)
[MySQL] - [subquery]

Log4qt memory leak, use of heob memory detection tool

Leetcode buckle classic problem -- 4. Find the median of two positively ordered arrays

多线程购物

一篇长文---深入理解synchronized
随机推荐
Gin abort cannot prevent subsequent code problems
Scala higher order (10): exception handling in Scala
@RequestMapping 用法详解
反射reflect
2-unified return class dto object
QT基础第二天(2)qt基础部件:按钮类,布局类,输出类,输入类,容器等个别举例
Spingboot integrates the quartz framework to realize dynamic scheduled tasks (support real-time addition, deletion, modification and query tasks)
[summer daily question] Luogu p6336 [coci2007-2008 2] bijele
【暑期每日一题】洛谷 P6500 [COCI2010-2011#3] ZBROJ
[summer daily question] Luogu p7760 [coci2016-2017 5] tuna
How does MySQL convert rows to columns?
Gin routing, parameters, output
Gin parameter validation
【暑期每日一题】洛谷 P6461 [COCI2006-2007#5] TRIK
mysql 单表最多能存多少数据?
QT连接两个qslite数据库报错QSqlQuery::exec: database not open
Some learning and understanding of vintage analysis
Job 7.28 file IO and standard IO
基于高阶无六环的LDPC最小和译码matlab仿真
log4qt内存泄露问题,heob内存检测工具的使用