1. 导入库文件
# import the necessary packages from keras.models import Sequential from keras.layers.convolutional import Conv2D from keras.layers.convolutional import MaxPooling2D from keras.layers.core import Activation from keras.layers.core import Flatten from keras.layers.core import Dense from keras import backend as K
# import the necessary packages from keras.preprocessing.image import ImageDataGenerator from keras.optimizers import Adam
from keras.preprocessing.image import img_to_array from keras.utils import to_categorical from imutils import paths import matplotlib.pyplot as plt
2. 定义常量
# initialize the number of epochs to train for, initial learning rate,
3. 读取数据
print("[INFO] loading images...") # grab the image paths and randomly shuffle them imagePaths = sorted(list(paths.list_images(path))) random.shuffle(imagePaths) # loop over the input images for imagePath in imagePaths: # load the image, pre-process it, and store it in the data list image = cv2.imread(imagePath) image = cv2.resize(image, (norm_size, norm_size)) image = img_to_array(image)
# extract the class label from the image path and update the label = int(imagePath.split(os.path.sep)[-2])
# scale the raw pixel intensities to the range [0, 1] data = np.array(data, dtype="float") / 255.0 labels = np.array(labels)
# convert the labels from integers to vectors labels = to_categorical(labels, num_classes=CLASS_NUM)
4. 训练并保存模型
def train(aug,trainX,trainY,testX,testY): print("[INFO] compiling model...") model = build(width=norm_size, height=norm_size, depth=3, classes=CLASS_NUM) opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS) model.compile(loss="categorical_crossentropy", optimizer=opt,
print("[INFO] training network...") H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS), validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS, epochs=EPOCHS, verbose=1)
print("[INFO] serializing network...")
5. 评估训练模型
# plot the training loss and accuracy plt.plot(np.arange(0, N), H.history["loss"], label="train_loss") plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss") plt.plot(np.arange(0, N), H.history["acc"], label="train_acc") plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc") plt.title("Training Loss and Accuracy on traffic-sign classifier") plt.ylabel("Loss/Accuracy") plt.legend(loc="lower left")
6. 训练模型主函数
#python train.py --dataset_train ../../traffic-sign/train --dataset_test ../../traffic-sign/test --model traffic_sign.model if __name__ == '__main__': train_file_path = './data/train' test_file_path = './data/test' trainX,trainY = load_data(train_file_path) testX,testY = load_data(test_file_path) # construct the image generator for data augmentation aug = ImageDataGenerator( train(aug,trainX,trainY,testX,testY)
7. 模型转换,由keras的h5转换为tensorflow的pb模型
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): Freezes the state of a session into a pruned computation graph. Creates a new computation graph where variable nodes are replaced by constants taking their current value in the session. The new graph will be pruned so subgraphs that are not necessary to compute the requested @param session The TensorFlow session to be frozen. @param keep_var_names A list of variable names that should not be frozen, or None to freeze all the variables in the graph. @param output_names Names of the relevant graph outputs. @param clear_devices Remove the device directives from the graph for better portability. @return The frozen graph definition. freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] output_names += [v.op.name for v in tf.global_variables()] input_graph_def = graph.as_graph_def() for node in input_graph_def.node: frozen_graph = tf.graph_util.convert_variables_to_constants( session, input_graph_def, output_names, freeze_var_names)
if __name__ == '__main__': input_file = 'traffic.h5' weight_file_path = os.path.join(input_path, input_file) output_graph_name = weight_file_path[:-3] + '.pb'
keras.backend.set_learning_phase(0) h5_model = keras.models.load_model(weight_file_path) frozen_graph = freeze_session(keras.backend.get_session(), output_names=[out.op.name for out in h5_model.outputs]) tf.train.write_graph(frozen_graph, input_path, output_graph_name, as_text=False)
# model = cv2.dnn.readNetFromTensorflow("traffic.pb")
8. 预测模型,通过keras的loadModel实现
Created on Mon Aug 24 10:08:04 2020
# import the necessary packages from keras.preprocessing.image import img_to_array from keras.models import load_model
# load the trained convolutional neural network print("[INFO] loading network...") model = load_model('traffic.h5')
image = cv2.imread('E:/python learn/TrafficClassify/data/test/00058/00413_00000.png')
# pre-process the image for classification image = cv2.resize(image, (norm_size, norm_size)) image = image.astype("float") / 255.0 image = img_to_array(image) image = np.expand_dims(image, axis=0)
# classify the input image result = model.predict(image)[0]
label = str(np.where(result==proba)[0]) label = "{}: {:.2f}%".format(label, proba * 100)
# draw the label on the image output = imutils.resize(orig, width=400) cv2.putText(output, label, (10, 25),cv2.FONT_HERSHEY_SIMPLEX, cv2.imshow("Output", output)
#python predict.py --model traffic_sign.model -i ../2.png -s if __name__ == '__main__':
9. 预测模型,通过OpenCV dnn 加载
Created on Sat Aug 29 15:57:15 2020
# load the trained convolutional neural network print("[INFO] loading network...")
net = cv2.dnn.readNetFromTensorflow('traffic.pb')
image = cv2.imread('E:/python learn/TrafficClassify/data/test/00058/00413_00000.png')
img_tensor = cv2.dnn.blobFromImage(image, 1 / 255.0, (norm_size, norm_size), swapRB=True, crop=False) #ln = net.getUnconnectedOutLayersNames() # label = str(np.where(result==proba)[2]) # label = "{}: {:.2f}%".format(label, proba * 100) ## # classify the input image
min_val,max_val,min_indx,max_indx=cv2.minMaxLoc(result) print(min_val,max_val,min_indx,max_indx) label = "{}: {:.2f}%".format(max_indx, max_val * 100) # draw the label on the image output = imutils.resize(orig, width=400) cv2.putText(output, label, (10, 25),cv2.FONT_HERSHEY_SIMPLEX, # # show the output image cv2.imshow("Output", output)
if __name__ == '__main__':
10. 通过OpenCV dnn加载深度学习模块,在传统的C++代码上实现了深度学习,能够根据输入的图像实现图像分类,完整代码 https://download.csdn.net/download/mr_liyonghong/12785394
|