当前位置:网站首页>Tensorflow2.0 自定义训练的方式求解函数系数
Tensorflow2.0 自定义训练的方式求解函数系数
2022-07-06 11:32:00 【赫凯】
做曲线拟合逼近的时候,如果知道拟合的大致方程,如何求出方程的系数呢?
其实在numpy中已经有函数去求多项式系数了,如下
# x和y对应的就是坐标轴上x和y。
z = np.polyfit(x, y, 3) #用3次多项式拟合,输出系数从高到0
# p就是所求方程
p = np.poly1d(z) #使用次数合成多项式
但是有时候曲线不只是多项式,而是指数函数、幂函数,这就需要我们自己写了,其实就和之前求线性函数的系数一样,只不过换了一个方程式
# 这里定义自变量和因变量x,y。
x = tf.constant([1, 2, 3, 4, 5, 6],dtype=tf.float64)
y = tf.constant([11, 23, 45, 46, 48, 59],dtype=tf.float64)
# 这里定义反传时需要更新的参数,也就是我们要求的参数
w_a = tf.Variable(-0.0, dtype=tf.float64,name='a')
w_b = tf.Variable(0.1, dtype=tf.float64,name='b')
w_c = tf.Variable(-0.0, dtype=tf.float64,name='c')
w_d = tf.Variable(2.1, dtype=tf.float64,name='d')
# 把我们的系数放在一个数组里,方便后面反传
variables=[w_a, w_b, w_c, w_d]
# 定义循环次数
epoch = 1000
# 定义优化器
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)
# 循环
for e in range(epoch):
with tf.GradientTape() as tape:
# 预测函数
y_pred=w_a*tf.exp(w_b*x)+w_c*x+w_d
# 损失函数
loss=tf.reduce_sum(tf.square(y_pred-y))
# 对变量求导
grads=tape.gradient(loss, variables)
# 调整参数
optimizer.apply_gradients(grads_and_vars=zip(grads, variables))
# 打印结果
if e%100==0:
print(f"step: {
e}, loss: {
loss}, a: {
w_a.numpy()}, b: {
w_b.numpy()}, c: {
w_c.numpy()}, d: {
w_d.numpy()}")
# 输出参数
print(variables)
就酱
边栏推荐
- Solution of commercial supply chain management platform for packaging industry: layout smart supply system and digitally integrate the supply chain of packaging industry
- 应用使用Druid连接池经常性断链问题分析
- 快速幂模板求逆元,逆元的作用以及例题【第20届上海大学程序设计联赛夏季赛】排列计数
- Dark horse -- redis
- How can my Haskell program or library find its version number- How can my Haskell program or library find its version number?
- tensorflow和torch代码验证cuda是否安装成功
- 今日直播 | “人玑协同 未来已来”2022弘玑生态伙伴大会蓄势待发
- Help improve the professional quality of safety talents | the first stage of personal ability certification and assessment has been successfully completed!
- pytorch常见损失函数
- Benefit a lot, Android interview questions
猜你喜欢
Pychrm Community Edition calls matplotlib pyplot. Solution of imshow() function image not popping up
三年Android开发,2022疫情期间八家大厂的Android面试经历和真题整理
五金机电行业供应商智慧管理平台解决方案:优化供应链管理,带动企业业绩增长
Computer network: sorting out common network interview questions (I)
Yutai micro rushes to the scientific innovation board: Huawei and Xiaomi fund are shareholders to raise 1.3 billion
提前解锁 2 大直播主题!今天手把手教你如何完成软件包集成?|第 29-30 期
凤凰架构3——事务处理
pychrm社区版调用matplotlib.pyplot.imshow()函数图像不弹出的解决方法
思维导图+源代码+笔记+项目,字节跳动+京东+360+网易面试题整理
pytorch常见损失函数
随机推荐
RT-Thread 组件 FinSH 使用时遇到的问题
Characteristic colleges and universities, jointly build Netease Industrial College
Graffiti intelligence is listed on the dual main board in Hong Kong: market value of 11.2 billion Hong Kong, with an annual revenue of 300 million US dollars
Black Horse - - Redis Chapter
第五期个人能力认证考核通过名单公布
ACTF 2022圆满落幕,0ops战队二连冠!!
五金机电行业智能供应链管理系统解决方案:数智化供应链为传统产业“造新血”
Lucun smart sprint technology innovation board: annual revenue of 400million, proposed to raise 700million
安装Mysql报错:Could not create or access the registry key needed for the...
Cereals Mall - Distributed Advanced p129~p339 (end)
The slave i/o thread stops because master and slave have equal MySQL serv
多线程基础:线程基本概念与线程的创建
ROS custom message publishing subscription example
Pytorch common loss function
ModuleNotFoundError: No module named ‘PIL‘解决方法
C # use Marshall to manually create unmanaged memory in the heap and use
GCC【7】- 编译检查的是函数的声明,链接检查的是函数的定义bug
Airiot IOT platform enables the container industry to build [welding station information monitoring system]
提前解锁 2 大直播主题!今天手把手教你如何完成软件包集成?|第 29-30 期
English topic assignment (25)