Save, Load and Inference From TensorFlow Frozen Graph

Introduction

TensorFlow model saving has become easier than it was in the early days. Now you can either use Keras to save h5 format model or use tf.train.Saver to save the check point files. Loading those saved models are also easy. You can find a lot of instructions on TensorFlow official tutorials. There is another model format called pb which is frequently seen in model zoos but hardly mentioned by TensorFlow official channels. pb stands for Protocol Buffers, it is a language-neutral, platform-neutral extensible mechanism for serializing structured data. It is widely used in model deployment, such as fast inference tool TensorRT. While pb format models seem to be important, there is lack of systematic tutorials on how to save, load and do inference on pb format models in TensorFlow.

In this blog post, I am going to introduce how to save, load, and run inference for frozen graph in TensorFlow 1.x. For doing the equivalent tasks in TensorFlow 2.x, please read the other blog post “Save, Load and Inference From TensorFlow 2.x Frozen Graph”.

Materials

This sample code was available on my GitHub. It was modified from my previous simple CNN model to classify CIFAR10 dataset.

Train Model

We have to train our model first. Train the model using the following command:

1
$ python main.py --train --test --epoch 30 --lr_decay 0.9 --dropout 0.5

The test accuracy after training is around 0.793900.

Save PB Model

The major component of pb file is graph structure and also the parameters of your model. While the parameters are optional for pb file, you need it for our task since we need to use parameters to do inference. Otherwise, people download your pb file and they will not be able to deploy it.

This is the key code to save pb file:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from tensorflow.python.tools import freeze_graph

def save(self, directory, filename):

if not os.path.exists(directory):
os.makedirs(directory)
filepath = os.path.join(directory, filename + '.ckpt')
self.saver.save(self.sess, filepath)
return filepath

def save_as_pb(self, directory, filename):

if not os.path.exists(directory):
os.makedirs(directory)

# Save check point for graph frozen later
ckpt_filepath = self.save(directory=directory, filename=filename)
pbtxt_filename = filename + '.pbtxt'
pbtxt_filepath = os.path.join(directory, pbtxt_filename)
pb_filepath = os.path.join(directory, filename + '.pb')
# This will only save the graph but the variables will not be saved.
# You have to freeze your model first.
tf.train.write_graph(graph_or_graph_def=self.sess.graph_def, logdir=directory, name=pbtxt_filename, as_text=True)

# Freeze graph
# Method 1
freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=ckpt_filepath, output_node_names='cnn/output', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='')

# Method 2
'''
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
output_node_names = ['cnn/output']

output_graph_def = graph_util.convert_variables_to_constants(self.sess, input_graph_def, output_node_names)
# For some models, we would like to remove training nodes
# output_graph_def = graph_util.remove_training_nodes(output_graph_def, protected_nodes=None)

with tf.gfile.GFile(pb_filepath, 'wb') as f:
f.write(output_graph_def.SerializeToString())
'''

return pb_filepath

You are required to save checkpoint of your model first, followed by saving the graph. Saving checkpoint is easy, you just have to use tf.train.Saver and everything should be straightforward. In my code, I wrapped saving checkpoint using tf.train.Saver in self.save method. Saving graph is to use tf.train.write_graph. There are two arguments which might be confusing to the new users, name and as_text. as_text is a boolean value indicating whether the saved graph is human-readable or not. By convention, if it is human-readable, the file extension we use will be .pbtxt, else the file extension will be .pb. But this pb file will not contain the parameters you trained in your model.

We then need to freeze and combine graph and parameters to pb file. There are two ways to freeze graph.

The first method is to use freeze_graph function. The argument description of freeze_graph could be found here. If input_graph is human-readable pbtxt file, input_binaryshould be False. If input_graph is binary pb file, input_binaryshould be True. You will also need to specify the name of your output node. It can be a string if you only have one output, or a list of strings if you have multiple outputs. restore_op_name and filename_tensor_name are being deprecated, using the values provided should be universal to all models. Leave the rest of the arguments the same as mine should be fine. The pb file will be saved to output_graph path you provided.

The second method is to serialization yourself. I believe the first method is just a higher-level wrapper for the second method. The pb files generated from the two methods both pass the accuracy tests that I am going to show below.

The model files generated in the model directory are the follows:

1
2
3
4
5
6
7
.
├── checkpoint
├── cifar10_cnn.ckpt.data-00000-of-00001
├── cifar10_cnn.ckpt.index
├── cifar10_cnn.ckpt.meta
├── cifar10_cnn.pb
└── cifar10_cnn.pbtxt

pb file is there!

Load PB Model

We wrote a object to load model from pb files.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class CNN(object):

def __init__(self, model_filepath):

# The file path of model
self.model_filepath = model_filepath
# Initialize the model
self.load_graph(model_filepath = self.model_filepath)

def load_graph(self, model_filepath):
'''
Lode trained model.
'''
print('Loading model...')
self.graph = tf.Graph()
self.sess = tf.InteractiveSession(graph = self.graph)

with tf.gfile.GFile(model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

print('Check out the input placeholders:')
nodes = [n.name + ' => ' + n.op for n in graph_def.node if n.op in ('Placeholder')]
for node in nodes:
print(node)

# Define input tensor
self.input = tf.placeholder(np.float32, shape = [None, 32, 32, 3], name='input')
self.dropout_rate = tf.placeholder(tf.float32, shape = [], name = 'dropout_rate')

tf.import_graph_def(graph_def, {'input': self.input, 'dropout_rate': self.dropout_rate})

print('Model loading complete!')

'''
# Get layer names
layers = [op.name for op in self.graph.get_operations()]
for layer in layers:
print(layer)
'''

'''
# Check out the weights of the nodes
weight_nodes = [n for n in graph_def.node if n.op == 'Const']
for n in weight_nodes:
print("Name of the node - %s" % n.name)
print("Value - " )
print(tensor_util.MakeNdarray(n.attr['value'].tensor))
'''

def test(self, data):

# Know your output node name
output_tensor = self.graph.get_tensor_by_name("import/cnn/output:0")
output = self.sess.run(output_tensor, feed_dict = {self.input: data, self.dropout_rate: 0})

return output

Working with the models loaded from pb files is a little bit painful since you will have to work with tensor names all the time. If you are not sure about the tensor names you are working with, try to print out the names from graph_def.node. In our case, because we are going to do inference, we need to bind the inputs of the graph to some placeholder so that we can feed values into the model. Getting the values of parameters is also available via graph_def.node. Here I attached two placeholder to the graph using tf.import_graph_def(graph_def, {'input': self.input, 'dropout_rate': self.dropout_rate}). It should be noted that 'input' and 'dropout_rate' are the name of inputs in the graph I defined in the original graph.

We also set up the test method. Simply find out the tensor you are interested in, in our case it is the output tensor, and feed the input values using sess.run.

Inference from PB Model

To verify that our loaded graph is correct and working, we need to do some inference to test.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def test_from_frozen_graph(model_filepath):

tf.reset_default_graph()

# Load CIFAR10 dataset
cifar10 = CIFAR10()
x_test = cifar10.x_test
y_test = cifar10.y_test
y_test_onehot = cifar10.y_test_onehot
num_classes = cifar10.num_classes
input_size = cifar10.input_size

# Test 500 samples
x_test = x_test[0:500]
y_test = y_test[0:500]

model = CNN(model_filepath = model_filepath)

test_prediction_onehot = model.test(data = x_test)
test_prediction = np.argmax(test_prediction_onehot, axis = 1).reshape((-1,1))
test_accuracy = model_accuracy(label = y_test, prediction = test_prediction)

print('Test Accuracy: %f' % test_accuracy)

Run the following command to test:

1
$ python test_pb.py

Here I tested 500 samples from the test set. If you want to test all the examples, you can write a for loop to do so. The test accuracy is 0.788000. Comparing to the test accuracy 0.793900 we got right after training, it suggests that the pb file we saved is valid.

Updates

2019/9/16

Thanks to the question raised by Yuqiong Li. I removed the usage of tf.InteractiveSession and replaced it with tf.Session. The new object to load pb file is as follows.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class CNN(object):

def __init__(self, model_filepath):

# The file path of model
self.model_filepath = model_filepath
# Initialize the model
self.load_graph(model_filepath = self.model_filepath)

def load_graph(self, model_filepath):
'''
Lode trained model.
'''
print('Loading model...')
self.graph = tf.Graph()

with tf.gfile.GFile(model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

print('Check out the input placeholders:')
nodes = [n.name + ' => ' + n.op for n in graph_def.node if n.op in ('Placeholder')]
for node in nodes:
print(node)

with self.graph.as_default():
# Define input tensor
self.input = tf.placeholder(np.float32, shape = [None, 32, 32, 3], name='input')
self.dropout_rate = tf.placeholder(tf.float32, shape = [], name = 'dropout_rate')
tf.import_graph_def(graph_def, {'input': self.input, 'dropout_rate': self.dropout_rate})

self.graph.finalize()

print('Model loading complete!')

# Get layer names
layers = [op.name for op in self.graph.get_operations()]
for layer in layers:
print(layer)

"""
# Check out the weights of the nodes
weight_nodes = [n for n in graph_def.node if n.op == 'Const']
for n in weight_nodes:
print("Name of the node - %s" % n.name)
# print("Value - " )
# print(tensor_util.MakeNdarray(n.attr['value'].tensor))
"""

# In this version, tf.InteractiveSession and tf.Session could be used interchangeably.
# self.sess = tf.InteractiveSession(graph = self.graph)
self.sess = tf.Session(graph = self.graph)

def test(self, data):

# Know your output node name
output_tensor = self.graph.get_tensor_by_name("import/cnn/output:0")
output = self.sess.run(output_tensor, feed_dict = {self.input: data, self.dropout_rate: 0})

return output

The previous one is nothing wrong, but I placed the tf.InteractiveSession before the graphdef was loaded to the default graph, taking advantage of the side effect that tf.InteractiveSession will set its corresponding graph as the default graph globally. Therefore, simply replacing tf.InteractiveSession to tf.Session would not work in the previous implementation. This might cause some confusion from the readers who really wanted to understand what is happening underneath. In this new implementation, I specifically created the default graph using Python resource manager and loaded the graphdef to the default graph. No side effect was used and therefore it should be much easier to understand.

2020/1/9

This blog and example were designed for TensorFlow 1.x. TensorFlow 2.x also supports the frozen graph. Please check the blog post “Save, Load and Inference From TensorFlow 2.x Frozen Graph”.

Final Remarks

Now you should be good to go with pb file in our deployment!

One additional caveat is that TensorFlow is starting to deprecating or changing a lot of APIs, including part of freeze_graph. We have to be kept updated on those functions.

Save, Load and Inference From TensorFlow Frozen Graph

https://leimao.github.io/blog/Save-Load-Inference-From-TF-Frozen-Graph/

Author

Lei Mao

Posted on

03-07-2019

Updated on

08-03-2020

Licensed under


Comments