分享

如何导出CNTK网络节点训练结果?

 imelee 2016-12-29

前言

CNTK作为微软在机器学习领域的重点项目,已经成功的应用于各种产品之中,CNTK可以快速且方便的训练神经网络并使用神经网络。其中CNTK的一大亮点就是CNTK的速度,在机器学习领域,训练速度以及计算决定了工具是否好用。

但是一些情况下,也许我们只需要使用CNTK对于网络训练的速度,之后导出训练的结果用于其他第三方软件使用,也许是绘图,也许是进行实际的生产计算。

通过dumpnode命令导出节点数据

dumpnode是一个action,在配置文件中可以进行指定,通过这个命令可以导出指定的节点信息(包括其结果)以文本的形式导出至一个文件。在使用这个dumpnode时需要制定如下参数。

modelPath 指定网络模型的文件。一般这个参数再配置文件中应该已经在上级给出了指定。
nodeName (可选)指定用于导出的网络节点名称。如果不指定或者指定的是一个不存在的名称则导出所有的节点。
nodeNameRegex (可选)可以通过正则表达式的方式指定网络节点名称,导出相匹配的节点。如果指定了这个参数的话,nodeName将值会被忽视。
outputFile (可选)指定导出文件的位置,如果不指定的情况下,默认会被设置为通网络模型文件同路径下的一个文本文件。在网络模型的文件名后面加上.{nodename}.txt
printValues (可选)是否导出网络节点的值,默认值为true
printMetadata (可选)是否导出元数据信息,包括节点名称、维度等,默认值为true

我们以Simple2D(位于CNTK的~\CNTK\Examples\Other\Simple2d\目录)为例,
我们只需要在配置文件末尾的位置添加如下内容,不需要设置modelPath是因为modelPath已经在上面指定了,

########################################
#  DUMP NODE INFORMATION               #
########################################
Simple_Demo_DumpNode=[
    action = "dumpnode"
]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

之后我们在配置文件的command中指定新设置的Simple_Demo_DumpNode即可,

command = Simple_Demo_Train:Simple_Demo_Test:Simple_Demo_Output:Simple_Demo_DumpNode
  • 1
  • 1

之后从新运行后,运行的结果如下,并在其中会多出有关dumpnode的相关内容,
这里写图片描述
(注,那个Action "dumpnode" complete.是笔者后期标黄的,不是默认输出就是黄色的。)

此时,我们可以去往model所在目录(~\CNTK\Examples\Other\Simple2d\Output\Models)打开导出的文件,
这里写图片描述

默认情况下,是导出的网络节点是包括参数以及元数据的,所以数据信息比较多。
这里写图片描述

通过探究CNTK源码来实现

CNTK可以通过dumpnode命令给出的文本格式的导出文件,如果是用来看的则没什么大问题,但是如果希望实现自动化的某些操作,则需要解析这个文件后输入下游程序中,解析过程也许会十分的繁琐,

我们也可以通过探究CNTK的源码来找出节点数据存储的位置,直接使用CNTK的文件即可,
这种情况,只需要我们在我们的程序中加载CNTK中的EvalDLL读取网络模型文件后,直接使用其参数,后者根据需要导出成我们所期望的形式。

CNTK中核心的一个类是ComputationNetwork,一个网络模型对应着一个ComputationNetwork的实例。
如果我们只是为了读取网络模型,那么我们可以通过如下代码打开一个CNTK的网络模型文件。

ComputationNetwork net(-1);    // -1 means we only need to use CPU
net.Load<ElemType>(modelPath); // ElemType should be a float type (float or double)
  • 1
  • 2
  • 1
  • 2

这样我们就可以将模型文件加载到net中,进而我们可以通过net的方法去获取我们所需要的内容,

class ComputationNetwork : ...
{
...
public:
    ComputationNetwork();
    ComputationNetwork(DEVICEID_TYPE deviceId);
    virtual ~ComputationNetwork();

    template <class ElemType> 
    void Load(const std::wstring& fileName);

    ComputationNodeBasePtr GetNodeFromName(const std::wstring& name) const;

    // GetNodesFromName - Get all the nodes from a name that may match a wildcard '*' pattern
    //   only patterns with a single '*' at the beginning, in the middle, or at the end are accepted
    // name - node name (with possible wildcard)
    // returns: vector of nodes that match the pattern, may return an empty vector for no match
    std::vector<ComputationNodeBasePtr> GetNodesFromName(const std::wstring& name) const;

    // these are specified as such by the user
    const std::vector<ComputationNodeBasePtr>& FeatureNodes();
    const std::vector<ComputationNodeBasePtr>& LabelNodes();
    const std::vector<ComputationNodeBasePtr>& FinalCriterionNodes();
    const std::vector<ComputationNodeBasePtr>& EvaluationNodes();
    const std::vector<ComputationNodeBasePtr>& OutputNodes();
    ...
private:
    // main node holder
    std::map<const std::wstring, ComputationNodeBasePtr, nocase_compare> m_nameToNodeMap; // [name] -> node; this is the main container that holds this networks' nodes

    // node groups
    // These are specified by the user by means of tags or explicitly listing the node groups.
    // TODO: Are these meant to be disjoint?
    std::vector<ComputationNodeBasePtr> m_featureNodes;    // tag="feature"
    std::vector<ComputationNodeBasePtr> m_labelNodes;      // tag="label"
    std::vector<ComputationNodeBasePtr> m_criterionNodes;  // tag="criterion"
    std::vector<ComputationNodeBasePtr> m_evaluationNodes; // tag="evaluation"
    std::vector<ComputationNodeBasePtr> m_outputNodes;     // tag="output"
...
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

在Net中,我们能够获取每个节点的ComputationNodeBasePtr, 我们进而可以通过这个指针来进行其节点中数据或者参数的访问。

class ComputationNodeBase : ...
{
...
public:
    // -----------------------------------------------------------------------
    // accessors for value and gradient
    // -----------------------------------------------------------------------

    const Matrix<ElemType>& Value() const { return *m_value; }
    Matrix<ElemType>&       Value()       { return *m_value; }

    MatrixBasePtr ValuePtr() const override final { return m_value; }    // readers want this as a shared_ptr straight
    // Note: We cannot return a const& since returning m_value as a MatrixBasePtr is a type cast that generates a temporary. Interesting.

    const Matrix<ElemType>& Gradient() const { return *m_gradient; }
    Matrix<ElemType>&       Gradient()       { return *m_gradient; }

    MatrixBasePtr GradientPtr() const { return m_gradient; }
...
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

这样我们就可以通过Value来获取来获取内部的计算矩阵以应用于实际的项目中。

总结

本文主要是描述了CNTK中网络节点训练结果的导出方法,首先将CNTK当做工具使用情况下的导出方法,之后接受的是将CNTK看做一个开源项目的前提下,通过研究源码的方式针对其内部数据结果进行探究。希望能够对大家使用或者学习CNTK有所帮助,如本文中有任何错误或者读者有任何意见或者建议,请在评论区给出,谢谢。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多