当前位置:网站首页>机器学习:numpy版本线性回归预测波士顿房价
机器学习:numpy版本线性回归预测波士顿房价
2022-07-02 23:36:00 【HanZee】
数据链接
链接: https://pan.baidu.com/s/1uDG_2IZVZCn9kndZ_ZIGaA?pwd=nec2 提取码: nec2
导入数据
import numpy as np
path = 'Desktop/波士顿房价/trian.csv'
data = np.loadtxt(path, delimiter = ",", skiprows=1)
data.shape
划分数据
train = data[:int(data.shape[0]*0.8)]
test = data[int(data.shape[0]*0.8):]
print(train.shape, test.shape)
train_x = train[:,:-1]
train_y = train[:,13:]
test_x = test[:,:-1]
test_y = test[:,13:]
print(train_x.shape, train_y.shape)
模型
class Network:
def __init__(self, num_weights):
self.num_weights = num_weights
self.w = np.random.rand(num_weights, 1)
self.b = 0
#前向计算
def forward(self, x):
z = np.dot(x, self.w) + self.b
return z
#计算loss
def loss(self, z, y):
cost = (z-y)*(z-y)
cost = np.mean(cost)
return cost
#计算梯度
def gradient(self, z, y):
w = (z-y)*train_x
w = np.mean(w, axis = 0)
w = np.array(w).reshape([13,1])
b = z-y
b = np.mean(b)
return w, b
#更新参数
def update(self, gradient_w, gradient_b, eta):
self.w = self.w - eta*gradient_w
self.b = self.b - eta*gradient_b
#训练
def train(self, items, eta):
for i in range(items):
z = self.forward(train_x)
loss = self.loss(z, train_y)
gradient_w, gradient_b = self.gradient(z, train_y)
self.update(gradient_w, gradient_b, eta)
if i%100 ==0:
test_loss = self.test()
print('item:',i,'loss:', loss, 'test_loss:', test_loss)
#测试
def test(self):
z = self.forward(test_x)
loss = self.loss(z,test_y)
return loss
net = Network(13)
net.train(1000000, eta= 6e-6)
边栏推荐
- Automated defect analysis in electron microscopic images-论文阅读笔记
- FRP reverse proxy +msf get shell
- Preview word documents online
- Why is the website slow to open?
- Multiprocess programming (II): Pipeline
- v8
- Array de duplication
- kubernetes资源对象介绍及常用命令(五)-(NFS&PV&PVC)
- Feature Engineering: summary of common feature transformation methods
- Nc17059 queue Q
猜你喜欢
UART、RS232、RS485、I2C和SPI的介绍
Sysdig analysis container system call
【单片机项目实训】八路抢答器
kubernetes资源对象介绍及常用命令(五)-(NFS&PV&PVC)
Hundreds of continuous innovation to create free low code office tools
MySQL 23道经典面试吊打面试官
Introduction of UART, RS232, RS485, I2C and SPI
Automated defect analysis in electronic microscopic images
pod生命周期详解
Redis21 classic interview questions, extreme pull interviewer
随机推荐
DotNet圈里一个优秀的ORM——FreeSql
NC24840 [USACO 2009 Mar S]Look Up
Introduction and use of ftrace tool
Linux软件:如何安装Redis服务
论文的英文文献在哪找(除了知网)?
腾讯云免费SSL证书扩展文件含义
Confluence的PDF导出中文文档异常显示问题解决
[IELTS reading] Wang Xiwei reading P2 (reading fill in the blank)
Multiprocess programming (II): Pipeline
logback配置文件
Free we media essential tools sharing
Bloom filter
Seckill system design
Sysdig analysis container system call
NC20806 区区区间间间
Bigder: how to deal with the bugs found in the 32/100 test if they are not bugs
NC17059 队列Q
【雅思阅读】王希伟阅读P1(阅读判断题)
How do educators find foreign language references?
多进程编程(四):共享内存