当前位置:网站首页>TensorFlow损失函数
TensorFlow损失函数
2022-07-27 08:45:00 【qq_27390023】
损失函数又称成本函数或目标函数,是真实值与预测值之间的差异,模型优化目标是最小化这个差异值。
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()
model.add(layers.Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(layers.Activation('softmax'))
loss_function = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(loss=loss_function, optimizer='adam')
model.compile(loss='sparse_categorical_crossentropy', optimizer='Adam')
### 1. 二元交叉熵
# 输出数据为0-1之间,模型输出层激活函数为sigmoid
import tensorflow as tf
y_true = [[0., 1.], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6]]
y_pred = [[0.6, 0.4], [0.4, 0.6], [0.6, 0.4], [0.8, 0.2]]
bce = tf.keras.losses.BinaryCrossentropy(reduction='sum_over_batch_size')
print(bce(y_true, y_pred).numpy()) # 0.839445
bce = tf.keras.losses.BinaryCrossentropy(reduction='sum')
print(bce(y_true, y_pred).numpy()) # 3.35778
bce = tf.keras.losses.BinaryCrossentropy(reduction='none')
print(bce(y_true, y_pred).numpy()) # [0.9162905 0.5919184 0.79465103 1.0549198 ]
### 2.分类交叉熵CategoricalCrossentropy
# 样本标签为one-hot编码,激活函数为softmax
y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
cce = tf.keras.losses.CategoricalCrossentropy()
print(cce(y_true, y_pred).numpy())
print(cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy())
cce = tf.keras.losses.CategoricalCrossentropy(reduction='sum')
print(cce(y_true, y_pred).numpy())
cce = tf.keras.losses.CategoricalCrossentropy(reduction='none')
print(cce(y_true, y_pred).numpy())
### 3.分类交叉熵SparseCategoricalCrossentropy
# 样本标签为索引位置,激活函数为softmax
# 如mnist数据集和y值不经过one-hot转化,模型fit时要用sparse_categorical_crossentropy作为损失函数
# 其它离散的分类标签,可以转化为数字索引
from sklearn import preprocessing
enc = preprocessing.OrdinalEncoder()
X = [['class1'], ['class2'],['class3']]
enc.fit(X)
X_transform = enc.transform(X)
print(X_transform)
y_true = [1, 2, 2]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1],[0.05, 0.05, 0.9]]
# Using 'auto'/'sum_over_batch_size' reduction type.
scce = tf.keras.losses.SparseCategoricalCrossentropy()
scce(y_true, y_pred).numpy()
### 4.KL散度KLDivergence
# 相对熵 KLDivergence ,也称KL散度,是连续分布的一种距离度量,通常在离散采样连续输出分布空间上直接回归。
# loss = y_true * log(y_true / y_pred)
y_true = [[0, 1], [0, 0]]
y_pred = [[0.6, 0.4], [0.4, 0.6]]
kl = tf.keras.losses.KLDivergence()
print(kl(y_true, y_pred).numpy())
print(kl(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy())
kl = tf.keras.losses.Poisson(reduction='sum')
print(kl(y_true, y_pred).numpy())
kl = tf.keras.losses.Poisson(reduction='none')
print(kl(y_true, y_pred).numpy())
### 5.泊松损失 Poisson
# 泊松损失 Poisson,适用于符合泊松分布的数据集
y_true = [[0., 1.], [0., 0.]]
y_pred = [[1., 1.], [0., 0.]]
p = tf.keras.losses.Poisson()
print(p(y_true, y_pred).numpy()) # 平均误差
print(p(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()) #带样本权重
p = tf.keras.losses.Poisson(reduction='sum') # 总和
print(p(y_true, y_pred).numpy())
p = tf.keras.losses.Poisson(reduction='none') # 分别计算每个样本的误差
print(p(y_true, y_pred).numpy())
### 6.均方差 Mean Squared Error
# 回归问题,MeanSquaredError计算真实值和预测值之间的误差的平方平均值。
y_true = [[0., 1.], [0., 0.]]
y_pred = [[1., 1.], [1., 0.]]
mse = tf.keras.losses.MeanSquaredError()
print(mse(y_true, y_pred).numpy())
print(mse(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy())
mse = tf.keras.losses.MeanSquaredError(reduction='sum')
print(mse(y_true, y_pred).numpy())
mse = tf.keras.losses.MeanSquaredError(reduction='none')
print(mse(y_true, y_pred).numpy())
### 7.均方对数误差 Mean Squared Logarithmic Error
# loss = square(log(y_true + 1.) - log(y_pred + 1.))
y_true = [[0., 1.], [0., 0.]]
y_pred = [[1., 1.], [1., 0.]]
msle = tf.keras.losses.MeanSquaredLogarithmicError()
print(msle(y_true, y_pred).numpy())
print(msle(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy())
msle = tf.keras.losses.MeanSquaredLogarithmicError(reduction='sum')
print(msle(y_true, y_pred).numpy())
msle = tf.keras.losses.MeanSquaredLogarithmicError(reduction='none')
print(msle(y_true, y_pred).numpy())
### 8.自定义损失函数
def get_loss(y_pre,y_input):
pass参考:
https://tensorflow.google.cn/api_docs/python/tf/keras/losses/Loss
边栏推荐
- Node installation and debugging
- Sequential storage and chain storage of stack implementation
- 4275. Dijkstra sequence
- Hangzhou E-Commerce Research Institute released an explanation of the new term "digital existence"
- 微信安装包从0.5M暴涨到260M,为什么我们的程序越来越大?
- Arm system call exception assembly
- E. Split into two sets
- 03. Use quotation marks to listen for changes in nested values of objects
- 03.使用引号来监听对象嵌套值的变化
- 【nonebot2】几个简单的机器人模块(一言+彩虹屁+每日60s)
猜你喜欢

The wechat installation package has soared from 0.5m to 260m. Why are our programs getting bigger and bigger?

Block, there is a gap between the block elements in the row

4279. Cartesian tree

4279. 笛卡尔树

Oppo self-developed large-scale knowledge map and its application in digital intelligence engineering

永久设置source的方法

4276. 擅长C

Arm system call exception assembly

Low cost, low threshold, easy deployment, a new choice for the digital transformation of 48 million + small and medium-sized enterprises

Flink1.15 source code reading Flink clients client execution process (reading is boring)
随机推荐
微信安装包从0.5M暴涨到260M,为什么我们的程序越来越大?
Day5 - Flame restful request response and Sqlalchemy Foundation
接口测试工具-Jmeter压力测试使用
Include error in vs Code (new header file)
Creation and simple application of QPushButton button button
4274. Suffix expression
Is online account opening safe? Want to know how securities companies get preferential accounts?
Element display mode: block level, inline, inline block, nesting specification, display mode conversion
General Administration of Customs: the import of such products is suspended
如何在B站上快乐的学习?
Do a reptile project by yourself
Background coupon management
Cenos7 update MariaDB
String type and bitmap of redis
Set set
NIO this.selector.select()
Using ecological power, opengauss breaks through the performance bottleneck
User management - restrictions
NiO example
How to merge multiple columns in an excel table into one column