Building a REST API with Tensorflow Serving (part 1)

What is Tensorflow Serving?

One of the features that I personally think is undervalued from Tensorflow is the capability of serving Tensorflow models. At the moment of writing this post, the API that helps you do that is named Tensorflow Serving, and is part of the Tensorflow Extended ecosystem, or TFX for short.

During the first releases of Tensorflow Serving, I found the documentation somehow daunting. And there were many concepts that a data scientist is not very used to work with: Servables, sources, loaders, managers… All these elements are part of the Tensorflow Serving architecture. More details on it, here.

So, as a gentle introduction, I will show you how you can build a REST API with Tensorflow Serving. From saving a Tensorflow object (this is what we call a servable) until testing the API endpoint. The first part will explain how to create and save Tensorflow objects ready to be put into production.

The meaning of servables

Functions, embeddings or saved models are some of the objects that can be used as servables. But how do we define those servables in Tensorflow?

Well, this is up to you but they must be able to be saved in what’s is called the SavedModel format. This format keeps all the components of a Tensorflow object in the same state once we load this object in a new environment. What are these components? The relevant ones are: Weights, graph, additional assets, etc.

The module to be called to save Tensorflow objects is tf.saved_model. And as we’ll see shortly it’s simple to use. For now, let’s see how we can generate two types of servables:

  • Tensorflow functions
  • Keras models

TensorFlow function as servable

Tensorflow functions are saved as valid servables if are defined in this way:

class Adder(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(shape=[None,3], dtype=tf.float32, name="x")])
    def sum_two(self, x):
        return x + 2
  • Function definition inside a class
  • The parent class has to be tf.Module
  • The @tf.function decorator somehow translates the function definition into a Tensorflow graph
  • The input_signature argument defines the type and shape of tensors that are accepted to be passed in the function

In this case we specify tensors with two dimensions, the second one be fixed to have three elements while the type is going to be float32. Another example of a Tensorflow function is:

class Randomizer(tf.Module):
    def fun_runif(self, N):
        return tf.random.uniform(shape=(N,))

Notice that input_signature argument is not necessary but always is good to include some safety tests when functions like this go into production.

Later, we create an instance of the class and we save them:

# For the first function
myfun = Adder(), "tmp/sum_two/1")

# For the second function
myfun2 = Randomizer(), "tmp/fun_runif/1")

The first argument of points to the instance object of the class, whereas the second argument is the path of you local filesystem where the model is going to be saved.

Keras model as servable

You can follow a similar procedure for saving Keras models. This example focuses on a pretrained image classification model, loaded with TensorFlow Hub. In addition, we are going to build a custom class around it to preprocess input images.

class CustomMobileNet_string(tf.keras.Model):
    model_handler = ""
    def __init__(self):
        super(CustomMobileNet_string, self).__init__()
        self.model = hub.load(self.__class__.model_handler)
        self.labels = None
    # Design you API with 'tf.function' decorator
    @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.string)])
    def call(self, input_img):
        def _preprocess(img_file):
            img_bytes = tf.reshape(img_file, [])
            img =, channels=3)
            img = tf.image.convert_image_dtype(img, tf.float32)
            return tf.image.resize(img, (224, 224))

        labels =
        labels = tf.strings.split(labels, sep='\n')
        img = _preprocess(input_img)[tf.newaxis,:]
        logits = self.model(img)
        get_class = lambda x: labels[tf.argmax(x)]
        class_text = tf.map_fn(get_class, logits, tf.string)
        return class_text # index of the class

The class inherits from tf.keras.Model, and there are few things to discuss about it:

  1. The input to the model is a string of bytes, which come in a JSON file, with a specific set of key|value pairs. More on that in the second part of the tutorial.
  2. tf.reshape is the first function in the preprocessing stage because images are put into an array in the JSON file. Since we impose restrictions in the input with @tf.function it is required to do this kind of transformation.
  3. The attribute labels store the image labels of ImageNet (available here). We want the model to output not the label index of the output layer of the model but the label with text format. The reason why it is defined in this way is explained below

As per usual we save the model following the same procedure but with a little addition:

model_string = CustomMobileNet_string()
# Save the image labels as an asset, saved in 'Assets' folder
model_string.labels = tf.saved_model.Asset("data/labels/ImageNetLabels.txt"), "tmp/mobilenet_v2_test/1/")

In order to store extra components to the SavedModel object, we have to define an asset. We do so with tf.saved_model.Asset, and I call this function outside the class definition to make it more explicit. It will probably work in the same way if we do that in the class definition. Notice here, we have to save the asset as an attribute of the class instance before saving the model.

Further details

These are the folders that have been generated in the local filesystem when we saved the custom Keras model.

The files generated are:

  • the graph of the function or model, saved in a Protobuf file with extension .pb
  • the weights of the model or any TensorFlow Variable used in the servable, saved in the variables folder
  • extra components are saved in the assets folder but it is empty in our examples

There are some questions that may arise when you build your own functions or models:

  • What’s the reasoning behind the choice of parent classes? Attaching tf.Module class to a tf.function allows the latter to be saved with tf.saved_model. The same goes for the tf.keras.Model. You can find more info here.
  • Why you add /1 in the model’s path? Servables must have an ID indicating the version of the model we are running inside the container. It’s helpful to keep track of multiple versions of your model when you are monitoring their metrics. You can a more in-depth explanation in the following link.

This is all for now. Thanks for reading!

You can continue with part 2, where we dive into the process of putting into production these Tensorflow objects we have generated.

5 thoughts on “Building a REST API with Tensorflow Serving (part 1)

  1. Hi Guillermo,

    thanks a lot for this blog post (and part 2). You explained exactly what I was looking i.e. customizing what is returned by a model and then deploying it. I could not find this anywhere else. Just some stuff where a lot of tf1 functions were used.

    So your posts were super helpful to me. Thanks!


    Liked by 1 person

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s