当前位置:网站首页>【load dataset】
【load dataset】
2022-07-05 11:48:00 【Network starry sky (LUOC)】
List of articles
1.generate data txt file
# coding:utf-8
import os
''' Generate corresponding txt file '''
train_txt_path = os.path.join("../..", "..", "Data", "train.txt")
train_dir = os.path.join("../..", "..", "Data", "train")
valid_txt_path = os.path.join("../..", "..", "Data", "valid.txt")
valid_dir = os.path.join("../..", "..", "Data", "valid")
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # obtain train Name of each folder under the file
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # Get all kinds of folders Absolute path
img_list = os.listdir(i_dir) # Get all items under the category folder png Path to picture
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # If it is not png file , skip
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
if __name__ == '__main__':
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
2.MyDataset class
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs # The main thing is to generate this list, then DataLoader Middle feeding index, adopt getitem Read image data
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB') # Pixel values 0~255, stay transfrom.totensor Will divide by 255, Change the pixel value to 0~1
if self.transform is not None:
img = self.transform(img) # Do it here transform, To tensor wait
return img, label
def __len__(self):
return len(self.imgs)
3.instance MyDataset
train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=validTransform)
4.compute mean
# coding: utf-8
import numpy as np
import cv2
import random
import os
""" Random selection CNum A picture , Calculate the mean value by channel mean And standard deviation std First, change the pixels from 0~255 Normalize to 0-1 Calculate again """
train_txt_path = os.path.join("../..", "..", "Data/train.txt")
CNum = 2000 # How many pictures are selected for calculation
img_h, img_w = 32, 32
imgs = np.zeros([img_w, img_h, 3, 1])
means, stdevs = [], []
with open(train_txt_path, 'r') as f:
lines = f.readlines()
random.shuffle(lines) # shuffle , Randomly selected pictures
for i in range(CNum):
img_path = lines[i].rstrip().split()[0]
img = cv2.imread(img_path)
img = cv2.resize(img, (img_h, img_w))
img = img[:, :, :, np.newaxis]
imgs = np.concatenate((imgs, img), axis=3)
print(i)
imgs = imgs.astype(np.float32)/255.
for i in range(3):
pixels = imgs[:,:,i,:].ravel() # Pull into a line
means.append(np.mean(pixels))
stdevs.append(np.std(pixels))
means.reverse() # BGR --> RGB
stdevs.reverse()
print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))
print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
5.DataLoder
# pytorch dateset
#import torchvision
#train_data = torchvision.datasets.CIFAR10(root = './data/', train = True, transform = trainTransform, download = True)
#valid_data = torchvision.datasets.CIFAR10(root = './data/', train = False, transform = trainTransform, download = True)
# structure DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)
6.net loss optimizer scheduler
#import torchvision.models as models
#net = models.vgg() # Create a network
# ------------------------------------ step 3/5 : Define the loss function and optimizer ------------------------------------
#criterion = nn.CrossEntropyLoss() # Choose the loss function
#optimizer = optim.SGD(net.parameters(), lr=lr_init, momentum=0.9, dampening=0.1) # Choose the optimizer
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) # Set learning rate reduction strategy
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc2_1 = nn.Linear(84, 40)
self.fc3 = nn.Linear(40, 10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc2_2(x))
x = self.fc3(x)
return x
# Define weight initialization
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
m.bias.data.zero_()
net = Net() # Create a network
# ================================ #
# finetune Weight initialization
# ================================ #
# load params
pretrained_dict = torch.load('net_params.pkl')
# Get the of the current network dict
net_state_dict = net.state_dict()
# Eliminate the mismatched weight parameters
pretrained_dict_1 = {
k: v for k, v in pretrained_dict.items() if k in net_state_dict}
# Update the new model parameter Dictionary
net_state_dict.update(pretrained_dict_1)
# The dictionary that will contain the pre training model parameters " discharge " To new model
net.load_state_dict(net_state_dict)
# ------------------------------------ step 3/5 : Define the loss function and optimizer ------------------------------------
# ================================= #
# Set the learning rate as needed
# ================================= #
# take fc3 Layer parameters are removed from the original network parameters
ignored_params = list(map(id, net.fc3.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
# by fc3 The learning rate required for layer setup
optimizer = optim.SGD([
{
'params': base_params},
{
'params': net.fc3.parameters(), 'lr': lr_init*10}], lr_init, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss() # Choose the loss function
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) # Set learning rate reduction strategy
7. train
Insert a code chip here
边栏推荐
猜你喜欢
随机推荐
redis集群中hash tag 使用
[upsampling method opencv interpolation]
Liunx prohibit Ping explain the different usage of traceroute
Pytorch linear regression
基于Lucene3.5.0怎样从TokenStream获得Token
[cloud native | kubernetes] actual battle of ingress case (13)
投资理财适合女生吗?女生可以买哪些理财产品?
Pytorch weight decay and dropout
【load dataset】
Evolution of multi-objective sorting model for classified tab commodity flow
Splunk configuration 163 mailbox alarm
谜语1
Prevent browser backward operation
Advanced technology management - what is the physical, mental and mental strength of managers
网络五连鞭
Harbor镜像仓库搭建
【TFLite, ONNX, CoreML, TensorRT Export】
【主流Nivida显卡深度学习/强化学习/AI算力汇总】
[crawler] bugs encountered by wasm
COMSOL -- establishment of 3D graphics