我是靠谱客的博主 安静老师,这篇文章主要介绍Tensorflow学习---tf.nn.embedding_lookup,现在分享给大家,希望可以做个参考。

tf.nn.embedding_lookup(params,ids, partition_strategy=’mod’, name=None, validate_indices=True,max_norm=None)
根据ids中的id,寻找params中的对应元素,可以理解为索引,所以ids中元素值不能超出params的第一维的维数值。
比如,ids=[1,3,5],则找出params中下标为1,3,5的向量组成一个矩阵返回。
参数说明:
params: 表示完整的embedding张量,或者除了第一维度之外具有相同形状的P个张量的列表,表示经分割的嵌入张量。
ids: 一个类型为int32或int64的Tensor,包含要在params中查找的id

 

下面是代码

#!/usr/bin/python
#encoding:utf-8

import tensorflow as tf

encode_embeddings = tf.constant([[1,2,3,4,5],[6,7,8,9,0]]) #2*5
# input_ids
中元素的值和encode_embeddings的第一维的维数有关,此例中为2维,input_ids只能是[0,2),也就是0和1
input_ids =tf.constant([[1,1,0],[1,0,1],[1,0, 1],[0,1, 1]])  #4*3
session = tf.Session()
with session.as_default():
    # 结果results是4*3*5矩阵。
   
results =tf.nn.embedding_lookup(encode_embeddings,input_ids)
    print(results.eval())# tf.eval()函数用于显示张量tensor的值,但需要放在with session.as_default()中才能使用。
   
'''结果值
    [[[6 7 8 9 0]
  [6 7 8 9 0]
  [1 2 3 4 5]]

 [[6 7 8 9 0]
  [1 2 3 4 5]
  [6 7 8 9 0]]

 [[6 7 8 9 0]
  [1 2 3 4 5]
  [6 7 8 9 0]]

 [[1 2 3 4 5]
  [6 7 8 9 0]
  [6 7 8 9 0]]]'''

 

最后

以上就是安静老师最近收集整理的关于Tensorflow学习---tf.nn.embedding_lookup的全部内容,更多相关Tensorflow学习---tf内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(42)

评论列表共有 0 条评论

立即
投稿
返回
顶部