我是靠谱客的博主 懦弱咖啡,最近开发中收集的这篇文章主要介绍TensorFlow Python API解析:图的核心数据结构,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

本文默认的一些名称叫法
1. “Tensor实例”、”Tensor对象”、”tensor”都是在说一个Tensor实例。同样,”Operation实例”、”Operation对象”、”operation”都是在说一个Operation实例。
2. 单独的”Tensor”和”Operation”都表示一个Python class
3. “operation”有时简写为”op”或”Op”,”tensor”有时简写为”t”

代码位置:tensorflow/tensorflow/python/framework/ops.py

一、Graph类

1. 要点

  1. TensorFlow中的计算,表示为一个数据流图,简称“图”
  2. 一个Graph实例就是一个图,由一组Operation对象和Tensor对象构成:每个Operation对象(简记为op)表示最小的计算单元,每个Tensor对象表示在operations间传递的基本数据单元
  3. 如果你没有注册自己的图,系统会提供一个默认图。你可通过调用tf.get_default_graph()显式地访问这个图,也可以不理会这个图,因为调用任一个operation函数时,如调用constant op,c=tf.constant(4.0),一个表示operation的节点会自动添加到这个图上,此时c.graph就指这个默认图。
  4. 如果我们创建了一个Graph实例,并想用它取代上面的默认图,把它指定为一个新的默认图,至少是临时换一下,可以调用该Graph实例的as_default()方法,并得到一个Python中的上下文管理器(context manager),来管理临时默认图的生命周期,即with ...下的代码区域。
g = tf.Graph()
with g.as_default():
  # 此时定义的operation和tensor都自动添加到图g上
  c = tf.constant(30.0)

小提示:组装图阶段,Graph类不是线程安全的,添加operations最好在单线程内完成。

2. Graph的属性

内部属性

  • 与operation相关:
    • _nodes_by_id:dict( op的id => op ),按id记录所有添加到图上的op
    • _nodes_by_name:dict( op的name => op ),按名字记录所有添加到图上的op
    • _next_id_counter:int,自增器,创建下一个op时用的id
    • _version:int,记录所有op中最大的id
    • _default_original_op:有些op需要附带一个original_op,如replica op需要指出它要对哪个op进行复制
    • _attr_scope_map:dict( name scope => attr ),用于添加一组额外的属性到指定scope中的所有op
    • _op_to_kernel_label_map:dict( op type => kernel label ),kernel可能是指operation中更底层的实现
    • _gradient_override_map:dict( op type => 另一个op type ),把一个含自定义gradient函数的注册op,用在一个已存在的op上
  • 与命名域name scope相关:
    • _name_stack:字符串,嵌套的各个scopes的名字拼成的栈,用带间隔符”/”的字符串表示
    • _names_in_use: dict( name scope => 使用次数 )
  • 与device相关:
    • _device_function_stack: list,用来选择device的函数栈,每个元素是一个device_function(op),用来获取op所在device
  • 与控制流相关:
    • _control_flow_context:一个context对象,表示当前控制流的上下文,如CondContext对象,WhileContext对象,定义在ops/control_flow_ops.py。实际上,控制流也是一个op,用来控制其他op的执行,添加一些条件依赖的关系到图中,使执行某个operation前先查看依赖
    • _control_dependencies_stack:list,一个控制器栈,每个控制器是一个上下文,存有控制依赖信息,表明当执行完依赖中的operations和tensors后,才能执行此上下文中的operations
  • 与feed和fetch相关:
    • _unfeedable_tensors:set,定义不能feed的tensors
    • _unfetchable_ops:set,定义不能fetch的ops
    • _handle_feeders:dict( tensor handle placeholder => tensor dtype )
    • _handle_readers:dict( tensor handle => 它的read op )
    • _handle_movers:dict( tensor handle => 它的move op )
    • _handle_deleters:dict( tensor handle => 它delete op )
  • 图需要:
    • _seed:当前图内使用的随机种子
    • _collections:dict( collection name => collection ),相当于图中的一块缓存,每个collection可看成一个list,可以存任何对象
    • _functions:定义图内使用中的一些函数
    • _container:资源容器resource container,用来存储跟踪stateful operations,如:variables,queues
    • _registered_ops:注册的所有操作
  • 程序运行需要:
    • _finalized:布尔值,真表示Graph属性都已确定,不再做修改
    • _lock:保证读取Graph某些属性(如:_version)时尽可能线程安全
  • TensorFlow框架需要:
    • _graph_def_version:图定义的版本
  • 其他:
    • _building_function:该图是否表示一个函数
    • _colocation_stack:保存共位设置(其他op都与指定op共位)的栈

对外属性

  • tf.Graph.version,也就是self._version,记录最新的节点version,即图中最大op id,但是与GraphDef的version无关
  • tf.Graph.graph_def_versions,也就是self._graph_def_versions,GraphDef版本,定义在tensorflow/tensorflow/core/framework/graph.proto
  • tf.Graph.seed,也就是self._seed,此图内使用的随机种子
  • tf.Graph.building_function,也就是self._building_function
  • tf.Graph.finalized,也就是self._finalized,表明组装图阶段是否完成

3. Graph的主要方法

构造方法

  • tf.Graph.__init__():创建一个空图

获取图元素(tensors, operations)的方法

  • tf.Graph.as_graph_def(from_version=None, add_shapes=False):返回该graph对应的GraphDef表示,使用了protocol buffer,见下面的message GraphDef。该方法是线程安全的。
    • 传参:(1) from_version表明包括的节点version(即op id)的范围,from_version之前的节点都不要;(2) add_shapes如果为真,则每个节点都要添加输出tensors的形状信息到_output_shapes
messsage GraphDef {
  // 图中的所有节点,参见下面message NodeDef
  repeated NodeDef node = 1;
  // 图的版本,不同于TensorFlow版本
  VersionDef versions = 4;
  // 丢弃
  int32 version = 3 [deprecated = true];
  // Experimental. 提供用户自定义的函数
  FunctionDefLibrary library = 2;
}
  • tf.Graph.as_graph_element(obj, allow_tensor=True, allow_operation=True):该获取信息的方法实际上完成了一个验证加转换的工作,给定一个obj,看它能否对应到图中的元素,可以是一个operation,也可以是一个tensor,如果对应,则以operation或tensor的身份返回它自己。该方法可以被多个线程同时调用。
    • 传参:(1) obj:可以是一个Tensor对象,或一个Operation对象,或tensor名,或operation名,或其他对象;(2) allow_tensor:真表示obj可以是tensor;(3) allow_operation:真表示obj可以是operation
  • get系列方法,可被多个线程同时调用:
    • tf.Graph.get_operation_by_name(name):根据名字获取某个operation
    • tf.Graph.get_tensor_by_name(name):根据名字获取某个tensor
    • tf.Graph.get_operations():获取所有operations
  • 判断是否可feed或可fetch
    • tf.Graph.is_feedable(tensor)
    • tf.Graph.is_fetchable(tensor_or_op)
  • 设置不可feed或不可fetch
    • tf.Graph.prevent_feeding(tensor)
    • tf.Graph.prevent_fetching(op)

添加节点组装图的方法

  • tf.Graph.unique_name(name, mark_as_used=True):为operation name构造一个唯一名,唯一名可能包含分隔符"/"。传参mark_as_used表示构造的唯一名,只是用来看看,还是要被创建出来使用的,将其传给方法create_op()来创建一个operation。
  • tf.Graph.create_op(op_type, inputs, dtypes, ...):这是一个低级别接口,开发者一般用不到,因为只用具体op的构造函数,如tf.constant(),即可实现向图添加op节点。此方法返回一个Operation实例,却有很多传入参数,包括:
    • 必填的有:(1) op_type:创建的op类型,也就是操作方法名,如”MatMul”,对应OpDef.name字段;(2) inputs:op的输入,是一个由Tensor对象组成的列表;(3) dtypes:op输出的tensors的数据类型,是一个由DType对象组成的列表
    • 可选的有:(1) input_types:op输入的tensors的类型,是一个由DType对象组成的列表,默认使用inputs中的tensors自带的Dtype;(2) name:op做节点的名字,默认基于op_type构造出;(3) attrs:dict( 属性名 => operation的属性),在NodeDef proto中有定义;(4) op_def:OpDef proto,是一个描述operation操作方法的protocol buffer;(5) compute_shapes:布尔值,是否计算输出的tensors的形状;(6) compute_device:布尔值,是否执行device_function来获取operation的device

结束组装图的定稿方法

  • tf.Graph.finalize():结束组装图,以后图只能读不能写,不能再添加新operation节点。此方法用在图要在多个线程间共享的场景下,如用于QueueRunner

切换默认图的方法

  • tf.Graph.as_default():让当前图取代默认图。一般来说,不常使用,因为当前线程会自动提供一个全局默认图,也就是说,全局默认图是当前线程的一个属性,新建一个线程后全局默认图就变了。除非你在同一个进程内创建了多个图,才会有用这个方法的需求。
    • 返回:一个context manager,实际上当前设为默认的graph就是一个上下文,在它的代码块内执行的op都会添加到该图上
# 两种等价方法
# 方法一:
g = tf.Graph()
with g.as_default():
  ...
# 方法二:
with tf.Graph().as_default() as g: # 当前上下文就是刚创建的g
  ...

返回新上下文的方法

  • tf.Graph.control_dependencies(control_inputs):返回一个控制依赖的上下文,使得上下文内的新加入op都有此依赖。控制依赖的意思是,若想执行下步的op,必须先完成依赖中的input。此方法的上下文就是一个控制器controller,controller内保存了新建的控制依赖,同时controller加入一个控制器栈controller stack,以支持嵌套的控制依赖上下文。
    • 传参:control_inputs是一个operation或tensor的列表
with g.control_dependencies([a, b]):
  # 这里新建的op在a,b后执行
  with g.control_dependencies([c, d]):
    # 这里新建的op在a,b,c,d后执行
    with g.control_dependencies(None):
      # 因为依赖链断掉,这里新建的op不需等待a,b,c,d
      with.control_dependencies([e, f])
        # 这里新建的op在e,f后执行
  • tf.Graph.device(device_name_or_function):返回一个默认device的上下文,使得上下文内的新加入op都被分派到该device上。此方法的上下文就是一个device function,它被压入一个device function stack,以支持嵌套。
    • 传参:device_name_or_function可以是一个表示device名的字符串,或一个返回device名的函数,或None
    • 特例:无论位于哪个device上下文中,variable assignment op v.assign()将随它的Variable对象v放在一起
def matmul_on_gpu(node):
  if node.type == "MatMul": # node.type就是指op type
    return "/gpu:0"
  else:
    return "/cpu:0"

with g.device(matmul_on_gpu):
  # 此处新建的所有type为"MatMul"的op将放在GPU 0上,其他则放在CPU 0上
  with g.device('/gpu:0'):
    # 此处新建的所有op都将放在GPU 0上

device的命名格式:
/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
- <name>:标识id,为一个字符串,形如[a-zA-Z][_a-zA-Z]*,比如/job:param_server为一个名为”param_server”的job
- <type>:支持的设备类型,如”cpu”或”gpu”
- <replica>,<task>,<device_num>:小的非负整数
举例:
(1) /job:w/replica:0/task:0/device:gpu:*:位于job w的replica 0, task 0上的任何gpu devices
(2) /job:*/replica:*/task:*/device:cpu:*:位于任何job/task/replica的任何cpu devices

  • tf.Graph.name_scope(name):返回一个层级命名operation的上下文。一个图维护一个命名域(或叫“命名空间”)的栈self._name_stack,此方法的传入参数name会被压入该栈,支持嵌套。
    • 参数:name可以是一个字符串,用于创建一个新的name scope;也可以是一个已有的name scope,用于重新进入这个已存在的scope;也可以是None或空字符,此时表示顶层的name scope
c = tf.constant(1.0, name="c") # c.op.name为"c"
c_1 = tf.constant(2.0, name="c") # c_1.op.name为"c_1"
with g.name_scope("nested") as scope:
  nested_c = tf.constant(3.0, name="c") # nested_c.op.name为"nested/c"
  with g.name_scope("inner"):
    nested_inner_c = tf.constant(4.0, name="c") # nested_inner_c.op.name为"nested/inner/c"
  with g.name_scope("inner"): # 因为此域下已有"inner",所以用"inner_1"
    nested_inner_1_c = tf.constant(5.0, name="c") # nested_inner_1_c.op.name为"nested/inner_1/c"
    with g.name_scope(scope): # 无论现在嵌套哪里,都转换成scope,即"nested/"
      nested_c_1 = tf.constant(6.0, name="c") # nested_c_1.op.name为"nested/c_1"
      with g.name_scope(""): # 变成顶级scope
        c_2 = tf.constant(7.0, name="c") # c_2.op.name为"c_2"
  • tf.Graph.gradient_override_map(op_type_map):返回一个改写gradient函数的上下文,使得针对某些operation,我们可以使用自己的gradient函数。一个图维护这样一个映射关系self._gradient_override_map.
# 先注册一个gradient函数
@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  # ...

with tf.Graph().as_default() as g:
  c = tf.constant(5.0)
  s_1 = tf.square(c) # 使用tf.square默认的gradient
  with g.gradient_override_map({"Sqaure": "CustomSquare"}):
    s_2 = tf.square(s_2): # 使用自定义的_custom_square_grad函数来计算s_2的梯度
  • tf.Graph.colocate_with(op, ignore_existing=False):返回一个共用给定op的位置的上下文,使得上下文内的新加入op都共用这个位置。传参ignore_existing为真则表示忽略以前所有的共位设置。
a = tf.Variable([1.0])
with g.colocate_with(a):
  # 下面的b,c与a共位
  b = tf.constant(1.0)
  c = tf.add(a, b)
  • tf.Graph.container(container_name):返回一个带资源容器的上下文,服务于带状态的operations,如:varaibles,queues,用来存储跟踪它们的状态。可使用tf.Session.reset()清除资源容器中保存的信息。
with g.container('experiment0'):
  v = tf.Variable([1.0]) # 将存到资源容器"experiment0"
  with g.container('experiment1'):
    q = tf.FIFOQueue(10, tf.float32) # 将存到资源容器"experiment1"
  with g.container(''):
    v2 = tf.Variable([2.0]) # 将存到默认的资源容器

与graph collections相关的方法

一个图中可以有多个collections,也称为graph collections,每个collection都有一个名字,用来存储一组相关的对象,可以把一个collection看成一个list或array。它的标准名字有:GLOBAL_VARIABLESLOCAL_VARIABLESMODEL_VARIABLES等,定义在GraphKeys类里。

  • 存value:
    • tf.Graph.add_to_collection(name, value):把value存到名为name的collection中
    • tf.Graph.add_to_collections(names, value):把value存到名字在names上的所有collections中
  • 取value:
    • tf.Graph.get_collection(name, scope=None):根据名为name的collection,返回它中所有的values,即一个values的列表。传参scope充当一个过滤器,筛出指定scope中的values。
    • tf.Graph.get_collection_ref(name):同上,不同之处是此方法返回对collection本身的引用,而不是复制一份,故在上面的修改会起作用。
    • tf.Graph.get_all_collection_keys():返回一个collections的list
  • 清除value:
    • f.Graph.clear_collection(name):清除名为name的collection上所有的values

二、Operation类

1. 要点

  1. Operation实例就是数据流图中的节点,负责tensors的计算,即输入是若干Tensor实例,输出也是若干Tensor实例。
  2. Operation实例与实例的type之间的区别:
    • 这里的Operation实例,也就是operation或op,与我们想的加法、减法等操作在概念上有略微差异,后者侧重于对方法的描述,前者则参与到图中,作为一个节点,叫”operator”更合适
    • 该实例的type,也就是op type,才是指像加法、减法这样的操作方法,如“MatMul”表示矩阵乘这个操作方法
    • 每个Operation实例的名字在图中都是唯一的,因为对应一个特定节点,但是相同的操作方法op type在图中可以有多个。它们在protocol buffers分别被定义为NodeDef和OpDef
  3. 创建一个Operation实例有两种方法:
    • 第一种:调用一个op构造函数,如c=tf.matmul(a,b),则创建一个表示矩阵乘操作的op节点,其中a, b, c都为tensor,a和b作输入,c作输出
    • 第二种:调用方法Graph.create_op()
  4. 启动一个session后,执行Operation实例也有两种方法:
    • 第一种:把该op传入session的方法run()
    • 第二种:直接调用op.run(),这实际上是tf.get_default_session().run(op)的简写

2. Operation的属性

  • 内部属性:
    • _node_def为一个NodeDef对象,_op_def为一个OpDef对象
    • _id_value为op在图中的id
    • _graph为op所在图
    • _inputs为输入op的tensors列表,_outputs为输出op的tensors列表
    • _input_types为输入op的tensors的数据类型列表,_output_types为输出op的tensors的数据类型列表
    • _control_inputs为执行op前的控制依赖
    • _original_op为当前op需要的一个原op,如replica op还需要一个op,称为原op
    • _traceback为创建op时的调用栈call stack
    • _control_flow_context为包含当前op的当前控制流上下文
  • 对外属性:
    • tf.Operation.name:该operation的全名
    • tf.Operation.type:该operation的type,如MatMul
    • tf.Operation.inputs:该operation的输入,是一个Tensor对象的列表
    • tf.Operation.outputs:该operation的输出,也是一个Tensor对象的列表
    • tf.Operation.control_inputs:是一个Operation对象的列表,执行当前operation前,需要保证此列表中的所有Operation对象都已执行完毕
    • tf.Operation.graph:该operation所在的graph
    • tf.Operation.device:该operation所在的device,表示为一个字符串
    • tf.Operation.traceback:自该operation创建以来的调用栈
    • tf.Operation.node_def:该operation对应的NodeDef表示,使用了protocol buffer,见下面的message NodeDef
    • tf.Operation.op_def:该operation的type对应的OpDef表示,使用了protocol buffer,见下面的message OpDef
// 定义图中的一个节点
message NodeDef {
  // 本operator名,在一个图中是唯一的,可看成当前节点名
  string name = 1;
  // operation名,在一个图中可有重复,可看成操作的方法名
  string op = 2;
  // input列表,每个input表示为字符串"<node>:<src_output>",表明来自哪个op的哪个output索引
  repeated string input = 3;
  // 本节点所在device,举例为:
  // 1) "@other/node":与另一个节点"other/node"共位置
  // 2) "/job:worker/replica:0/task:1/gpu:3":全路径
  // 3) "/job:worker/gpu:3":部分路径
  // 4) "":无
  string device = 4;
  // 应包含OpDef的所有attrs
  map<string, AttrValue> attr = 5;
}
// 定义一个操作
message OpDef {
  // operation名,等于NodeDef中的"op",名字采用CameCase格式,若首字符为"_"则为内部保留的操作
  string name = 1;

  // 定义一个argument message,用作一个input或output
  message ArgDef {
    // 当前input或output的名
    sting name = 1;
    // 给人读的描述
    string description = 2;
    // 当前input或output,可以接受一到多个tensors
    // (1)当接受一个tensors时,要么设置type字段,要么设置type_attr字段,指向一个类型为"type"的attr
    // (2)当接受多个type相同的tensors时,要设置number_attr字段,指向一个类型为"int"的attr,表示tensors的数目,当然还要设置type或type_attr字段
    // (3)当接受多个type不同的tensors时,要设置type_list_attr字段,指向一个类型为"list(type)"的attr,不用设置type、type_attr和number_attr字段
    DataType type = 3;
    string type_attr = 4;
    string number_attr = 5;
    string type_list_attr = 5;
    // 当前input或output是否为ref
    bool is_ref = 16;
  }
  // 当前操作的所有输入
  repeated ArgDef input_arg = 2;
  // 当前操作的所有输出
  repeated ArgDef output_arg = 3;

  // 定义一个attr message
  message AttrDef {
    string name = 1;
    string type = 2; // 如:"string", "list(string)", "int"
    AttrValue default_value = 3;
    string description = 4;
    // 对于"int"型,有下面两字段
    bool has_minimum = 5;
    int64 minimum = 6;
    AttrValue allowed_values = 7;
  }
  // 在op中定义的attr会加入NodeDef
  repeated AttrDef attr = 4;

  // 其他
  OpDeprecation deprecation = 8;
  string summary = 5;
  string description = 6;
  // 操作是否满足交换律
  bool is_commutative = 18;
  // 操作可接受2个以上的inputs,得出1个同类型的output,需满足交换律和结合律
  bool is_aggregate = 16;
  // 操作是否带状态,stateful ops不能在devices间移动,除非状态也能移动
  bool is_stateful = 17; // 如:variables, queue
  // 默认情况下,所有op的inputs必须是初始化后的tensors
  bool allows_uninitialized_input = 19; // 如:assign
}

3. Operation的主要方法

构造方法

  • tf.Operation.__init__(node_def, g, inputs=None, ...):创建一个Operation实例,传入参数有很多,包括:
    • 必填参数:(1) node_def为一个node_def_pb2.NodeDef实例,包含了描述operation的属性,有nameopdevice但没有input,因为input是生成模型时才有的;(2) g为所在的图
    • 可选参数:(1) inputs为当前operation的输入,是一个Tensor对象的列表;(2) output_types当前operation的输出的类型,为是一个DType对象的列表;(3) control_inputs执行当前operation的前提,为一个operations或tensors的列表;(4) input_types为输入的类型,默认为[x.dtype.base_dtype for x in inputs];(5) original_op为一个关联的原op,如复制op的replica op,要给出那个op;(6) op_def为当前operation代表的op type,如”matmul”,定义在op_def_pb2.OpDef

执行operation的方法

  • tf.Operation.run(feed_dict=None, session=None):在session中运行当前operation,会一级级触发那些给当前operation提供inputs的所有直接或间接的operations。实际上,最后调用的是session.run(operation, feed_dict)
    • 传参:(1) feed_dict是一个dict( Tensor对象或tensor名 => 具体值 ),具体值可以是list、numpy ndarray、TensorProto或string;(2) session若没指定,则用当前线程的默认session。

获取operation信息的方法

  • tf.Operation.get_attr(name):根据名字返回当前operation的某个属性值。一个operation会有多个属性,定义在self._node_def.attr上。
  • tf.Operation.colocation_groups():返回当前operation的共位置组列表,格式为["loc:@<节点名即_node_def.name>", ...]

三、Tensor类

1. 要点

  1. Tensor对象作为表示数量的符号,要参与到数学计算中,就要重载Python的许多操作符
  2. Tensor对象和Operation对象一起构建了图,如果说operation是节点,tensor更像是边,把不同的operations链接在一起
    • 一般来说,operation的输入除了tensor,还可以是其他类型,只要有能转化为tensor的相应支持,但是operation的输出只能是tensor。而且,任何一个Tensor实例都对应一个operation,作它的一个输出,所以tensor的创建离不开operation的创建,但Tensor实例并不保留operation的输出数值,而是提供一种计算这些数值的通道,用在session中
    • Tensor实例自创建就是某个operation的一个output,但它也可以是其他operation的一个input。把tensor传给其他operation作输入的过程,就是在operations间建立连接的过程,就是组网的过程.
  3. Tensor对象是符号,不是具体值
    • 启动session后,想得到tensor的具体数值,需要调用Session.run()t.eval()来计算。t.eval()实际上是tf.get_default_session().run(t)
# 组网过程,下面的c、d、e都是Tensor实例,即一个符号,而不是具体值
c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) # constant op的输入是一个二维数组常量
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d)
# 启动session来执行图
sess = tf.Session()
result = sess.run(e) # 这里result是一个numpy array,负责存储具体值

2. Tensor的属性

  • 内部属性:
    • _op:以该tensor作输出的op
    • _value_index:在op输出中的索引
    • _dtype:数据类型
    • _shape:tensor形状
    • _consumers:使用该input做输入的operations列表,方便在图中游走
    • _handle_shape_handle_dtype:用于C++形状推断
  • 对外属性:
    • tf.Tensor.dtype:该tensor中元素的DType
    • tf.Tensor.name:该tensor的名字,为”op名:输出索引”
    • tf.Tensor.op:该tensor所在的operation,tensor作它的输出
    • tf.Tensor.value_index:该tensor位于它所在operation的输出列表中的索引
    • tf.Tensor.graph:该tensor所在的图,也是它的op的图
    • tf.Tensor.device:该tensor所在的device
    • tf.Tensor.shape:该tensor的形状,是一个TensorShape对象,如TensorShape([Dimension(3), Dimension(4)])

3. Tensor的主要方法

  • 构造方法
    • tf.Tensor.__init__(op, value_index, dtype):创建一个Tensor实例,op为以它做输出的operation,value_index为它在输出中的索引,dtype为它的元素数据类型
  • 求tensor值的方法
    • tf.Tensor.eval(feed_dict=None, session=None):启动session后,如果session的图与该tensor的图相同,则在session中对该tensor求值,会触发计算它的operation以及图中所依赖的前面operations。实际上,最终调用的是session.run(tensors, feed_dict)
      • 传参:(1) feed_dict是一个dict( Tensor对象或tensor名 => 具体值 ),具体值可以是list、numpy ndarray、TensorProto或string;(2) session若没指定,则用当前线程的默认session。
      • 返回:一个numpy array
  • 获取和设置tensor形状信息的方法
    • tf.Tensor.get_shape():获取该tensor的形状,返回是一个TensorShape对象,推断形状(shape inference)的过程不用启动session,但在operation中需注册一个用于推断形状的函数。比如,c=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0]]),则调用c.get_shape()得到一个TensorShape([Dimension(2), Dimension(3)])
    • tf.Tensor.set_shape(shape):设置或更新该tensor的形状,如image.set_shape([28, 28, 3])
  • 获取tensor作输入的operations信息的方法
    • tf.Tensor.consumers():返回以该tensor做输入的所有operations

4. Tensor重载的Python operators

  • 算术操作:__(r)add__“+”,__(r)sub__“-“,__(r)mul__“*”,__(r)div__“/”, __(r)floordiv__“//”,__(r)truediv“/”,__(r)mod__“mod”,__neg__“-“,__(r)pow__“pow(x,y)”,__abs__“| |”
  • 逻辑操作:__(r)and__“&”,__(r)or__“|”,__invert__“~”,__(r)xor__“^”
  • 比较操作:__eq__“==”,__ge__“>=”,__gt__“>”,__le__“<=”,__lt__“<”
  • 其他:__getitem__“[ ]”,__hash__

默认都是元素级(element-wise)操作,除了__(r)mul__,源码注释这样说:# Dispatches cwise mul for "DenseDense" and "DenseSparse"

# __getitem__:限定到子tensor
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[::2, ::-1].eval()) # => [[3,2,1], [9,8,7]]

最后

以上就是懦弱咖啡为你收集整理的TensorFlow Python API解析:图的核心数据结构的全部内容,希望文章能够帮你解决TensorFlow Python API解析:图的核心数据结构所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(53)

评论列表共有 0 条评论

立即
投稿
返回
顶部