分享

【干货】基于Keras的注意力机制实战

 zbpjlc 2018-05-31

【导读】近几年,注意力机制(Attention)大量地出现在自动翻译、信息检索等模型中。可以把Attention看成模型中的一个特征选择组件,特征选择一方面可以增强模型的效果,另一方面,我们可以通过计算出的特征的权重来计算结果与特征之间的某种关联。例如在自动翻译模型中,Attention可以计算出不同语种词之间的关系。本文一个简单的例子,来展示Attention是怎么在模型中起到特征选择作用的。


代码




导入相关库

#coding=utf-8
import numpy as np
from keras.models import *
from keras.layers import Input, Dense, merge
import matplotlib.pyplot as plt
import pandas as pd


数据生成函数

# 输入维度
input_dim = 32


# 生成数据,数据的的第attention_column个特征由label决定,
# 即
label只与数据的第attention_column个特征相关
def get_data(n, input_dim, attention_column=1):
   x = np.random.standard_normal(size=(n, input_dim))
   y = np.random.randint(low=0, high=2, size=(n, 1))
   x[:, attention_column] = y[:, 0]
   return x, y


模型定义函数

将输入进行一次变换后,计算出Attention权重,将输入乘上Attention权重,获得新的特征。


# Attention模型
def build_model():
   inputs = Input(shape=(input_dim,))

   # 计算Attention权重
   
attention_probs = Dense(input_dim, activation='softmax',
name='attention_vec')(inputs)
   # 根据Attention权重更新特征
   
attention_mul = merge([inputs, attention_probs],
output_shape=32,
name='attention_mul', mode='mul')

   # 预测标签
   
attention_mul = Dense(64)(attention_mul)
   output = Dense(1, activation='sigmoid')(attention_mul)
   model = Model(input=[inputs], output=output)
   attention_vec_model = Model(input=[inputs],
output=attention_probs)
   return model, attention_vec_model


主函数

if __name__ == '__main__':
   # 生成训练数据
   
N = 10000
   
inputs_1, outputs = get_data(N, input_dim)

   # 获取模型,以及用于计算Attention权重的子模型
   
m, attention_vec_model = build_model()
   m.compile(optimizer='adam', loss='binary_crossentropy',
metrics=['accuracy'])
   print(m.summary())

   # 训练
   
m.fit([inputs_1], outputs, epochs=20, batch_size=64,
validation_split=0.5)

   # 生成测试数据
   
testing_inputs_1, testing_outputs = get_data(1, input_dim)

   # 根据测试数据计算Attention权重
   
attention_vector = attention_vec_model.
   predict([testing_inputs_1])[0].flatten()
   print('attention =', attention_vector)

   # 绘图
pd.DataFrame(attention_vector, columns=['attention (%)'])
.plot(kind='bar', title='Attention Mechanism as a function of
input dimensions.'
)
   plt.show()


运行结果

代码中,attention_column为1,也就是说,label只与数据的第1个特征相关。从运行结果中可以看出,Attention权重成功地获取了这个信息。


参考链接

https://github.com/philipperemy/keras-attention-mechanism


更多教程资料请访问:人工智能知识资料全集

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多