分享

详解TensorFlow的新seq2seq模块及其用法

 zbpjlc 2018-05-27

基于注意力机制的seq2seq模型现在在自然语言处理、语音识别中的应用已经越来越广泛了,如何更灵活地掌握seq2seq的使用非常重要,本期我将带领大家阅读TensorFlow的seq2seq源码,并手动写一个完整的基于TensorFlow的seq2seq模型(代码见文末),下一期将重点介绍如何基于TensorFlow构建基于注意力机制的seq2seq模型。



TensorFlow在 1.1 版本中就已经开始对seq2seq模块进行了完整构建,为了避免源码中可能出现的较大变动,我现在以 TF1.3 版本的源码为例来带领大家熟悉各个代码的结构以及其功能。在tensorflow/contrib/seq2seq/python/ops/目录下面有六个重要的代码文件,它们的文件名如下所示,除了attention_wrapper.py下一期介绍以外,接下来我将根据seq2seq模型的结构顺序来介绍这些文件的作用与设计细节。

attention_wrapper.py

basic_decoder.py

beam_search_decoder.py

decoder.py

helper.py

loss.py


1. 编码器

有读者会注意到,上面列举的七个源码文件没有编码器的代码,这是因为编码器是一个比较基础的网络结构,普通的BasicRNNCell或者BasicLSTMCell就可以作为编码器,需要注意的是编码器的输出应该具有两个成分:

每一时刻的输出

最后时刻的隐含状态


如果我们用python中的namedtuple数据结构来表示编码器的输出,其可以表示如下:

EncoderOutput = namedtuple('encoder_output',  'outputs_final_state')


虽然编码器大多是由RNN构成的,但编码器不一定只有RNN才可以充当,只要我们使得编码器的输出满足上面的形式,任何形式的模型都可以作为seq2seq模型的编码器,例如之前推送的完全基于卷积神经网络的seq2seq就是使用卷积神经网络来充当编码器。


2. 解码器

从上面的文件目录我们可以发现,有两个文件都是关于解码器的,分别是decoder.py和basic_decoder.py,个人感觉decoder.py的代码不是十分清晰,该代码中主要包含一个Decoder抽象类和一个dynamic_decode函数,定义Decoder抽象类是为了更好地实现具体类,很多大型面向对象的项目中都会定义抽象类或接口,Decoder抽象类提供了batch_size、output_size、output_dtype、initialize、step等未实现的抽象函数,这些函数都是一个具体Decoder类必须要实现的函数,这里面initialize函数的功能是提供每一步解码的输入、初始状态、是否完成解码,即(finished, first_inputs, initial_state);step函数的功能是执行解码操作,提供输入和状态就能通过解码得到下一时刻的输入以及状态。


除了Decoder抽象类以外,decoder.py中还写了一个dynamic_decode函数,这个函数放在decoder.py代码中显得有点突兀。如果说Decoder抽象类中的initializer函数和step函数是单步操作的话,那么dynamic_decode函数就是用来将各个时刻串联起来完整地实现解码过程,返回值包括最后时刻的输出final_outputs,最后时刻的状态final_state,最后序列的长度final_sequence_length。


另外一个文件basic_decoder.py只包含一个继承Decoder类的BasicDecoder类,Decoder中未实现的函数在BasicDecoder中都有具体实现,因此,我们在搭建seq2seq模型时只需要调用BasicDecoder类即可。


3. helper

不太明白为什么这个文件取名为helper.py,感觉有点奇怪,因为仅仅看文件名根据猜不出来这个文件的作用,实际上,helper.py的作用就是提供在解码过程中的抽样方法,例如训练过程中解码器的输出是采用argmax算法还是广义伯努利分布算法,推断过程中输出是采用argmax来获取输出id还是采用广义伯努利分布采样来得到输出id,除此之外,用户也可以自定义其他类型的解码方法。


在helper.py代码中,主要有以下几个类:

'Helper'

'TrainingHelper'

'GreedyEmbeddingHelper'

'SampleEmbeddingHelper'

'CustomHelper'

'ScheduledEmbeddingTrainingHelper'

'ScheduledOutputTrainingHelper'

'InferenceHelper'


看起来似乎很复杂,其实我们可以将它们分为三类,Helper类是抽象基础类,其中定义了几个抽象方法,如initialize、sample、next_inputs,接下来所有的具体类都是继承自Helper抽象类。


CustomHelper类虽然是继承自Helper类的一个具体类,但是这个类没有外加太多约束,它需要用户自定义initialize_fn, sample_fn, next_inputs_fn这三个函数,而InferenceHelper类我们可以看成是CustomHelper类的一个特殊情况,由于这个类只在推断的时候使用,因此在next_inputs函数中只需要将前一时刻的抽样结果作为下一时刻的输入即可;GreedyEmbeddingHelper类也是用于推理过程,不过它是采取argmax抽样算法来得到输出id,并且经过embedding层作为下一时刻的输入;而SampleEmbeddingHelper是继承自GreedyEmbeddingHelper类的一个类,与GreedyEmbeddingHelper类不同的是,SampleEmbeddingLayer是通过抽样算法来得到解码器的输出。


TrainingHelper类也是继承自Helper类的一个具体类,在sample过程中,它采用的是最简单的argmax算法;而ScheduledEmbeddingTrainingHelper类是继承自TrainingHelper类,其中的sample算法采取的是广义伯努利算法,并且并不是每一个时刻都会采样,同时这里添加了embedding操作,即根据解码器的输出id从embedding矩阵中查找其对应的embedding向量;ScheduledOutputTrainingHelper类同样也是继承自TrainingHelper类,没有embedding操作,直接对输出进行抽样。


4. beam search decoder

除了上面提到的argmax算法和伯努利抽样算法以外,我们还可以使用Beam Search的抽样方法来获得最终的解码序列,在beam_search_decoder.py文件中,BeamSearchDecoder类是继承自Decoder类的,与之前的BasicDecoder类不同的是,BasicDecoder类需要设定helper参数,而这里的BeamSearchDecoder没有helper参数,因为它所采用的算法是Beam Search,其细节在该文件中有实现。


5. 基于TensorFlow的seq2seq代码示例

下图中是根据上面提到的API来搭建一个简单的seq2seq模型,流程是定义编码器、定义helper、定义解码器、动态解码。其中inputX是指输入的tensor变量,类型默认为tf.int32;target是指对应的输出标签,类型默认为tf.int32;seq_len是指每个batch中的target的长度;start_token是指在推理时使用的起始字符id;end_token是指在推理时使用的结束字符id;inference为布尔变量,设定是推断模式还是训练模式,因为对于这两种模式,我们所定义的helper是不同的。



    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多