当前位置:网站首页>JAX-based activation function, softmax function and cross entropy function
JAX-based activation function, softmax function and cross entropy function
2022-08-01 23:10:00 【Albert Darren】
1.tanh激活函数公式
tanh ( x ) = e x − e − x e x + e − x \tanh \left( x \right) =\frac{e^x-e^{-x}}{e^x+e^{-x}} \notag tanh(x)=ex+e−xex−e−x
2.基于JAX实现tanh激活函数
import jax.numpy as jnp
from jax import random
def tanh(x):
"""tanh function"""
return (jnp.exp(x)-jnp.exp(-x))/(jnp.exp(x)+jnp.exp(-x))
# Set the pseudorandom number seed
rng=random.PRNGKey(0)
# Standard normal sampling yields the input vector
x=random.normal(rng,shape=(4,1))
# 调用内置tanh函数实现
print(jnp.tanh(x))
# 调用自定义tanh函数实现
print(tanh(x))
3.softmax函数公式
s i = e x i ∑ i = 0 j e x i s_i=\frac{e^{x_i}}{\sum\limits _{i=0}^{j}e^{x_i} } \notag si=i=0∑jexiexi
其中 x i x_i xi表示第 i i ireal-valued output of a neuron
4.基于JAX实现softmax函数
import jax.numpy as jnp
import jax.nn as nn
def softmax(x,axis=-1):
"""softmax function"""
unnormalized=jnp.exp(x)
return unnormalized/jnp.sum(unnormalized)
# 定义数组
arr=jnp.arange(-2,4)
# 调用自定义softmax函数
print(softmax(arr))
# 调用jax自带softmax函数
print(nn.softmax(arr))
5.Cross-entropy function formula
H ( p , q ) = − ∑ i = 1 n p ( x i ) log ( q ( x i ) ) H\left( p,q \right) =-\sum_{i=1}^n{p\left( x_i \right) \log \left( q\left( x_i \right) \right)} \notag H(p,q)=−i=1∑np(xi)log(q(xi))
其中 p ( x ) p(x) p(x)represents the true probability distribution, q ( x ) q(x) q(x)represents the predicted probability distribution
6.基于JAXImplement the cross-entropy function
import jax.numpy as jnp
def cross_entropy(y_true,y_pred,eps=1e-7):
"""cross entropy function :param y_true:真实标签 :param y_pred:Neural network predicts labels :param eps:Default minimal positive number,The logarithm is guaranteed to be true0,增强logFunction Numerical Stability :return:交叉熵,保留到小数点后4位 """
y_true=jnp.array(y_true)
y_pred=jnp.array(y_pred)
res=-jnp.sum(y_true*jnp.log(y_pred+eps),axis=-1)
return jnp.round(res,4)
# 预测概率分布
y_pred=[0.1,0.05,0.6,0.0,0.05,0.1,0.0,0.1,0.0,0.0]
# 真实概率分布
y_true=[0,0,1,0,0,0,0,0,0,0]
# 交叉熵为0.5108
print(cross_entropy(y_true,y_pred))
边栏推荐
猜你喜欢

下载安装 vscode(含汉化、插件的推荐和安装)

解决yolov5训练时出现:“AssertionError: train: No labels in VOCData/dataSet_path/train.cache. Can not train ”

leetcode刷题

xctf attack and defense world web master advanced area webshell

xctf attack and defense world web master advanced area web2

UML diagram of soft skills

Small application project works WeChat stadium booking applet graduation design of the finished product (1) the development profile

从0到100:招生报名小程序开发笔记

xss相关知识点以及从 XSS Payload 学习浏览器解码

drf生成序列化类代码
随机推荐
如何使用pywinauto和pyautogui将动漫小姐姐链接请回家
PDF转Word有那么难吗?做一个文件转换器,都解决了
Thesis understanding [RL - Exp Replay] - Experience Replay with Likelihood-free Importance Weights
数据增强--学习笔记(图像类,cnn)
C#大型互联网平台管理框架源码:基于ASP.NET MVC+EF6+Bootstrap开发,支持多数据库
C语言——分支语句和循环语句
【好书推荐】第一本无人驾驶技术书
SQL Server(设计数据库--存储过程--触发器)
10年稳定性保障经验总结,故障复盘要回答哪三大关键问题?|TakinTalks大咖分享
解决端口占用
When using DocumentFragments add a large number of elements
y84. Chapter 4 Prometheus Factory Monitoring System and Actual Combat -- Advanced Prometheus Alarm Mechanism (15)
2022/7/31
软技能之UML图
Three, mysql storage engine - building database and table operation
Error creating bean with name ‘dataSource‘:Unsatisfied dependency expressed through field ‘basicPro
problem solved
DRF generating serialization class code
Deep Learning Course2 Week 2 Optimization Algorithms Exercises
部门项目源码分享