当前位置:网站首页>Linear regression from zero sum using mxnet
Linear regression from zero sum using mxnet
2022-07-26 17:10:00 【Full stack programmer webmaster】
Hello everyone , I meet you again , I'm the king of the whole stack
1. Linear regression is achieved from zero
from mxnet import ndarray as nd
import matplotlib.pyplot as plt
import numpy as np
import timenum_inputs = 2
num_examples = 1000
w = [2,-3.4]
b = 4.2
x = nd.random.normal(scale=1,shape=(num_examples,num_inputs))
y = nd.dot(x,nd.array(w).T) + b
y += nd.random.normal(scale=0.01,shape=y.shape)
print(y.shape)(1000,)plt.scatter(x[:,1].asnumpy(),y.asnumpy())
plt.show()class LinearRegressor:
def __init__(self,input_shape,output_shape):
self.input_shape = input_shape
self.output_shape = output_shape
self.weight = nd.random.normal(scale=0.01,shape=(input_shape,1))
self.bias = nd.zeros(shape=(1))
def fit(self,x,y,learning_rate,epoches,batch_size):
start = time.time()
for epoch in range(epoches):
for batch_data in self.batches(x,y,batch_size):
x_batch,y_batch = batch_data[0],batch_data[1]
y_hat = self.forward(x_batch)
loss = self.mse(y_batch,y_hat)
error = y_hat - y_batch.reshape(y_hat.shape)
self.optimizer(x_batch,error,learning_rate)
print('epoch:{},loss:{:.4f}'.format(epoch+1,self.mse(y,self.forward(x)).asscalar()))
print('weight:',self.weight)
print('bias:',self.bias)
print('time interval:{:.2f}'.format(time.time() - start))
def forward(self,x):
return nd.dot(x,self.weight) + self.bias
def mse(self,y,y_hat):
m = len(y)
mean_square = nd.sum((y - y_hat.reshape(y.shape)) ** 2) / (2 * m)
return mean_square
def optimizer(self,x,error,learning_rate):
gradient = 1/len(x) * nd.dot(x.T,error)
self.weight = self.weight - learning_rate * gradient
self.bias = self.bias - learning_rate * error[0]
def batches(self,x,y,batch_size):
nSamples = len(x)
nBatches = nSamples // batch_size
indexes = np.random.permutation(nSamples)
for i in range(nBatches):
yield (x[indexes[i*batch_size:(i+1)*batch_size]], y[indexes[i*batch_size:(i+1)*batch_size]])lr = LinearRegressor(input_shape=2,output_shape=1)
lr.fit(x,y,learning_rate=0.1,epoches=20,batch_size=200)epoch:1,loss:5.7996
epoch:2,loss:2.1903
epoch:3,loss:0.9078
epoch:4,loss:0.3178
epoch:5,loss:0.0795
epoch:6,loss:0.0204
epoch:7,loss:0.0156
epoch:8,loss:0.0068
epoch:9,loss:0.0022
epoch:10,loss:0.0009
epoch:11,loss:0.0003
epoch:12,loss:0.0001
epoch:13,loss:0.0001
epoch:14,loss:0.0001
epoch:15,loss:0.0000
epoch:16,loss:0.0000
epoch:17,loss:0.0000
epoch:18,loss:0.0001
epoch:19,loss:0.0001
epoch:20,loss:0.0001
weight:
[[ 1.999662]
[-3.400079]]
<NDArray 2x1 @cpu(0)>
bias:
[4.2030163]
<NDArray 1 @cpu(0)>
time interval:0.222. Simple implementation of linear regression
from mxnet import gluon
from mxnet.gluon import loss as gloss
from mxnet.gluon import data as gdata
from mxnet.gluon import nn
from mxnet import init,autograd
# Defining models
net = nn.Sequential()
net.add(nn.Dense(1))
# Initialize model parameters
net.initialize(init.Normal(sigma=0.01))
# Define the loss function
loss = gloss.L2Loss()
# Define optimization algorithms
optimizer = gluon.Trainer(net.collect_params(), 'sgd',{'learning_rate':0.1})
epoches = 20
batch_size = 200
# Get batch data
dataset = gdata.ArrayDataset(x,y)
data_iter = gdata.DataLoader(dataset,batch_size,shuffle=True)
# Training models
start = time.time()
for epoch in range(epoches):
for batch_x,batch_y in data_iter:
with autograd.record():
l = loss(net(batch_x),batch_y)
l.backward()
optimizer.step(batch_size)
l = loss(net(x),y)
print('epoch:{},loss:{:.4f}'.format(epoch+1,l.mean().asscalar()))
print('weight:',net[0].weight.data())
print('weight:',net[0].bias.data())
print('time interval:{:.2f}'.format(time.time() - start))epoch:1,loss:5.7794
epoch:2,loss:1.9934
epoch:3,loss:0.6884
epoch:4,loss:0.2381
epoch:5,loss:0.0825
epoch:6,loss:0.0286
epoch:7,loss:0.0100
epoch:8,loss:0.0035
epoch:9,loss:0.0012
epoch:10,loss:0.0005
epoch:11,loss:0.0002
epoch:12,loss:0.0001
epoch:13,loss:0.0001
epoch:14,loss:0.0001
epoch:15,loss:0.0001
epoch:16,loss:0.0000
epoch:17,loss:0.0000
epoch:18,loss:0.0000
epoch:19,loss:0.0000
epoch:20,loss:0.0000
weight:
[[ 1.9996439 -3.400059 ]]
<NDArray 1x2 @cpu(0)>
weight:
[4.2002025]
<NDArray 1 @cpu(0)>
time interval:0.863. attach :mxnet Initialization method of loss function kernel in
- Loss function all = [‘Loss’, ‘L2Loss’, ‘L1Loss’, ‘SigmoidBinaryCrossEntropyLoss’, ‘SigmoidBCELoss’, ‘SoftmaxCrossEntropyLoss’, ‘SoftmaxCELoss’, ‘KLDivLoss’, ‘CTCLoss’, ‘HuberLoss’, ‘HingeLoss’, ‘SquaredHingeLoss’, ‘LogisticLoss’, ‘TripletLoss’, ‘PoissonNLLLoss’, ‘CosineEmbeddingLoss’]
- Initialization method [‘Zero’, ‘One’, ‘Constant’, ‘Uniform’, ‘Normal’, ‘Orthogonal’,’Xavier’,’MSRAPrelu’,’Bilinear’,’LSTMBias’,’DusedRNN’]
Publisher : Full stack programmer stack length , Reprint please indicate the source :https://javaforall.cn/120006.html Link to the original text :https://javaforall.cn
边栏推荐
- movable-view 组件(可上下左右拖动 )
- PyQt5快速开发与实战 3.4 信号与槽关联
- Reuse idea through registry
- Alibaba cloud Toolkit - project one click deployment tool
- Batch normalization batch_ normalization
- Analysis of the advantages of eolink and JMeter interface testing
- My SQL is OK. Why is it still so slow? MySQL locking rules
- ES:Compressor detection can only be called on some xcontent bytes or compressed xcontent bytes
- 【Express接收Get、Post、路由请求参数】
- 37. [categories of overloaded operators]
猜你喜欢

【飞控开发基础教程2】疯壳·开源编队无人机-定时器(LED 航情灯、指示灯闪烁)

How can win11 system be reinstalled with one click?

The difference and efficiency comparison of three methods of C # conversion integer

2022 software testing skills postman+newman+jenkins continuous integration practical tutorial

How emqx 5.0 under the new architecture of mria+rlog realizes 100million mqtt connections

【Express接收Get、Post、路由请求参数】

【开发教程7】疯壳·开源蓝牙心率防水运动手环-电容触摸

Marxan模型保护区优化与保护空缺甄选技术、InVEST生态系统中的应用

Implementing DDD based on ABP -- aggregation and aggregation root practice

什么是分布式定时任务框架?
随机推荐
What is a distributed timed task framework?
How does win11 reinstall the system?
Pack tricks
Win11怎么自动清理回收站?
Win11怎么重新安装系统?
Wechat applet - network data request
JD Sanmian: I want to query a table with tens of millions of data. How can I operate it?
搭建typora图床
Digital currency of quantitative transactions - merge transaction by transaction data through timestamp and direction (large order consolidation)
Execution process of select statement in MySQL
PXE高效批量网络装机
[basic course of flight control development 1] crazy shell · open source formation UAV GPIO (LED flight information light and signal light control)
The Ministry of Public Security issued a traffic safety warning for summer tourism passenger transport: hold the steering wheel and tighten the safety string
Create MySQL function: access denied; you need (at least one of) the SUPER privilege(s) for this operation
Who is safe to open the VIP account of CICC securities?
接口比较器
[development tutorial 8] crazy shell · open source Bluetooth heart rate waterproof sports Bracelet - triaxial meter pace
Differences between the use of structs and classes
[flight control development basic tutorial 3] crazy shell · open source formation UAV - serial port (basic transceiver)
TCP 和 UDP 可以使用相同端口吗?