当前位置:网站首页>机器学习:线性回归
机器学习:线性回归
2022-06-24 19:28:00 【翁炜强】
低级API实现:
1.随机初始化数据
import matplotlib.pyplot as plt
import tensorflow as tf
TRUE_W=3.0
TRUE_b=2.0
NUM_SAMPLES=100
#初始化随机数据
X=tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
noise=tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
y=X*TRUE_W+TRUE_b+noise #添加噪声
plt.scatter(X,y)
2.
定义一元回归模型并拟合曲线:
𝑓(𝑤,𝑏,𝑥)=𝑤∗𝑥+𝑏
class Model(object): #object模型的主体
def __init__(self):
self.W = tf.Variable(tf.random.uniform([1])) # 随机初始化参数
self.b = tf.Variable(tf.random.uniform([1]))
def __call__(self, x):
return self.W * x + self.b # w*x + b
model = Model() # 实例化模型
plt.scatter(X, y)
plt.plot(X, model(X), c='r')

可见拟合效果不是很好 因此继续训练模型
3.利用损失函数 去 进行梯度下降迭代 得到好的拟合结果
损失函数:

更新参数:
𝑏←b−𝑙𝑟∗∂loss(𝑤,𝑏)
w←w−𝑙𝑟∗∂loss(𝑤,𝑏)
lr指是学习率
最后迭代十次
def loss_fn(model,x,y):
y_=model(x)
return tf.reduce_mean(tf.square(y_ -y))
EPOCHS =10
LEARNING_RATE=0.1
for epoch in range (EPOCHS): #迭代次数
with tf.GradientTape() as tape:
loss=loss_fn(model,X,y)#计算损失
dW,db=tape.gradient(loss,[model.W,model.b]) #计算梯度
model.W.assign_sub(LEARNING_RATE*dW)
model.b.assign_sub(LEARNING_RATE*db)
#输出计算结果
print(f'Epoch[{epoch}/{EPOCHS}], loss[{loss}], W/b[{model.W.numpy()}/{model.b.numpy()}]')
plt.scatter(X, y)
plt.plot(X, model(X), c='r')得到以下结果:
高阶API实现:
使用tensorflow现有库中的keras
model = tf.keras.Sequential() # 新建顺序模型
model.add(tf.keras.layers.Dense(units=1, input_dim=1)) # 添加线性层
model.compile(optimizer='sgd', loss='mse') # 定义损失函数和优化方法
model.fit(X, y, epochs=10, batch_size=32) # 训练模型
边栏推荐
- Installing Oracle without graphical interface in virtual machine centos7 (nanny level installation)
- Tso hardware sharding is a header copy problem
- leetcode-201_2021_10_17
- VirtualBox virtual machine installation win10 Enterprise Edition
- 数据链路层 && 一些其他的协议or技术
- C语言-关键字1
- EasyBypass
- Multi view function in blender
- Datakit 代理实现局域网数据统一汇聚
- 如何化解35岁危机?华为云数据库首席架构师20年技术经验分享
猜你喜欢

EditText 控制软键盘出现 搜索

Bld3 getting started UI

VirtualBox virtual machine installation win10 Enterprise Edition

Multiplexer select

【吴恩达笔记】机器学习基础

Fuzhou business office of Fujian development and Reform Commission visited the health department of Yurun university to guide and inspect the work

2022 international women engineers' Day: Dyson design award shows women's design strength

【吴恩达笔记】卷积神经网络

Implementing DNS requester with C language

About transform InverseTransformPoint, transform. InverseTransofrmDirection
随机推荐
【论】Deep learning in the COVID-19 epidemic: A deep model for urban traffic revitalization index
Multi task model of recommended model: esmm, MMOE
Static routing experiment
01---两列波在相遇处发生干涉的条件
Blender FAQs
Memcached comprehensive analysis – 5 Memcached applications and compatible programs
Transport layer UDP & TCP
WMI and PowerShell get TCP connection list
OSI and tcp/ip model
[camera Foundation (II)] camera driving principle and Development & v4l2 subsystem driving architecture
Vscode netless environment rapid migration development environment (VIP collection version)
多路转接select
Docking of arkit and character creator animation curves
Li Kou daily question - day 25 -496 Next larger element I
[camera Foundation (I)] working principle and overall structure of camera
Failed to open after installing Charles without any prompt
图的邻接表存储 数组实现
MySQL optimizes query speed
#国企央企结构化面试#国企就业#墨斗互动就业服务管家
EditText controls the soft keyboard to search
