Preface
这两天概览了一下卜居(赵永科)的《深度学习 21天实战caffe》,进入深度学习挺长时间的了。文章也看了不少,Caffe、Theano、Torch 也都用过。其实个人认为,这本书对于已经深入这个领域已定时间的人来说,帮助不大。本书讲述的只是“术“,有点像深度学习的说明书,讲的很浅。
但是翻了一翻,还是有点收获的,这个 MNIST 手写数字识别是深度学习入门很经典的例子。基本上所有的深度学习框架,在让初学者入门使用的时候都有这个例子。
我一直对 Caffe 中使用的 LMDB、LEVELDB 数据组织比较疑惑,很多时候不明白该怎么样组织图像数据、以及其对应的标签。之前都是按照别人的代码生成的,自己其实懵懵的。所以,我想通过 MNIST 输入数据生成过程,熟悉一下 LMDB、LEVELDB 的基本使用方法。
熟悉了 C++ 版本的转 lmdb 方式后,我会解析一下 Python 版本的 lmdb 转换过程。
最后 Reference 部分,列出了我这里面参考的文章。
MNIST 及其转 LMDB 数据库源码 create_mnist_data
MINIST(Mixed National Institute of Stanfords and Technology)是一个大型的手写数字数据库,广泛用于机器学习领域的训练和测试,由纽约大学 Yann LeCun 教授整理。MNIST 包括 60000 个训练集和 10000 个测试集,每张图都已经进行尺寸归一化,数字居中处理,固定尺寸为 28×28。如下图所示:
MNIST 数据格式描述
MNIST 具体的文件格式描述如下面的表所示:
MNIST 原始数据文件
训练集图片文件格式描述(train-images-idx3-ubyte)
训练集标签文件格式描述(train-labels-idx1-ubyte)
测试集图片文件格式描述(t10k-images-idx3-ubyte)
测试集标签文件格式描述(t10k-labels-idx1-ubyte)
注意:图片文件中像素按行组织,像素值 0 表示背景(白色),像素值 255 表示前景(黑色)。
转换格式、create_mnist_data.cpp 源码解析
先说一下 Caffe 为什么采用 LMDB、LEVELDB,而不是直接读取原始数据?
原因是,一方面,数据类型多种多样,有二进制文件、文本文件、编码后的图像文件(如 JPEG、PNG、网络爬取的数据等),不可能用一套代码实现所有类型的输入数据读取,转换为统一格式可以简化数据读取层的实现;
另一方面,使用 LMDB、LEVELDB 可以提高磁盘 IO 利用率。
下载到的原始数据为二进制文件,需要转换为 LEVELDB 或 LMDB 才能被 Caffe 识别。
我们 Git 得到的 Caffe 中,在 examples/mnist/ 下有一个脚本文件:create_mnist.sh ,这个就可以将原始的二进制数据,生成 LMDB 格式数据。
运行后,会生成 examples/mnist/mnist_train_lmdb/ 和 examples/mnist/mnist_test_lmdb/ 这两个目录。每个目录下都有两个文件:data.mdb 和 lock.mdb 。
看一下脚本文件:create_mnist.sh 里面是什么:
#!/usr/bin/env sh
# This script converts the mnist data into lmdb/leveldb format,
# depending on the value assigned to $BACKEND.
EXAMPLE=examples/mnist
DATA=data/mnist
BUILD=build/examples/mnist
BACKEND="lmdb"
echo "Creating ${BACKEND}..."
rm -rf $EXAMPLE/mnist_train_${BACKEND}
rm -rf $EXAMPLE/mnist_test_${BACKEND}
$BUILD/convert_mnist_data.bin $DATA/train-images-idx3-ubyte $DATA/train-labels-idx1-ubyte $EXAMPLE/mnist_train_${BACKEND} --backend=${BACKEND}
$BUILD/convert_mnist_data.bin $DATA/t10k-images-idx3-ubyte $DATA/t10k-labels-idx1-ubyte $EXAMPLE/mnist_test_${BACKEND} --backend=${BACKEND}
echo "Done."
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
create_mnist_data.cpp 源码解析
可以看到,上面脚本最核心的部分,就是调用 convert_mnist_data.bin 这个可执行程序,对应的源文件为 examples/mnist/convert_mnist_data.cpp ,对这个源代码的解读如下,深入这段代码可以更清楚的了解 LMDB 是如何生成的。
// 这段代码将 MNIST 数据集转换为(默认的)lmdb 或者 leveldb(--backend=leveldb) 格式,用于在使用 caffe 的时候读取数据
// 使用方法:
// convert_mnist_data [FLAGS] input_image_file input_label_file output_db_file
// gflags: 命令行参数解析头文件
#include <gflags/gflags.h>
// glog: 记录程序日志头文件
#include <glog/logging.h>
// 解析 *.prototxt 文件
#include <google/protobuf/text_format.h>
#include <leveldb/db.h>
#include <leveldb/write_batch.h>
#include <lmdb.h>
#include <stdint.h>
#include <sys/stat.h>
#include <fstream> // NOLINT(readability/streams)
#include <string>
// 解析caffe中proto类型文件的头文件
#include "caffe/proto/caffe.pb.h"
using namespace caffe; // NOLINT(build/namespace)
using std::string;
// GFLAGS 工具定义命令行选项 backend, 默认值为 lmdb, 即: --backend=lmdb
DEFINE_string(backend, "lmdb", "The backend for storing the result");
// 大小端转换, MNIST 原始数据文件中 32 位整型值为大端存储, C/C++ 变量为小端存储,因此需要加入转换机制
uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}
// 转换数据集函数
void convert_dataset(const char* image_filename, const char* label_filename,
const char* db_path, const string& db_backend) {
// 用 C++ 输入文件流以二进制方式打开
// 定义, 打开图像文件 对象: image_file(读入的文件名, 读入方式), 此处以二进制的方式
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
// 定义, 打开标签文件 对象: label_file(读入的文件名, 读入方式), 此处以二进制的方式
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
// CHECK: 用于检测文件能否正常打开函数
CHECK(image_file) << "Unable to open file " << image_filename;
CHECK(label_file) << "Unable to open file " << label_filename;
// 读取魔数与基本信息
// uint32_t 用 typedef 来自定义的一种数据类型, unsigned int32, 每个int32整数占用4个字节, 这样做是为了程序的可扩展性
uint32_t magic; // 魔数
uint32_t num_items; // 文件包含条目总数
uint32_t num_labels; // 标签值
uint32_t rows; // 行数
uint32_t cols; // 列数
// 读取魔数: magic
// image_file.read( 读取内容的指针, 读取的字节数 ) , magic 是一个 int32 类型的整数,每个占 4 个字节,所以这里指定为 4
// reinterpret_cast 为 C++ 中定义的强制转换符, 这里把 &magic, 即 magic 的地址(一个 16 进制的数), 转变成 char 类型的指针
image_file.read(reinterpret_cast<char*>(&magic), 4);
// 大端到小端的转换
magic = swap_endian(magic);
// 校验图像文件中魔数是否为 2051, 不是则报错
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
// 同理, 校验标签文件中的魔数是否为 2049, 不是则报错
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
// 读取图片的数量: num_items
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items); // 大端到小端转换
// 读取图片标签的数量
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels); // 大端到小端转换
// 图片数量应等于其标签数量, 检查两者是否相等
CHECK_EQ(num_items, num_labels);
// 读取图片的行大小
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows); // 大端到小端转换
// 读取图片的列大小
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols); // 大端到小端转换
// lmdb 相关句柄
MDB_env *mdb_env;
MDB_dbi mdb_dbi;
MDB_val mdb_key, mdb_data;
MDB_txn *mdb_txn;
// leveldb 相关句柄
leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
level::WriteBatch* batch = NULL;
// 打开 db
if (db_backend == "leveldb") { // leveldb
LOG(INFO) << "Opening leveldb " << db_path;
leveldb::Status status = leveldb::DB::Open(
options, db_path, &db);
CHECK(status.ok()) << "Failed to open leveldb " << db_path << ". Is it already existing?";
batch = new leveldb::WriteBatch();
}else if (db_backend == "lmdb") { // lmdb
LOG(INFO) << "Opening lmdb " << db_path;
CHECK_EQ(mkdir(db_path, 0744), 0) << "mkdir " << db_path << "failed";
CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS) << "mdb_env_set_mapsize failed"; // 1TB
CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS) << "mdb_env_open_failed";
CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) << "mdb_txn_begin failed";
CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS) << "mdb_open failed. Does the lmdb already exist?";
} else {
LOG(FATAL) << "Unknown db backend " << db_backend;
}
// 将读取数据保存至 db
char label;
char* pixels = new char[rows * cols];
int count = 0;
const int kMaxKeyLength = 10;
char key_cstr[kMaxKeyLength];
string value;
// 设置datum数据对象的结构,其结构和源图像结构相同
Datum datum;
datum.set_channels(1);
datum.set_height(rows);
datum.set_width(cols);
// 输出 Log, 输出图片总数
LOG(INFO) << "A total of " << num_items << " items.";
// 输出 Log, 输出图片的行、列大小
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
// 读取图片数据以及 label 存入 protobuf 定义好的数据结构中,
// 序列化成字符串储存到数据库中,
// 这里为了减少单次操作带来的带宽成本(验证数据包完整等),
// 每 1000 次执行一次操作
for (int item_id = 0; item_id < num_items; ++item_id) {
// 从数据中读取 rows * cols 个字节, 图像中一个像素值(应该是 int8 类型)用一个字节表示即可
image_file.read(pixels, rows * cols);
// 读取标签
label_file.read(&label, 1);
// set_data 函数, 把源图像值放入 datum 对象
datum.set_data(pixels, rows*cols);
// set_label 函数, 把标签值放入 datum
datum.set_label(label);
// snprintf(str1, size_t, "format", str), 把 str 按照 format 的格式以字符串的形式写入 str1, size_t 表示写入的字符个数
// 这里是把 item_id 转换成 8 位长度的十进制整数,然后在变成字符串复制给 key_str, 如:item_id=1500(int), 则 key_cstr = 00015000(string, \0为字符串结束标志)
snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
datum.SerializeToString(&value);
// 感觉是将 datum 中的值序列化成字符串,保存在变量 value 内,通过指针来给 value 赋值
string keystr(key_cstr);
// 放到数据库中
if (db_backend == "leveldb") { // leveldb
// 通过 batch 中的子方法 Put, 把数据写入 datum 中(此时在内存中)
batch->Put(keystr, value);
} else if (db_backend == "lmdb") { // lmdb
// mv 应该是 move value, 应该是和 write() 和 read() 函数文件读写的方式一样, 以固定的子节长度按照地址进行读写操作
// 获取 value 的子节长度, 类似 sizeof() 函数
mdb_data.mv_size = value.size()
// 把 value 的首个字符地址转换成空类型的指针
mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
mdb_key.mv_size = keystr.size();
mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);
// 通过 mdb_put 函数把 mdb_key 和 mdb_data 所指向的数据, 写入到 mdb_dbi
CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS) << "mdb_put failed";
} else {
LOG(FATAL) << "Unknown db backend " << db_back_end;
}
// 把 db 数据写入硬盘
// 选择 1000 个样本放入一个 batch 中,通过 batch 以批量的方式把数据写入硬盘
// 写入硬盘通过 db.write() 函数来实现
if (++count % 1000 == 0) {
// 批量提交更改
if(db_backend == "leveldb") { // leveldb
// 把batch写入到 db 中,然后删除 batch 并重新创建
db->Write(leveldb::WriteOptions(), batch);
delete batch;
batch = new leveldb::WriteBatch();
} else if (db_backend == "lmdb") { // lmdb
// 通过 mdb_txn_commit 函数把 mdb_txn 数据写入到硬盘
CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
// 重新设置 mdb_txn 的写入位置, 追加(继续)写入
CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) << "mdb_txn_begin failed";
} else {
LOG(FATAL) << "Unknown db backend " << db_backend;
}
} // if (++count % 1000 == 0)
} // for (int item_id = 0; item_id < num_items; ++item_id)
// 写最后一个 batch
if (count % 1000 != 0) {
if (db_backend == "leveldb") { // leveldb
db->Write(leveldb::WriteOptions(), batch);
delete batch;
delete db; // 删除临时变量,清理内存占用
} else if (db_backend == "lmdb") { // lmdb
CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
// 关闭 mdb 数据对象变量
mdb_close(mdb_env, mdb_dbi);
// 关闭 mdb 操作环境变量
mdb_env_close(mdb_env);
} else {
LOG(FATAL) << "Unknown db backend " << db_backend;
}
LOG(ERROE) << "Processed " << count << " files.";
}
delete[] pixels;
} // void convert_dataset(const char* image_filename, const char* label_filename, const char* db_path, const string& db_backend)
int main(int argc, char** argv) {
#ifndef GFLAGS_GFLAGS_H
namespace gflags = google;
#endif
gflags::SetUsageMessage("This script converts the MNIST dataset to \n"
"the lmdb/leveldb format used by Caffe to load data. \n"
"Usage:\n"
" convert_mnist_data [FLAGS] input_image_file input_label_file "
"output_db_file\n"
"The MNIST dataset could be downloaded at\n"
" http://yann./exdb/mnist/\n"
"You should gunzip them after downloading,"
"or directly use the data/mnist/get_mnist.sh\n");
gflags::ParseCommandLineFlags(&argc, &argv, true);
// FLAGS_backend 在前面通过 DEFINE_string 定义,是字符串类型
const string& db_backend = FLAGS_backend;
if (argc != 4) {
gflags::ShowUsageWithFlagsRestrict(argv[0], "examples/mnist/convert_mnist_data");
} else {
google::InitGoogleLogging(argv[0]);
convert_dataset(argv[1], argv[2], argv[3], db_backend);
}
return 0;
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
LMDB 句柄
变量 |
说明 |
MDB_dbi mdb_dbi |
环境中一个数据库的句柄 |
MDB_env *mdb_env |
整个数据环境的句柄 |
MDB_val mdb_key, mdb_data |
存放要输入进数据库的数据值 |
MDB_txn *mdb_txn |
数据库事物操作的句柄 |
LMDB 流程图
小端存储、大端存储(Little-Endian、Big-Endian)
上面的源码中,有一个函数是进行大端存储到小端存储的转换的。这部分没有计算机汇编的基础,一开始一头雾水……参考的一篇博客:http://www.cnblogs.com/passingcloudss/archive/2011/05/03/2035273.html
不同的CPU有不同的字节序类型,这些字节序是指整数在内存中保存的顺序。最常见的有两种:
1. Little-endian:将低序字节存储在起始地址(低位编址)
2. Big-endian:将高序字节存储在起始地址(高位编址)
LE(little-endian):
最符合人的思维的字节序,地址低位存储值的低位 ,地址高位存储值的高位 。
这种存储最符合人的思维的字节序,因为从人的第一观感来说,低位值小,就应该放在内存地址小的地方,也即内存地址低位。反之,高位值就应该放在内存地址大的地方,也即内存地址高位
BE(big-endian):
最直观的字节序,地址低位存储值的高位,地址高位存储值的低位
为什么说直观,不要考虑对应关系,只需要把内存地址从左到右按照由低到高的顺序写出,把值按照通常的高位到低位的顺序写出。两者对照,一个字节一个字节的填充进去 。
注:×86 系列的 CPU 都是 Little-Endian 的字节序。
例子1:在内存中双字 0x01020304(DWORD) 的存储方式:
内存地址为:4000 4001 4002 4003
小端存储: 04 03 02 01
大端存储: 01 02 03 04
注:每个地址存 1 个字节,每个字有 4 字节。2 位 16 进制数是 1 个字节(0xFF = 11111111)。
例子2:如果我们将 0x1234abcd 写入到以 0x0000 开始的内存中,则结果为:
|
big-endian |
little-endian |
0x0000 |
0x12 |
0xcd |
0x0001 |
0x23 |
0xab |
0x0002 |
0xab |
0x34 |
0x0003 |
0xcd |
0x12 |
Python 读写 LMDB 格式图像数据
我想这部分才是很多人关心的,因为我们使用 caffe,将图像数据转换为 caffe 可以识别的数据格式是第一步。同时大多数都是通过 python 接口来转换数据格式的。
LMDB 数据库
Caffe 使用 LMDB 的情况大约有两类:
- 第一类是 DataLayer 层中 使用的 训练集、验证集、测试集;
- 第二类 就是
./caffe/build/tools/extract_feature.bin 这种特征提取工具提取特征后,输出的特征文件。
LMDB 的全称是 Lighting Memory-Mapped Database(闪电般的内存映射数据库) 。它文件结构简单,一个文件夹,里面一个数据文件,一个锁文件。数据随意复制,随意传输。它的访问简单,不需要运行单独的数据管理进程。只要在访问的代码里引用 LMDB 库,访问时给文件路径即可。
Caffe 中使用的数据较为很简单,就是大量的矩阵/向量平铺开来。数据之间没有什么关联,数据内没有复杂的对象结构,就是向量和矩阵。既然数据并不复杂,Caffe 就选择了 LMDB 这个简单的数据库来存放数据。
上面提到了,Caffe 使用 LMDB 数据库有两点原因:
一方面是因为数据源的格式多样性,有文本文件、二进制文件图像文件等等,不可能用一个代码完成上述所有的数据格式。因此,通过 LMDB 数据库,转化成统一的数据格式可以简化数据读取层的实现。
第二个方面就是使用 LMDB 数据库可以大大的节约磁盘 IO 的时间开销。因为读取大量小文件的时间开销是相当大的,尤其是在机械硬盘上。
数据库单文件还能减少数据集复制、传输过程的开销。因为我们都有过体会,一个具有几万个、几十万个文件的数据集,不管是直接复制,还是打开再解包,过程都巨慢无比。LMDB 只有一个文件,你的介质有多快,就能复制多快,不会因为文件多而慢的令人心碎。
Caffe 中 Datum 数据结构
Caffe 并不是把向量和矩阵直接放进数据库的,而是将数据通过 caffe.proto 里定义的一个 datum 类来封装的。数据库里存放的是一个个 datum 序列化成的字符串。Datum 的定义如下:
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
// the actual image data, in bytes
optional bytes data = 4;
optional int32 label = 5;
// Optionally, the datum could also hold float data.
repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded
optional bool encoded = 7 [default = false];
}
一个 Datum 有三个维度,channnels 、height 、width ,可以看作是少了 num 维度的 Blob 。
存放数据的地方有两个:bytes data 、float_data ,分别存放整数型和浮点型数据。图像数据一般是整形,放在 bytes data 中,特征向量一般是浮点型,存放在 float_data 中。
label 里存放的是类别标签,是整数型。
encoded 标识数据是否需要被解码,因为里面可能存放的是 JPEG 或者 PNG 之类经过编码的数据。
Datum 这个数据结构将数据和标签封装在一起,兼容整形和浮点型数据。经过 protobuf 编译后,可以在 Python 和 C++ 中都提供高效的访问。
同时 protobuf 还为它提供了序列化、反序列化的功能。存放进 LMDB 的就是 Datum 序列化生成的字符串。
Caffe 中将图像写入 LMDB 数据库
我上面解析的 create_mnist_data.cpp 代码对于这部分是很有用的,特别是 LMDB 流程图中的 lmdb 数据操作函数,如打开一个 lmdb 数据库,写入数据等操作,python 中的使用类似,但比 C++ 的要简洁许多 。
下面通过代码来说明吧,这段代码是一个大牛写的教程:《A Practical Introduction to Deep Learning with Caffe and Python》,写的很清晰。
import os
import glob
import random
import numpy as np
import cv2
import caffe
from caffe.proto import caffe_pb2
import lmdb
#Size of images
IMAGE_WIDTH = 227
IMAGE_HEIGHT = 227
# train_lmdb、validation_lmdb 路径
train_lmdb = '/home/chenxp/Documents/vehicleID/val/train_lmdb'
validation_lmdb = '/home/chenxp/Documents/vehicleID/val/validation_lmdb'
# 如果存在了这个文件夹, 先删除
os.system('rm -rf ' + train_lmdb)
os.system('rm -rf ' + validation_lmdb)
# 读取图像
train_data = [img for img in glob.glob("/home/chenxp/Documents/vehicleID/val/query/*jpg")]
test_data = [img for img in glob.glob("/home/chenxp/Documents/vehicleID/val/query/*jpg")]
# Shuffle train_data
# 打乱数据的顺序
random.shuffle(train_data)
# 图像的变换, 直方图均衡化, 以及裁剪到 IMAGE_WIDTH x IMAGE_HEIGHT 的大小
def transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT):
#Histogram Equalization
img[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
img[:, :, 1] = cv2.equalizeHist(img[:, :, 1])
img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])
#Image Resizing, 三次插值
img = cv2.resize(img, (img_width, img_height), interpolation = cv2.INTER_CUBIC)
return img
def make_datum(img, label):
#image is numpy.ndarray format. BGR instead of RGB
return caffe_pb2.Datum(
channels=3,
width=IMAGE_WIDTH,
height=IMAGE_HEIGHT,
label=label,
data=np.rollaxis(img, 2).tobytes()) # or .tostring() if numpy < 1.9
# 打开 lmdb 环境, 生成一个数据文件,定义最大空间, 1e12 = 1000000000000.0
in_db = lmdb.open(train_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn: # 创建操作数据库句柄
for in_idx, img_path in enumerate(train_data):
if in_idx % 6 == 0: # 只处理 5/6 的数据作为训练集
continue # 留下 1/6 的数据用作验证集
# 读取图像. 做直方图均衡化、裁剪操作
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
if 'cat' in img_path: # 组织 label, 这里是如果文件名称中有 'cat', 标签就是 0
label = 0 # 如果图像名称中没有 'cat', 有的是 'dog', 标签则为 1
else: # 这里方, label 需要自己去组织
label = 1 # 每次情况可能不一样, 灵活点
datum = make_datum(img, label)
# '{:0>5d}'.format(in_idx):
# lmdb的每一个数据都是由键值对构成的,
# 因此生成一个用递增顺序排列的定长唯一的key
in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString()) #调用句柄,写入内存
print '{:0>5d}'.format(in_idx) + ':' + img_path
# 结束后记住释放资源,否则下次用的时候打不开。。。
in_db.close()
# 创建验证集 lmdb 格式文件
print '\nCreating validation_lmdb'
in_db = lmdb.open(validation_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
for in_idx, img_path in enumerate(train_data):
if in_idx % 6 != 0:
continue
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
if 'cat' in img_path:
label = 0
else:
label = 1
datum = make_datum(img, label)
in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
print '{:0>5d}'.format(in_idx) + ':' + img_path
in_db.close()
print '\nFinished processing all images'
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
再展示一段生成 lmdb 的代码,来源自:http:///2015/04/28/creating-lmdb-in-python/
这段代码并没有用真实的图像数据来生成,二是用 numpy 中的 np.zeros() 生成了图像格式的数据:
import numpy as np
import lmdb
import caffe
N = 1000
# Let's pretend this is interesting data
X = np.zeros((N, 3, 32, 32), dtype=np.uint8)
y = np.zeros(N, dtype=np.int64)
# We need to prepare the database for the size. We'll set it 10 times
# greater than what we theoretically need. There is little drawback to
# setting this too big. If you still run into problem after raising
# this, you might want to try saving fewer entries in a single
# transaction.
map_size = X.nbytes * 10
env = lmdb.open('mylmdb', map_size=map_size)
with env.begin(write=True) as txn:
# txn is a Transaction object
for i in range(N):
datum = caffe.proto.caffe_pb2.Datum()
datum.channels = X.shape[1]
datum.height = X.shape[2]
datum.width = X.shape[3]
datum.data = X[i].tobytes() # or .tostring() if numpy < 1.9
datum.label = int(y[i])
str_id = '{:08}'.format(i)
# The encode is only essential in Python 3
txn.put(str_id.encode('ascii'), datum.SerializeToString())
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
运行上一段代码,会生成下面两个文件:
Caffe 从 LMDB 数据库中读取数据
下面就是从生成好的 lmdb 中读取数据了:
import numpy as np
import caffe
import lmdb
import cv2
# 打开 lmdb 数据库, 指定好位置
env = lmdd.open('mylmdb', readonly=True)
with env.begin() as txn:
raw_datum = txn.get(b'00000000')
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(raw_datum)
flat_x = np.fromstring(datum.data, dtype=np.uint8)
x = flat_x.reshape(datum.channels, datum.height, datum.width)
y = datum.label
print(datum.channels)
print 'label = ' + str(y) # y 为整型, 需要转成字符串
# C x H x W 转换到 H x W x C, 才能在 cv2 中显示
img = cv2.transpose(img, (1, 2, 0)) # 或者: img = x.transpose(1, 2, 0)
cv2.imshow("Image", img)
cv2.waitKey(0)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
输出为:
下图是输出的图像……别笑……那是因为上面代码用 np.zeros() 生成的太小了:
可以迭代读取 <key, value> :
with env.open() as txn:
cursor = txn.cursor()
for key, value in cursor:
print(key, value)
下面代码用迭代循环 txn.cursor() 读取:
import caffe
from caffe.proto import caffe_pb2
import lmdb
import cv2
import numpy as np
lmdb_env = lmdb.open('mylmdb', readonly=True) # 打开数据文件
lmdb_txn = lmdb_env.begin() # 生成处理句柄
lmdb_cursor = lmdb_txn.cursor() # 生成迭代器指针
datum = caffe_pb2.Datum() # caffe 定义的数据类型
for key, value in lmdb_cursor: # 循环获取数据
datum.ParseFromString(value) # 从 value 中读取 datum 数据
label = datum.label
data = caffe.io.datum_to_array(datum)
print data.shape
print datum.channels
image = data.transpose(1, 2, 0)
cv2.imshow('cv2.png', image)
cv2.waitKey(0)
cv2.destroyAllWindows()
lmdb_env.close()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
Reference
- 《深度学习 21天实战 Caffe》, 卜居
- Caffe1——Mnist数据集创建lmdb或leveldb类型的数据
- caffe源码阅读(1): 数据加载
- 愚见caffe中的LeNet
- 小端格式和大端格式(Little-Endian&Big-Endian)
- Creating an LMDB database in Python
- A Practical Introduction to Deep Learning with Caffe and Python
- 中科院自动化所博士@beanfrog:Write/Read lmdb file for caffe with python
- 利用caffe与lmdb读写图像数据
- Caffe中LMDB的使用
- Caffe: Reading LMDB from Python
|