Keras mult_gup_model 报错 cannot import name 'multi_gpu_model’ from 'keras.utils’
服务器上有多个GPU,想要使用多GPU训练,但是在调用多GPU模型时出现了报错,使用的命令和报错如下:
from keras.utils import multi_gpu_model
2022-03-22 14:42:15.655742: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
Traceback (most recent call last):File "<stdin>", line 1, in <module>
ImportError: cannot import name 'multi_gpu_model' from 'keras.utils' (/opt/anaconda3/lib/python3.7/site-packages/keras/utils/__init__.py)
报错的大概意思是,在 keras.utils 里面没有multi_gpu_model这个函数
在查阅相关的网站我发现了问题所在,使用以下语句可以解决报错
from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
或者
from keras.utils.multi_gpu_utils import multi_gpu_model
问题是在 keras.utils 下面没有 multi_gpu_model 这个函数,这个函数的实际位置是在 multi_gpu_utils 下面。所以在调用的时候需要在这下面调用。
然后使用以下代码可以调用多GPU训练
import tensorflow as tf
from keras.applications import Xception
from keras.utils import multi_gpu_model
import numpy as np
num_samples = 1000
height = 224
width = 224
num_classes = 1000
# 实例化基础模型(或者「模版」模型)。
# 我们推荐在 CPU 设备范围内做此操作,
# 这样模型的权重就会存储在 CPU 内存中。
# 否则它们会存储在 GPU 上,而完全被共享。
with tf.device('/cpu:0'):
model = Xception(weights=None,
input_shape=(height, width, 3),
classes=num_classes)
# 复制模型到 8 个 GPU 上。
# 这假设你的机器有 8 个可用 GPU。
parallel_model = multi_gpu_model(model, gpus=8)
parallel_model.compile(loss='categorical_crossentropy',
optimizer='rmsprop')
# 生成虚拟数据
x = np.random.random((num_samples, height, width, 3))
y = np.random.random((num_samples, num_classes))
# 这个 `fit` 调用将分布在 8 个 GPU 上。
# 由于 batch size 是 256, 每个 GPU 将处理 32 个样本。
parallel_model.fit(x, y, epochs=20, batch_size=256)
# 通过模版模型存储模型(共享相同权重):
model.save('my_model.h5')
这里还需要注意的是,batchsize 需要设置为偶数不然会出现以下报错
Non-OK-status: GpuLaunchKernel( SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize, kTileSize, conjugate>, total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims, output) status: Internal: invalid configuration argument
参考网址: https://github.com/keras-team/keras/issues/14440 https://github.com/tensorflow/tensorflow/issues/36310 https:///zh/utils/
|