分享

Keras 多GPU mult...

 行走在理想边缘 2022-06-24 发布于四川

Keras mult_gup_model 报错 cannot import name 'multi_gpu_model’ from 'keras.utils’

服务器上有多个GPU,想要使用多GPU训练,但是在调用多GPU模型时出现了报错,使用的命令和报错如下:

from keras.utils import multi_gpu_model
2022-03-22 14:42:15.655742: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
Traceback (most recent call last):File "<stdin>", line 1, in <module>
ImportError: cannot import name 'multi_gpu_model' from 'keras.utils' (/opt/anaconda3/lib/python3.7/site-packages/keras/utils/__init__.py)

报错的大概意思是,在 keras.utils 里面没有multi_gpu_model这个函数

在查阅相关的网站我发现了问题所在,使用以下语句可以解决报错

from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model

或者

from keras.utils.multi_gpu_utils import multi_gpu_model

问题是在 keras.utils 下面没有 multi_gpu_model 这个函数,这个函数的实际位置是在 multi_gpu_utils 下面。所以在调用的时候需要在这下面调用。

然后使用以下代码可以调用多GPU训练

import tensorflow as tf
from keras.applications import Xception
from keras.utils import multi_gpu_model
import numpy as np

num_samples = 1000
height = 224
width = 224
num_classes = 1000

# 实例化基础模型(或者「模版」模型)。
# 我们推荐在 CPU 设备范围内做此操作,
# 这样模型的权重就会存储在 CPU 内存中。
# 否则它们会存储在 GPU 上,而完全被共享。
with tf.device('/cpu:0'):
    model = Xception(weights=None,
                     input_shape=(height, width, 3),
                     classes=num_classes)

# 复制模型到 8 个 GPU 上。
# 这假设你的机器有 8 个可用 GPU。
parallel_model = multi_gpu_model(model, gpus=8)
parallel_model.compile(loss='categorical_crossentropy',
                       optimizer='rmsprop')

# 生成虚拟数据
x = np.random.random((num_samples, height, width, 3))
y = np.random.random((num_samples, num_classes))

# 这个 `fit` 调用将分布在 8 个 GPU 上。
# 由于 batch size 是 256, 每个 GPU 将处理 32 个样本。
parallel_model.fit(x, y, epochs=20, batch_size=256)

# 通过模版模型存储模型(共享相同权重):
model.save('my_model.h5')

这里还需要注意的是,batchsize 需要设置为偶数不然会出现以下报错

Non-OK-status: GpuLaunchKernel( SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize, kTileSize, conjugate>, total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims, output) status: Internal: invalid configuration argument

参考网址:
https://github.com/keras-team/keras/issues/14440
https://github.com/tensorflow/tensorflow/issues/36310
https:///zh/utils/

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多