当前位置:网站首页>Tensorflow2.0 self defined training method to solve function coefficients
Tensorflow2.0 self defined training method to solve function coefficients
2022-07-06 19:29:00 【Hekai】
When doing curve fitting and approximation , If you know the approximate equation of fitting , How to find the coefficients of the equation ?
Actually in numpy There are already functions in to find polynomial coefficients , as follows
# x and y Corresponding to the coordinate axis x and y.
z = np.polyfit(x, y, 3) # use 3 Quadratic polynomial fitting , The output coefficient ranges from high to 0
# p Is the equation
p = np.poly1d(z) # Use degree to synthesize polynomials
But sometimes curves are more than polynomials , It's an exponential function 、 power function , This requires us to write by ourselves , In fact, it's the same as finding the coefficients of a linear function before , It's just another equation
# Here we define independent variables and dependent variables x,y.
x = tf.constant([1, 2, 3, 4, 5, 6],dtype=tf.float64)
y = tf.constant([11, 23, 45, 46, 48, 59],dtype=tf.float64)
# Here we define the parameters that need to be updated during reverse transmission , That is, the parameters we require
w_a = tf.Variable(-0.0, dtype=tf.float64,name='a')
w_b = tf.Variable(0.1, dtype=tf.float64,name='b')
w_c = tf.Variable(-0.0, dtype=tf.float64,name='c')
w_d = tf.Variable(2.1, dtype=tf.float64,name='d')
# Put our coefficients in an array , Convenient back transmission
variables=[w_a, w_b, w_c, w_d]
# Define the number of cycles
epoch = 1000
# Define optimizer
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)
# loop
for e in range(epoch):
with tf.GradientTape() as tape:
# Prediction function
y_pred=w_a*tf.exp(w_b*x)+w_c*x+w_d
# Loss function
loss=tf.reduce_sum(tf.square(y_pred-y))
# Take the derivative of the variable
grads=tape.gradient(loss, variables)
# Adjust the parameters
optimizer.apply_gradients(grads_and_vars=zip(grads, variables))
# Print the results
if e%100==0:
print(f"step: {
e}, loss: {
loss}, a: {
w_a.numpy()}, b: {
w_b.numpy()}, c: {
w_c.numpy()}, d: {
w_d.numpy()}")
# Output parameters
print(variables)
sauce
边栏推荐
- ZABBIX proxy server and ZABBIX SNMP monitoring
- Meilu biological IPO was terminated: the annual revenue was 385million, and Chen Lin was the actual controller
- 利用 clip-path 绘制不规则的图形
- 全套教学资料,阿里快手拼多多等7家大厂Android面试真题
- Php+redis realizes the function of canceling orders over time
- C # use Marshall to manually create unmanaged memory in the heap and use
- 通俗的讲解,带你入门协程
- spark基础-scala
- Analysis of frequent chain breaks in applications using Druid connection pools
- [pytorch] yolov5 train your own data set
猜你喜欢

学习探索-使用伪元素清除浮动元素造成的高度坍塌

潇洒郎: AttributeError: partially initialized module ‘cv2‘ has no attribute ‘gapi_wip_gst_GStreamerPipe

Interface test tool - postman

Detailed idea and code implementation of infix expression to suffix expression

史上超级详细,想找工作的你还不看这份资料就晚了

MRO industrial products enterprise procurement system: how to refine procurement collaborative management? Industrial products enterprises that want to upgrade must see!

Mathematical knowledge -- code implementation of Gaussian elimination (elementary line transformation to solve equations)

The second day of rhcsa study

谷粒商城--分布式高级篇P129~P339(完结)

Cereals Mall - Distributed Advanced p129~p339 (end)
随机推荐
MRO industrial products enterprise procurement system: how to refine procurement collaborative management? Industrial products enterprises that want to upgrade must see!
A method of removing text blur based on pixel repair
Camel case with Hungarian notation
An error occurs when installing MySQL: could not create or access the registry key needed for the
[pytorch] yolov5 train your own data set
黑马--Redis篇
Pychrm Community Edition calls matplotlib pyplot. Solution of imshow() function image not popping up
JDBC details
Yyds dry goods inventory leetcode question set 751 - 760
Abstract classes and abstract methods
Graffiti intelligence is listed on the dual main board in Hong Kong: market value of 11.2 billion Hong Kong, with an annual revenue of 300 million US dollars
LeetCode_双指针_中等_61. 旋转链表
Use of deg2rad and rad2deg functions in MATLAB
First day of rhcsa study
Unbalance balance (dynamic programming, DP)
通俗的讲解,带你入门协程
R language ggplot2 visualization: use the ggstripchart function of ggpubr package to visualize the grouped dot strip plot, and set the add parameter to add box plots for different levels of dot strip
R language uses the order function to sort the dataframe data, and descending sorting based on a single field (variable)
第五期个人能力认证考核通过名单公布
zabbix 代理服务器 与 zabbix-snmp 监控