当前位置:网站首页>PyTorch①---加载数据、tensorboard的使用

PyTorch①---加载数据、tensorboard的使用

2022-08-02 14:07:00 伏月三十

两大函数

dir():打开
help():该函数的使用方法
在控制台:
dir()打开torch函数,看里面都有那些方法。
help()查看torch.cuda.is_available()函数的使用方法。
在这里插入图片描述

PyTorth加载数据

如何读取数据(得到自己想要的数据),主要涉及到两个类:
eg:一堆垃圾,如何获取可回收垃圾。
Dataset:提供一种方式去获取数据及其label。如何去获取想要的数据,例如把可回收垃圾提取出来,并且编号。【如何获取每一个数据及其label】【告诉我们总共有多少个数据】
Dataloader:为后面的网络提供不同的数据形式。例如送进网络时,对数据进行打包(数据不会一个一个送进去)。

Dataset的使用

from torch.utils.data import Dataset
from PIL import Image
import os
#继承Dataset类
class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        '''初始化类:根据一个类创建实例。为整个class/后面的函数提供全局变量'''
        '''获取图片的步骤 1、创建图片的列表:获取文件夹;获取文件夹下所有的图片root_dir 2、label:这里的label是图片上一级的名称(文件夹的名字) '''
        # self:指定了类里面的一个全局变量
        #获取图片的列表和label
        #dataset/train
        self.root_dir=root_dir
        #ants或bees,注意:文件名就是label
        self.label_dir=label_dir
        #将路径连接起来,就是ants或bees的路径
        self.path=os.path.join(self.root_dir,self.label_dir)
        #获取所有图片的路径(只是获取了所有图片的名字)
        self.img_path=os.listdir(self.path)


    def __getitem__(self, idx):
        '''从路径获取每一张图片'''
        #从列表里拿出图片的名称
        img_name=self.img_path[idx]
        #加上相对路径:加上每一张图片的路径
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        #获取img:从路径打开图片
        img=Image.open(img_item_path)
        #获取label
        label=self.label_dir
        return img,label
    def __len__(self):
        '''返回数据集的长度'''
        return len(self.img_path)

#获取蚂蚁的数据集
root_dir="dataset/train"
ants_label_dir="ants"
ants_dataset=MyData(root_dir,ants_label_dir)
#获取蜜蜂的数据集
bees_label_dir="bees"
bees_dataset=MyData(root_dir,bees_label_dir)

#整个数据集:对象可以相加
train_dataset=ants_dataset+bees_dataset

Dataloader的使用

待定

tensorboard的使用

add_scalar()函数:可视化
add_image()函数:加载图片

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
#创建一个实例,生成的事件文件存到logs里
writer=SummaryWriter("logs")
'''加载图片'''
image_path="data/train/bees_image/29494643_e3410f0d37.jpg"
img_PIL=Image.open(image_path)
img_array=np.array(img_PIL)

writer.add_image("test1",img_array,3,dataformats='HWC')
'''可视化'''
#y=x
for i in range(100):
    writer.add_scalar("y=2x",2*i,i)
writer.close()

【注意事项】
writer.add_image(“test1”,img_array,3,dataformats=‘HWC’)
test1:为标题。可修改,修改后图片变位置
img_array:图片,为ndarray形式。torch.Tensor, numpy.array, or string/blobname。
3:第几步。可修改,修改后图片变位置
dataformats=‘HWC’:图片的长、宽、通道。

在终端运行:
在这里插入图片描述
结果:
在这里插入图片描述
在这里插入图片描述

原网站

版权声明
本文为[伏月三十]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_45703331/article/details/125989946