当前位置:网站首页>【load dataset】

【load dataset】

2022-07-05 11:34:00 网络星空(luoc)


在这里插入图片描述

1.generate data txt file

# coding:utf-8
import os
''' 为数据集生成对应的txt文件 '''

train_txt_path = os.path.join("../..", "..", "Data", "train.txt")
train_dir = os.path.join("../..", "..", "Data", "train")

valid_txt_path = os.path.join("../..", "..", "Data", "valid.txt")
valid_dir = os.path.join("../..", "..", "Data", "valid")


def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()


if __name__ == '__main__':
    gen_txt(train_txt_path, train_dir)
    gen_txt(valid_txt_path, valid_dir)


2.MyDataset class

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

3.instance MyDataset

train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=validTransform)

4.compute mean

# coding: utf-8

import numpy as np
import cv2
import random
import os

""" 随机挑选CNum张图片,进行按通道计算均值mean和标准差std 先将像素从0~255归一化至 0-1 再计算 """


train_txt_path = os.path.join("../..", "..", "Data/train.txt")

CNum = 2000     # 挑选多少图片进行计算

img_h, img_w = 32, 32
imgs = np.zeros([img_w, img_h, 3, 1])
means, stdevs = [], []

with open(train_txt_path, 'r') as f:
    lines = f.readlines()
    random.shuffle(lines)   # shuffle , 随机挑选图片

    for i in range(CNum):
        img_path = lines[i].rstrip().split()[0]

        img = cv2.imread(img_path)
        img = cv2.resize(img, (img_h, img_w))

        img = img[:, :, :, np.newaxis]
        imgs = np.concatenate((imgs, img), axis=3)
        print(i)

imgs = imgs.astype(np.float32)/255.


for i in range(3):
    pixels = imgs[:,:,i,:].ravel()  # 拉成一行
    means.append(np.mean(pixels))
    stdevs.append(np.std(pixels))

means.reverse() # BGR --> RGB
stdevs.reverse()

print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))
print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))

5.DataLoder

# pytorch dateset
#import torchvision
#train_data = torchvision.datasets.CIFAR10(root = './data/', train = True, transform = trainTransform, download = True)
#valid_data = torchvision.datasets.CIFAR10(root = './data/', train = False, transform = trainTransform, download = True)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)

6.net loss optimizer scheduler

#import torchvision.models as models
#net = models.vgg() # 创建一个网络

# ------------------------------------ step 3/5 : 定义损失函数和优化器 ------------------------------------

#criterion = nn.CrossEntropyLoss() # 选择损失函数
#optimizer = optim.SGD(net.parameters(), lr=lr_init, momentum=0.9, dampening=0.1) # 选择优化器
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) # 设置学习率下降策略
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc2_1 = nn.Linear(84, 40)
        self.fc3 = nn.Linear(40, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc2_2(x))
        x = self.fc3(x)
        return x

    # 定义权值初始化
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                m.bias.data.zero_()


net = Net()     # 创建一个网络

# ================================ #
# finetune 权值初始化
# ================================ #

# load params
pretrained_dict = torch.load('net_params.pkl')

# 获取当前网络的dict
net_state_dict = net.state_dict()

# 剔除不匹配的权值参数
pretrained_dict_1 = {
    k: v for k, v in pretrained_dict.items() if k in net_state_dict}

# 更新新模型参数字典
net_state_dict.update(pretrained_dict_1)

# 将包含预训练模型参数的字典"放"到新模型中
net.load_state_dict(net_state_dict)

# ------------------------------------ step 3/5 : 定义损失函数和优化器 ------------------------------------
# ================================= #
# 按需设置学习率
# ================================= #

# 将fc3层的参数从原始网络参数中剔除
ignored_params = list(map(id, net.fc3.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())

# 为fc3层设置需要的学习率
optimizer = optim.SGD([
    {
    'params': base_params},
    {
    'params': net.fc3.parameters(), 'lr': lr_init*10}],  lr_init, momentum=0.9, weight_decay=1e-4)

criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)     # 设置学习率下降策略

7. train

在这里插入代码片
原网站

版权声明
本文为[网络星空(luoc)]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_42483745/article/details/125075092