我是靠谱客的博主 悦耳煎饼,最近开发中收集的这篇文章主要介绍Tensorflow2(Temporal-level Attention Calculation),觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

论文 : Vehicle Trajectory Prediction Using LSTMs with Spatial-Temporal Attention Mechanisms
 


import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import initializers
class TemporalAttention(keras.Model):
def __init__(self,fin,fout=1):
super(TemporalAttention,self).__init__()
self.fin = fin # 输入维度
self.fout = fout # 输出维度 这里为1 求得是分数
self.initializer = initializers.GlorotUniform() # 初始化分布
# 自定义可学习参数
self.w = tf.Variable(self.initializer(shape=[self.fin, self.fout], dtype=tf.float32))
def call(self,h): # h:[bs,seq,fin]
x = h # [bs,seq,fin]
alpha = h @ self.w # [bs,seq,1] fout==1
alpha = tf.nn.softmax(tf.tanh(alpha),1) # [bs,seq,1]
x = tf.einsum('ijk,ijm->ikm', alpha, x) # [bs,1,fin]
return tf.squeeze(x,[1]) # [bs,fin]
a = tf.random.normal([42,8,64])
model = TemporalAttention(64,1)
z = model(a)
print(z.shape)

 

最后

以上就是悦耳煎饼为你收集整理的Tensorflow2(Temporal-level Attention Calculation)的全部内容,希望文章能够帮你解决Tensorflow2(Temporal-level Attention Calculation)所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(99)

评论列表共有 0 条评论

立即
投稿
返回
顶部