当前位置:网站首页>TensorFlow损失函数
TensorFlow损失函数
2022-07-27 08:45:00 【qq_27390023】
损失函数又称成本函数或目标函数,是真实值与预测值之间的差异,模型优化目标是最小化这个差异值。
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()
model.add(layers.Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(layers.Activation('softmax'))
loss_function = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(loss=loss_function, optimizer='adam')
model.compile(loss='sparse_categorical_crossentropy', optimizer='Adam')
### 1. 二元交叉熵
# 输出数据为0-1之间,模型输出层激活函数为sigmoid
import tensorflow as tf
y_true = [[0., 1.], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6]]
y_pred = [[0.6, 0.4], [0.4, 0.6], [0.6, 0.4], [0.8, 0.2]]
bce = tf.keras.losses.BinaryCrossentropy(reduction='sum_over_batch_size')
print(bce(y_true, y_pred).numpy()) # 0.839445
bce = tf.keras.losses.BinaryCrossentropy(reduction='sum')
print(bce(y_true, y_pred).numpy()) # 3.35778
bce = tf.keras.losses.BinaryCrossentropy(reduction='none')
print(bce(y_true, y_pred).numpy()) # [0.9162905 0.5919184 0.79465103 1.0549198 ]
### 2.分类交叉熵CategoricalCrossentropy
# 样本标签为one-hot编码,激活函数为softmax
y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
cce = tf.keras.losses.CategoricalCrossentropy()
print(cce(y_true, y_pred).numpy())
print(cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy())
cce = tf.keras.losses.CategoricalCrossentropy(reduction='sum')
print(cce(y_true, y_pred).numpy())
cce = tf.keras.losses.CategoricalCrossentropy(reduction='none')
print(cce(y_true, y_pred).numpy())
### 3.分类交叉熵SparseCategoricalCrossentropy
# 样本标签为索引位置,激活函数为softmax
# 如mnist数据集和y值不经过one-hot转化,模型fit时要用sparse_categorical_crossentropy作为损失函数
# 其它离散的分类标签,可以转化为数字索引
from sklearn import preprocessing
enc = preprocessing.OrdinalEncoder()
X = [['class1'], ['class2'],['class3']]
enc.fit(X)
X_transform = enc.transform(X)
print(X_transform)
y_true = [1, 2, 2]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1],[0.05, 0.05, 0.9]]
# Using 'auto'/'sum_over_batch_size' reduction type.
scce = tf.keras.losses.SparseCategoricalCrossentropy()
scce(y_true, y_pred).numpy()
### 4.KL散度KLDivergence
# 相对熵 KLDivergence ,也称KL散度,是连续分布的一种距离度量,通常在离散采样连续输出分布空间上直接回归。
# loss = y_true * log(y_true / y_pred)
y_true = [[0, 1], [0, 0]]
y_pred = [[0.6, 0.4], [0.4, 0.6]]
kl = tf.keras.losses.KLDivergence()
print(kl(y_true, y_pred).numpy())
print(kl(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy())
kl = tf.keras.losses.Poisson(reduction='sum')
print(kl(y_true, y_pred).numpy())
kl = tf.keras.losses.Poisson(reduction='none')
print(kl(y_true, y_pred).numpy())
### 5.泊松损失 Poisson
# 泊松损失 Poisson,适用于符合泊松分布的数据集
y_true = [[0., 1.], [0., 0.]]
y_pred = [[1., 1.], [0., 0.]]
p = tf.keras.losses.Poisson()
print(p(y_true, y_pred).numpy()) # 平均误差
print(p(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()) #带样本权重
p = tf.keras.losses.Poisson(reduction='sum') # 总和
print(p(y_true, y_pred).numpy())
p = tf.keras.losses.Poisson(reduction='none') # 分别计算每个样本的误差
print(p(y_true, y_pred).numpy())
### 6.均方差 Mean Squared Error
# 回归问题,MeanSquaredError计算真实值和预测值之间的误差的平方平均值。
y_true = [[0., 1.], [0., 0.]]
y_pred = [[1., 1.], [1., 0.]]
mse = tf.keras.losses.MeanSquaredError()
print(mse(y_true, y_pred).numpy())
print(mse(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy())
mse = tf.keras.losses.MeanSquaredError(reduction='sum')
print(mse(y_true, y_pred).numpy())
mse = tf.keras.losses.MeanSquaredError(reduction='none')
print(mse(y_true, y_pred).numpy())
### 7.均方对数误差 Mean Squared Logarithmic Error
# loss = square(log(y_true + 1.) - log(y_pred + 1.))
y_true = [[0., 1.], [0., 0.]]
y_pred = [[1., 1.], [1., 0.]]
msle = tf.keras.losses.MeanSquaredLogarithmicError()
print(msle(y_true, y_pred).numpy())
print(msle(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy())
msle = tf.keras.losses.MeanSquaredLogarithmicError(reduction='sum')
print(msle(y_true, y_pred).numpy())
msle = tf.keras.losses.MeanSquaredLogarithmicError(reduction='none')
print(msle(y_true, y_pred).numpy())
### 8.自定义损失函数
def get_loss(y_pre,y_input):
pass参考:
https://tensorflow.google.cn/api_docs/python/tf/keras/losses/Loss
边栏推荐
猜你喜欢

Process control - Branch

First experience of tryme in opengauss

redis的string类型及bitmap

4274. Suffix expression

MCDF top level verification scheme

Have a good laugh

How to upload qiniu cloud

Initial summary of flask framework creation project

Flink1.15源码阅读flink-clients客户端执行流程(阅读较枯燥)

Arm system call exception assembly
随机推荐
02 linear structure 3 reversing linked list
Vertical align cannot align the picture and text vertically
“鼓浪屿元宇宙”,能否成为中国文旅产业的“升级样本”
Do a reptile project by yourself
Redis network IO
Flask request data acquisition and response
E. Split into two sets
Initial summary of flask framework creation project
Day4 --- flask blueprint and rest ful
Arm undefined instruction exception assembly
NiO Summary - read and understand the whole NiO process
4274. Suffix expression
JS rotation chart
arguments
List删除集合元素
Openresty + keepalived to achieve load balancing + IPv6 verification
Realization of background channel group management function
Flink1.15 source code reading Flink clients client execution process (reading is boring)
Realization of backstage brand management function
P7 Day1 get to know the flask framework