当前位置:网站首页>Tensorflow2.0 自定义训练的方式求解函数系数
Tensorflow2.0 自定义训练的方式求解函数系数
2022-07-06 11:32:00 【赫凯】
做曲线拟合逼近的时候,如果知道拟合的大致方程,如何求出方程的系数呢?
其实在numpy中已经有函数去求多项式系数了,如下
# x和y对应的就是坐标轴上x和y。
z = np.polyfit(x, y, 3) #用3次多项式拟合,输出系数从高到0
# p就是所求方程
p = np.poly1d(z) #使用次数合成多项式
但是有时候曲线不只是多项式,而是指数函数、幂函数,这就需要我们自己写了,其实就和之前求线性函数的系数一样,只不过换了一个方程式
# 这里定义自变量和因变量x,y。
x = tf.constant([1, 2, 3, 4, 5, 6],dtype=tf.float64)
y = tf.constant([11, 23, 45, 46, 48, 59],dtype=tf.float64)
# 这里定义反传时需要更新的参数,也就是我们要求的参数
w_a = tf.Variable(-0.0, dtype=tf.float64,name='a')
w_b = tf.Variable(0.1, dtype=tf.float64,name='b')
w_c = tf.Variable(-0.0, dtype=tf.float64,name='c')
w_d = tf.Variable(2.1, dtype=tf.float64,name='d')
# 把我们的系数放在一个数组里,方便后面反传
variables=[w_a, w_b, w_c, w_d]
# 定义循环次数
epoch = 1000
# 定义优化器
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)
# 循环
for e in range(epoch):
with tf.GradientTape() as tape:
# 预测函数
y_pred=w_a*tf.exp(w_b*x)+w_c*x+w_d
# 损失函数
loss=tf.reduce_sum(tf.square(y_pred-y))
# 对变量求导
grads=tape.gradient(loss, variables)
# 调整参数
optimizer.apply_gradients(grads_and_vars=zip(grads, variables))
# 打印结果
if e%100==0:
print(f"step: {
e}, loss: {
loss}, a: {
w_a.numpy()}, b: {
w_b.numpy()}, c: {
w_c.numpy()}, d: {
w_d.numpy()}")
# 输出参数
print(variables)
就酱
边栏推荐
- R语言ggplot2可视化:使用ggpubr包的ggviolin函数可视化小提琴图
- Solution of intelligent management platform for suppliers in hardware and electromechanical industry: optimize supply chain management and drive enterprise performance growth
- Documents to be used in IC design process
- Analysis of frequent chain breaks in applications using Druid connection pools
- IC设计流程中需要使用到的文件
- The second day of rhcsa study
- R语言ggplot2可视化:使用ggpubr包的ggstripchart函数可视化分组点状条带图(dot strip plot)、设置add参数为不同水平点状条带图添加箱图
- Unlock 2 live broadcast themes in advance! Today, I will teach you how to complete software package integration Issues 29-30
- test about BinaryTree
- JDBC details
猜你喜欢

Digital "new" operation and maintenance of energy industry

Php+redis realizes the function of canceling orders over time

Graffiti intelligence is listed on the dual main board in Hong Kong: market value of 11.2 billion Hong Kong, with an annual revenue of 300 million US dollars

今日直播 | “人玑协同 未来已来”2022弘玑生态伙伴大会蓄势待发

Simple understanding of MySQL database

三面蚂蚁金服成功拿到offer,Android开发社招面试经验

Live broadcast today | the 2022 Hongji ecological partnership conference of "Renji collaboration has come" is ready to go

助力安全人才专业素养提升 | 个人能力认证考核第一阶段圆满结束!

pychrm社区版调用matplotlib.pyplot.imshow()函数图像不弹出的解决方法

Lucun smart sprint technology innovation board: annual revenue of 400million, proposed to raise 700million
随机推荐
Use of deg2rad and rad2deg functions in MATLAB
PMP每日一练 | 考试不迷路-7.6
QPushButton绑定快捷键的注意事项
五金机电行业智能供应链管理系统解决方案:数智化供应链为传统产业“造新血”
First day of rhcsa study
How can my Haskell program or library find its version number- How can my Haskell program or library find its version number?
Live broadcast today | the 2022 Hongji ecological partnership conference of "Renji collaboration has come" is ready to go
反射及在运用过程中出现的IllegalAccessException异常
多线程基础:线程基本概念与线程的创建
Based on butterfly species recognition
CCNP Part 11 BGP (III) (essence)
终于可以一行代码也不用改了!ShardingSphere 原生驱动问世
spark基础-scala
IC设计流程中需要使用到的文件
Actf 2022 came to a successful conclusion, and 0ops team won the second consecutive championship!!
LeetCode-1279. Traffic light intersection
pytorch常见损失函数
GCC [7] - compilation checks the declaration of functions, and link checks the definition bugs of functions
Interface test tool - postman
Translation D28 (with AC code POJ 26:the nearest number)