当前位置:网站首页>【load dataset】
【load dataset】
2022-07-05 11:34:00 【网络星空(luoc)】
文章目录
1.generate data txt file
# coding:utf-8
import os
''' 为数据集生成对应的txt文件 '''
train_txt_path = os.path.join("../..", "..", "Data", "train.txt")
train_dir = os.path.join("../..", "..", "Data", "train")
valid_txt_path = os.path.join("../..", "..", "Data", "valid.txt")
valid_dir = os.path.join("../..", "..", "Data", "valid")
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
if __name__ == '__main__':
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
2.MyDataset class
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
3.instance MyDataset
train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=validTransform)
4.compute mean
# coding: utf-8
import numpy as np
import cv2
import random
import os
""" 随机挑选CNum张图片,进行按通道计算均值mean和标准差std 先将像素从0~255归一化至 0-1 再计算 """
train_txt_path = os.path.join("../..", "..", "Data/train.txt")
CNum = 2000 # 挑选多少图片进行计算
img_h, img_w = 32, 32
imgs = np.zeros([img_w, img_h, 3, 1])
means, stdevs = [], []
with open(train_txt_path, 'r') as f:
lines = f.readlines()
random.shuffle(lines) # shuffle , 随机挑选图片
for i in range(CNum):
img_path = lines[i].rstrip().split()[0]
img = cv2.imread(img_path)
img = cv2.resize(img, (img_h, img_w))
img = img[:, :, :, np.newaxis]
imgs = np.concatenate((imgs, img), axis=3)
print(i)
imgs = imgs.astype(np.float32)/255.
for i in range(3):
pixels = imgs[:,:,i,:].ravel() # 拉成一行
means.append(np.mean(pixels))
stdevs.append(np.std(pixels))
means.reverse() # BGR --> RGB
stdevs.reverse()
print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))
print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
5.DataLoder
# pytorch dateset
#import torchvision
#train_data = torchvision.datasets.CIFAR10(root = './data/', train = True, transform = trainTransform, download = True)
#valid_data = torchvision.datasets.CIFAR10(root = './data/', train = False, transform = trainTransform, download = True)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)
6.net loss optimizer scheduler
#import torchvision.models as models
#net = models.vgg() # 创建一个网络
# ------------------------------------ step 3/5 : 定义损失函数和优化器 ------------------------------------
#criterion = nn.CrossEntropyLoss() # 选择损失函数
#optimizer = optim.SGD(net.parameters(), lr=lr_init, momentum=0.9, dampening=0.1) # 选择优化器
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) # 设置学习率下降策略
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc2_1 = nn.Linear(84, 40)
self.fc3 = nn.Linear(40, 10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc2_2(x))
x = self.fc3(x)
return x
# 定义权值初始化
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
m.bias.data.zero_()
net = Net() # 创建一个网络
# ================================ #
# finetune 权值初始化
# ================================ #
# load params
pretrained_dict = torch.load('net_params.pkl')
# 获取当前网络的dict
net_state_dict = net.state_dict()
# 剔除不匹配的权值参数
pretrained_dict_1 = {
k: v for k, v in pretrained_dict.items() if k in net_state_dict}
# 更新新模型参数字典
net_state_dict.update(pretrained_dict_1)
# 将包含预训练模型参数的字典"放"到新模型中
net.load_state_dict(net_state_dict)
# ------------------------------------ step 3/5 : 定义损失函数和优化器 ------------------------------------
# ================================= #
# 按需设置学习率
# ================================= #
# 将fc3层的参数从原始网络参数中剔除
ignored_params = list(map(id, net.fc3.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
# 为fc3层设置需要的学习率
optimizer = optim.SGD([
{
'params': base_params},
{
'params': net.fc3.parameters(), 'lr': lr_init*10}], lr_init, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss() # 选择损失函数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) # 设置学习率下降策略
7. train
在这里插入代码片
边栏推荐
- 我用开天平台做了一个城市防疫政策查询系统【开天aPaaS大作战】
- Dspic33ep clock initialization program
- COMSOL--三维图形的建立
- [LeetCode] Wildcard Matching 外卡匹配
- Technology sharing | common interface protocol analysis
- 如何通俗理解超级浏览器?可以用于哪些场景?有哪些品牌?
- Advanced technology management - what is the physical, mental and mental strength of managers
- Summary of websites of app stores / APP markets
- 网络五连鞭
- Spark Tuning (I): from HQL to code
猜你喜欢
Is it difficult to apply for a job after graduation? "Hundreds of days and tens of millions" online recruitment activities to solve your problems
COMSOL--三维图形的建立
Advanced technology management - what is the physical, mental and mental strength of managers
网络五连鞭
中非 钻石副石怎么镶嵌,才能既安全又好看?
OneForAll安装使用
简单解决redis cluster中从节点读取不了数据(error) MOVED
技术管理进阶——什么是管理者之体力、脑力、心力
CDGA|数据治理不得不坚持的六个原则
Ziguang zhanrui's first 5g R17 IOT NTN satellite in the world has been measured on the Internet of things
随机推荐
Manage multiple instagram accounts and share anti Association tips
How did the situation that NFT trading market mainly uses eth standard for trading come into being?
Project summary notes series wstax kt session2 code analysis
Web API configuration custom route
解决readObjectStart: expect { or n, but found N, error found in #1 byte of ...||..., bigger context ..
2048 game logic
How to make your products as expensive as possible
Solve the problem of slow access to foreign public static resources
Programmers are involved and maintain industry competitiveness
分类TAB商品流多目标排序模型的演进
【爬虫】charles unknown错误
Is it difficult to apply for a job after graduation? "Hundreds of days and tens of millions" online recruitment activities to solve your problems
go语言学习笔记-初识Go语言
紫光展锐全球首个5G R17 IoT NTN卫星物联网上星实测完成
Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
【云原生 | Kubernetes篇】Ingress案例实战(十三)
Startup process of uboot:
I used Kaitian platform to build an urban epidemic prevention policy inquiry system [Kaitian apaas battle]
爬虫(9) - Scrapy框架(1) | Scrapy 异步网络爬虫框架
高校毕业求职难?“百日千万”网络招聘活动解决你的难题