编写自己的层
对于简单的定制操作,我们或许可以通过使用layers.core.Lambda 层来完成。但对于任何具有可训练权重的定制层,你应该自己来实现。
这里是一个Keras层应该具有的框架结构,要定制自己的层,你需要实现下面三个方法
-
build(input_shape) :这是定义权重的方法,可训练的权应该在这里被加入列表`self.trainable_weights 中。其他的属性还包括self.non_trainabe_weights (列表)和self.updates (需要更新的形如(tensor, new_tensor)的tuple的列表)。你可以参考BatchNormalization 层的实现来学习如何使用上面两个属性。
-
call(x) :这是定义层功能的方法,除非你希望你写的层支持masking,否则你只需要关心call 的第一个参数:输入张量
-
get_output_shape_for(input_shape) :如果你的层修改了输入数据的shape,你应该在这里指定shape变化的方法,这个函数使得Keras可以做自动shape推断
from keras import backend as K
from keras.engine.topology import Layer
class MyLayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[1]
initial_weight_value = np.random.random((input_dim, output_dim))
self.W = K.variable(initial_weight_value)
self.trainable_weights = [self.W]
def call(self, x, mask=None):
return K.dot(x, self.W)
def get_output_shape_for(self, input_shape):
return (input_shape[0] + self.output_dim)
调整旧版Keras编写的层以适应Keras1.0
以下内容是你在将旧版Keras实现的层调整为新版Keras应注意的内容,这些内容对你在Keras1.0中编写自己的层也有所帮助。
-
你的Layer应该继承自keras.engine.topology.Layer ,而不是之前的keras.layers.core.Layer 。另外,MaskedLayer 已经被移除。
-
build 方法现在接受input_shape 参数,而不是像以前一样通过self.input_shape 来获得该值,所以请把build(self) 转为build(self, input_shape)
-
请正确将output_shape 属性转换为方法get_output_shape_for(self, train=False) ,并删去原来的output_shape
-
新层的计算逻辑现在应实现在call 方法中,而不是之前的get_output 。注意不要改动__call__ 方法。将get_output(self,train=False) 转换为call(self,x,mask=None) 后请删除原来的get_output 方法。
-
Keras1.0不再使用布尔值train 来控制训练状态和测试状态,如果你的层在测试和训练两种情形下表现不同,请在call 中使用指定状态的函数。如,x=K.in_train_phase(train_x, test_y) 。例如,在Dropout的call 方法中你可以看到:
return K.in_train_phase(K.dropout(x, level=self.p), x)
-
get_config 返回的配置信息可能会包括类名,请从该函数中将其去掉。如果你的层在实例化时需要更多信息(即使将config 作为kwargs传入也不能提供足够信息),请重新实现from_config 。请参考Lambda 或Merge 层看看复杂的from_config 是如何实现的。
-
如果你在使用Masking,请实现compute_mas(input_tensor, input_mask) ,该函数将返回output_mask 。请确保在__init__() 中设置self.supports_masking = True
-
如果你希望Keras在你编写的层与Keras内置层相连时进行输入兼容性检查,请在__init__ 设置self.input_specs 或实现input_specs() 并包装为属性(@property)。该属性应为engine.InputSpec 的对象列表。在你希望在call 中获取输入shape时,该属性也比较有用。
-
下面的方法和属性是内置的,请不要覆盖它们
现存的Keras层代码可以为你的实现提供良好参考,阅读源代码吧!
|