当前位置:网站首页>Tensorflow2.0 自定义训练的方式求解函数系数
Tensorflow2.0 自定义训练的方式求解函数系数
2022-07-06 11:32:00 【赫凯】
做曲线拟合逼近的时候,如果知道拟合的大致方程,如何求出方程的系数呢?
其实在numpy中已经有函数去求多项式系数了,如下
# x和y对应的就是坐标轴上x和y。
z = np.polyfit(x, y, 3) #用3次多项式拟合,输出系数从高到0
# p就是所求方程
p = np.poly1d(z) #使用次数合成多项式
但是有时候曲线不只是多项式,而是指数函数、幂函数,这就需要我们自己写了,其实就和之前求线性函数的系数一样,只不过换了一个方程式
# 这里定义自变量和因变量x,y。
x = tf.constant([1, 2, 3, 4, 5, 6],dtype=tf.float64)
y = tf.constant([11, 23, 45, 46, 48, 59],dtype=tf.float64)
# 这里定义反传时需要更新的参数,也就是我们要求的参数
w_a = tf.Variable(-0.0, dtype=tf.float64,name='a')
w_b = tf.Variable(0.1, dtype=tf.float64,name='b')
w_c = tf.Variable(-0.0, dtype=tf.float64,name='c')
w_d = tf.Variable(2.1, dtype=tf.float64,name='d')
# 把我们的系数放在一个数组里,方便后面反传
variables=[w_a, w_b, w_c, w_d]
# 定义循环次数
epoch = 1000
# 定义优化器
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)
# 循环
for e in range(epoch):
with tf.GradientTape() as tape:
# 预测函数
y_pred=w_a*tf.exp(w_b*x)+w_c*x+w_d
# 损失函数
loss=tf.reduce_sum(tf.square(y_pred-y))
# 对变量求导
grads=tape.gradient(loss, variables)
# 调整参数
optimizer.apply_gradients(grads_and_vars=zip(grads, variables))
# 打印结果
if e%100==0:
print(f"step: {
e}, loss: {
loss}, a: {
w_a.numpy()}, b: {
w_b.numpy()}, c: {
w_c.numpy()}, d: {
w_d.numpy()}")
# 输出参数
print(variables)
就酱
边栏推荐
- Live broadcast today | the 2022 Hongji ecological partnership conference of "Renji collaboration has come" is ready to go
- pytorch常见损失函数
- tensorflow和torch代码验证cuda是否安装成功
- Php+redis realizes the function of canceling orders over time
- 【pytorch】yolov5 训练自己的数据集
- C # - realize serialization with Marshall class
- Precautions for binding shortcut keys of QPushButton
- Solution of intelligent management platform for suppliers in hardware and electromechanical industry: optimize supply chain management and drive enterprise performance growth
- 打家劫舍III[后序遍历与回溯+动态规划]
- php+redis实现超时取消订单功能
猜你喜欢
接雨水问题解析
渲大师携手向日葵,远控赋能云渲染及GPU算力服务
反射及在运用过程中出现的IllegalAccessException异常
Documents to be used in IC design process
通俗的讲解,带你入门协程
MRO工业品企业采购系统:如何精细化采购协同管理?想要升级的工业品企业必看!
提前解锁 2 大直播主题!今天手把手教你如何完成软件包集成?|第 29-30 期
Three years of Android development, Android interview experience and real questions sorting of eight major manufacturers during the 2022 epidemic
Php+redis realizes the function of canceling orders over time
The second day of rhcsa study
随机推荐
Precautions for binding shortcut keys of QPushButton
快速幂模板求逆元,逆元的作用以及例题【第20届上海大学程序设计联赛夏季赛】排列计数
谷粒商城--分布式高级篇P129~P339(完结)
倒计时2天|腾讯云消息队列数据接入平台(Data Import Platform)直播预告
Tensorflow and torch code verify whether CUDA is successfully installed
R language ggplot2 visual time series histogram: visual time series histogram through two-color gradient color matching color theme
R language uses the order function to sort the dataframe data, and descending sorting based on a single field (variable)
Cereals Mall - Distributed Advanced p129~p339 (end)
冒烟测试怎么做
助力安全人才专业素养提升 | 个人能力认证考核第一阶段圆满结束!
第五期个人能力认证考核通过名单公布
How to type multiple spaces when editing CSDN articles
Unlock 2 live broadcast themes in advance! Today, I will teach you how to complete software package integration Issues 29-30
Digital "new" operation and maintenance of energy industry
受益匪浅,安卓面试问题
五金机电行业供应商智慧管理平台解决方案:优化供应链管理,带动企业业绩增长
The nearest library of Qinglong panel
Intelligent supply chain management system solution for hardware and electromechanical industry: digital intelligent supply chain "creates new blood" for traditional industries
多线程基础:线程基本概念与线程的创建
CCNP Part 11 BGP (III) (essence)