当前位置:网站首页>多元线性回归(梯度下降法)
多元线性回归(梯度下降法)
2022-07-05 08:42:00 【python-码博士】
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 读取数据
data = np.loadtxt('Delivery.csv',delimiter=',')
print(data)
# 构造特征x,目标y
# 特征
x_data = data[:,0:-1]
y_data = data[:,-1]
# 初始化学习率 (步长)
learning_rate = 0.0001
# 初始化 截距
theta0 = 0
# 初始化 系数
theta1 = 0
theta2 = 0
# 初始化最大迭代次数
n_iterables = 100
def compute_mse(theta0,theta1,theta2,x_data,y_data):
''' 计算代价函数 '''
total_error = 0
for i in range(len(x_data)):
# 计算损失 (真实值-预测值)**2
total_error += (y_data[i]-(theta0 + theta1*x_data[i,0]+theta2*x_data[i,1]))**2
mse_ = total_error/len(x_data)/2
return mse_
def gradient_descent(x_data,y_data,theta0,theta1,theta2,learning_rate,n_iterables):
''' 梯度下降法 '''
m = len(x_data)
# 循环
for i in range(n_iterables):
# 初始化theta0,theta1,theta2偏导数
theta0_grad = 0
theta1_grad = 0
theta2_grad = 0
# 计算偏导的总和再平均
# 遍历m次
for j in range(m):
theta0_grad += (1/m)*((theta1*x_data[j,0]+theta2*x_data[j,1]+theta0)-y_data[j])
theta1_grad += (1/m)*((theta1*x_data[j,0]+theta2*x_data[j,1]+theta0)-y_data[j])*x_data[j,0]
theta2_grad += (1/m)*((theta1*x_data[j,0]+theta2*x_data[j,1]+theta0)-y_data[j])*x_data[j,1]
# 更新theta
theta0 = theta0 - (learning_rate*theta0_grad)
theta1 = theta1 - (learning_rate*theta1_grad)
theta2 = theta2 - (learning_rate*theta2_grad)
return theta0,theta1,theta2
# 可视化分布
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x_data[:,0],x_data[:,1],y_data)
plt.show()
print(f"开始:截距theta0={
theta0},theta1={
theta1},theta2={
theta2},损失={
compute_mse(theta0,theta1,theta2,x_data,y_data)}")
print("开始跑起来了~")
theta0,theta1,theta2 = gradient_descent(x_data,y_data,theta0,theta1,theta2,learning_rate,n_iterables)
print(f"迭代{
n_iterables}次后:截距theta0={
theta0},theta1={
theta1},theta2={
theta2},损失={
compute_mse(theta0,theta1,theta2,x_data,y_data)}")
# 绘制预期平面
x_0 = x_data[:,0]
x_1 = x_data[:,1]
# 生成网格矩阵
x_0,x_1 = np.meshgrid(x_0,x_1)
# y
y_hat = theta0 + theta1*x_0 +theta2*x_1
ax.plot_surface(x_0,x_1,y_hat)
# 设置标签
ax.set_xlabel('Miles')
ax.set_ylabel('nums')
ax.set_zlabel('Time')
plt.show()
边栏推荐
- Daily question - input a date and output the day of the year
- [three tier architecture]
- Halcon snap, get the area and position of coins
- Explore the authentication mechanism of StarUML
- 猜谜语啦(142)
- 我从技术到产品经理的几点体会
- STM32 summary (HAL Library) - DHT11 temperature sensor (intelligent safety assisted driving system)
- Halcon Chinese character recognition
- Chapter 18 using work queue manager (1)
- Cmder of win artifact
猜你喜欢

Halcon blob analysis (ball.hdev)

Guess riddles (7)

Example 005: three numbers sorting input three integers x, y, Z, please output these three numbers from small to large.

Example 002: the bonus paid by the "individual income tax calculation" enterprise is based on the profit commission. When the profit (I) is less than or equal to 100000 yuan, the bonus can be increase

猜谜语啦(10)

猜谜语啦(142)

Shift operation of complement

Sword finger offer 06 Print linked list from end to end
![[matlab] matlab reads and writes Excel](/img/80/78e4c7fcd27473526e480d4b930e2c.jpg)
[matlab] matlab reads and writes Excel
![[three tier architecture]](/img/73/c4c75a453f03830e83cabb0762eb9b.png)
[three tier architecture]
随机推荐
图解八道经典指针笔试题
Latex improve
猜谜语啦(4)
特征工程
【三层架构及JDBC总结】
Halcon snap, get the area and position of coins
Typescript hands-on tutorial, easy to understand
猜谜语啦(5)
Sword finger offer 09 Implementing queues with two stacks
golang 基础 ——map、数组、切片 存放不同类型的数据
leetcode - 445. Add two numbers II
Business modeling of software model | overview
Esphone Feixun DC1 soft change access homeassstant
Example 010: time to show
Classification of plastic surgery: short in long long long
Example 009: pause output for one second
猜谜语啦(8)
Agile project management of project management
2020-05-21
319. Bulb switch