当前位置:网站首页>机器学习:numpy版本线性回归预测波士顿房价
机器学习:numpy版本线性回归预测波士顿房价
2022-07-02 23:36:00 【HanZee】
数据链接
链接: https://pan.baidu.com/s/1uDG_2IZVZCn9kndZ_ZIGaA?pwd=nec2 提取码: nec2
导入数据
import numpy as np
path = 'Desktop/波士顿房价/trian.csv'
data = np.loadtxt(path, delimiter = ",", skiprows=1)
data.shape
划分数据
train = data[:int(data.shape[0]*0.8)]
test = data[int(data.shape[0]*0.8):]
print(train.shape, test.shape)
train_x = train[:,:-1]
train_y = train[:,13:]
test_x = test[:,:-1]
test_y = test[:,13:]
print(train_x.shape, train_y.shape)
模型
class Network:
def __init__(self, num_weights):
self.num_weights = num_weights
self.w = np.random.rand(num_weights, 1)
self.b = 0
#前向计算
def forward(self, x):
z = np.dot(x, self.w) + self.b
return z
#计算loss
def loss(self, z, y):
cost = (z-y)*(z-y)
cost = np.mean(cost)
return cost
#计算梯度
def gradient(self, z, y):
w = (z-y)*train_x
w = np.mean(w, axis = 0)
w = np.array(w).reshape([13,1])
b = z-y
b = np.mean(b)
return w, b
#更新参数
def update(self, gradient_w, gradient_b, eta):
self.w = self.w - eta*gradient_w
self.b = self.b - eta*gradient_b
#训练
def train(self, items, eta):
for i in range(items):
z = self.forward(train_x)
loss = self.loss(z, train_y)
gradient_w, gradient_b = self.gradient(z, train_y)
self.update(gradient_w, gradient_b, eta)
if i%100 ==0:
test_loss = self.test()
print('item:',i,'loss:', loss, 'test_loss:', test_loss)
#测试
def test(self):
z = self.forward(test_x)
loss = self.loss(z,test_y)
return loss
net = Network(13)
net.train(1000000, eta= 6e-6)
边栏推荐
猜你喜欢
随机推荐
Which software can translate an English paper in its entirety?
How SQLSEVER removes data with duplicate IDS
Bigder:32/100 测试发现的bug开发认为不是bug怎么处理
NC50528 滑动窗口
关于XML一些介绍和注意事项
The "2022 China Digital Office Market Research Report" can be downloaded to explain the 176.8 billion yuan market in detail
在线预览Word文档
Preview word documents online
NC24840 [USACO 2009 Mar S]Look Up
写论文可以去哪些网站搜索参考文献?
关于Unity屏幕相关Screen的练习题目,Unity内部环绕某点做运动
[MCU project training] eight way answering machine
多进程编程(一):基本概念
NC24840 [USACO 2009 Mar S]Look Up
Wechat applet obtains the information of an element (height, width, etc.) and converts PX to rpx.
antv x6节点拖拽到画布上后的回调事件(踩大坑记录)
node_modules删不掉
setInterval定时器在ie不生效原因之一:回调的是箭头函数
百数不断创新,打造自由的低代码办公工具
Markdown tutorial