Convert your Tensorflow Object Detection model to Tensorflow Lite.

TFLite Object Detection Prediction Example

Use your Tensorflow Object Detection model on edge devices by converting them to Tensorflow Lite.

TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It allows you to run machine learning models on edge devices with low latency, eliminating the need for a server.

This article provides a step-by-step guide on converting a Tensorflow Object Detection model to an optimized format that can be used with Tensorflow Lite and how to run it on an edge device like the Raspberry Pi.

All the code covered in the article can be found on my Github.

1. Train an object detection model using the Tensorflow Object Detection API

Figure 1: Tensorflow Object Detection Example

For this guide, you can use a pre-trained model from the Tensorflow Model zoo or train a custom model as described in one of my other Github repositories.

Note: At this time only SSD models are supported.

2. Convert the model to Tensorflow Lite

After you have a Tensorflow Object Detection model, you can start to convert it to Tensorflow Lite.

This is a three-step process:

  1. Export frozen inference graph for TFLite
  2. Build Tensorflow from source (needed for the third step)
  3. Using TOCO to create an optimized TensorFlow Lite Model

2.1 Export frozen inference graph for TFLite

After training the model, you need to export the model so that the graph architecture and network operations are compatible with Tensorflow Lite. This can be done with the export_tflite_sdd_graph.py file inside the object_detection directory.

To make the following commands easier to run, let's set up some environment variables:

export CONFIG_FILE=PATH_TO_BE_CONFIGURED/pipeline.config
export CHECKPOINT_PATH=PATH_TO_BE_CONFIGURED/model.ckpt-XXXX
export OUTPUT_DIR=/tmp/tflite

on Windows, use set instead of export:

set CONFIG_FILE=PATH_TO_BE_CONFIGURED/pipeline.config
set CHECKPOINT_PATH=PATH_TO_BE_CONFIGURED/model.ckpt-XXXX
set OUTPUT_DIR=C:<path>/tflite

XXXX represents the highest number.

python export_tflite_ssd_graph.py \
    --pipeline_config_path=$CONFIG_FILE \
    --trained_checkpoint_prefix=$CHECKPOINT_PATH \
    --output_directory=$OUTPUT_DIR \
    --add_postprocessing_op=true

After executing the above command, you should see two files in the OUTPUT_DIR: tflite_graph.pb and tflite_graph.pbtxt.

2.2 Convert to TFLite

To convert the frozen graph to Tensorflow Lite, we need to run it through the Tensorflow Lite Converter. It converts the model into an optimized FlatBuffer format that runs efficiently on Tensorflow Lite.

If you want to convert a quantized model, you can run the following command:

tflite_convert \
    --input_file=$OUTPUT_DIR/tflite_graph.pb \
    --output_file=$OUTPUT_DIR/detect.tflite \
    --input_shapes=1,300,300,3 \
    --input_arrays=normalized_input_image_tensor \
    --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
    --inference_type=QUANTIZED_UINT8 \
    --mean_values=128 \
    --std_values=128 \
    --change_concat_input_ranges=false \
    --allow_custom_ops

If you are using a floating-point model, you need to change the command:

tflite_convert \
    --input_file=$OUTPUT_DIR/tflite_graph.pb \
    --output_file=$OUTPUT_DIR/detect.tflite \
    --input_shapes=1,300,300,3 \
    --input_arrays=normalized_input_image_tensor \
    --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
    --inference_type=FLOAT  \
    --allow_custom_ops

If things ran successfully, you should now see a third file in the /tmp/tflite directory called detect.tflite.

2.3 Create a new labelmap for Tensorflow Lite

Next, you need to create a label map for Tensorflow Lite since the format is different from classical TensorFlow.

Tensorflow labelmap:

item {
    name: "a"
    id: 1
    display_name: "a"
}
item {
    name: "b"
    id: 2
    display_name: "b"
}
item {
    name: "c"
    id: 3
    display_name: "c"
}

The Tensorflow Lite labelmap format only has the display_names (if there is no display_name, the name is used).

a
b
c

So basically, the only thing you need to do is create a new labelmap file and copy the display_names (names) from the other labelmap file into it.

2.4 Optional: Convert the Tensorflow Lite model so it can be used with the Google Coral EdgeTPU

If you want to use the model with a Google Coral EdgeTPU, you need to run it through the EdgeTPU Compiler.

The compiler can be installed on Linux systems (Debian 6.0 or higher) with the following commands:

curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -

echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list

sudo apt-get update

sudo apt-get install edgetpu-compiler

After installing the compiler, you can convert the model with the following command:

edgetpu_compiler --out_dir <output_directory> <path_to_tflite_file>

Before using the compiler, be sure you have a model that's compatible with the Edge TPU. For compatibility details, read TensorFlow models on the Edge TPU.

3. Run the Tensorflow Lite model

Now, you should have your model ready to go. I created two scripts that allow you to run the object detection model on a webcam or a video. If you have any questions about the code, be sure to contact me.

Run object detection on webcam:

# based on https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/raspberry_pi/detect_picamera.py
from imutils.video import VideoStream, FPS
from tflite_runtime.interpreter import Interpreter, load_delegate
import argparse
import time
import cv2
import re
from PIL import Image, ImageDraw, ImageFont
import numpy as np


def draw_image(image, results, labels, size):
    result_size = len(results)
    for idx, obj in enumerate(results):
        print(obj)
        # Prepare image for drawing
        draw = ImageDraw.Draw(image)

        # Prepare boundary box
        ymin, xmin, ymax, xmax = obj['bounding_box']
        xmin = int(xmin * size[0])
        xmax = int(xmax * size[0])
        ymin = int(ymin * size[1])
        ymax = int(ymax * size[1])

        # Draw rectangle to desired thickness
        for x in range( 0, 4 ):
            draw.rectangle((ymin, xmin, ymax, xmax), outline=(255, 255, 0))

        # Annotate image with label and confidence score
        display_str = labels[obj['class_id']] + ": " + str(round(obj['score']*100, 2)) + "%"
        draw.text((box[0], box[1]), display_str, font=ImageFont.truetype("/usr/share/fonts/truetype/piboto/Piboto-Regular.ttf", 20))

        displayImage = np.asarray( image )
        cv2.imshow('Coral Live Object Detection', displayImage)


def load_labels(path):
    """Loads the labels file. Supports files with or without index numbers."""
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        labels = {}
        for row_number, content in enumerate(lines):
            pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
            if len(pair) == 2 and pair[0].strip().isdigit():
                labels[int(pair[0])] = pair[1].strip()
            else:
                labels[row_number] = pair[0].strip()
    return labels


def set_input_tensor(interpreter, image):
    """Sets the input tensor."""
    tensor_index = interpreter.get_input_details()[0]['index']
    input_tensor = interpreter.tensor(tensor_index)()[0]
    input_tensor[:, :] = image


def get_output_tensor(interpreter, index):
    """Returns the output tensor at the given index."""
    output_details = interpreter.get_output_details()[index]
    tensor = np.squeeze(interpreter.get_tensor(output_details['index']))
    return tensor


def detect_objects(interpreter, image, threshold):
    """Returns a list of detection results, each a dictionary of object info."""
    set_input_tensor(interpreter, image)
    interpreter.invoke()

    # Get all output details
    boxes = get_output_tensor(interpreter, 0)
    classes = get_output_tensor(interpreter, 1)
    scores = get_output_tensor(interpreter, 2)
    count = int(get_output_tensor(interpreter, 3))

    results = []
    for i in range(count):
        if scores[i] >= threshold:
            result = {
                'bounding_box': boxes[i],
                'class_id': classes[i],
                'score': scores[i]
            }
            results.append(result)
    return results


def make_interpreter(model_file, use_edgetpu):
    model_file, *device = model_file.split('@')
    if use_edgetpu:
        return Interpreter(
            model_path=model_file,
            experimental_delegates=[
                load_delegate('libedgetpu.so.1',
                {'device': device[0]} if device else {})
            ]
        )
    else:
        return Interpreter(model_path=model_file)


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-m', '--model', type=str, required=True, help='File path of .tflite file.')
    parser.add_argument('-l', '--labels', type=str, required=True, help='File path of labels file.')
    parser.add_argument('-t', '--threshold', type=float, default=0.4, required=False, help='Score threshold for detected objects.')
    parser.add_argument('-p', '--picamera', action='store_true', default=False, help='Use PiCamera for image capture')
    parser.add_argument('-e', '--use_edgetpu', action='store_true', default=False, help='Use EdgeTPU')
    args = parser.parse_args()

    labels = load_labels(args.labels)
    interpreter = make_interpreter(args.model, args.use_edgetpu)
    interpreter.allocate_tensors()
    _, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']

    # Initialize video stream
    vs = VideoStream(usePiCamera=args.picamera, resolution=(640, 480)).start()
    time.sleep(1)

    fps = FPS().start()

    while True:
        try:
            # Read frame from video
            screenshot = vs.read()
            image = Image.fromarray(screenshot)
            image_pred = image.resize((input_width ,input_height), Image.ANTIALIAS)

            # Perform inference
            results = detect_objects(interpreter, image_pred, args.threshold)
            
            draw_image(image, results, labels, image.size)

            if( cv2.waitKey( 5 ) & 0xFF == ord( 'q' ) ):
                fps.stop()
                break

            fps.update()
        except KeyboardInterrupt:
            fps.stop()
            break

    print("Elapsed time: " + str(fps.elapsed()))
    print("Approx FPS: :" + str(fps.fps()))

    cv2.destroyAllWindows()
    vs.stop()
    time.sleep(2)


if __name__ == '__main__':
    main()
Figure 2: Tensorflow Lite Object Detection

You can also find the code on my Github.

Conclusion

Using the Tensorflow Object Detection API, you can create object detection models that runs on many platforms, including desktops, mobile phones, and edge devices. For running models on edge devices and mobile phones, converting the model to Tensorflow Lite is recommended.

That's all from this article. If you have any questions or want to chat with me, feel free to contact me via EMAIL or social media.