Run PyTorch models on the Jetson Nano with TensorRT
Use TensorRT to run PyTorch models on the Jetson Nano.
PyTorch models can be converted to TensorRT using the torch2trt converter.
torch2trt is a PyTorch to TensorRT converter which utilizes the TensorRT Python API. The converter is
- Easy to use - Convert modules with a single function called
torch2trt
- Easy to extend - Write your own layer converter in Python and register it with @tensorrt_converter
Installation
torch2trt can be installed by cloning the Github repository and executing the setup.py file.
git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
sudo python setup.py install
Conversion Example
import torch
from torch2trt import torch2trt
from torchvision.models.alexnet import alexnet
# create some regular pytorch model...
model = alexnet(pretrained=True).eval().cuda()
# create example data
x = torch.ones((1, 3, 224, 224)).cuda()
# convert to TensorRT feeding sample data as input
model_trt = torch2trt(model, [x])
Here we are first loading a pre-trained alexnet model using torchvision. Next, we need to have some sample input that will be used to infer the shape and data types of our TensoRT engine. Then we can call the torch2trt method to create the optimized TensorRT engine.
Make Predictions
We can execute the returned TRTModule just like the original PyTorch model.
y = model(x)
y_trt = model_trt(x)
# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))
Save and load model
We can save the model as a state_dict
.
torch.save(model_trt.state_dict(), 'alexnet_trt.pth')
We can load the saved model into a TRTModule
using the load_state_dict
method.
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('alexnet_trt.pth'))