当前位置:网站首页>tf. keras. layers. Attention understanding summary
tf. keras. layers. Attention understanding summary
2022-06-30 09:46:00 【A grain of sand in the vast sea of people】
The official link :https://tensorflow.google.cn/versions/r2.1/api_docs/python/tf/keras/layers/Attention
tf.keras.layers.Attention(
use_scale=False, **kwargs
)
Inputs are query tensor of shape [batch_size, Tq, dim], value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:
- Calculate scores with shape
[batch_size, Tq, Tv]as aquery-keydot product:scores = tf.matmul(query, key, transpose_b=True). - Use scores to calculate a distribution with shape
[batch_size, Tq, Tv]:distribution = tf.nn.softmax(scores). - Use
distributionto create a linear combination ofvaluewith shapebatch_size, Tq, dim]:return tf.matmul(distribution, value).
Example 1
import tensorflow as tf
import numpy as np
query = tf.convert_to_tensor(np.asarray([[[1., 1., 1., 3.]]]))
key_list = tf.convert_to_tensor(np.asarray([[[1., 1., 2., 4.], [4., 1., 1., 3.], [1., 1., 2., 1.]],
[[1., 0., 2., 1.], [1., 2., 1., 2.], [1., 0., 2., 1.]]]))
query_value_attention_seq = tf.keras.layers.Attention()([query, key_list])
print('query shape:', query.shape)
print('key shape:', key_list.shape)
print('result 1:',query_value_attention_seq)result :
query shape: (1, 1, 4)
key shape: (2, 3, 4)
result 1: tf.Tensor(
[[[1.8067516 1. 1.7310829 3.730812 ]]
[[0.99999994 1.9293262 1.0353367 1.9646629 ]]], shape=(2, 1, 4), dtype=float32)Implement by yourself according to the steps mentioned in the document
scores = tf.matmul(query, key_list, transpose_b=True)
distribution = tf.nn.softmax(scores)
result = tf.matmul(distribution, key_list)
print('result 2:',query_value_attention_seq)give the result as follows : We can see that the result is the same as we understand
result 2: tf.Tensor(
[[[1.8067516 1. 1.7310829 3.730812 ]]
[[0.99999994 1.9293262 1.0353367 1.9646629 ]]], shape=(2, 1, 4), dtype=float32)边栏推荐
- Solution to the eighth training competition of 2020 Provincial Games
- How to reduce the delay in live broadcast in the development of live broadcast source code with goods?
- 云技能提升好伙伴,亚马逊云师兄今天正式营业
- 【新书推荐】Deno Web Development
- 银河麒麟server-V10配置镜像源
- NTP of Prometheus monitoring_ exporter
- Solution to the sixth training competition of 2020 provincial competition
- 八大排序(一)
- ABAP time function
- Tablet PC based ink handwriting recognition input method
猜你喜欢

【新书推荐】Deno Web Development

仿照微信Oauth2.0接入方案

八大排序(一)

Express file upload

小程序手持弹幕的原理及实现(uni-app)

Express の post request

JVM notes (III): analysis of JVM object creation and memory allocation mechanism

Summary of Android knowledge points and common interview questions

2021-10-20

Machine learning note 9: prediction model optimization (to prevent under fitting and over fitting problems)
随机推荐
1. Basic configuration
How to reduce the delay in live broadcast in the development of live broadcast source code with goods?
JWT expiration processing - single token scheme
Eight sorts (II)
Tclistener server and tcpclient client use -- socket listening server and socketclient use
Solution to pychart's failure in importing torch package
Based on svelte3 X desktop UI component library svelte UI
MySQL优化
Code management related issues
机器学习笔记 九:预测模型优化(防止欠拟合和过拟合问题发生)
[new book recommendation] DeNO web development
Linear-gradient()
Properties of string
小程序手持弹幕的原理及实现(uni-app)
(zero) most complete JVM knowledge points
Pytorch for former Torch users - Tensors
Flutter 中的 ValueNotifier 和 ValueListenableBuilder
Express の post request
[ubuntu-mysql 8 installation and master-slave replication]
Idea shortcut key settings