分享

一文上手OpenCV DNN(实现图像分类)

 黄爸爸好 2021-03-30

一文上手OpenCV DNN

1.DNN模块介绍

OpenCV的DNN模块是在OpenCV3.3合并到OpenCV release中的,它最早是在扩展模块中的,它可以导入caffe、tensorflow、pytorch等深度学习框架训练生成的模型文件,从而正向传递实现预测功能。

2.加载模型读取网络信息

模型可以使用readNet API来加载:

Net cv::dnn::readNet(const String & model,const String & config =String(),const String & framework =String());
  • 1

其中model是训练好的二进制网络权重文件,支持多种框架训练出来的模型。config是二进制模型的描述文件,不同的框架配置文件有不同扩展名。framework则声明模型对应框架名称。
除了readNet,也可以使用

Net readNetFromTensorflow(const String&model, const String&config = String());
Net readNetFromCaffe(const String&prototxt,const String&caffeModel = String());
  • 1
  • 2

等API直接加载对应框架训练出来的模型。这里以加载TensorFlow模型为例,代码如下:

//模型文件
String tf_pbfile = "./tensorflow_inception_graph.pb";

//加载模型
Net cnn_net = readNetFromTensorflow(tf_pbfile);
if (cnn_net.empty())
{
cout << "load net failed!" << endl;
return -1;
}
//使用getLayerNames()读取各层信息
vector<String> layer_names = cnn_net.getLayerNames();
for (int i = 0; i < layer_names.size(); i++)
{
int id = cnn_net.getLayerId(layer_names[i]);
auto layer = cnn_net.getLayer(id);
cout << "layerIndex:" << id << " " << "type:" << layer->type.c_str() << " " << "name:" << layer->name.c_str() << endl;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

输出网络的每层信息如下:

layerIndex:1 type:Convolution name:conv2d0_pre_relu/conv
layerIndex:2 type:ReLU name:conv2d0
layerIndex:3 type:Pooling name:maxpool0
layerIndex:4 type:LRN name:localresponsenorm0
layerIndex:5 type:Convolution name:conv2d1_pre_relu/conv
layerIndex:6 type:ReLU name:conv2d1
layerIndex:7 type:Convolution name:conv2d2_pre_relu/conv
layerIndex:8 type:ReLU name:conv2d2
layerIndex:9 type:LRN name:localresponsenorm1
layerIndex:10 type:Pooling name:maxpool1
layerIndex:11 type:Convolution name:mixed3a_1x1_pre_relu/conv
layerIndex:12 type:ReLU name:mixed3a_1x1
layerIndex:13 type:Convolution name:mixed3a_3x3_bottleneck_pre_relu/conv
layerIndex:14 type:ReLU name:mixed3a_3x3_bottleneck
layerIndex:15 type:Convolution name:mixed3a_3x3_pre_relu/conv
...
layerIndex:151 type:Convolution name:head1_bottleneck_pre_relu/conv
layerIndex:152 type:ReLU name:head1_bottleneck
layerIndex:153 type:Permute name:head1_bottleneck/reshape/nchw
layerIndex:154 type:Reshape name:head1_bottleneck/reshape
layerIndex:155 type:InnerProduct name:nn1_pre_relu/matmul
layerIndex:156 type:ReLU name:nn1
layerIndex:157 type:Reshape name:nn1/reshape
layerIndex:158 type:InnerProduct name:softmax1_pre_activation/matmul
layerIndex:159 type:Softmax name:softmax1
layerIndex:160 type:Permute name:avgpool0/reshape/nchw
layerIndex:161 type:Reshape name:avgpool0/reshape
layerIndex:162 type:InnerProduct name:softmax2_pre_activation/matmul
layerIndex:163 type:Softmax name:softmax2
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

3.模型正向传递预测

使用模型实现预测的时候,需要读取图像作为输入,网络模型支持的输入数据是四维的输入,所以要把读取到的Mat对象转换为四维张量,OpenCV的提供的API为如下:

Mat blobFromImage(
InputArray image,
double scalefactor = 1.0,
const Size & size = Size(),
const Scalar & mean = Scalar(),
bool swapRB = false,
bool crop = false,
int ddepth = CV_32F 
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

其中image为输入图像,scalefactor默认是1.0,size表示网络接受的数据大小,mean表示训练时数据集的均值,swapRB是否互换Red与Blur通道,crop是剪切,ddepth是数据类型。

//读取输入图像
Mat input_image = imread("bird.jpg");
if (input_image.empty())
{
cout << "read image failed!" << endl;
return -1;
}
namedWindow("input_image", WINDOW_AUTOSIZE);
imshow("input_image", input_image);
cvtColor(input_image, input_image, COLOR_BGR2RGB);

Mat input_blob = blobFromImage(input_image, 1.0f, Size(h, w), Scalar(), true, false);
//减均值
input_blob -= 117.0;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

模型的输入和正向传递则使用如下两个API:

void setInput(InputArray blob, const String& name = "");
Mat forward(const String& outputName = String());
  • 1
  • 2

部分代码如下:

Mat prob;
//input
cnn_net.setInput(input_blob, "input");
//output
prob = cnn_net.forward("softmax2");
  • 1
  • 2
  • 3
  • 4
  • 5

4.输出预测结果

首先读取标签文件,定义一个读取文件的函数read_class_names():

String label_file = "./imagenet_comp_graph_label_strings.txt";
vector<String> read_class_names(String model_label_file)
{
vector<String> class_names;
ifstream fp(model_label_file);
if (!fp.is_open())
{
cout << "open label file failed!" << endl;
exit(-1);
}
string name;
while (!fp.eof())
{
getline(fp, name);
if (name.length())
class_names.push_back(name);
}
fp.close();
return class_names;
}

//main中:
vector<String> labels = read_class_names(label_file);

Mat probMat = prob.reshape(1, 1);
Point classNumber;
double classProb;
minMaxLoc(probMat, NULL, &classProb, NULL, &classNumber);
int classidx = classNumber.x;
cout<<"classification:"<< labels.at(classidx).c_str() <<"score:"<<fixed<<setprecision(2) <<classProb;

// 显示文本
putText(input_image, labels.at(classidx), Point(20, 20), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(0, 0, 255), 2, 8);
imshow("result", input_image);
imwrite("result.png", input_image);
waitKey(0);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

输出结果如下图所示:
result

5.完整代码

#include<iostream>
#include<fstream>
#include<vector>
#include<opencv2/opencv.hpp>
#include<opencv2/dnn.hpp>

using namespace std;
using namespace cv;
using namespace cv::dnn;

String label_file = "./imagenet_comp_graph_label_strings.txt";
String tf_pbfile = "./tensorflow_inception_graph.pb";
vector<String> read_class_names(String model_label_file);
const int w = 224;
const int h = 224;

int main(int argc, char**argv)
{
Mat input_image = imread("bird.jpg");
if (input_image.empty())
{
cout << "read image failed!" << endl;
return -1;
}

namedWindow("input_image", WINDOW_AUTOSIZE);
imshow("input_image", input_image);
cvtColor(input_image, input_image, COLOR_BGR2RGB);
vector<String> labels = read_class_names(label_file);

Net cnn_net = readNetFromTensorflow(tf_pbfile);
if (cnn_net.empty())
{
cout << "load net failed!" << endl;
return -1;
}
vector<String> layer_names = cnn_net.getLayerNames();
for (int i = 0; i < layer_names.size(); i++)
{
int id = cnn_net.getLayerId(layer_names[i]);
auto layer = cnn_net.getLayer(id);
cout << "layerIndex:" << id << " " << "type:" << layer->type.c_str() << " " << "name:" << layer->name.c_str() << endl;
}

Mat input_blob = blobFromImage(input_image, 1.0f, Size(h, w), Scalar(), true, false);
input_blob -= 117.0;

Mat prob;
cnn_net.setInput(input_blob, "input");
prob = cnn_net.forward("softmax2");
Mat probMat = prob.reshape(1, 1);
Point classNumber;
double classProb;
minMaxLoc(probMat, NULL, &classProb, NULL, &classNumber);
int classidx = classNumber.x;
cout<<"classification:"<< labels.at(classidx).c_str()<<endl <<"score:"<<fixed<<setprecision(2) <<classProb;

putText(input_image, labels.at(classidx), Point(20, 20), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(0, 0, 255), 2, 8);
imshow("result", input_image);
waitKey(0);
return 0;
}

vector<String> read_class_names(String model_label_file)
{
vector<String> class_names;
ifstream fp(model_label_file);
if (!fp.is_open())
{
cout << "open label file failed!" << endl;
exit(-1);
}
string name;
while (!fp.eof())
{
getline(fp, name);
if (name.length())
class_names.push_back(name);
}
fp.close();
return class_names;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多