Run PyTorch models on the Jetson Nano with TensorRT

by Gilbert Tanner on Jul 04, 2020 · 1 min read

Run PyTorch models on the Jetson Nano with TensorRT

PyTorch models can be converted to TensorRT using the torch2trt converter.

torch2trt is a PyTorch to TensorRT converter which utilizes the TensorRT Python API. The converter is

  • Easy to use - Convert modules with a single function call torch2trt
  • Easy to extend - Write your own layer converter in Python and register it with @tensorrt_converter

Installation

torch2trt can be installed by cloning the Github reposioty and executing the setup.py file.

git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
sudo python setup.py install

Conversion Example

import torch
from torch2trt import torch2trt
from torchvision.models.alexnet import alexnet

# create some regular pytorch model...
model = alexnet(pretrained=True).eval().cuda()

# create example data
x = torch.ones((1, 3, 224, 224)).cuda()

# convert to TensorRT feeding sample data as input
model_trt = torch2trt(model, [x])

Here we are first loading a pre-trained alexnet model using torchvision. Next, we need to have some sample input that will be used to infer the shape and data types of our TensoRT engine. Then we can call the torch2trt method to create the optimized TensorRT engine.

Make Predictions

We can execute the returned TRTModule just like the original PyTorch model.

y = model(x)
y_trt = model_trt(x)

# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))

Save and load model

We can save the model as a state_dict.

torch.save(model_trt.state_dict(), 'alexnet_trt.pth')

We can load the saved model into a TRTModule using the load_state_dict method.

from torch2trt import TRTModule

model_trt = TRTModule()

model_trt.load_state_dict(torch.load('alexnet_trt.pth'))