当前位置:网站首页>多元线性回归(梯度下降法)
多元线性回归(梯度下降法)
2022-07-05 08:42:00 【python-码博士】
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 读取数据
data = np.loadtxt('Delivery.csv',delimiter=',')
print(data)
# 构造特征x,目标y
# 特征
x_data = data[:,0:-1]
y_data = data[:,-1]
# 初始化学习率 (步长)
learning_rate = 0.0001
# 初始化 截距
theta0 = 0
# 初始化 系数
theta1 = 0
theta2 = 0
# 初始化最大迭代次数
n_iterables = 100
def compute_mse(theta0,theta1,theta2,x_data,y_data):
''' 计算代价函数 '''
total_error = 0
for i in range(len(x_data)):
# 计算损失 (真实值-预测值)**2
total_error += (y_data[i]-(theta0 + theta1*x_data[i,0]+theta2*x_data[i,1]))**2
mse_ = total_error/len(x_data)/2
return mse_
def gradient_descent(x_data,y_data,theta0,theta1,theta2,learning_rate,n_iterables):
''' 梯度下降法 '''
m = len(x_data)
# 循环
for i in range(n_iterables):
# 初始化theta0,theta1,theta2偏导数
theta0_grad = 0
theta1_grad = 0
theta2_grad = 0
# 计算偏导的总和再平均
# 遍历m次
for j in range(m):
theta0_grad += (1/m)*((theta1*x_data[j,0]+theta2*x_data[j,1]+theta0)-y_data[j])
theta1_grad += (1/m)*((theta1*x_data[j,0]+theta2*x_data[j,1]+theta0)-y_data[j])*x_data[j,0]
theta2_grad += (1/m)*((theta1*x_data[j,0]+theta2*x_data[j,1]+theta0)-y_data[j])*x_data[j,1]
# 更新theta
theta0 = theta0 - (learning_rate*theta0_grad)
theta1 = theta1 - (learning_rate*theta1_grad)
theta2 = theta2 - (learning_rate*theta2_grad)
return theta0,theta1,theta2
# 可视化分布
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x_data[:,0],x_data[:,1],y_data)
plt.show()
print(f"开始:截距theta0={
theta0},theta1={
theta1},theta2={
theta2},损失={
compute_mse(theta0,theta1,theta2,x_data,y_data)}")
print("开始跑起来了~")
theta0,theta1,theta2 = gradient_descent(x_data,y_data,theta0,theta1,theta2,learning_rate,n_iterables)
print(f"迭代{
n_iterables}次后:截距theta0={
theta0},theta1={
theta1},theta2={
theta2},损失={
compute_mse(theta0,theta1,theta2,x_data,y_data)}")
# 绘制预期平面
x_0 = x_data[:,0]
x_1 = x_data[:,1]
# 生成网格矩阵
x_0,x_1 = np.meshgrid(x_0,x_1)
# y
y_hat = theta0 + theta1*x_0 +theta2*x_1
ax.plot_surface(x_0,x_1,y_hat)
# 设置标签
ax.set_xlabel('Miles')
ax.set_ylabel('nums')
ax.set_zlabel('Time')
plt.show()
边栏推荐
- Example 002: the bonus paid by the "individual income tax calculation" enterprise is based on the profit commission. When the profit (I) is less than or equal to 100000 yuan, the bonus can be increase
- 319. Bulb switch
- Xrosstools tool installation for X-Series
- Infected Tree(树形dp)
- Arduino operation stm32
- [牛客网刷题 Day4] JZ32 从上往下打印二叉树
- Guess riddles (11)
- Guess riddles (142)
- [daily training] 1200 Minimum absolute difference
- 每日一题——替换空格
猜你喜欢
每日一题——输入一个日期,输出它是该年的第几天
资源变现小程序添加折扣充值和折扣影票插件
Run菜单解析
Typical low code apaas manufacturer cases
Arduino+a4988 control stepper motor
Meizu Bluetooth remote control temperature and humidity access homeassistant
Digital analog 1: linear programming
STM32 summary (HAL Library) - DHT11 temperature sensor (intelligent safety assisted driving system)
Halcon Chinese character recognition
猜谜语啦(7)
随机推荐
319. Bulb switch
Arduino operation stm32
Task failed task_ 1641530057069_ 0002_ m_ 000000
Affected tree (tree DP)
剑指 Offer 06. 从尾到头打印链表
Yolov4 target detection backbone
leetcode - 445. Add two numbers II
Guess riddles (142)
【日常训练】1200. 最小绝对差
Digital analog 1: linear programming
[formation quotidienne - Tencent Selection 50] 557. Inverser le mot III dans la chaîne
[daily training -- Tencent selected 50] 557 Reverse word III in string
第十八章 使用工作队列管理器(一)
Illustration of eight classic pointer written test questions
【日常訓練--騰訊精選50】557. 反轉字符串中的單詞 III
Guess riddles (7)
Classification of plastic surgery: short in long long long
2022.7.4-----leetcode.1200
關於線性穩壓器的五個設計細節
TypeScript手把手教程,简单易懂