我是靠谱客的博主 飞快蜡烛,最近开发中收集的这篇文章主要介绍tensorflow2(6)数据增幅,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

数据增幅

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,optimizers,datasets
a = tf.constant([1,2,3,4,5,6])
x = tf.range(9)
x
<tf.Tensor: id=4, shape=(9,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5, 6, 7, 8])>
tf.maximum(x,2)
<tf.Tensor: id=6, shape=(9,), dtype=int32, numpy=array([2, 2, 2, 3, 4, 5, 6, 7, 8])>
tf.minimum(x,7)
<tf.Tensor: id=8, shape=(9,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5, 6, 7, 7])>
def relu(x):
    return tf.maximum(x,0.)
x
<tf.Tensor: id=4, shape=(9,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5, 6, 7, 8])>

高级数据操作

tf.gather 可以实现根据索引号收集数据的目的。考虑班级成绩册的例子,假设共有 4个班级,每个班级 35 个学生,8 门科目,保存成绩册的张量 shape 为[4,35,8]。

x = tf.random.uniform([4,35,8],maxval=100,dtype=tf.int32)
x
<tf.Tensor: id=12, shape=(4, 35, 8), dtype=int32, numpy=
array([[[46, 12, 76, ...,  1, 40, 21],
        [42, 90, 47, ...,  6, 99,  9],
        [11, 17, 80, ..., 51, 65, 56],
        ...,
        [75, 81,  4, ..., 39, 64, 82],
        [79, 95, 79, ..., 53,  1, 98],
        [30, 48, 54, ..., 19, 36, 56]],

       [[19, 89, 76, ..., 91, 88, 84],
        [15,  8, 84, ..., 95, 59, 63],
        [ 6, 39, 48, ..., 82, 58, 28],
        ...,
        [33, 87, 46, ..., 94, 47, 98],
        [75, 69, 56, ..., 10, 20, 26],
        [48, 18, 24, ..., 57, 11,  3]],

       [[75, 12, 24, ..., 76, 59,  6],
        [68, 37, 42, ..., 79, 29, 70],
        [ 5, 69, 22, ..., 75, 10, 83],
        ...,
        [56, 45, 96, ..., 90, 67, 68],
        [32, 99, 95, ..., 10, 76, 71],
        [18, 28, 61, ..., 42, 23, 34]],

       [[29, 49, 27, ...,  5,  0, 77],
        [39, 78, 61, ..., 69, 20, 84],
        [54, 70, 48, ..., 39, 96, 17],
        ...,
        [33, 71, 69, ..., 42, 46, 34],
        [ 0, 56,  4, ...,  7, 50, 16],
        [90, 64,  1, ..., 50, 25, 63]]])>
tf.gather(x,[0,1],axis=0)
<tf.Tensor: id=15, shape=(2, 35, 8), dtype=int32, numpy=
array([[[46, 12, 76, 64, 34,  1, 40, 21],
        [42, 90, 47, 80, 37,  6, 99,  9],
        [11, 17, 80, 35,  1, 51, 65, 56],
        [41,  8, 27, 93, 52, 17, 27,  8],
        [78, 20, 59, 54, 22, 70, 82, 77],
        [42, 14, 58, 68, 60, 10, 35, 18],
        [68, 93, 43,  3, 63, 44,  2, 21],
        [21, 69, 90, 43, 54, 71, 40, 85],
        [93,  5, 20, 13, 34, 43, 76, 50],
        [77, 16, 25, 38, 34, 60, 32, 95],
        [85, 19, 68, 13, 96, 44, 37, 78],
        [ 4, 95, 77, 19, 99, 94, 17, 22],
        [63, 68,  0, 26, 53, 58, 18, 47],
        [ 3, 33, 79, 36,  5, 91,  0, 44],
        [61, 47, 26, 60, 66, 74, 95, 75],
        [91, 98, 62, 75, 58, 48, 23, 21],
        [10, 16, 78,  3, 10, 54, 83, 84],
        [91, 70, 27, 37,  8, 44, 68, 82],
        [83, 19, 52, 30, 91,  5,  3, 21],
        [79, 10, 42, 28, 95, 59, 12, 31],
        [79, 97, 65, 63, 29, 74, 92, 60],
        [84, 56, 13, 85, 77, 86, 14, 25],
        [70, 93, 19, 75, 69, 22, 42, 58],
        [57, 44, 85, 49, 66, 36,  5,  3],
        [31, 84, 63, 40, 23, 87, 55, 40],
        [80, 72, 81, 88, 77, 96,  5, 39],
        [65, 30, 41, 82, 99, 24, 34, 28],
        [11, 66, 58, 60,  2, 14, 14, 27],
        [77, 79, 69, 15,  5,  3, 36, 84],
        [51, 10, 40,  4, 56, 33, 29, 59],
        [83, 87, 12, 71, 40, 18, 41,  4],
        [38, 24, 62,  1, 20, 52, 49, 10],
        [75, 81,  4, 55, 38, 39, 64, 82],
        [79, 95, 79, 14, 96, 53,  1, 98],
        [30, 48, 54, 77,  6, 19, 36, 56]],

       [[19, 89, 76, 77, 32, 91, 88, 84],
        [15,  8, 84,  7, 24, 95, 59, 63],
        [ 6, 39, 48, 29,  0, 82, 58, 28],
        [ 9,  8, 45, 56, 91, 20, 92, 34],
        [65,  8, 25, 44, 53, 51, 75, 45],
        [ 0, 40,  4, 98, 19, 72, 31, 73],
        [31,  6, 96,  1, 88,  3, 35, 40],
        [82, 17, 15, 79,  4, 22, 92, 69],
        [78, 25, 16,  6, 40, 34, 30, 27],
        [63,  8, 55, 68, 94, 62, 41, 70],
        [70, 47,  5, 43, 31, 63, 40, 46],
        [61, 30, 65, 79, 77, 32, 37, 60],
        [34, 60, 44, 58, 69, 33, 59, 85],
        [45, 88, 36, 58, 10, 18,  5, 92],
        [88, 71, 67, 73, 39, 59, 16, 45],
        [26, 48, 83, 80, 33, 88, 74,  4],
        [47, 23, 30, 14, 13, 10,  2, 79],
        [17, 69, 36, 26, 69, 20,  2, 42],
        [12,  1, 90, 22, 82, 12, 43, 20],
        [52,  4, 18, 71, 46, 42, 90, 60],
        [95, 16, 90, 66, 80, 85, 33, 64],
        [29,  5, 13, 87, 26, 19, 65, 10],
        [53, 92, 23, 69, 89, 82, 11, 49],
        [25, 72, 22, 64, 22, 18, 81, 82],
        [ 8, 82, 69, 93, 25, 42,  0, 55],
        [13,  8, 56, 74, 94, 82, 71, 68],
        [94, 80, 32,  4, 30, 52, 37, 25],
        [37, 12, 25, 77, 25, 60, 55, 28],
        [56, 77, 53, 40,  7,  0, 46, 51],
        [71, 16, 29, 97, 70, 84, 40, 40],
        [21, 91, 28, 79, 55, 87, 42, 58],
        [14, 51, 41,  8, 76, 86, 41, 34],
        [33, 87, 46, 72, 44, 94, 47, 98],
        [75, 69, 56, 87, 63, 10, 20, 26],
        [48, 18, 24, 25, 92, 57, 11,  3]]])>

实际上,对于上述需求,通过切片????[:2]可以更加方便地实现。但是对于不规则的索引方
式,比如,需要抽查所有班级的第 1、4、9、12、13、27 号同学的成绩数据,则切片方式
实现起来非常麻烦,而 tf.gather 则是针对于此需求设计的,使用起来更加方便,实现如
下:

tf.gather(x,[0,3,28,29],axis=1)
<tf.Tensor: id=18, shape=(4, 4, 8), dtype=int32, numpy=
array([[[46, 12, 76, 64, 34,  1, 40, 21],
        [41,  8, 27, 93, 52, 17, 27,  8],
        [77, 79, 69, 15,  5,  3, 36, 84],
        [51, 10, 40,  4, 56, 33, 29, 59]],

       [[19, 89, 76, 77, 32, 91, 88, 84],
        [ 9,  8, 45, 56, 91, 20, 92, 34],
        [56, 77, 53, 40,  7,  0, 46, 51],
        [71, 16, 29, 97, 70, 84, 40, 40]],

       [[75, 12, 24, 49, 30, 76, 59,  6],
        [86, 49,  2, 27, 27, 74, 83, 18],
        [41, 68, 29, 36, 64,  8,  3, 24],
        [76, 10, 20, 69, 87, 83,  2, 29]],

       [[29, 49, 27, 68, 70,  5,  0, 77],
        [55, 12, 23,  0, 92, 69, 54, 19],
        [84, 70,  3, 45, 86, 67, 34, 76],
        [84,  7, 59, 12, 19, 82, 64, 16]]])>
tf.gather(x,[2,4],axis=2)
<tf.Tensor: id=21, shape=(4, 35, 2), dtype=int32, numpy=
array([[[76, 34],
        [47, 37],
        [80,  1],
        [27, 52],
        [59, 22],
        [58, 60],
        [43, 63],
        [90, 54],
        [20, 34],
        [25, 34],
        [68, 96],
        [77, 99],
        [ 0, 53],
        [79,  5],
        [26, 66],
        [62, 58],
        [78, 10],
        [27,  8],
        [52, 91],
        [42, 95],
        [65, 29],
        [13, 77],
        [19, 69],
        [85, 66],
        [63, 23],
        [81, 77],
        [41, 99],
        [58,  2],
        [69,  5],
        [40, 56],
        [12, 40],
        [62, 20],
        [ 4, 38],
        [79, 96],
        [54,  6]],

       [[76, 32],
        [84, 24],
        [48,  0],
        [45, 91],
        [25, 53],
        [ 4, 19],
        [96, 88],
        [15,  4],
        [16, 40],
        [55, 94],
        [ 5, 31],
        [65, 77],
        [44, 69],
        [36, 10],
        [67, 39],
        [83, 33],
        [30, 13],
        [36, 69],
        [90, 82],
        [18, 46],
        [90, 80],
        [13, 26],
        [23, 89],
        [22, 22],
        [69, 25],
        [56, 94],
        [32, 30],
        [25, 25],
        [53,  7],
        [29, 70],
        [28, 55],
        [41, 76],
        [46, 44],
        [56, 63],
        [24, 92]],

       [[24, 30],
        [42, 78],
        [22, 47],
        [ 2, 27],
        [ 6, 50],
        [74, 78],
        [29, 74],
        [25, 99],
        [39, 10],
        [76, 53],
        [76, 17],
        [24, 18],
        [77, 15],
        [13, 68],
        [58,  6],
        [20,  9],
        [17, 50],
        [ 9, 97],
        [ 8, 72],
        [37, 51],
        [42,  7],
        [13, 39],
        [94, 11],
        [52, 82],
        [37, 41],
        [37, 41],
        [24, 24],
        [68, 41],
        [29, 64],
        [20, 87],
        [16, 61],
        [66, 13],
        [96,  0],
        [95, 46],
        [61, 58]],

       [[27, 70],
        [61, 46],
        [48,  2],
        [23, 92],
        [35, 57],
        [38, 56],
        [23, 90],
        [23, 52],
        [87, 54],
        [29, 88],
        [63, 50],
        [44, 61],
        [30, 19],
        [ 4, 62],
        [90, 93],
        [14, 96],
        [27, 95],
        [73,  2],
        [47, 69],
        [36, 98],
        [50, 15],
        [90, 67],
        [38, 45],
        [20, 44],
        [87, 30],
        [51, 43],
        [13, 45],
        [ 8, 25],
        [ 3, 86],
        [59, 19],
        [53, 88],
        [54, 77],
        [69, 56],
        [ 4, 18],
        [ 1, 45]]])>
a = tf.range(8)
a = tf.reshape(a,[4,2])
a
<tf.Tensor: id=27, shape=(4, 2), dtype=int32, numpy=
array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]])>
tf.gather(a,[0,2,3,1],axis=0)
<tf.Tensor: id=30, shape=(4, 2), dtype=int32, numpy=
array([[0, 1],
       [4, 5],
       [6, 7],
       [2, 3]])>
tf.gather(tf.gather(x,[1,2],axis=0),[2,3,4,5],axis=1)
<tf.Tensor: id=36, shape=(2, 4, 8), dtype=int32, numpy=
array([[[ 6, 39, 48, 29,  0, 82, 58, 28],
        [ 9,  8, 45, 56, 91, 20, 92, 34],
        [65,  8, 25, 44, 53, 51, 75, 45],
        [ 0, 40,  4, 98, 19, 72, 31, 73]],

       [[ 5, 69, 22, 81, 47, 75, 10, 83],
        [86, 49,  2, 27, 27, 74, 83, 18],
        [51, 60,  6, 60, 50,  1, 52, 33],
        [15, 69, 74, 24, 78, 88, 14, 47]]])>

通过 tf.gather_nd 函数,可以通过指定每次采样点的多维坐标来实现采样多个点的目
的。回到上面的挑战,我们希望抽查第 2 个班级的第 2 个同学的所有科目,第 3 个班级的
第 3 个同学的所有科目,第 4 个班级的第 4 个同学的所有科目。那么这 3 个采样点的索引
坐标可以记为:[1,1]、[2,2]、[3,3],我们将这个采样方案合并为一个 List 参数,即

tf.gather_nd(x,[[1,1],[2,2],[3,3]])
<tf.Tensor: id=38, shape=(3, 8), dtype=int32, numpy=
array([[15,  8, 84,  7, 24, 95, 59, 63],
       [ 5, 69, 22, 81, 47, 75, 10, 83],
       [55, 12, 23,  0, 92, 69, 54, 19]])>
tf.gather_nd(x,[[1,1,2],[2,2,3],[3,3,4]])
<tf.Tensor: id=40, shape=(3,), dtype=int32, numpy=array([84, 81, 92])>

除了可以通过给定索引号的方式采样,还可以通过给定掩码(Mask)的方式进行采样。
继续以 shape 为[4,35,8]的成绩册张量为例,这次我们以掩码方式进行数据提取。
考虑在班级维度上进行采样,对这 4 个班级的采样方案的掩码为
mask = [True,False,False,True]
即采样第 1 和第 4 个班级的数据,通过 tf.boolean_mask(x, mask, axis)可以在 axis 轴上根据
mask 方案进行采样,实现为:

tf.boolean_mask(x,mask=[True, False,False,True],axis=0)
WARNING:tensorflow:From D:anacodaenvstf2-cpulibsite-packagestensorflow_corepythonopsarray_ops.py:1486: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where





<tf.Tensor: id=68, shape=(2, 35, 8), dtype=int32, numpy=
array([[[46, 12, 76, 64, 34,  1, 40, 21],
        [42, 90, 47, 80, 37,  6, 99,  9],
        [11, 17, 80, 35,  1, 51, 65, 56],
        [41,  8, 27, 93, 52, 17, 27,  8],
        [78, 20, 59, 54, 22, 70, 82, 77],
        [42, 14, 58, 68, 60, 10, 35, 18],
        [68, 93, 43,  3, 63, 44,  2, 21],
        [21, 69, 90, 43, 54, 71, 40, 85],
        [93,  5, 20, 13, 34, 43, 76, 50],
        [77, 16, 25, 38, 34, 60, 32, 95],
        [85, 19, 68, 13, 96, 44, 37, 78],
        [ 4, 95, 77, 19, 99, 94, 17, 22],
        [63, 68,  0, 26, 53, 58, 18, 47],
        [ 3, 33, 79, 36,  5, 91,  0, 44],
        [61, 47, 26, 60, 66, 74, 95, 75],
        [91, 98, 62, 75, 58, 48, 23, 21],
        [10, 16, 78,  3, 10, 54, 83, 84],
        [91, 70, 27, 37,  8, 44, 68, 82],
        [83, 19, 52, 30, 91,  5,  3, 21],
        [79, 10, 42, 28, 95, 59, 12, 31],
        [79, 97, 65, 63, 29, 74, 92, 60],
        [84, 56, 13, 85, 77, 86, 14, 25],
        [70, 93, 19, 75, 69, 22, 42, 58],
        [57, 44, 85, 49, 66, 36,  5,  3],
        [31, 84, 63, 40, 23, 87, 55, 40],
        [80, 72, 81, 88, 77, 96,  5, 39],
        [65, 30, 41, 82, 99, 24, 34, 28],
        [11, 66, 58, 60,  2, 14, 14, 27],
        [77, 79, 69, 15,  5,  3, 36, 84],
        [51, 10, 40,  4, 56, 33, 29, 59],
        [83, 87, 12, 71, 40, 18, 41,  4],
        [38, 24, 62,  1, 20, 52, 49, 10],
        [75, 81,  4, 55, 38, 39, 64, 82],
        [79, 95, 79, 14, 96, 53,  1, 98],
        [30, 48, 54, 77,  6, 19, 36, 56]],

       [[29, 49, 27, 68, 70,  5,  0, 77],
        [39, 78, 61,  3, 46, 69, 20, 84],
        [54, 70, 48, 46,  2, 39, 96, 17],
        [55, 12, 23,  0, 92, 69, 54, 19],
        [34, 25, 35, 10, 57, 57, 64, 75],
        [30, 38, 38, 25, 56, 25,  9, 47],
        [68, 47, 23, 89, 90, 31, 41, 85],
        [27, 82, 23, 24, 52,  1,  9, 25],
        [76, 91, 87, 78, 54, 64, 93, 72],
        [92, 90, 29, 24, 88,  5,  6,  0],
        [46,  8, 63, 82, 50, 67, 48, 36],
        [21,  6, 44, 72, 61, 93, 56, 68],
        [85, 97, 30, 81, 19, 30, 49, 43],
        [26,  7,  4, 25, 62, 94, 48, 28],
        [42, 95, 90, 58, 93, 90, 56, 18],
        [ 2, 51, 14, 42, 96,  9, 64, 45],
        [14,  0, 27,  9, 95,  2, 99, 68],
        [37, 31, 73, 45,  2, 64, 69, 26],
        [60, 54, 47, 27, 69, 90, 13, 37],
        [46, 75, 36, 66, 98, 19, 10, 55],
        [10, 15, 50, 90, 15, 75, 13, 55],
        [20, 70, 90,  1, 67, 48, 17, 17],
        [84, 52, 38, 52, 45,  1, 37, 12],
        [28,  9, 20, 25, 44, 76, 40, 78],
        [45, 86, 87, 88, 30,  5, 86, 28],
        [38, 40, 51, 44, 43, 20, 93, 49],
        [25, 85, 13,  0, 45, 25, 14, 12],
        [47, 28,  8, 56, 25, 41, 30, 41],
        [84, 70,  3, 45, 86, 67, 34, 76],
        [84,  7, 59, 12, 19, 82, 64, 16],
        [60, 17, 53, 66, 88,  5, 48, 91],
        [96, 89, 54, 27, 77, 33, 13, 28],
        [33, 71, 69, 65, 56, 42, 46, 34],
        [ 0, 56,  4, 25, 18,  7, 50, 16],
        [90, 64,  1, 87, 45, 50, 25, 63]]])>
tf.boolean_mask(x,mask=[True,False,False,True,True,False,False,True],axis=2)
<tf.Tensor: id=96, shape=(4, 35, 4), dtype=int32, numpy=
array([[[46, 64, 34, 21],
        [42, 80, 37,  9],
        [11, 35,  1, 56],
        [41, 93, 52,  8],
        [78, 54, 22, 77],
        [42, 68, 60, 18],
        [68,  3, 63, 21],
        [21, 43, 54, 85],
        [93, 13, 34, 50],
        [77, 38, 34, 95],
        [85, 13, 96, 78],
        [ 4, 19, 99, 22],
        [63, 26, 53, 47],
        [ 3, 36,  5, 44],
        [61, 60, 66, 75],
        [91, 75, 58, 21],
        [10,  3, 10, 84],
        [91, 37,  8, 82],
        [83, 30, 91, 21],
        [79, 28, 95, 31],
        [79, 63, 29, 60],
        [84, 85, 77, 25],
        [70, 75, 69, 58],
        [57, 49, 66,  3],
        [31, 40, 23, 40],
        [80, 88, 77, 39],
        [65, 82, 99, 28],
        [11, 60,  2, 27],
        [77, 15,  5, 84],
        [51,  4, 56, 59],
        [83, 71, 40,  4],
        [38,  1, 20, 10],
        [75, 55, 38, 82],
        [79, 14, 96, 98],
        [30, 77,  6, 56]],

       [[19, 77, 32, 84],
        [15,  7, 24, 63],
        [ 6, 29,  0, 28],
        [ 9, 56, 91, 34],
        [65, 44, 53, 45],
        [ 0, 98, 19, 73],
        [31,  1, 88, 40],
        [82, 79,  4, 69],
        [78,  6, 40, 27],
        [63, 68, 94, 70],
        [70, 43, 31, 46],
        [61, 79, 77, 60],
        [34, 58, 69, 85],
        [45, 58, 10, 92],
        [88, 73, 39, 45],
        [26, 80, 33,  4],
        [47, 14, 13, 79],
        [17, 26, 69, 42],
        [12, 22, 82, 20],
        [52, 71, 46, 60],
        [95, 66, 80, 64],
        [29, 87, 26, 10],
        [53, 69, 89, 49],
        [25, 64, 22, 82],
        [ 8, 93, 25, 55],
        [13, 74, 94, 68],
        [94,  4, 30, 25],
        [37, 77, 25, 28],
        [56, 40,  7, 51],
        [71, 97, 70, 40],
        [21, 79, 55, 58],
        [14,  8, 76, 34],
        [33, 72, 44, 98],
        [75, 87, 63, 26],
        [48, 25, 92,  3]],

       [[75, 49, 30,  6],
        [68, 41, 78, 70],
        [ 5, 81, 47, 83],
        [86, 27, 27, 18],
        [51, 60, 50, 33],
        [15, 24, 78, 47],
        [63, 65, 74, 68],
        [33, 18, 99, 10],
        [29, 78, 10, 30],
        [84, 39, 53, 38],
        [84, 57, 17, 32],
        [ 8, 32, 18, 18],
        [88,  5, 15,  3],
        [30, 73, 68, 67],
        [60, 80,  6, 10],
        [97, 48,  9, 57],
        [14, 61, 50, 16],
        [79, 23, 97, 50],
        [62, 23, 72, 65],
        [74, 45, 51,  6],
        [90, 41,  7, 27],
        [ 8, 59, 39, 47],
        [ 6, 48, 11, 96],
        [64, 36, 82, 68],
        [20, 67, 41, 90],
        [20, 68, 41, 74],
        [92, 74, 24, 77],
        [86, 32, 41, 38],
        [41, 36, 64, 24],
        [76, 69, 87, 29],
        [97, 68, 61, 46],
        [ 3, 92, 13, 64],
        [56, 92,  0, 68],
        [32, 95, 46, 71],
        [18, 25, 58, 34]],

       [[29, 68, 70, 77],
        [39,  3, 46, 84],
        [54, 46,  2, 17],
        [55,  0, 92, 19],
        [34, 10, 57, 75],
        [30, 25, 56, 47],
        [68, 89, 90, 85],
        [27, 24, 52, 25],
        [76, 78, 54, 72],
        [92, 24, 88,  0],
        [46, 82, 50, 36],
        [21, 72, 61, 68],
        [85, 81, 19, 43],
        [26, 25, 62, 28],
        [42, 58, 93, 18],
        [ 2, 42, 96, 45],
        [14,  9, 95, 68],
        [37, 45,  2, 26],
        [60, 27, 69, 37],
        [46, 66, 98, 55],
        [10, 90, 15, 55],
        [20,  1, 67, 17],
        [84, 52, 45, 12],
        [28, 25, 44, 78],
        [45, 88, 30, 28],
        [38, 44, 43, 49],
        [25,  0, 45, 12],
        [47, 56, 25, 41],
        [84, 45, 86, 76],
        [84, 12, 19, 16],
        [60, 66, 88, 91],
        [96, 27, 77, 28],
        [33, 65, 56, 34],
        [ 0, 25, 18, 16],
        [90, 87, 45, 63]]])>

不难发现,这里的 tf.boolean_mask 的用法其实与 tf.gather 非常类似,只不过一个通过掩码
方式采样,一个直接给出索引号采样。

5.6.4 tf.where通过 tf.where(cond, a, b)操作可以根据 cond 条件的真假从参数????或????中读取数据,条件判定规则如

其中????为张量的元素索引,返回的张量大小与????和????一致,当对应位置的cond ???? 为 True,???? ???? 从
???? ???? 中复制数据;当对应位置的cond ???? 为 False,???? ???? 从???? ???? 中复制数据。考虑从 2 个全 1 和全 0 的
3 × 3大小的张量????和????中提取数据,其中cond ???? 为 True 的位

a = tf.ones([3,3])
b = tf.zeros([3,3])
cond = tf.constant([[True,False,True],[False,True,False],[True,True,True]])
tf.where(cond,a,b)
<tf.Tensor: id=104, shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 1.],
       [0., 1., 0.],
       [1., 1., 1.]], dtype=float32)>
cond
<tf.Tensor: id=103, shape=(3, 3), dtype=bool, numpy=
array([[ True, False,  True],
       [False,  True, False],
       [ True,  True,  True]])>
tf.where(cond)
<tf.Tensor: id=105, shape=(6, 2), dtype=int64, numpy=
array([[0, 0],
       [0, 2],
       [1, 1],
       [2, 0],
       [2, 1],
       [2, 2]], dtype=int64)>

5.6.5 scatter_nd

indices = tf.constant([[4],[3],[1],[7]])
update = tf.constant([4.4,3.3,1.1,7.7])
tf.scatter_nd(indices,update,[8])
<tf.Tensor: id=109, shape=(8,), dtype=float32, numpy=array([0. , 1.1, 0. , 3.3, 4.4, 0. , 0. , 7.7], dtype=float32)>
indices = tf.constant([[1],[3]])
updates = tf.constant([# 构造写入数据,即 2 个矩阵
[[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]],
[[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
])
indices
<tf.Tensor: id=110, shape=(2, 1), dtype=int32, numpy=
array([[1],
       [3]])>
tf.scatter_nd(indices,updates,[4,4,4])
<tf.Tensor: id=113, shape=(4, 4, 4), dtype=int32, numpy=
array([[[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[5, 5, 5, 5],
        [6, 6, 6, 6],
        [7, 7, 7, 7],
        [8, 8, 8, 8]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]]])>

meshgrid

通过 tf.meshgrid 函数可以方便地生成二维网格的采样点坐标,方便可视化等应用场
合。考虑 2 个自变量 x 和 y 的 Sinc 函数表达式为:

points = []
from tensorflow import keras
import matplotlib
import numpy as np

for x in range(-8,8,100): # 循环生成 x 坐标,100 个采样点
for y in range(-8,8,100): # 循环生成 y 坐标,100 个采样点
z = sinc(x,y) # 计算每个点(x,y)处的 sinc 函数值
points.append([x,y,z]) # 保存采样点

x = tf.linspace(-8.,8,100) # 设置 x 轴的采样点
y = tf.linspace(-8.,8,100) # 设置 y 轴的采样点
x,y = tf.meshgrid(x,y)
x
<tf.Tensor: id=234, shape=(100, 100), dtype=float32, numpy=
array([[-8.       , -7.8383837, -7.676768 , ...,  7.6767673,  7.8383837,
         8.       ],
       [-8.       , -7.8383837, -7.676768 , ...,  7.6767673,  7.8383837,
         8.       ],
       [-8.       , -7.8383837, -7.676768 , ...,  7.6767673,  7.8383837,
         8.       ],
       ...,
       [-8.       , -7.8383837, -7.676768 , ...,  7.6767673,  7.8383837,
         8.       ],
       [-8.       , -7.8383837, -7.676768 , ...,  7.6767673,  7.8383837,
         8.       ],
       [-8.       , -7.8383837, -7.676768 , ...,  7.6767673,  7.8383837,
         8.       ]], dtype=float32)>
x.shape,y.shape
(TensorShape([100, 100]), TensorShape([100, 100]))
z = tf.sqrt(x**2,y**2)
z = tf.sin(z)/z
z
<tf.Tensor: id=242, shape=(100, 100), dtype=float32, numpy=
array([[0.1236698 , 0.12756181, 0.12822306, ..., 0.12822306, 0.12756181,
        0.1236698 ],
       [0.1236698 , 0.12756181, 0.12822306, ..., 0.12822306, 0.12756181,
        0.1236698 ],
       [0.1236698 , 0.12756181, 0.12822306, ..., 0.12822306, 0.12756181,
        0.1236698 ],
       ...,
       [0.1236698 , 0.12756181, 0.12822306, ..., 0.12822306, 0.12756181,
        0.1236698 ],
       [0.1236698 , 0.12756181, 0.12822306, ..., 0.12822306, 0.12756181,
        0.1236698 ],
       [0.1236698 , 0.12756181, 0.12822306, ..., 0.12822306, 0.12756181,
        0.1236698 ]], dtype=float32)>
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  #倒入3d坐标
fig = plt.figure()
<Figure size 432x288 with 0 Axes>
ax = Axes3D(fig) #设置3d坐标轴
# 根据网格点绘制 sinc 函数 3D 曲面
plt.show(ax.contour3D(x.numpy(), y.numpy(), z.numpy(), 50))

数据集的加载

到这里为止,我们已经学习完张量的常用操作方法,已具备实现大部分深度网络的技

术储备。最后我们将以一个完整的张量方式实现的分类网络模型实战收尾本章。在进入实
战之前,我们先正式介绍对于常用的经典数据集,如何利用 TensorFlow 提供的工具便捷地
加载数据集。对于自定义的数据集的加载,我们会在后续章节介绍。

在 TensorFlow 中,keras.datasets 模块提供了常用经典数据集的自动下载、管理、加载
与转换功能,并且提供了 tf.data.Dataset 数据集对象,方便实现多线程(Multi-threading)、预
处理(Preprocessing)、随机打散(Shuffle)和批训练(Training on Batch)等常用数据集的功能。

from tensorflow.keras import datasets
(x,y),(x_train,y_train) = datasets.mnist.load_data()
x
array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)
x_train
array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)
x == x_train
D:anacodaenvstf2-cpulibsite-packagesipykernel_launcher.py:1: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
  """Entry point for launching an IPython kernel.





False
x.shape ,x_train.shape
((60000, 28, 28), (10000, 28, 28))
train_db = tf.data.Dataset.from_tensor_slices((x, y)) # 构建 Dataset 对象
train_db
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>

5.7.1 随机打散

通过 Dataset.shuffle(buffer_size)工具可以设置 Dataset 对象随机打散数据之间的顺序,

防止每次训练时数据按固定顺序产生,从而使得模型尝试“记忆”住标签信息,代码实现
如下:

train_db = train_db.shuffle(10000) # 随机打散样本,不会打乱样本与标签映射关系

其中,buffer_size 参数指定缓冲池的大小,一般设置为一个较大的常数即可。调用 Dataset
提供的这些工具函数会返回新的 Dataset 对象,可以通过

批训练

train_db = train_db.batch(128) # 设置批训练,batch size 为 128

其中 128 为 Batch Size 参数,即一次并行计算 128 个样本的数据。Batch Size 一般根据用户
的 GPU 显存资源来设置,当显存不足时,可以适量减少 Batch Size 来减少算法的显存使用
量。

预处理

从 keras.datasets 中加载的数据集的格式大部分情况都不能直接满足模型的输入要求,
因此需要根据用户的逻辑自行实现预处理步骤。Dataset 对象通过提供 map(func)工具函
数,可以非常方便地调用用户自定义的预处理逻辑,它实现在 func 函数里。例如,下方代
码调用名为 preprocess 的函数完成每个样本的预处理:

# 预处理函数实现在 preprocess 函数中,传入函数名即可
#train_db = train_db.map(preprocess)
def preprocess(x,y):#预处理函数
    # 调用此函数时会自动传入 x,y 对象,shape 为[b, 28, 28], [b]
    # 标准化到 0~1
    x = tf.cast(x,dtype=tf.float32)/255
    x = tf.reshape(x,[-1,28*28])#daping
    y = tf.cast(y,dtype=tf.int32)
    y = tf.one_hot(y,depth=10)  #采用onehot编码
    # 返回的 x,y 将替换传入的 x,y 参数,从而实现数据的预处理功能
    return x,y
# 预处理函数实现在 preprocess 函数中,传入函数名即可
train_db = train_db.map(preprocess)
WARNING:tensorflow:Entity <function preprocess at 0x000000F00FA13D08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4
WARNING: Entity <function preprocess at 0x000000F00FA13D08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4
train_db
<MapDataset shapes: ((None, 784), (None, None, 10, 10, 10)), types: (tf.float32, tf.float32)>

循环训练

#对于 Dataset 对象,在使用时可以通过
#for step, (x,y) in enumerate(train_db): # 迭代数据集对象,带 step 参数
if step % 100 == 0:
    
  File "<ipython-input-154-3179dee46fa9>", line 1
    if step % 100 == 0
                      ^
SyntaxError: invalid syntax

最后

以上就是飞快蜡烛为你收集整理的tensorflow2(6)数据增幅的全部内容,希望文章能够帮你解决tensorflow2(6)数据增幅所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(55)

评论列表共有 0 条评论

立即
投稿
返回
顶部