当前位置:网站首页>第二讲 Linear Model 线性模型
第二讲 Linear Model 线性模型
2022-08-05 05:13:00 【长路漫漫 大佬为伴】
参考资料
- 一句话解释numpy.meshgrid()
- matplotlib教程之——自定义配置文件和绘图风格(rcParams和style)
- python中zip()函数的用法
- matplotlib之plot()详解
- matplotlib 3D绘图警告
课堂练习
实现线性模型y=wx的平面图
import numpy as np
import matplotlib.pyplot as plt
#保存数据集,相同的索引为一个样本
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
#模型的前馈
def forward(x):
return x * w
#损失函数
def loss(x, y):
y_pred = forward(x) #根据前馈求y_hat
return (y_pred - y) ** 2 #计算损失
# 穷举法
w_list = [] #权重
mse_list = [] #权重对应的损失值
for w in np.arange(0.0, 4.1, 0.1):
print("w=", w)
l_sum = 0
#从x_data, y_data取出x_val, y_val
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
l_sum += loss_val
print('x_val==', x_val, 'y_val==',y_val, 'y_pred_val==',y_pred_val,'loss_val==', loss_val)
print('MSE=', l_sum / 3)
w_list.append(w)
mse_list.append(l_sum / 3)
#调用画图
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
图案轨迹
课后练习
实现线性模型(y=wx+b)并输出loss的3D图像
这里存在几个问题需要解决
1.w,b的取值
之前课堂练习中,只需要取一个w,因此可以用for循环取值。课后练习中需要对w,b两个值进行取值操作,因此要使用meshgrid函数
2.图像无法显示中文
在前方加上
from pylab import * mpl.rcParams[‘font.sans-serif’] = [‘SimHei’]
3.matplotlib 3D绘图警告
matplotlib 3D绘图警告
课后习题代码:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
#这里设函数为y=3x+2
x_data = [1.0,2.0,3.0]
y_data = [5.0,8.0,11.0]
def forward(x):
return x * w + b
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)*(y_pred-y)
mse_list = []
W=np.arange(0.0,4.1,0.1)
B=np.arange(0.0,4.1,0.1)
w,b=np.meshgrid(W,B)
# print("w==",w)
# print('b==',b)
l_sum = 0
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
print('x_val==', x_val,'\ny_val==', y_val,'\ny_pred_val==', y_pred_val, '\nloss_val==',loss_val)
l_sum += loss_val
fig = plt.figure()
ax = Axes3D(fig,auto_add_to_figure=False)
fig.add_axes(ax)
ax.plot_surface(w, b, l_sum/3)
ax.set_xlabel("权重 W")
ax.set_ylabel("偏置项 B")
ax.set_zlabel("损失值")
plt.show()
3D图:
边栏推荐
猜你喜欢

OFDM Lecture 16 5 -Discrete Convolution, ISI and ICI on DMT/OFDM Systems

类的底层机制

How can Flutter parent and child components receive click events

Detailed Explanation of Redis Sentinel Mode Configuration File

开发一套高容错分布式系统
![[Student Graduation Project] Design and Implementation of the Website Based on the Web Student Information Management System (13 pages)](/img/86/9c9a2541f2b7089ae47e9832fffdb3.png)
[Student Graduation Project] Design and Implementation of the Website Based on the Web Student Information Management System (13 pages)

Develop a highly fault-tolerant distributed system

【过一下 17】pytorch 改写 keras

【Transfer】What is etcd
![LeetCode: 1403. Minimum subsequence in non-increasing order [greedy]](/img/99/41629dcd84e95eb3672d0555d6ef2c.png)
LeetCode: 1403. Minimum subsequence in non-increasing order [greedy]
随机推荐
How to deal with DNS hijacking?
【Untitled】
1068找到更多的硬币
The mall background management system based on Web design and implementation
服务器磁盘阵列
2022 Hangzhou Electric Multi-School 1st Session 01
Difference between for..in and for..of
Flutter learning 2-dart learning
Mvi架构浅析
电话溥功能
Error creating bean with name ‘configDataContextRefresher‘ defined in class path resource
Wise Force Deleter强制删除工具
"Recursion" recursion concept and typical examples
jvm 三 之堆与栈
Transformation 和 Action 常用算子
【过一下14】自习室的一天
Redis哨兵模式配置文件详解
【读书】长期更新
【技能】长期更新
结构光三维重建(一)条纹结构光三维重建