分享

CNTK从入门到深入研究(9)

 imelee 2016-12-29

前言

在上一篇文章中,已经说明过CNTK的工程中涉及的代码实现的共有三部分(CNTK Core、Extensibility扩展性、 Reader Plugins),并且已经针对第一部分的CNTK Core 做了一些介绍,本篇文章将针对剩余的两部分Extensibility和Reader Plugins的工程结构进行说明。

Extensibility扩展性

这里写图片描述

Extensibility在这里解读为扩展性,CNTK作为C++工程,在目前主流通过Python、.Net等更高级语言平台下的数据分析略显得有些不方便。毕竟目前做数据的主流还在R语言以及Python上。Extensibility这部分主要其实是实现了一个Wrapper的库将C++的EvalDll通过C++\CLI封装成.Net的Assembly。之后CSEvalClient其实是一个通过C#集成CNTK的一个例子。

Extensibility目前只是做了一些简单封装。封装的本质是将CNTK Core中的EvalDLL做一些包装使其将IEvaluateModel<ElemType>接口暴露出来供第三方调用而已。这里官方给出了针对.Net的封装,也是抛砖引玉的过程,证明其可扩展。需要调用或者封装的接口如下:

// IEvaluateModel - interface used by decoders and other components that need just evaluator functionality in DLL form
template <class ElemType>
class IEvaluateModel // Evaluate Model Interface
{
public:
    virtual void Init(const std::string& config) = 0;
    virtual void Destroy() = 0;

    virtual void CreateNetwork(const std::string& networkDescription) = 0;
    virtual void GetNodeDimensions(std::map<std::wstring, size_t>& dimensions, NodeGroup nodeGroup) = 0;
    virtual void StartEvaluateMinibatchLoop(const std::wstring& outputNodeName) = 0;
    virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& inputs, std::map<std::wstring, std::vector<ElemType>*>& outputs) = 0;
    virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs) = 0;
    virtual void ResetState() = 0;
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

通过调用GetEvalF或者是GetEvalD即可获得上述描述接口的一个实例,然后通过配置文件的内容加载网络模型后,调用Evaluate方法即可。下面的内容指出获得IEvaluateModel实例的接口:

// GetEval - get a evaluator type from the DLL
// since we have 2 evaluator types based on template parameters, exposes 2 exports
// could be done directly with the templated name, but that requires mangled C++ names
template <class ElemType>
void EVAL_API GetEval(IEvaluateModel<ElemType>** peval);
extern "C" EVAL_API void GetEvalF(IEvaluateModel<float>** peval);
extern "C" EVAL_API void GetEvalD(IEvaluateModel<double>** peval);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

Reader Plugins

这里写图片描述

通过上图可以看出,CNTK提供了很多Reader用于针对不同的数据源来源。毕竟数据格式不同,搜需要的Reader也不同。

CNTK中的所有Reader都是以插件的形式提供的,所谓插件形式就是实现某一特定的接口的动态库供CNTK动态加载并调用相关接口。

之前在介绍CNTK Core的时候,有说过在CNTK Core中的CNTK主工程中,包括了DataReader.h,这个文件中描述了具体的Reader的接口。

// Data Reader interface
// implemented by DataReader and underlying classes
class DATAREADER_API IDataReader
{
public:
    typedef std::string  LabelType;     // surface form of an input token
    typedef unsigned int LabelIdType;   // input token mapped to an integer  --TODO: why not size_t? Does this save space?

    // BUGBUG: We should not have data members in an interace!
    unsigned m_seed;
    size_t mRequestedNumParallelSequences; // number of desired parallel sequences in each minibatch

    virtual void Init(const ConfigParameters& config) = 0;
    virtual void Init(const ScriptableObjects::IConfigRecord& config) = 0;
    virtual void Destroy() = 0;
protected:
    virtual ~IDataReader() { }
public:
    virtual void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples = requestDataSize) = 0;

    virtual bool SupportsDistributedMBRead() const
    {
        return false;
    };
    virtual void StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize)
    {
        if (SupportsDistributedMBRead() || (numSubsets != 1) || (subsetNum != 0))
        {
            LogicError("This reader does not support distributed reading of mini-batches");
        }

        return StartMinibatchLoop(mbSize, epoch, requestedEpochSamples);
    }

    virtual bool GetMinibatch(StreamMinibatchInputs& matrices) = 0;
    virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
    {
        NOT_IMPLEMENTED;
    };
    virtual bool GetHmmData(msra::asr::simplesenonehmm* /*hmm*/)
    {
        NOT_IMPLEMENTED;
    };
    virtual size_t GetNumParallelSequences() = 0;
    //virtual int GetSentenceEndIdFromOutputLabel() { return -1; }
    virtual void SetNumParallelSequences(const size_t sz)
    {
        mRequestedNumParallelSequences = sz;
    }
    //virtual bool RequireSentenceSeg() const { return false; }
    virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring&)
    {
        NOT_IMPLEMENTED;
    }
    virtual void SetLabelMapping(const std::wstring&, const std::map<LabelIdType, LabelType>&)
    {
        NOT_IMPLEMENTED;
    }
    virtual bool GetData(const std::wstring&, size_t, void*, size_t&, size_t)
    {
        NOT_IMPLEMENTED;
    }
    virtual bool DataEnd()
    {
        NOT_IMPLEMENTED;
    }
    virtual void CopyMBLayoutTo(MBLayoutPtr)
    {
        NOT_IMPLEMENTED;
    }
    virtual void SetRandomSeed(unsigned seed = 0)
    {
        m_seed = seed;
    }
    virtual bool GetProposalObs(StreamMinibatchInputs*, const size_t, vector<size_t>&)
    {
        return false;
    }
    virtual void InitProposals(StreamMinibatchInputs*)
    {
    }
    virtual bool CanReadFor(wstring /* nodeName */) // return true if this reader can output for a node with name nodeName  --TODO: const wstring&
    {
        return false;
    }

    bool GetFrame(StreamMinibatchInputs& /*matrices*/, const size_t /*tidx*/, vector<size_t>& /*history*/)
    {
        NOT_IMPLEMENTED;
    }

    // Workaround for the two-forward-pass sequence and ctc training, which
    // allows processing more utterances at the same time. Only used in
    // Kaldi2Reader.
    // TODO: move this out of the reader.
    virtual bool GetMinibatchCopy(
        std::vector<std::vector<std::pair<wstring, size_t>>>& /*uttInfo*/,
        StreamMinibatchInputs& /*matrices*/,
        MBLayoutPtr /*data copied here*/)
    {
        return false;
    }

    // Workaround for the two-forward-pass sequence and ctc training, which
    // allows processing more utterances at the same time. Only used in
    // Kaldi2Reader.
    // TODO: move this out of the reader.
    virtual bool SetNetOutput(
        const std::vector<std::vector<std::pair<wstring, size_t>>>& /*uttInfo*/,
        const MatrixBase& /*outputs*/,
        const MBLayoutPtr)
    {
        return false;
    }
};
typedef std::shared_ptr<IDataReader> IDataReaderPtr;

// GetReaderX() - get a reader type from the DLL
// The F version gets the 'float' version, and D gets 'double'.
extern "C" DATAREADER_API void GetReaderF(IDataReader** preader);
extern "C" DATAREADER_API void GetReaderD(IDataReader** preader);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121

笔者也发现,该接口中有很多方法与通常概念下的Reader显得有些唐突,例如GetHmmData这个,实际跟了下代码发现,接口用的很多方法其实都应该是衍生类的一些方法,但是也被定义在了这里。所以对于在实现自己的Reader时,可以只针对于自己所关心的进行实现。

这里写图片描述

CNTK其实也提供了一些方便的库用来专门的去实现Reader的扩展。
大家可以在工程目录中发现,有一个专门的工程为ReaderLib。该工程中针对IDataReader进行了一些封装,在ReaderShim中实现了对IDataReader的接口的继承,同时提供了一个更为简单直观的接口类Reader class。

//////////////////////////////////////////////////////////////////////////////////////////////////
// Main Reader interface. The border interface between the CNTK and reader libraries.
// TODO: Expect to change in a little bit: stream matrices provided by the network as input.
//////////////////////////////////////////////////////////////////////////////////////////////////
class Reader
{
public:
    // Describes the streams this reader produces.
    virtual std::vector<StreamDescriptionPtr> GetStreamDescriptions() = 0;

    // Starts a new epoch with the provided configuration
    virtual void StartEpoch(const EpochConfiguration& config) = 0;

    // Reads a minibatch that contains data across all streams.
    virtual Minibatch ReadMinibatch() = 0;

    virtual ~Reader() {};
};

typedef std::shared_ptr<Reader> ReaderPtr;
}}}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

上述就是class Reader接口,可以看出其比IDataReader更加简单干练。当我们需要自行实现Reader时,可以继承并实现class Reader,然后其作为ReaderShim的一个成员变量的方式被调用(用组合代替继承)。
ReaderLib中还提供了很多其他的工具方法,例如随机读取的的发生器、数据变换等。

总结

CNTK工程接口部分分为了两篇文章进行讲解,第一篇文章中讲解了CNTKCore的相关的工程的作用以及他们之间的关系,而第二篇文章中,针对CNTK目前所支持的扩展性的工程进行了简单的说明。

CNTK作为一个开放的工具箱,扩展性是十分重要的,CNTK不仅仅是提供了针对第三方调用以及Reader的扩展,还提供了训练方法的扩展以及网络节点以及模型的扩展。本文作为抛砖引玉,希望能够为读者在解读CNTK工程的过程提供一些便利,当然本文中如有错误,请及时指正。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多