The Backstory
I recently bought an RTX 4090 and was eager to see how difficult it would be to run a large language model (LLM) inference engine on it. Specifically, I wanted to explore the steps necessary to deploy Llama-3 (8B version) using TensorRT-LLM on my new GPU. In this blog post, I’ll share my journey, the challenges I faced, and provide a guide to help others achieve the same. I am going to store the code used throughout this post in here.
Getting the RTX 4090 was not as easy as someone might think, especially if you are planning to get the Founder Edition. I tried for a few months but never got lucky enough to find it in stock. I ended up getting the ASUS TUF Gaming GeForce RTX 4090, and everything has been going well so far.
Getting Started with TensorRT-LLM
Once I assembled the card in my machine and installed all the drivers, the next step was to figure out how to build a Docker image for running TensorRT-LLM. The first reaction when going through the documentation is to think that Triton is needed together with TensorRT-LLM to serve models. You could use it if you need to, but it is not necessary for testing an optimized TensorRT engine.
One advantage of using Triton, however, is that it offers features necessary to run these engines efficiently at scale. These include inflight batching for higher throughput and paged attention for memory efficiency. For those curious, inflight batching is intended to produce the same effect as continuous batching, which is the term used by other LLM serving engines such as vLLM and TGI. The following article has great diagrams to describe batching approaches for serving LLMs.
Interestingly, you only need roughly five lines in a Docker file to install TensorRT-LLM environment:
# base image
FROM nvidia/cuda:12.1.1-devel-ubuntu22.04
# install trt-llm dependencies
RUN apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev git git-lfs wget
# download converter
ARG GIT_SHA=71d8d4d3dc655671f32535d6d2b60cab87f36e87
RUN wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/${GIT_SHA}/examples/llama/convert_checkpoint.py -O "/convert_checkpoint.py"
# install trt-llm
RUN pip3 install tensorrt_llm==0.10.0.dev2024042300 -U --pre --extra-index-url https://pypi.nvidia.com
And the container can be built with:
docker build -t trt-llm .
Downloading Llama-3
After building the Docker image, we need to download the weights of the model we will be working with, which in this case is Llama-3-8B. Therefore, make sure you have accepted the licensing terms and generated the appropriate HuggingFace tokens to download the model.
For example, I used their cli for this task:
huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --local-dir $HOME/models/llama-3-8b-instruct --local-dir-use-symlinks False
From HuggingFace Models to TensorRT-LLM Engines
There are two major steps required to make this happen: converting the HuggingFace model into a TensorRT-compatible model and building the engine from there.
The conversion of the HuggingFace model can be done with a single command using a script. One thing to keep in mind is that for every model architecture, there is a conversion script, and these can be found in the examples directory. The most basic options used when converting a model include tensor parallelism (tp_size), which is the number of GPUs intended for inference, and the datatype.
For example, for Llama-3-8B we can do:
HF_MODEL_DIR=$HOME/models/llama-3-8b-instruct
TRT_MODEL_DIR=$HOME/models/llama-3-8b-instruct-trt
USER_ID=$(id -u)
GROUP_ID=$(id -g)
mkdir -p $TRT_MODEL_DIR
docker run \
--rm \
--gpus all \
--user $USER_ID:$GROUP_ID \
-v $HF_MODEL_DIR:/input_model \
-v $TRT_MODEL_DIR:/output_model \
trt-llm \
/bin/bash -c "python3 convert_checkpoint.py \
--model_dir=/input_model \
--output_dir=/output_model \
--tp_size=1 \
--dtype=float16"
This should produce two files: a model configuration (config.json
) and weights (rank0.safetensors
), both located in the $HOME/models/llama-3-8b-instruct-trt
directory.
Next, we can build the TensorRT-LLM engine with another command. Several optimizations are available in this step, including batch size, maximum input and output token length, among others. In addition to these general optimizations, there are plugins that provide mechanisms for advanced optimizations, such as gemm kernels or fusing operations for flash attention.
The simplest engine configuration I used for this example was:
TRT_MODEL_DIR=$HOME/models/llama-3-8b-instruct-trt
TRT_ENGINE_DIR=$HOME/engines/llama-3-8b-instruct
mkdir -p $TRT_ENGINE_DIR
USER_ID=$(id -u)
GROUP_ID=$(id -g)
# --user $USER_ID:$GROUP_ID does not work here for some reason
# therefore, I am adding chown -R $USER_ID:$GROUP_ID /trt-engine at the end
docker run \
--rm \
--gpus all \
-v $TRT_MODEL_DIR:/trt-model \
-v $TRT_ENGINE_DIR:/trt-engine \
trt-llm \
/bin/bash -c "trtllm-build --checkpoint_dir=/trt-model \
--output_dir=/trt-engine \
--tp_size=1 \
--workers=1 \
--max_batch_size=4 \
--max_input_len=8192 \
--max_output_len=8192 \
--gemm_plugin=float16 \
--gpt_attention_plugin=float16 && chown -R $USER_ID:$GROUP_ID /trt-engine"
If everything goes well, we should have two new files located at $HOME/engines/llama-3-8b-instruct
: one configuration file (config.json
) and one engine file (rank0.engine
), assuming we are running on a single GPU. Now that we have built the engine, we can test it in a single Python file. In addition to the engine, we need the tokenizer to convert text into tokens.
Here is the test file:
import tensorrt_llm
from tensorrt_llm.runtime import ModelRunner
from transformers import AutoTokenizer
def generate(user_prompt: str, max_new_tokens: int = 128):
# Get tokenizer from a folder
tokenizer = AutoTokenizer.from_pretrained("/tokenizer")
# Llama models do not have a padding token, so we use the EOS token
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
# and then we add it from the left, to minimize impact on the output
tokenizer.padding_side = "left"
# engine options
engine_opt = {
"engine_dir": "/engine",
"rank": tensorrt_llm.mpi_rank(),
}
# engine
engine = ModelRunner.from_dir(**engine_opt)
# chat template
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": user_prompt,
},
]
# tokenize (encode)
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
)
# inference options
inference_opt = {
"temperature": 0.1,
"top_k": 1,
"repetition_penalty": 1.1,
"max_new_tokens": max_new_tokens,
"end_id": tokenizer.eos_token_id,
"pad_id": tokenizer.eos_token_id,
"streaming": True,
}
# run inference
outputs = engine.generate(inputs, **inference_opt)
# number of input tokens
start = inputs.size(-1)
# for streaming we decode one token at the time and return a generator
for i, out in enumerate(outputs):
token = out[0][0][start + i].item()
# found last token
if token == tokenizer.eos_token_id:
break
yield tokenizer.decode([token])
if __name__ == "__main__":
question = "what is life is like a box of chocolates?"
print(f"\n\n{question}\n")
gen = generate(question, max_new_tokens=128)
for x in gen:
print(x, end="")
print()
We can keep using our docker image to run this test:
TOKENIZER_DIR=$HOME/models/llama-3-8b-instruct
ENGINE_DIR=$HOME/engines/llama-3-8b-instruct
docker run \
--rm \
--gpus all \
-v $TOKENIZER_DIR:/tokenizer \
-v $ENGINE_DIR:/engine \
-v $PWD:/app \
trt-llm \
/bin/bash -c "python3 /app/test_engine.py"
This is the final output:
==========
== CUDA ==
==========
CUDA Version 12.1.1
Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license
A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.
[TensorRT-LLM] TensorRT-LLM version: 0.10.0.dev2024042300
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
/usr/local/lib/python3.10/dist-packages/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:177.)
return _nested.nested_tensor(
what is life is like a box of chocolates?
A classic quote! "Life is like a box of chocolates, you never know what you're gonna get" is a famous line from the 1994 movie Forrest Gump, spoken by Tom Hanks as the titular character.
The phrase has since become a popular idiom that suggests that life is unpredictable and full of surprises. Just like how you can't know which piece of candy you'll get when you open a box of chocolates, you can't always anticipate what will happen in life.
It's a reminder to be flexible, adaptable, and open-minded, as things don't always go according to plan.
Final Thoughts
Deploying Llama-3 with TensorRT-LLM was an interesting experience compared to how things were a couple of years ago when you wanted to optimize and serve machine learning models with TensorRT. Think about it: we didn’t have to do any PyTorch ahead-of-time compilation or worry about exporting a model to ONNX. Better yet, we didn’t have to deal with complex Python dependencies to make everything work. That’s how it used to be. There’s still a long way to go, but this is definitely heading in the right direction.