当前位置:网站首页>[LZY learning notes dive into deep learning] 3.5 image classification dataset fashion MNIST
[LZY learning notes dive into deep learning] 3.5 image classification dataset fashion MNIST
2022-07-03 10:18:00 【DadongDer】
Fashion-MNIST Data sets
Data set profile
MNIST Data sets [LeCun et al., 1998] It is in image classification ⼴ Pan emissary ⽤ Of the data set ⼀, But as a benchmark data set is too simple . We
Will make ⽤ Similar but more complex Fashion-MNIST Data sets [Xiao et al., 2017].
Fashion-MNIST from 10 Images of two categories , Each category consists of training data sets (train dataset) Medium 6000 Images and test data
Set (test dataset) Medium 1000 It's made up of images . therefore , The training set and the test set contain 60000 and 10000 Zhang image . Test data set
Can't ⽤ In training , only ⽤ To evaluate model performance .
Each loses ⼊ The height and width of the image are 28 Pixels . The data set consists of grayscale images , The number of channels is 1.
Fashion-MNIST It contains 10 Categories , Respectively t-shirt(T T-shirt )、trouser( pants ⼦)、pullover( Pullover )、dress( even ⾐
skirt )、coat( coat )、sandal( Sandals )、shirt( shirt )、sneaker( Sports shoes )、bag( package ) and ankle boot( Boots ).
Small trial ox knife version 1.0
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt
import time
class Timer:
def __init__(self):
self.times = []
self.tik = None
self.start()
def start(self):
self.tik = time.time()
def stop(self):
self.times.append(time.time() - self.tik)
return self.times[-1]
def avg(self):
return sum(self.times) / len(self.times)
def sum(self):
return sum(self.times)
def cumsum(self):
return np.array(self.times).cumsum().tolist()
# step 1 Reading data sets
# adopt ToTensor Instance transfers image data from PIL Type conversion to 32 Bit floating point format , And divide by 255 Yes, all pixel values are 01 Between
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
print(len(mnist_train), len(mnist_test))
print(mnist_train[0][0].shape)
# print(mnist_train[0])
# print(mnist_train[0][0])
# Convert between numeric label index and its text name
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# Sample visualization
def show_images(imgs, num_rows, num_cols, titles=None):
_, axes = plt.subplots(num_rows, num_cols)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
plt.show()
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
print(X.shape)
# show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
# comment show_images(...) when running the following code
# step 2 Read ⼩ Batch
batch_size = 256
def get_dataloader_workers():
return 0
# return 4 Use 4 Process to read data
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())
# train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True)
# In each iteration , The data loader reads every time ⼀ Small batch data ,⼤ Xiao Wei batch_size.
# Through the built-in data iterator , We can randomly disrupt all the samples , thus ⽆ partial ⻅ To read small batches .
timer = Timer()
for X, y in train_iter:
continue
print(f'{
timer.stop(): .2f} sec')
Integrated version 2.0
import torchvision
from torch.utils import data
from torchvision import transforms
def get_dataloader_workers():
return 0
# return 4
# Used to get and read Fashion-MNIST Data sets . This function returns the data iterator of the training set and the verification set .
# Besides , This function also accepts ⼀ Optional parameters resize,⽤ To resize the image to another ⼀ Species shape .
def load_data_fashion_mnist(batch_size, resize=None):
# Download the dataset and load it into memory
# adopt ToTensor Instance transfers image data from PIL Type conversion to 32 Bit floating point format , And divide by 255 Yes, all pixel values are 01 Between
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
# print(trans)
# Compose(
# Resize(size=64, interpolation=bilinear)
# ToTensor()
# )
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
# Test the image resizing function
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
边栏推荐
- 20220608其他:逆波兰表达式求值
- Leetcode-513:找树的左下角值
- Deep learning by Pytorch
- Leetcode bit operation
- LeetCode - 508. 出现次数最多的子树元素和 (二叉树的遍历)
- Basic use and actual combat sharing of crash tool
- . DLL and Differences between lib files
- What can I do to exit the current operation and confirm it twice?
- Leetcode-100:相同的树
- My openwrt learning notes (V): choice of openwrt development hardware platform - mt7688
猜你喜欢
波士顿房价预测(TensorFlow2.9实践)
CV learning notes - feature extraction
LeetCode - 673. 最长递增子序列的个数
Anaconda installation package reported an error packagesnotfounderror: the following packages are not available from current channels:
Label Semantic Aware Pre-training for Few-shot Text Classification
LeetCode - 1172 餐盘栈 (设计 - List + 小顶堆 + 栈))
Opencv feature extraction - hog
LeetCode - 508. Sum of subtree elements with the most occurrences (traversal of binary tree)
QT self drawing button with bubbles
CV learning notes - image filter
随机推荐
QT self drawing button with bubbles
Opencv Harris corner detection
RESNET code details
Development of intelligent charging pile (I): overview of the overall design of the system
Anaconda安装包 报错packagesNotFoundError: The following packages are not available from current channels:
Positive and negative sample division and architecture understanding in image classification and target detection
Opencv interview guide
20220604数学:x的平方根
Problems encountered when MySQL saves CSV files
Swing transformer details-2
Policy Gradient Methods of Deep Reinforcement Learning (Part Two)
Retinaface: single stage dense face localization in the wild
Label Semantic Aware Pre-training for Few-shot Text Classification
Standard library header file
My openwrt learning notes (V): choice of openwrt development hardware platform - mt7688
4.1 Temporal Differential of one step
CV learning notes ransca & image similarity comparison hash
3.1 Monte Carlo Methods & case study: Blackjack of on-Policy Evaluation
Leetcode-112: path sum
CV learning notes - Stereo Vision (point cloud model, spin image, 3D reconstruction)