当前位置:网站首页>Handwritten character recognition
Handwritten character recognition
2022-07-29 09:05:00 【Salty salty】
import numpy as np
import torch
from torchvision.datasets import mnist # Import pytorch Built in mnist data
from torch import nn
from torch.autograd import Variable
Download datasets
train_set=mnist.MNIST('./data',train=True,download=True)
test_set=mnist.MNIST('./data',train=False,download=True)
a_data,a_lable=train_set[0] # Show the first data
a_data
![]()
a_lable
![]()
# The data read in earlier is pil Cut view in Library , Convert it to numpy array
a_data=np.array(a_data,dtype='float32')
print(a_data.shape)
![]()
print(a_data)

def data_tf(x):
x=np.array(x,dtype='float32')/255
x=(x-0.5)/0.5
x=x.reshape((-1,))
x=torch.from_numpy(x)
return x
train_set=mnist.MNIST('./data',train=True,transform=data_tf,download=True)
test_set=mnist.MNIST('./data',train=False,transform=data_tf,download=True)
a,a_lable=train_set[0]
print(a.shape)
print(a_lable)
from torch.utils.data import DataLoader
train_data=DataLoader(train_set,batch_size=64,shuffle=True)
test_data=DataLoader(test_set,batch_size=128,shuffle=False)
a,a_lable=next(iter(train_data))# use iter Convert array to lterator,next Will continue to return to the next element
print(a.shape)
print(a_lable.shape)
# Use Sequential Definition 4 Layer neural networks
net=nn.Sequential(
nn.Linear(784,400),
nn.ReLU(),
nn.Linear(400,200),
nn.ReLU(),
nn.Linear(200,100),
nn.ReLU(),
nn.Linear(100,10)
)
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(net.parameters(),1e-1)
losses=[]
acces=[]
eval_losses=[]
eval_acces=[]
for e in range(20):
train_loss=0
train_acc=0
net.train()
for im,lable in train_data:
im=Variable(im)
lable=Variable(lable)
out=net(im)
loss=criterion(out,lable)
optimizer.zero_grad()
loss.backward()
train_loss+=loss.item()
_,pred=out.max(1)
num_correct=(pred==lable).sum().item() # Statistical labels and the correctness of labels
acc=num_correct/im.shape[0]
train_acc+=acc
losses.append(train_loss/len(train_data))
acces.append(train_acc/len(train_data))
eval_loss=0
eval_acc=0
net.eval()
for im, lable in test_data:
im=Variable(im)
lable=Variable(lable)
out=net(im)
loss=criterion(out,lable)
eval_loss+=loss.item()
_,pred=out.max(1)
num_correct=(pred==lable).sum().item()
acc=num_correct/im.shape[0]
eval_acc+=acc
eval_losses.append(eval_loss/len(test_data))
eval_acces.append(eval_acc/len(test_data))
print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, EvalAcc: {:.6f}'.format(e, train_loss / len(train_data), train_acc / len(train_data),
eval_loss / len(test_data), eval_acc / len(test_data)))
We can draw it train loss,train acc,test loss,test acc Graph

边栏推荐
- Cloud security daily 220712: the IBM integration bus integration solution has found a vulnerability in the execution of arbitrary code, which needs to be upgraded as soon as possible
- How to choose effective keywords
- Leetcode question brushing (6)
- Redis series 3: highly available master-slave architecture
- One article tells you the salary after passing the PMP Exam
- Database system design: partition
- GBase 8s数据库有哪些备份恢复方式
- (视频+图文)机器学习入门系列-第3章 逻辑回归
- Unity3d learning notes (I)
- smart-webcomponents 14.2.0 Crack
猜你喜欢

WQS binary learning notes

C # use database to bind listview control data

Regular expression verification version number

Leetcode: interview question 08.14. Boolean operation
![[from_bilibili_dr_can][[advanced control theory] 9_ State observer design] [learning record]](/img/9d/d9a4a3091c00ec65a9492ca49267db.png)
[from_bilibili_dr_can][[advanced control theory] 9_ State observer design] [learning record]

On the charm of code language

正则表达式校验版本号

1.2.24 fastjson deserialization templatesimpl uses chain analysis (very detailed)

Intel将逐步结束Optane存储业务 未来不再开发新产品

Flowable 高级篇
随机推荐
One article tells you the salary after passing the PMP Exam
access数据库可以被远程访问吗
Regular expression verification version number
(Video + graphic) introduction to machine learning series - Chapter 3 logical regression
How does xjson implement four operations?
What are the backup and recovery methods of gbase 8s database
Sublime text create page
On the charm of code language
Data is the main body of future world development, and data security should be raised to the national strategic level
md
2022危险化学品经营单位主要负责人操作证考试题库及答案
Use disco diffusion to generate AI artwork in moment pool cloud
Application of matrix transpose
CVPR 2022 | clonedperson: building a large-scale virtual pedestrian data set of real wear and wear from a single photo
One click automated data analysis! Come and have a look at these treasure tool libraries
Tesseract text recognition -- simple
2022 spsspro certification cup mathematical modeling problem B phase II scheme and post game summary
2022 electrician (elementary) test question simulation test platform operation
Flowable 基础篇2
6.2 function-parameters