当前位置:网站首页>JAX-based activation function, softmax function and cross entropy function
JAX-based activation function, softmax function and cross entropy function
2022-08-01 23:10:00 【Albert Darren】
1.tanh激活函数公式
tanh ( x ) = e x − e − x e x + e − x \tanh \left( x \right) =\frac{e^x-e^{-x}}{e^x+e^{-x}} \notag tanh(x)=ex+e−xex−e−x
2.基于JAX实现tanh激活函数
import jax.numpy as jnp
from jax import random
def tanh(x):
"""tanh function"""
return (jnp.exp(x)-jnp.exp(-x))/(jnp.exp(x)+jnp.exp(-x))
# Set the pseudorandom number seed
rng=random.PRNGKey(0)
# Standard normal sampling yields the input vector
x=random.normal(rng,shape=(4,1))
# 调用内置tanh函数实现
print(jnp.tanh(x))
# 调用自定义tanh函数实现
print(tanh(x))
3.softmax函数公式
s i = e x i ∑ i = 0 j e x i s_i=\frac{e^{x_i}}{\sum\limits _{i=0}^{j}e^{x_i} } \notag si=i=0∑jexiexi
其中 x i x_i xi表示第 i i ireal-valued output of a neuron
4.基于JAX实现softmax函数
import jax.numpy as jnp
import jax.nn as nn
def softmax(x,axis=-1):
"""softmax function"""
unnormalized=jnp.exp(x)
return unnormalized/jnp.sum(unnormalized)
# 定义数组
arr=jnp.arange(-2,4)
# 调用自定义softmax函数
print(softmax(arr))
# 调用jax自带softmax函数
print(nn.softmax(arr))
5.Cross-entropy function formula
H ( p , q ) = − ∑ i = 1 n p ( x i ) log ( q ( x i ) ) H\left( p,q \right) =-\sum_{i=1}^n{p\left( x_i \right) \log \left( q\left( x_i \right) \right)} \notag H(p,q)=−i=1∑np(xi)log(q(xi))
其中 p ( x ) p(x) p(x)represents the true probability distribution, q ( x ) q(x) q(x)represents the predicted probability distribution
6.基于JAXImplement the cross-entropy function
import jax.numpy as jnp
def cross_entropy(y_true,y_pred,eps=1e-7):
"""cross entropy function :param y_true:真实标签 :param y_pred:Neural network predicts labels :param eps:Default minimal positive number,The logarithm is guaranteed to be true0,增强logFunction Numerical Stability :return:交叉熵,保留到小数点后4位 """
y_true=jnp.array(y_true)
y_pred=jnp.array(y_pred)
res=-jnp.sum(y_true*jnp.log(y_pred+eps),axis=-1)
return jnp.round(res,4)
# 预测概率分布
y_pred=[0.1,0.05,0.6,0.0,0.05,0.1,0.0,0.1,0.0,0.0]
# 真实概率分布
y_true=[0,0,1,0,0,0,0,0,0,0]
# 交叉熵为0.5108
print(cross_entropy(y_true,y_pred))
边栏推荐
- 基于JAX的激活函数、softmax函数和交叉熵函数
- (Translation) How the contrasting color of the button guides the user's actions
- B. Difference Array--Codeforces Round #808 (Div. 1)
- qt-faststart 安装使用
- 6132. 使数组中所有元素都等于零-快速排序法
- SQL29 Calculate the average next day retention rate of users
- From 0 to 1: Design and R&D Notes of Graphic Voting Mini Program
- excel split text into different rows
- Codeforces CodeTON Round 2 (Div. 1 + Div. 2, Rated, Prizes!) A-D Solution
- 选择合适的 DevOps 工具,从理解 DevOps 开始
猜你喜欢
随机推荐
【参营经历贴】2022网安夏令营
JS prototype hasOwnProperty in Add method Prototype end point Inherit Override parent class method
SRv6 L3VPN的工作原理
Check if point is inside rectangle
What is CICD excuse me
CF1705D Mark and Lightbulbs
基于JAX的激活函数、softmax函数和交叉熵函数
PostgreSQL 基础--常用命令
Oracle 数据库设置为只读及读写
Jmeter是什么
Chapter 11 Working with Dates and Times
Mini Program Graduation Works WeChat Food Recipe Mini Program Graduation Design Finished Product (8) Graduation Design Thesis Template
【C补充】链表专题 - 单向链表
6133. Maximum number of packets
文件查询匹配神器 【glob.js】 实用教程
SQL29 Calculate the average next day retention rate of users
PHP算法之有效的括号
如何更好的理解的和做好工作?
Wechat Gymnasium Reservation Mini Program Graduation Design Finished Work Mini Program Graduation Design Finished Product (2) Mini Program Function
vscode hide menu bar