当前位置:网站首页>Tensorflow2.0 自定义训练的方式求解函数系数
Tensorflow2.0 自定义训练的方式求解函数系数
2022-07-06 11:32:00 【赫凯】
做曲线拟合逼近的时候,如果知道拟合的大致方程,如何求出方程的系数呢?
其实在numpy中已经有函数去求多项式系数了,如下
# x和y对应的就是坐标轴上x和y。
z = np.polyfit(x, y, 3) #用3次多项式拟合,输出系数从高到0
# p就是所求方程
p = np.poly1d(z) #使用次数合成多项式
但是有时候曲线不只是多项式,而是指数函数、幂函数,这就需要我们自己写了,其实就和之前求线性函数的系数一样,只不过换了一个方程式
# 这里定义自变量和因变量x,y。
x = tf.constant([1, 2, 3, 4, 5, 6],dtype=tf.float64)
y = tf.constant([11, 23, 45, 46, 48, 59],dtype=tf.float64)
# 这里定义反传时需要更新的参数,也就是我们要求的参数
w_a = tf.Variable(-0.0, dtype=tf.float64,name='a')
w_b = tf.Variable(0.1, dtype=tf.float64,name='b')
w_c = tf.Variable(-0.0, dtype=tf.float64,name='c')
w_d = tf.Variable(2.1, dtype=tf.float64,name='d')
# 把我们的系数放在一个数组里,方便后面反传
variables=[w_a, w_b, w_c, w_d]
# 定义循环次数
epoch = 1000
# 定义优化器
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)
# 循环
for e in range(epoch):
with tf.GradientTape() as tape:
# 预测函数
y_pred=w_a*tf.exp(w_b*x)+w_c*x+w_d
# 损失函数
loss=tf.reduce_sum(tf.square(y_pred-y))
# 对变量求导
grads=tape.gradient(loss, variables)
# 调整参数
optimizer.apply_gradients(grads_and_vars=zip(grads, variables))
# 打印结果
if e%100==0:
print(f"step: {
e}, loss: {
loss}, a: {
w_a.numpy()}, b: {
w_b.numpy()}, c: {
w_c.numpy()}, d: {
w_d.numpy()}")
# 输出参数
print(variables)
就酱
边栏推荐
- 第五期个人能力认证考核通过名单公布
- Synchronous development of business and application: strategic suggestions for application modernization
- 驼峰式与下划线命名规则(Camel case With hungarian notation)
- Reflection and illegalaccessexception exception during application
- Problems encountered in using RT thread component fish
- Looting iii[post sequence traversal and backtracking + dynamic planning]
- R language ggplot2 visualization: use the ggstripchart function of ggpubr package to visualize the grouped dot strip plot, and set the add parameter to add box plots for different levels of dot strip
- R语言ggplot2可视化:使用ggpubr包的ggdotplot函数可视化点阵图(dot plot)、设置palette参数设置不同水平点阵图数据点和箱图的颜色
- map的使用(列表的数据赋值到表单,json逗号隔开显示赋值)
- R语言ggplot2可视化:使用ggpubr包的ggviolin函数可视化小提琴图
猜你喜欢
Mysql Information Schema 学习(二)--Innodb表
The list of people who passed the fifth phase of personal ability certification assessment was published
Solution of intelligent management platform for suppliers in hardware and electromechanical industry: optimize supply chain management and drive enterprise performance growth
How to type multiple spaces when editing CSDN articles
php+redis实现超时取消订单功能
第五期个人能力认证考核通过名单公布
Interview assault 63: how to remove duplication in MySQL?
Tongyu Xincai rushes to Shenzhen Stock Exchange: the annual revenue is 947million Zhang Chi and Su Shiguo are the actual controllers
Benefit a lot, Android interview questions
Graffiti intelligence is listed on the dual main board in Hong Kong: market value of 11.2 billion Hong Kong, with an annual revenue of 300 million US dollars
随机推荐
关于静态类型、动态类型、id、instancetype
Tensorflow and torch code verify whether CUDA is successfully installed
渲大师携手向日葵,远控赋能云渲染及GPU算力服务
谷粒商城--分布式高级篇P129~P339(完结)
学习探索-函数防抖
R语言ggplot2可视化时间序列柱形图:通过双色渐变配色颜色主题可视化时间序列柱形图
Interview assault 63: how to remove duplication in MySQL?
ModuleNotFoundError: No module named ‘PIL‘解决方法
10 schemes to ensure interface data security
凤凰架构2——访问远程服务
USB host driver - UVC swap
[translation] a GPU approach to particle physics
ROS custom message publishing subscription example
Test technology stack arrangement -- self cultivation of test development engineers
The dplyr package of R language performs data grouping aggregation statistical transformations and calculates the grouping mean of dataframe data
Graffiti intelligence is listed on the dual main board in Hong Kong: market value of 11.2 billion Hong Kong, with an annual revenue of 300 million US dollars
The list of people who passed the fifth phase of personal ability certification assessment was published
冒烟测试怎么做
提前解锁 2 大直播主题!今天手把手教你如何完成软件包集成?|第 29-30 期
Meilu biological IPO was terminated: the annual revenue was 385million, and Chen Lin was the actual controller