当前位置:网站首页>Tensorflow2.0 自定义训练的方式求解函数系数
Tensorflow2.0 自定义训练的方式求解函数系数
2022-07-06 11:32:00 【赫凯】
做曲线拟合逼近的时候,如果知道拟合的大致方程,如何求出方程的系数呢?
其实在numpy中已经有函数去求多项式系数了,如下
# x和y对应的就是坐标轴上x和y。
z = np.polyfit(x, y, 3) #用3次多项式拟合,输出系数从高到0
# p就是所求方程
p = np.poly1d(z) #使用次数合成多项式
但是有时候曲线不只是多项式,而是指数函数、幂函数,这就需要我们自己写了,其实就和之前求线性函数的系数一样,只不过换了一个方程式
# 这里定义自变量和因变量x,y。
x = tf.constant([1, 2, 3, 4, 5, 6],dtype=tf.float64)
y = tf.constant([11, 23, 45, 46, 48, 59],dtype=tf.float64)
# 这里定义反传时需要更新的参数,也就是我们要求的参数
w_a = tf.Variable(-0.0, dtype=tf.float64,name='a')
w_b = tf.Variable(0.1, dtype=tf.float64,name='b')
w_c = tf.Variable(-0.0, dtype=tf.float64,name='c')
w_d = tf.Variable(2.1, dtype=tf.float64,name='d')
# 把我们的系数放在一个数组里,方便后面反传
variables=[w_a, w_b, w_c, w_d]
# 定义循环次数
epoch = 1000
# 定义优化器
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)
# 循环
for e in range(epoch):
with tf.GradientTape() as tape:
# 预测函数
y_pred=w_a*tf.exp(w_b*x)+w_c*x+w_d
# 损失函数
loss=tf.reduce_sum(tf.square(y_pred-y))
# 对变量求导
grads=tape.gradient(loss, variables)
# 调整参数
optimizer.apply_gradients(grads_and_vars=zip(grads, variables))
# 打印结果
if e%100==0:
print(f"step: {
e}, loss: {
loss}, a: {
w_a.numpy()}, b: {
w_b.numpy()}, c: {
w_c.numpy()}, d: {
w_d.numpy()}")
# 输出参数
print(variables)
就酱
边栏推荐
- 凤凰架构2——访问远程服务
- 如何提高网站权重
- Airiot IOT platform enables the container industry to build [welding station information monitoring system]
- zabbix 代理服务器 与 zabbix-snmp 监控
- Word如何显示修改痕迹
- Cereals Mall - Distributed Advanced p129~p339 (end)
- Elastic search indexes are often deleted [closed] - elastic search indexes gets deleted frequently [closed]
- spark基础-scala
- Sanmian ant financial successfully got the offer, and has experience in Android development agency recruitment and interview
- Mind map + source code + Notes + project, ByteDance + JD +360+ Netease interview question sorting
猜你喜欢
![[paper notes] transunet: transformers make strongencoders for medical image segmentation](/img/21/3d4710024248b62495e2681ebd1bc4.png)
[paper notes] transunet: transformers make strongencoders for medical image segmentation

Wx applet learning notes day01
Three years of Android development, Android interview experience and real questions sorting of eight major manufacturers during the 2022 epidemic

Characteristic colleges and universities, jointly build Netease Industrial College

ROS custom message publishing subscription example

10 schemes to ensure interface data security

Abstract classes and abstract methods

How to improve website weight

php+redis实现超时取消订单功能

C language daily practice - day 22: Zero foundation learning dynamic planning
随机推荐
倒计时2天|腾讯云消息队列数据接入平台(Data Import Platform)直播预告
How to improve website weight
ACTF 2022圆满落幕,0ops战队二连冠!!
Airiot IOT platform enables the container industry to build [welding station information monitoring system]
Druid 数据库连接池 详解
Benefit a lot, Android interview questions
GCC [7] - compilation checks the declaration of functions, and link checks the definition bugs of functions
保证接口数据安全的10种方案
Translation D28 (with AC code POJ 26:the nearest number)
第五期个人能力认证考核通过名单公布
R语言使用order函数对dataframe数据进行排序、基于单个字段(变量)进行降序排序(DESCENDING)
GCC【7】- 编译检查的是函数的声明,链接检查的是函数的定义bug
学习探索-使用伪元素清除浮动元素造成的高度坍塌
安装Mysql报错:Could not create or access the registry key needed for the...
Use of map (the data of the list is assigned to the form, and the JSON comma separated display assignment)
Druid database connection pool details
Qlabel marquee text display
R language uses rchisq function to generate random numbers that conform to Chi square distribution, and uses plot function to visualize random numbers that conform to Chi square distribution
全套教学资料,阿里快手拼多多等7家大厂Android面试真题
An error occurs when installing MySQL: could not create or access the registry key needed for the