本文将通俗的进行解释embedding_lookup( )的用法
首先看一段简单代码:
复制代码
1
2
3
4
5
6
7
8
9
10
11#!/usr/bin/env/python # coding=utf-8 import tensorflow as tf import numpy as np input_ids = tf.placeholder(dtype=tf.int32, shape=[None]) embedding = tf.Variable(np.identity(5, dtype=np.int32)) input_embedding = tf.nn.embedding_lookup(embedding, input_ids) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) print(embedding.eval()) print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))
代码中先使用palceholder定义了一个未知变量input_ids用于存储索引,和一个已知变量embedding,是一个5*5的对角矩阵。
运行结果为:
复制代码
1
2
3
4
5
6
7
8
9
10
11
12embedding = [[1 0 0 0 0] [0 1 0 0 0] [0 0 1 0 0] [0 0 0 1 0] [0 0 0 0 1]] input_embedding = [[0 1 0 0 0] [0 0 1 0 0] [0 0 0 1 0] [1 0 0 0 0] [0 0 0 1 0] [0 0 1 0 0] [0 1 0 0 0]]
简单的讲就是根据input_ids中的id,寻找embedding中的对应元素。比如,input_ids=[1,3,5],则找出embedding中下标为1,3,5的向量组成一个矩阵返回。
如果将input_ids改写成下面的格式:
复制代码
1
2input_embedding = tf.nn.embedding_lookup(embedding, input_ids) print(sess.run(input_embedding, feed_dict={input_ids:[[1, 2], [2, 1], [3, 3]]}))
输出结果就会变成如下的格式:
复制代码
1
2
3
4
5
6[[[0 1 0 0 0] [0 0 1 0 0]] [[0 0 1 0 0] [0 1 0 0 0]] [[0 0 0 1 0] [0 0 0 1 0]]]
对比上下两个结果不难发现,相当于在np.array中直接采用下标数组获取数据。需要注意的细节是返回的tensor的dtype和传入的被查询的tensor的dtype保持一致;和ids的dtype无关。
最后
以上就是义气缘分最近收集整理的关于Tensorflow 中的embedding_lookup详解的全部内容,更多相关Tensorflow内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复