当前位置:网站首页>第二讲 Linear Model 线性模型
第二讲 Linear Model 线性模型
2022-08-05 05:13:00 【长路漫漫 大佬为伴】
参考资料
- 一句话解释numpy.meshgrid()
- matplotlib教程之——自定义配置文件和绘图风格(rcParams和style)
- python中zip()函数的用法
- matplotlib之plot()详解
- matplotlib 3D绘图警告
课堂练习
实现线性模型y=wx的平面图
import numpy as np
import matplotlib.pyplot as plt
#保存数据集,相同的索引为一个样本
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
#模型的前馈
def forward(x):
return x * w
#损失函数
def loss(x, y):
y_pred = forward(x) #根据前馈求y_hat
return (y_pred - y) ** 2 #计算损失
# 穷举法
w_list = [] #权重
mse_list = [] #权重对应的损失值
for w in np.arange(0.0, 4.1, 0.1):
print("w=", w)
l_sum = 0
#从x_data, y_data取出x_val, y_val
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
l_sum += loss_val
print('x_val==', x_val, 'y_val==',y_val, 'y_pred_val==',y_pred_val,'loss_val==', loss_val)
print('MSE=', l_sum / 3)
w_list.append(w)
mse_list.append(l_sum / 3)
#调用画图
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
图案轨迹
课后练习
实现线性模型(y=wx+b)并输出loss的3D图像
这里存在几个问题需要解决
1.w,b的取值
之前课堂练习中,只需要取一个w,因此可以用for循环取值。课后练习中需要对w,b两个值进行取值操作,因此要使用meshgrid函数
2.图像无法显示中文
在前方加上
from pylab import * mpl.rcParams[‘font.sans-serif’] = [‘SimHei’]
3.matplotlib 3D绘图警告
matplotlib 3D绘图警告
课后习题代码:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
#这里设函数为y=3x+2
x_data = [1.0,2.0,3.0]
y_data = [5.0,8.0,11.0]
def forward(x):
return x * w + b
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)*(y_pred-y)
mse_list = []
W=np.arange(0.0,4.1,0.1)
B=np.arange(0.0,4.1,0.1)
w,b=np.meshgrid(W,B)
# print("w==",w)
# print('b==',b)
l_sum = 0
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
print('x_val==', x_val,'\ny_val==', y_val,'\ny_pred_val==', y_pred_val, '\nloss_val==',loss_val)
l_sum += loss_val
fig = plt.figure()
ax = Axes3D(fig,auto_add_to_figure=False)
fig.add_axes(ax)
ax.plot_surface(w, b, l_sum/3)
ax.set_xlabel("权重 W")
ax.set_ylabel("偏置项 B")
ax.set_zlabel("损失值")
plt.show()
3D图:
边栏推荐
- The role of DataContext in WPF
- LeetCode:1403. 非递增顺序的最小子序列【贪心】
- 类的底层机制
- 【过一下10】sklearn使用记录
- How to deal with DNS hijacking?
- Community Sharing|Tencent Overseas Games builds game security operation capabilities based on JumpServer
- Flutter real machine running and simulator running
- [Nine Lectures on Backpacks - 01 Backpack Problems]
- 电话溥功能
- Wise Force Deleter强制删除工具
猜你喜欢
随机推荐
WPF中DataContext作用
How to deal with DNS hijacking?
u-boot in u-boot, dm-pre-reloc
算法---一和零(Kotlin)
Flutter Learning 4 - Basic UI Components
How to quickly upgrade your Taobao account to a higher level
多线程查询结果,添加List集合
【过一下9】卷积
u-boot debugging and positioning means
flex布局青蛙游戏通关攻略
Transformation 和 Action 常用算子
Multi-threaded query results, add List collection
【cesium】元素高亮显示
Flutter learning 2-dart learning
人性的弱点
OFDM 十六讲 5 -Discrete Convolution, ISI and ICI on DMT/OFDM Systems
[Student Graduation Project] Design and Implementation of the Website Based on the Web Student Information Management System (13 pages)
淘宝账号如何快速提升到更高等级
Flutter学习4-基本UI组件
Error creating bean with name 'configDataContextRefresher' defined in class path resource