当前位置:网站首页>Tensorflow2.0 自定义训练的方式求解函数系数
Tensorflow2.0 自定义训练的方式求解函数系数
2022-07-06 11:32:00 【赫凯】
做曲线拟合逼近的时候,如果知道拟合的大致方程,如何求出方程的系数呢?
其实在numpy中已经有函数去求多项式系数了,如下
# x和y对应的就是坐标轴上x和y。
z = np.polyfit(x, y, 3) #用3次多项式拟合,输出系数从高到0
# p就是所求方程
p = np.poly1d(z) #使用次数合成多项式
但是有时候曲线不只是多项式,而是指数函数、幂函数,这就需要我们自己写了,其实就和之前求线性函数的系数一样,只不过换了一个方程式
# 这里定义自变量和因变量x,y。
x = tf.constant([1, 2, 3, 4, 5, 6],dtype=tf.float64)
y = tf.constant([11, 23, 45, 46, 48, 59],dtype=tf.float64)
# 这里定义反传时需要更新的参数,也就是我们要求的参数
w_a = tf.Variable(-0.0, dtype=tf.float64,name='a')
w_b = tf.Variable(0.1, dtype=tf.float64,name='b')
w_c = tf.Variable(-0.0, dtype=tf.float64,name='c')
w_d = tf.Variable(2.1, dtype=tf.float64,name='d')
# 把我们的系数放在一个数组里,方便后面反传
variables=[w_a, w_b, w_c, w_d]
# 定义循环次数
epoch = 1000
# 定义优化器
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)
# 循环
for e in range(epoch):
with tf.GradientTape() as tape:
# 预测函数
y_pred=w_a*tf.exp(w_b*x)+w_c*x+w_d
# 损失函数
loss=tf.reduce_sum(tf.square(y_pred-y))
# 对变量求导
grads=tape.gradient(loss, variables)
# 调整参数
optimizer.apply_gradients(grads_and_vars=zip(grads, variables))
# 打印结果
if e%100==0:
print(f"step: {
e}, loss: {
loss}, a: {
w_a.numpy()}, b: {
w_b.numpy()}, c: {
w_c.numpy()}, d: {
w_d.numpy()}")
# 输出参数
print(variables)
就酱
边栏推荐
- The dplyr package of R language performs data grouping aggregation statistical transformations and calculates the grouping mean of dataframe data
- ModuleNotFoundError: No module named ‘PIL‘解决方法
- R语言使用order函数对dataframe数据进行排序、基于单个字段(变量)进行降序排序(DESCENDING)
- 业务与应用同步发展:应用现代化的策略建议
- 深入分析,Android面试真题解析火爆全网
- Live broadcast today | the 2022 Hongji ecological partnership conference of "Renji collaboration has come" is ready to go
- R语言使用dt函数生成t分布密度函数数据、使用plot函数可视化t分布密度函数数据(t Distribution)
- 通俗的讲解,带你入门协程
- Documents to be used in IC design process
- R语言使用rchisq函数生成符合卡方分布的随机数、使用plot函数可视化符合卡方分布的随机数(Chi Square Distribution)
猜你喜欢

今日直播 | “人玑协同 未来已来”2022弘玑生态伙伴大会蓄势待发

凤凰架构3——事务处理

pychrm社区版调用matplotlib.pyplot.imshow()函数图像不弹出的解决方法

php+redis实现超时取消订单功能

Synchronous development of business and application: strategic suggestions for application modernization

ACTF 2022圆满落幕,0ops战队二连冠!!

10 schemes to ensure interface data security

Countdown 2 days | live broadcast preview of Tencent cloud message queue data import platform
![Looting iii[post sequence traversal and backtracking + dynamic planning]](/img/9b/e9eeed138e46afdeed340bf2629ee1.png)
Looting iii[post sequence traversal and backtracking + dynamic planning]

倒计时2天|腾讯云消息队列数据接入平台(Data Import Platform)直播预告
随机推荐
GCC [7] - compilation checks the declaration of functions, and link checks the definition bugs of functions
GCC【7】- 编译检查的是函数的声明,链接检查的是函数的定义bug
Solution of intelligent management platform for suppliers in hardware and electromechanical industry: optimize supply chain management and drive enterprise performance growth
Use of map (the data of the list is assigned to the form, and the JSON comma separated display assignment)
Dark horse -- redis
The dplyr package of R language performs data grouping aggregation statistical transformations and calculates the grouping mean of dataframe data
spark基础-scala
受益匪浅,安卓面试问题
应用使用Druid连接池经常性断链问题分析
CCNP Part 11 BGP (III) (essence)
Don't miss this underestimated movie because of controversy!
如何提高网站权重
Elastic search indexes are often deleted [closed] - elastic search indexes gets deleted frequently [closed]
倒计时2天|腾讯云消息队列数据接入平台(Data Import Platform)直播预告
RT-Thread 组件 FinSH 使用时遇到的问题
R language ggplot2 visualization: use ggviolin function of ggpubr package to visualize violin diagram
swagger2报错Illegal DefaultValue null for parameter type integer
业务与应用同步发展:应用现代化的策略建议
The second day of rhcsa study
ROS自定义消息发布订阅示例