首页 > Tensorflow C++ API调用Keras模型实现RGB图像语义分割

Tensorflow C++ API调用Keras模型实现RGB图像语义分割

我的实验是基于PSPNet模型实现二维图像的语义分割,下面的代码直接从得到的h5文件开始往下做。。。

也不知道是自己的检索能力出现了问题还是咋回事,搜遍全网都没有可以直接拿来用的语义分割代码,东拼西凑,算是搞成功了。

实验平台:Windows、VS2015、Tensorflow1.8 api、Python3.6

具体的流程为:keras训练模型 --> model.h5 --> 转换成.pb文件 --> tensorflow 载入.pb 验证正确性 --> tensorflow C++ api调用 .pb文件

1. 将训练好的h5模型转换为pb文件

# convert .h5 to .pb
import tensorflow as tf
from tensorflow.python.framework import graph_io
from keras import backend as K
from nets.pspnet import pspnet
from keras.models import load_model
from keras.models import Model, load_model
from keras.models import model_from_jsondef freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):from tensorflow.python.framework.graph_util import convert_variables_to_constantsgraph = session.graphwith graph.as_default():freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))output_names = output_names or []output_names += [v.op.name for v in tf.global_variables()]input_graph_def = graph.as_graph_def()if clear_devices:for node in input_graph_def.node:node.device = ""frozen_graph = convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names)return frozen_graphK.set_learning_phase(0)keras_model = load_model('./model.h5')
/*如果h5不包含模型结构,需要先load json然后再load weights
json_file = '/model.json'
with open(json_file, 'r') as f:json_str = f.read()
model = model_from_json(json_str)
keras_model.load_weights('./model.h5')
*/# .inputs和.outputs非常重要,需要记录下来
print('Inputs are:', keras_model.inputs)
print('Outputs are:', keras_model.outputs)
// 保存pb文件
frozen_graph = freeze_session(K.get_session(), output_names=[keras_model.output.op.name])
graph_io.write_graph(frozen_graph, "./", "model.pb", as_text=False)

2. 在Python中验证pb文件是否可用

import numpy as np
import tensorflow as tf
import cv2
from PIL import Image
//INPUT_TENSOR_NAME和OUTPUT_TENSOR_NAME 是基于1中的结果
INPUT_TENSOR_NAME = 'input_1:0'
OUTPUT_TENSOR_NAME = 'main/truediv:0'
INPUT_SIZE = 473
colors = [(73, 73, 73), (0, 255, 255), (255, 255, 0)]
num_classes = 3with tf.gfile.FastGFile('./model.pb', "rb") as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)def run(image):height, width = image.shape[0:2]# 归一化很重要,卡了我1天多image = (image - image.min())/(image.max() - image.min())resized_image = cv2.resize(image, (INPUT_SIZE, INPUT_SIZE))input_x = sess.graph.get_tensor_by_name(INPUT_TENSOR_NAME)out_softmax = sess.graph.get_tensor_by_name(OUTPUT_TENSOR_NAME)batch_seg_map = sess.run(out_softmax,feed_dict={ input_x: [np.asarray(resized_image)]})# batch_seg_map是[1,473,473,3],batch_seg_map[0]尺寸是[473,473,3]# seg_map 便是预测结果seg_map = batch_seg_map[0]seg_map = seg_map.argmax(axis=-1).reshape([INPUT_SIZE, INPUT_SIZE])seg_img = np.zeros((np.shape(seg_map)[0], np.shape(seg_map)[1], 3))# 根据seg_map 中每个像素的值来赋予mask颜色for c in range(num_classes):seg_img[:, :, 0] += ((seg_map[:, :] == c) * (colors[c][0])).astype('uint8')seg_img[:, :, 1] += ((seg_map[:, :] ==c) * (colors[c][1])).astype('uint8')seg_img[:, :, 2] += ((seg_map[:, :] == c) * (colors[c][2])).astype('uint8')image = cv2.resize(seg_img, (int(width), int(height)))return imageinput_image = cv2.imread('./img/image.jpg')
seg_map = run(input_image)
cv2.imwrite("./out.jpg", seg_map)

3. 在C++程序中调用pb文件

坑太多,小小的错误可能就得不到预期结果。

#include 
#include 
#include 
#include 
#include 
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/graph/default_device.h"using namespace std;
using namespace cv;#define INPUT_W 474
#define INPUT_H 474
#define num_classes 4// Define a function to convert Opencv's Mat data to tensorflow tensor. In python, just perform np.reshape on the matrix read in cv2.imread(), and the data type becomes a tensor, that is, tensor is the same as a matrix. Then you can even enter the entrance of the network
// In the C++ version, the input of the network also needs to be a tensor data type, so the input image needs to be converted into a tensor. If you use Opencv to read the image, the format is a Mat, you need to consider how to convert a Mat to tensor
void CVMat_to_Tensor(cv::Mat& img, tensorflow::Tensor& output_tensor, int input_rows, int input_cols)
{ Mat resize_img;resize(img, resize_img, cv::Size(input_cols, input_rows));Mat dst = resize_img.reshape(1, 1);// 第二个坑 rgb图像的归一化for (int i = 0; i < dst.cols; i++) { dst.at<float>(0, i) = dst.at<float>(0, i) / 255.0;}resize_img = dst.reshape(3, INPUT_H);float * p = (&output_tensor)->flat<float>().data();cv::Mat tempMat(input_rows, input_cols, CV_32FC3, p);resize_img.convertTo(tempMat, CV_32FC3);}void tensor2Mat(tensorflow::Tensor &t, cv::Mat &image) { float *p = t.flat<float>().data();image = Mat(INPUT_H, INPUT_W, CV_32FC3, p); //根据分类个数来,现在是3分类,如果是4分类就写成CV_32FC4}int main(int argc, char ** argv)
{ /* --------------------Configuration key information------------------------- -----------*/std::string model_path = "./psp_1w_resnet50_wudongjie.pb"; // pb model addressstd::string image_path = "./model/image.jpg"; // test pictureint input_height = INPUT_H; // Enter the image height of the networkint input_width = INPUT_W; // input network image widthstd::string input_tensor_name = "input_1:0"; // The name of the input node of the networkstd::string output_tensor_name = "main/truediv:0"; // The name of the output node of the network/* --------------------Create session------------------------------------*/tensorflow::Session * session;tensorflow::Status status = tensorflow::NewSession(tensorflow::SessionOptions(), &session); // Create a new session Session/* --------------------Read model from pb file------------------------------------*/tensorflow::GraphDef graphdef; //Define a graph for the current modeltensorflow::Status status_load = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), model_path, &graphdef); // read graph model from pb fileif (!status_load.ok()) // Determine whether the read model is correct, if it is wrong, print out the wrong information{ std::cout << "ERROR: Loading model failed..." << model_path << std::endl;std::cout << status_load.ToString() << "
";return -1;}tensorflow::Status status_create = session->Create(graphdef); // Import the model into the session Sessionif (!status_create.ok()) // Determine whether the model is imported into the session successfully, if it is wrong, print out the error message{ std::cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;return -1;}std::cout << "<------Sucessfully created session and load graph------>" << std::endl;/* --------------------Load test picture------------------------ ------------*/cv::Mat img = cv::imread(image_path, -1); // read image, read grayscale imageimg.convertTo(img, CV_32FC3);//第一个小坑,整个程序都读取Floatif (img.empty()){ std::cout << "can't open the image!!!!!" << std::endl;return -1;}// Create a tensor as the input network interfacetensorflow::Tensor resized_tensor(tensorflow::DT_FLOAT, 		    tensorflow::TensorShape({ 1, input_height,input_width, 3}));// Save the Mat format picture read by opencv into tensorCVMat_to_Tensor(img, resized_tensor, input_height, input_width);std::cout << resized_tensor.DebugString() << std::endl;/* --------------------Test with network------------------------ ------------*/std::cout << std::endl << "<------------------Runing the model with test_image------------------->" << std::endl;// Run forward, the output result must be a vector of tensorstd::vector<tensorflow::Tensor> outputs;std::string output_node = output_tensor_name; // output node nametensorflow::Status status_run = session->Run({  {  input_tensor_name, resized_tensor } }
                

更多相关:

  • 上篇笔记中梳理了一把 resolver 和 balancer,这里顺着前面的流程走一遍入口的 ClientConn 对象。ClientConn// ClientConn represents a virtual connection to a conceptual endpoint, to // perform RPCs. // //...

  • Path Tracing 懒得翻译了,相信搞图形学的人都能看得懂,2333 Path Tracing is a rendering algorithm similar to ray tracing in which rays are cast from a virtual camera and traced through a s...

  • configure_file( [COPYONLY] [ESCAPE_QUOTES] [@ONLY][NEWLINE_STYLE [UNIX|DOS|WIN32|LF|CRLF] ]) 我遇到的是 configure_file(config/config.in ${CMAKE_SOURCE_DIR}/...

  •     直接复制以下代码创建一个名为settings.xml的文件,放到C:UsersAdministrator.m2下即可