调用.fit:
model.fit(trainX, trainY, batch_size=32, epochs=50)
一般来讲,model.fit调用起来比较简单。首先需要提供训练数据集和数据集的标签。然后需要提供训练批次的大小和迭代次数。 对.fit的调用在这里做出两个主要假设:
- 我们的整个训练集可以放入RAM
- 没有数据增强(即不需要Keras生成器)
相反,我们的网络模型将在原始数据上训练。 原始数据本身将适合内存,我们无需将旧批量数据从RAM中移出并将新批量数据移入RAM。 此外,我们不会使用数据增强动态操纵训练数据。
Keras fit_generator函数

对于小型,简单化的数据集,使用Keras的.fit函数是完全可以接受的。 这些数据集通常不是很具有挑战性,不需要任何数据增强。 但是,真实世界的数据集很少这么简单: 真实世界的数据集通常太大而无法放入内存中 它们也往往具有挑战性,要求我们执行数据增强以避免过拟合并增加我们的模型的泛化能力 在这些情况下,我们需要利用Keras的.fit_generator函数:
# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32
# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
horizontal_flip=True, fill_mode="nearest")
# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS)
我们首先初始化将要训练的网络的epoch和batch size。
然后我们初始化aug,这是一个Keras ImageDataGenerator对象,用于图像的数据增强,随机平移,旋转,调整大小等。
执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。
但是,应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。
根据提供给ImageDataGenerator的参数随机调整每批新数据。
因此,我们现在需要利用Keras的.fit_generator函数来训练我们的模型。
顾名思义,.fit_generator函数假定存在一个为其生成数据的基础函数。
该函数本身是一个Python生成器。
Keras在使用.fit_generator训练模型时的过程:
- Keras调用提供给.fit_generator的生成器函数(在本例中为aug.flow)
- 生成器函数为.fit_generator函数生成一批大小为BS的数据
- .fit_generator函数接受批量数据,执行反向传播,并更新模型中的权重
- 重复该过程直到达到期望的epoch数量
您会注意到我们现在需要在调用.fit_generator时提供steps_per_epoch参数(.fit方法没有这样的参数)。
为什么我们需要steps_per_epoch?
请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。
由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。 因此,我们将训练数据的总数除以批量大小的结果作为steps_per_epoch的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。
|