当前位置:网站首页>Implement softmax classification from zero sum using mxnet
Implement softmax classification from zero sum using mxnet
2022-07-26 17:10:00 【Full stack programmer webmaster】
Hello everyone , I meet you again , I'm the king of the whole stack
1.softmax From zero
from mxnet.gluon import data as gdata
from sklearn import datasets
from mxnet import nd,autograd# Load data set
digits = datasets.load_digits()
features,labels = nd.array(digits['data']),nd.array(digits['target'])
print(features.shape,labels.shape)
labels_onehot = nd.one_hot(labels,10)
print(labels_onehot.shape)(1797, 64) (1797,)
(1797, 10)class softmaxClassifier:
def __init__(self,inputs,outputs):
self.inputs = inputs
self.outputs = outputs
self.weight = nd.random.normal(scale=0.01,shape=(inputs,outputs))
self.bias = nd.zeros(shape=(1,outputs))
self.weight.attach_grad()
self.bias.attach_grad()
def forward(self,x):
output = nd.dot(x,self.weight) + self.bias
return self._softmax(output)
def _softmax(self,x):
step1 = x.exp()
step2 = step1.sum(axis=1,keepdims=True)
return step1 / step2
def _bgd(self,params,learning_rate,batch_size):
'''
Batch gradient descent
'''
for param in params: # Use it directly mxnet The automatic gradient of
param[:] = param - param.grad * learning_rate / batch_size
def loss(self,y_pred,y):
return nd.sum((-y * y_pred.log())) / len(y)
def dataIter(self,x,y,batch_size):
dataset = gdata.ArrayDataset(x,y)
return gdata.DataLoader(dataset,batch_size,shuffle=True)
def fit(self,x,y,learning_rate,epoches,batch_size):
for epoch in range(epoches):
for x_batch,y_batch in self.dataIter(x,y,batch_size):
with autograd.record():
y_pred = self.forward(x_batch)
l = self.loss(y_pred,y_batch)
l.backward()
self._bgd([self.weight,self.bias],learning_rate,batch_size)
if epoch % 50 == 0:
y_all_pred = self.forward(x)
print('epoch:{},loss:{},accuracy:{}'.format(epoch+50,self.loss(y_all_pred,y),self.accuracyScore(y_all_pred,y)))
def predict(self,x):
y_pred = self.forward(x)
return y_pred.argmax(axis=0)
def accuracyScore(self,y_pred,y):
acc_sum = (y_pred.argmax(axis=1) == y.argmax(axis=1)).sum().asscalar()
return acc_sum / len(y)sfm_clf = softmaxClassifier(64,10)
sfm_clf.fit(features,labels_onehot,learning_rate=0.1,epoches=500,batch_size=200)epoch:50,loss:
[1.9941667]
<NDArray 1 @cpu(0)>,accuracy:0.3550361713967724
epoch:100,loss:
[0.37214527]
<NDArray 1 @cpu(0)>,accuracy:0.9393433500278241
epoch:150,loss:
[0.25443634]
<NDArray 1 @cpu(0)>,accuracy:0.9549248747913188
epoch:200,loss:
[0.20699367]
<NDArray 1 @cpu(0)>,accuracy:0.9588202559821926
epoch:250,loss:
[0.1799827]
<NDArray 1 @cpu(0)>,accuracy:0.9660545353366722
epoch:300,loss:
[0.1619963]
<NDArray 1 @cpu(0)>,accuracy:0.9677239844184753
epoch:350,loss:
[0.14888664]
<NDArray 1 @cpu(0)>,accuracy:0.9716193656093489
epoch:400,loss:
[0.13875261]
<NDArray 1 @cpu(0)>,accuracy:0.9738452977184195
epoch:450,loss:
[0.13058177]
<NDArray 1 @cpu(0)>,accuracy:0.9760712298274903
epoch:500,loss:
[0.12379646]
<NDArray 1 @cpu(0)>,accuracy:0.9777406789092933print(' Predicted results :',sfm_clf.predict(features[:10]))
print(' The real result :',labels[:10]) Predicted results :
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
<NDArray 10 @cpu(0)>
The real result :
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
<NDArray 10 @cpu(0)>2. Use mxnet Realization softmax classification
from mxnet import gluon,nd,autograd,init
from mxnet.gluon import nn,trainer,loss as gloss,data as gdata
# Defining models
net = nn.Sequential()
net.add(nn.Dense(10))
# Initialize model
net.initialize(init=init.Normal(sigma=0.01))
# Loss function
loss = gloss.SoftmaxCrossEntropyLoss(sparse_label=False)
# optimization algorithm
optimizer = trainer.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1})
# Training
epoches = 500
batch_size = 200
dataset = gdata.ArrayDataset(features, labels_onehot)
data_iter = gdata.DataLoader(dataset,batch_size,shuffle=True)
for epoch in range(epoches):
for x_batch,y_batch in data_iter:
with autograd.record():
l = loss(net.forward(x_batch), y_batch).sum() / batch_size
l.backward()
optimizer.step(batch_size)
if epoch % 50 == 0:
y_all_pred = net.forward(features)
acc_sum = (y_all_pred.argmax(axis=1) == labels_onehot.argmax(axis=1)).sum().asscalar()
print('epoch:{},loss:{},accuracy:{}'.format(epoch+50,loss(y_all_pred,labels_onehot).sum() / len(labels_onehot),acc_sum/len(y_all_pred)))epoch:50,loss:
[2.1232333]
<NDArray 1 @cpu(0)>,accuracy:0.24652198107957707
epoch:100,loss:
[0.37193483]
<NDArray 1 @cpu(0)>,accuracy:0.9410127991096272
epoch:150,loss:
[0.25408813]
<NDArray 1 @cpu(0)>,accuracy:0.9543683917640512
epoch:200,loss:
[0.20680156]
<NDArray 1 @cpu(0)>,accuracy:0.9627156371730662
epoch:250,loss:
[0.1799252]
<NDArray 1 @cpu(0)>,accuracy:0.9666110183639399
epoch:300,loss:
[0.16203885]
<NDArray 1 @cpu(0)>,accuracy:0.9699499165275459
epoch:350,loss:
[0.14899409]
<NDArray 1 @cpu(0)>,accuracy:0.9738452977184195
epoch:400,loss:
[0.13890252]
<NDArray 1 @cpu(0)>,accuracy:0.9749582637729549
epoch:450,loss:
[0.13076076]
<NDArray 1 @cpu(0)>,accuracy:0.9755147468002225
epoch:500,loss:
[0.1239901]
<NDArray 1 @cpu(0)>,accuracy:0.9777406789092933Publisher : Full stack programmer stack length , Reprint please indicate the source :https://javaforall.cn/120005.html Link to the original text :https://javaforall.cn
边栏推荐
- 如何借助自动化工具落地DevOps|含低代码与DevOps应用实践
- Speaker recruitment | AI time recruit icml/ijcai 2022 as a Chinese speaker!!!
- maximum likelihood estimation
- 公安部发出暑期旅游客运交通安全预警:手握方向盘 绷紧安全弦
- The Ministry of Public Security issued a traffic safety warning for summer tourism passenger transport: hold the steering wheel and tighten the safety string
- regular expression
- Vlang's way of beating drums
- Video media video
- [Luogu cf643f] bears and juice (conclusion)
- [daily3] vgg16 learning
猜你喜欢

Anaconda download and Spyder error reporting solution

How to ensure cache and database consistency

Are CRM and ERP the same thing? What's the difference?

Win11系统如何一键进行重装?

My meeting of OA project (meeting seating & submission for approval)

Chapter 1 Overview - Section 1 - 1.3 composition of the Internet

Matlab论文插图绘制模板第40期—带偏移扇区的饼图
![[Development Tutorial 9] crazy shell arm function mobile phone-i2c tutorial](/img/9d/2a1deca934e6d56d729922b1d9e515.png)
[Development Tutorial 9] crazy shell arm function mobile phone-i2c tutorial

2022 software testing skills postman+newman+jenkins continuous integration practical tutorial

【开发教程9】疯壳·ARM功能手机-I2C教程
随机推荐
导数、微分、偏导数、全微分、方向导数、梯度的定义与关系
Win11如何关闭共享文件夹
“青出于蓝胜于蓝”,为何藏宝计划(TPC)是持币生息最后的一朵白莲花
Implementing DDD based on ABP -- aggregation and aggregation root practice
[visdrone data set] yolov7 training visdrone data set and results
Pack tricks
中金证券vip账户找谁开安全啊?
[development tutorial 7] crazy shell · open source Bluetooth heart rate waterproof sports Bracelet - capacitive touch
Anaconda download and Spyder error reporting solution
Can TCP and UDP use the same port?
如何借助自动化工具落地DevOps|含低代码与DevOps应用实践
Vlang's way of beating drums
Thoroughly uncover how epoll realizes IO multiplexing
Alibaba Cloud Toolkit —— 项目一键部署工具
营销指南 | 几种常见的微博营销打法
Speaker recruitment | AI time recruit icml/ijcai 2022 as a Chinese speaker!!!
There are six ways to help you deal with the simpledateformat class, which is not a thread safety problem
How to ensure cache and database consistency
Thinkphp历史漏洞复现
Create MySQL function: access denied; you need (at least one of) the SUPER privilege(s) for this operation