分享

如何学习TensorFlow源码

 _为往圣继绝学_ 2016-10-20

在静下心来默默看了大半年机器学习的资料并做了些实践后,打算学习下现在热门的TensorFlow的实现,毕竟系统这块和自己关系较大。本文会简单的说明一下如何阅读TensorFlow的源码。最重要的是了解其构建工具bazel以及脚本语言调用c或cpp的包裹工具swig。这里假设大家对bazel及swig以及有所了解(不了解的可以google下)。要看代码首先要知道代码怎么构建,因此本文的一大部分会关注构建这块。

如果从源码构建TensorFlow会需要执行如下命令:

1
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package

对应的BUILD文件的rule为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
sh_binary(
    name = 'build_pip_package',
    srcs = ['build_pip_package.sh'],
    data = [
        'MANIFEST.in',
        'README',
        'setup.py',
        '//tensorflow/core:framework_headers',
        ':other_headers',
        ':simple_console',
        '//tensorflow:tensorflow_py',
        '//tensorflow/examples/tutorials/mnist:package',
        '//tensorflow/models/embedding:package',
        '//tensorflow/models/image/cifar10:all_files',
        '//tensorflow/models/image/mnist:convolutional',
        '//tensorflow/models/rnn:package',
        '//tensorflow/models/rnn/ptb:package',
        '//tensorflow/models/rnn/translate:package',
        '//tensorflow/tensorboard',
    ],
)

sh_binary在这里的主要作用是生成data的这些依赖。一个一个来看,一开始的三个文件MANIFEST.in、README、setup.py是直接存在的,因此不会有什么操作。

“//tensorflow/core:framework_headers”:
其对应的rule为:

1
2
3
4
5
6
7
8
filegroup(
    name = 'framework_headers',
    srcs = [
        'framework/allocator.h',
        ......
        'util/device_name_utils.h',
    ],
)

这里filegroup的作用是给这一堆头文件一个别名,方便其他rule引用。

“:other_headers”:
rule为:

1
2
3
4
5
6
7
transitive_hdrs(
    name = 'other_headers',
    deps = [
        '//third_party/eigen3',
        '//tensorflow/core:protos_all_cc',
    ],
)

transitive_hdrs的定义在:

1
load('//tensorflow:tensorflow.bzl', 'transitive_hdrs')

实现为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Bazel rule for collecting the header files that a target depends on.
def _transitive_hdrs_impl(ctx):
  outputs = set()
  for dep in ctx.attr.deps:
    outputs += dep.cc.transitive_headers
  return struct(files=outputs)
_transitive_hdrs = rule(attrs={
    'deps': attr.label_list(allow_files=True,
                            providers=['cc']),
},
                        implementation=_transitive_hdrs_impl,)
def transitive_hdrs(name, deps=[], **kwargs):
  _transitive_hdrs(name=name + '_gather',
                   deps=deps)
  native.filegroup(name=name,
                   srcs=[':' + name + '_gather'])

其作用依旧是收集依赖需要的头文件。

“:simple_console”:
其rule为:

1
2
3
4
5
6
py_binary(
    name = 'simple_console',
    srcs = ['simple_console.py'],
    srcs_version = 'PY2AND3',
    deps = ['//tensorflow:tensorflow_py'],
)
1
2
3
4
5
6
7
py_library(
    name = 'tensorflow_py',
    srcs = ['__init__.py'],
    srcs_version = 'PY2AND3',
    visibility = ['//visibility:public'],
    deps = ['//tensorflow/python'],
)

simple_console.py的代码的主要部分是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import code
import sys
def main(_):
  '''Run an interactive console.'''
  code.interact()
  return 0
if __name__ == '__main__':
  sys.exit(main(sys.argv))

可以看到起通过deps = [“//tensorflow/python”]构建了依赖包,然后生成了对应的执行文件。看下依赖的rule规则。
//tensorflow/python对应的rule为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
py_library(
    name = 'python',
    srcs = [
        '__init__.py',
    ],
    srcs_version = 'PY2AND3',
    visibility = ['//tensorflow:__pkg__'],
    deps = [
        ':client',
        ':client_testlib',
        ':framework',
        ':framework_test_lib',
        ':kernel_tests/gradient_checker',
        ':platform',
        ':platform_test',
        ':summary',
        ':training',
        '//tensorflow/contrib:contrib_py',
    ],
)

这里如果仔细看的话会发现其主要是生成一堆python的模块。从这里貌似可以看出每个python的module都对应了一个rule,且module依赖的module都写在了deps里。特别的,作为一个C++的切入,我们关注下training这个依赖:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
py_library(
    name = 'training',
    srcs = glob(
        ['training/**/*.py'],
        exclude = ['**/*test*'],
    ),
    srcs_version = 'PY2AND3',
    deps = [
        ':client',
        ':framework',
        ':lib',
        ':ops',
        ':protos_all_py',
        ':pywrap_tensorflow',
        ':training_ops',
    ],
)

这里其依赖的pywrap_tensorflow的rule为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
tf_py_wrap_cc(
    name = 'pywrap_tensorflow',
    srcs = ['tensorflow.i'],
    swig_includes = [
        'client/device_lib.i',
        'client/events_writer.i',
        'client/server_lib.i',
        'client/tf_session.i',
        'framework/python_op_gen.i',
        'lib/core/py_func.i',
        'lib/core/status.i',
        'lib/core/status_helper.i',
        'lib/core/strings.i',
        'lib/io/py_record_reader.i',
        'lib/io/py_record_writer.i',
        'platform/base.i',
        'platform/numpy.i',
        'util/port.i',
        'util/py_checkpoint_reader.i',
    ],
    deps = [
        ':py_func_lib',
        ':py_record_reader_lib',
        ':py_record_writer_lib',
        ':python_op_gen',
        ':tf_session_helper',
        '//tensorflow/core/distributed_runtime:server_lib',
        '//tensorflow/core/distributed_runtime/rpc:grpc_server_lib',
        '//tensorflow/core/distributed_runtime/rpc:grpc_session',
        '//util/python:python_headers',
    ],
)

tf_py_wrap_cc为其自己实现的一个rule,这里的.i就是SWIG的interface文件。来看下其实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs):
  module_name = name.split('/')[-1]
  # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
  # and use that as the name for the rule producing the .so file.
  cc_library_name = '/'.join(name.split('/')[:-1] + ['_' + module_name + '.so'])
  extra_deps = []
  _py_wrap_cc(name=name + '_py_wrap',
              srcs=srcs,
              swig_includes=swig_includes,
              deps=deps + extra_deps,
              module_name=module_name,
              py_module_name=name)
  native.cc_binary(
      name=cc_library_name,
      srcs=[module_name + '.cc'],
      copts=(copts + ['-Wno-self-assign', '-Wno-write-strings']
             + tf_extension_copts()),
      linkopts=tf_extension_linkopts(),
      linkstatic=1,
      linkshared=1,
      deps=deps + extra_deps)
  native.py_library(name=name,
                    srcs=[':' + name + '.py'],
                    srcs_version='PY2AND3',
                    data=[':' + cc_library_name])

按照SWIG的正常流程,先要通过swig命令生成我们的wrap的c文件,然后和依赖生成我们的so文件,最后生成一个同名的python文件用于import。这里native.cc_binary和native.py_library做了我们后面的两件事情,而swig命令的执行则交给了_py_wrap_cc。其实现为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
_py_wrap_cc = rule(attrs={
    'srcs': attr.label_list(mandatory=True,
                            allow_files=True,),
    'swig_includes': attr.label_list(cfg=DATA_CFG,
                                     allow_files=True,),
    'deps': attr.label_list(allow_files=True,
                            providers=['cc'],),
    'swig_deps': attr.label(default=Label(
        '//tensorflow:swig')),  # swig_templates
    'module_name': attr.string(mandatory=True),
    'py_module_name': attr.string(mandatory=True),
    'swig_binary': attr.label(default=Label('//tensorflow:swig'),
                              cfg=HOST_CFG,
                              executable=True,
                              allow_files=True,),
},
                   outputs={
                       'cc_out': '%{module_name}.cc',
                       'py_out': '%{py_module_name}.py',
                   },
                   implementation=_py_wrap_cc_impl,)

_py_wrap_cc_impl的实现为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Bazel rules for building swig files.
def _py_wrap_cc_impl(ctx):
  srcs = ctx.files.srcs
  if len(srcs) != 1:
    fail('Exactly one SWIG source file label must be specified.', 'srcs')
  module_name = ctx.attr.module_name
  cc_out = ctx.outputs.cc_out
  py_out = ctx.outputs.py_out
  src = ctx.files.srcs[0]
  args = ['-c++', '-python']
  args += ['-module', module_name]
  args += ['-l' + f.path for f in ctx.files.swig_includes]
  cc_include_dirs = set()
  cc_includes = set()
  for dep in ctx.attr.deps:
    cc_include_dirs += [h.dirname for h in dep.cc.transitive_headers]
    cc_includes += dep.cc.transitive_headers
  args += ['-I' + x for x in cc_include_dirs]
  args += ['-I' + ctx.label.workspace_root]
  args += ['-o', cc_out.path]
  args += ['-outdir', py_out.dirname]
  args += [src.path]
  outputs = [cc_out, py_out]
  ctx.action(executable=ctx.executable.swig_binary,
             arguments=args,
             mnemonic='PythonSwig',
             inputs=sorted(set([src]) + cc_includes + ctx.files.swig_includes +
                         ctx.attr.swig_deps.files),
             outputs=outputs,
             progress_message='SWIGing {input}'.format(input=src.path))
  return struct(files=set(outputs))

这里的ctx.executable.swig_binary是一个shell脚本,内容为:

1
2
3
4
5
6
7
8
9
# If possible, read swig path out of 'swig_path' generated by configure
SWIG=swig
SWIG_PATH=tensorflow/tools/swig/swig_path
if [ -e $SWIG_PATH ]; then
  SWIG=`cat $SWIG_PATH`
fi
# If this line fails, rerun configure to set the path to swig correctly
'$SWIG' '$@'

可以看到起就是调用了swig命令。

“//tensorflow:tensorflow_py”:
其rule为:

1
2
3
4
5
6
7
py_library(
    name = 'tensorflow_py',
    srcs = ['__init__.py'],
    srcs_version = 'PY2AND3',
    visibility = ['//visibility:public'],
    deps = ['//tensorflow/python'],
)

可以看到起主要依赖了我们上面生成的”//tensorflow/python”这个module。

剩余的几个其实和主框架关系不大,主要是生成一些model、文档啥的。

现在清楚了其构建链后,我们来看个简单的程序,其通过梯度下降算法求线性拟合的W和b。我们会从这个例子入手看下如何找到其使用的函数的具体实现的源码位置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
(python3.5)?  tmp cat th.py
import tensorflow as tf
import numpy as np
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for step in range(0, 201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))
(python3.5)?  tmp python th.py
0 [ 0.42190057] [ 0.17155224]
20 [ 0.1743494] [ 0.26045772]
40 [ 0.11817314] [ 0.29033473]
60 [ 0.10444205] [ 0.29763755]
80 [ 0.10108578] [ 0.29942256]
100 [ 0.10026541] [ 0.29985884]
120 [ 0.10006487] [ 0.2999655]
140 [ 0.10001585] [ 0.29999158]
160 [ 0.10000388] [ 0.29999796]
180 [ 0.10000096] [ 0.29999951]
200 [ 0.10000025] [ 0.29999989]

从我们上面的分析可以看到,import tensorflow as tf来自于tensorflow目录下的__init__.py文件,其内容为:

1
from tensorflow.python import *

再来看tf.Variable,在tensorflow.python的__init__.py中可以看到其导入了很多符号。但要定位到Variable还是比较困难,因为其很多直接是import *。所以一个快速定位的方法是直接grep这个class:

1
2
  python grep 'class Variable(' -R ./*
./ops/variables.py:class Variable(object):

对于tf.Session等也可以用同样的方法定位。

我们来找个走SWIG包裹的,如果我们去看sess.run,我们会看到如下的代码:

1
2
3
return tf_session.TF_Run(session, options,
                         feed_dict, fetch_list, target_list,
                         run_metadata)

这里tf_session就是一个SWIG包裹的模块:

1
from tensorflow.python import pywrap_tensorflow as tf_session

pywrap_tensorflow在源码里是找不到的,因为这个得从SWIG生成后才有,我们可以从.i文件里找下TF_Run的声明,或者直接grep下这个函数:

1
2
  tensorflow grep 'TF_Run(' -R ./*
./core/client/tensor_c_api.cc:void TF_Run(TF_Session* s, const TF_Buffer* run_options,

这样就可以看其实现了:

1
2
3
4
5
6
7
8
9
10
11
12
13
void TF_Run(TF_Session* s, const TF_Buffer* run_options,
            // Input tensors
            const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
            // Output tensors
            const char** c_output_tensor_names, TF_Tensor** c_outputs,
            int noutputs,
            // Target nodes
            const char** c_target_node_names, int ntargets,
            TF_Buffer* run_metadata, TF_Status* status) {
  TF_Run_Helper(s, nullptr, run_options, c_input_names, c_inputs, ninputs,
                c_output_tensor_names, c_outputs, noutputs, c_target_node_names,
                ntargets, run_metadata, status);
}

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多