分享

TensorFlow常用API | 天空的城

 imelee 2017-05-26
文章目录
  1. 轴(axis)
  2. tf.slice

越来越感觉学习python一个第三方库就跟学一门新语言差不多,庞大的TensorFlow更是如此。这篇博客会持续记录关于tensorflow的一些api的用法。

轴(axis)

在tensorflow中,最基本的变量或者常量都是tensor类型,说直接点就是多维数组,因此涉及到了很多在某一个维度上操作的API,比如tf.reduce_*()tf_arg*()tf_expand_dims()等,这些基本的API都有一个共同的参数:axis=,也就说这些操作都是哪些维度(轴)上做运算,除了在tf中,在numpy里面也有这一参数,之前一直是一知半解。现在总结一下这个axis,由于numpy与tensorflow的axis基本一样,因此下面以numpy为例。

首先可以从多维数组(tensor)的shape来得到数组的维度: d=len(array.shape):

1
2
3
4
5
In [2]: a = np.array([[1,2,3],[4,5,6]])
In [3]: a.shape
Out[3]: (2, 3)
In [4]: len(a.shape)
Out[4]: 2

也就是说上述的数组a的shape为(2,3),维度是2维,也就是有两层嵌套数组,看看有多少层[],因此axis的取值只能为0,1,下面以sum为例,看看不同的axis的结果:

1
2
3
4
5
6
7
8
In [5]: np.sum(a)
Out[5]: 21
In [6]: np.sum(a,axis=0)
Out[6]: array([5, 7, 9])
In [7]: np.sum(a,axis=1)
Out[7]: array([ 6, 15])

一个一个的说明。首先需要明白操作的对象是谁?这这个例子中,就是那些元素相加。怎么来确定是哪些元素呢?个人使用的原则是:

对原来数组去掉axis+1层嵌套([]),对剩下的元素再分组进行相应的操作。

首先如果不加axis参数,则默认对所有元素直接做sum操作,因此结果就是21。当axis=0的时候, 相加的对象,则是变成了[1,2,3]+[4,5,6],按照上面的原则,a去掉一层嵌套,也就是[]之后,变为: [1,2,3],[4,5,6] 那么sum的对象: [1,2,3]+[4,5,6]=[5,7,9]。 再接着看axis=1,这时候需要去掉两层嵌套,去掉第一层变为:[1,2,3], [4,5,6],去掉第二层嵌套,剩下了1,2,34,5,6两个组,因此两个组内各个相加即可,得到的最终的结果为:[6,15]。 下面以同样的思路,计算最小值:

1
2
3
4
5
6
7
8
In [3]: np.min(a)
Out[3]: 1
In [4]: np.min(a, axis=0)
Out[4]: array([1, 2, 3])
In [5]: np.min(a, axis=1)
Out[5]: array([1, 4])

有一个稍微特殊的是expand_dims 这个是扩展维度。对于这个api它的axis没有上限约束,不过当所有的元素自成一个列表之后,不再变化。它的axis参数不一样的地方,是在axis的位置加入新的一个轴,也就是说此时对原数组去掉axis层嵌套,而非axis+1层,然后再对每组进行加一个轴,看例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
In [25]: a
Out[25]:
array([[1, 2, 3],
[4, 5, 6]])
In [26]: np.expand_dims(a, axis=0) # 新加的轴为第0轴
Out[26]:
array([[[1, 2, 3],
[4, 5, 6]]])
In [27]: np.expand_dims(a, axis=1)# 新加的轴为第1层
Out[27]:
array([[[1, 2, 3]],
[[4, 5, 6]]])
In [28]: np.expand_dims(a, axis=2)
Out[28]:
array([[[1],
[2],
[3]],
[[4],
[5],
[6]]])

到现在基本上tensorflow以及numpy的axis这个参数基本没有问题了。 下面再看几个常用的函数。

tf.slice

这个函数主要用于对多维数组的截取(切片),首先看看原型:

1
tf.slice(input_, begin, size, name=None)
  • input_: 输入
  • begin: 开始截取的位置
  • size: 每一个维度截取的长度

直接看文档的例子:

1
2
3
4
5
6
7
8
# 'input' is [[[1, 1, 1], [2, 2, 2]],
# [[3, 3, 3], [4, 4, 4]],
# [[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
[[5, 5, 5]]]

先看第一个例子:

begin=[1,0,0]也就是开始的元素,这里就是input[1][0][0] = 3,就是第二行[3,3,3]的第一个3,,再看截取距离: [1,1,3],首先在axis=0的上截取一个距离,得到是第二整行[[3,3,3],[4,4,4]],在axis=1上截取2个距离,得到[3,3,3],最后在axis=2上截取了3个距离,也就是[3,3,3]这三个元素都得到,最终结果为:[3,3,3]

第二个例子解析:

开始的元素仍然是input[1][0][0=0],在axis=0上截取1个距离,得到仍然是[[[3,3,3],[4,4,4]]] ,第二个维度上截取距离为2,得到[[[3,3,3],[4,4,4]]],最后在第三个维度上截取3,即三个元素都保留,得到结果。这里如果第三个维度截取2的话,那么得到就是[[[3,3],[4,4]]]

第三个同理,不再叙述。

参考: 1

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多