当前位置:网站首页>3. Classification problems - initial experience of handwritten digit recognition
3. Classification problems - initial experience of handwritten digit recognition
2022-07-27 05:58:00 【Pie star's favorite spongebob】
List of articles
Classification problem introduction
MNIST
The size of each picture is 28*28
matrix [28,28]->[784]->[1,784], Two dimensions become one dimension .
loss
H3:[1,1], first 1 Is the number of pictures , the second 1 Express 0-9 The number of
Y:[0/1/,/9], use one_hot, If its label is 3->[0,0,0,3,0,0,0,0,0]
Suppose the number on the picture is 5, be H3 May be [0.1,0.0.01,0.01,0.03,0.01,0.8,0.01,0.01,0.01,0.01], Actually Y Should be [0,0,0,0,0,1,0,0,0,0].
| H3 | 0.1 | 0.01 | 0.01 | 0.03 | 0.01 | 0.8 | 0.01 | 0.01 | 0.01 | 0.01 |
|---|---|---|---|---|---|---|---|---|---|---|
| Y | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
loss Use Euclidean distance formula to calculate
gradient descent
H1=relu(W1x+b1)
H2=relu(W2H1+b2)
H3=W3*H2+b3
pred=H3
out = Σ(pred-Y)**2
result
argmax(pred): The maximum value returns the index
First experience of handwritten digits
1. Data loading (Load data)
batch_size = 512
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# hold numpy Format to tensor
# Regularization , stay 0 near , Can improve performance
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False)
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')
2. Creating models (Build Model)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
#wx+b
self.fc1 = nn.Linear(28*28,256) #,28*28 yes x Dimensions ,256 It is generally decided randomly based on experience , The big dimension becomes the small dimension
self.fc2 = nn.Linear(256,64) # The input of the second layer is the same as the output of the previous layer
self.fc3 = nn.Linear(64,10) #10 classification , This is not based on experience
# The calculation process
def forward(self,x):
# x: [b,1,28,28]
# h1 =relu(xw1+b1)
x = F.relu(self.fc1(x))
# h2 = relu(h1w2+b2)
x = F.relu(self.fc2(x))
# h3 = h2w3+b3, The last layer adds activation functions according to the situation
x = self.fc3(x)
return x
3. Training (Train)
#train
#net.parameters() return [w1,b1,w2,b2,w3,b3], This is what we need to optimize ; lr It's a learning step ;momentum Help better optimize
net = Net()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
# hold loss Save up
train_loss = []
for epoch in range(3):
for batch_idx, (x,y) in enumerate(train_loader):
# x: [b,1,28,28] y : [512]
# [b,1,28,28] Make it even [b,feature],size(0) yes batch
x = x.view(x.size(0),28*28)
# become [b,10]
out = net(x)
# [b,10], Actual y
y_onehot= one_hot(y)
# loss=mse(out,y_onehot), Find its mean square deviation
loss = F.mse_loss(out,y_onehot)
# Zero gradient
optimizer.zero_grad()
# Calculate the gradient
loss.backward()
# Update gradient :w‘ = w-lr*grad
optimizer.step()
# Visualize gradient descent , Record the data
train_loss.append(loss.item())
if batch_idx % 10 == 0:
print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
#we can get optimal [w1,b1,w2,b2,w3,b3]

4. test (Test)
total_correct = 0
for x,y in test_loader:
x = x.view(x.size(0),28*28)
#out : [b,10]
out = net(x)
#out -> pred:[b]
pred =out.argmax(dim=1)
# The sum of the number of current forecast pairs is converted to float, At this time or tensor type , Then convert to numeric type
correct = pred.eq(y).sum().float().item()
total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc', acc)
x,y = next(iter(test_loader))
out = net(x.view(x.size(0),28*28))
pred =out.argmax(dim=1)
plot_image(x,pred,'test')

边栏推荐
- Minimum handling charges and margins for futures companies
- Deploy redis with docker for high availability master-slave replication
- Day 3. Suicidal ideation and behavior in institutions of higher learning: A latent class analysis
- GBASE 8C——SQL参考6 sql语法(1)
- Day 11. Evidence for a mental health crisis in graduate education
- 数字图像处理第五章——图像复原与重建
- Move protocol launched a beta version, and you can "0" participate in p2e
- Day14. Using interpretable machine learning method to distinguish intestinal tuberculosis and Crohn's disease
- 10.梯度、激活函数和loss
- golang封装mysql涉及到的包以及sqlx和gorm的区别
猜你喜欢

How can I get the lowest handling charge for opening a futures account?

Global evidence of expressed sentimental alterations during the covid-19 pandemics

Day 9. Graduate survey: A love–hurt relationship

使用Docker部署Redis进行高可用主从复制

【mysql学习】8

数字图像处理——第九章 形态学图像处理

2.简单回归问题

Digital image processing Chapter 5 - image restoration and reconstruction

Emoji Emoji for text emotion analysis -improving sentimental analysis accuracy with Emoji embedding

15.GPU加速、minist测试实战和visdom可视化
随机推荐
8.数学运算与属性统计
If the interviewer asks you about JVM, the extra answer of "escape analysis" technology will give you extra points
Day 4.Social Data Sentiment Analysis: Detection of Adolescent Depression Signals
Web3 traffic aggregation platform starfish OS interprets the "p2e" ecosystem of real business
16.过拟合欠拟合
vim编辑器全部删除文件内容
GBASE 8C——SQL参考4 字符集支持
vscode打造golang开发环境以及golang的debug单元测试
Fortex Fangda releases the electronic trading ecosystem to share and win-win with customers
新冠时空分析——Global evidence of expressed sentiment alterations during the COVID-19 pandemic
Day 17.The role of news sentiment in oil futures returns and volatility forecasting
Public opinion & spatio-temporal analysis of infectious diseases literature reading notes
2021中大厂php+go面试题(2)
Day 8.Developing Simplified Chinese Psychological Linguistic Analysis Dictionary for Microblog
17.动量与学习率的衰减
GBase 8c产品简介
Minio8.x version setting policy bucket policy
Day 7. Towards Preemptive Detection of Depression and Anxiety in Twitter
How to not overwrite the target source data when dBSwitch data migrates data increments
18.卷积神经网络