分享

Keras

 行走在理想边缘 2022-05-23 发布于四川

调用.fit:

model.fit(trainX, trainY, batch_size=32, epochs=50)

一般来讲,model.fit调用起来比较简单。首先需要提供训练数据集和数据集的标签。然后需要提供训练批次的大小和迭代次数。
对.fit的调用在这里做出两个主要假设:

  • 我们的整个训练集可以放入RAM
  • 没有数据增强(即不需要Keras生成器)

相反,我们的网络模型将在原始数据上训练。
原始数据本身将适合内存,我们无需将旧批量数据从RAM中移出并将新批量数据移入RAM。
此外,我们不会使用数据增强动态操纵训练数据。

Keras fit_generator函数

在这里插入图片描述

对于小型,简单化的数据集,使用Keras的.fit函数是完全可以接受的。
这些数据集通常不是很具有挑战性,不需要任何数据增强。
但是,真实世界的数据集很少这么简单:
真实世界的数据集通常太大而无法放入内存中
它们也往往具有挑战性,要求我们执行数据增强以避免过拟合并增加我们的模型的泛化能力
在这些情况下,我们需要利用Keras的.fit_generator函数:

# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32

# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
	width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
	horizontal_flip=True, fill_mode="nearest")

# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
	validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
	epochs=EPOCHS)

我们首先初始化将要训练的网络的epoch和batch size。

然后我们初始化aug,这是一个Keras ImageDataGenerator对象,用于图像的数据增强,随机平移,旋转,调整大小等。

执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。

但是,应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。

根据提供给ImageDataGenerator的参数随机调整每批新数据。

因此,我们现在需要利用Keras的.fit_generator函数来训练我们的模型。

顾名思义,.fit_generator函数假定存在一个为其生成数据的基础函数。

该函数本身是一个Python生成器。

Keras在使用.fit_generator训练模型时的过程:

  • Keras调用提供给.fit_generator的生成器函数(在本例中为aug.flow)
  • 生成器函数为.fit_generator函数生成一批大小为BS的数据
  • .fit_generator函数接受批量数据,执行反向传播,并更新模型中的权重
  • 重复该过程直到达到期望的epoch数量

您会注意到我们现在需要在调用.fit_generator时提供steps_per_epoch参数(.fit方法没有这样的参数)。

为什么我们需要steps_per_epoch?

请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。

由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。
因此,我们将训练数据的总数除以批量大小的结果作为steps_per_epoch的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多