Using TensorRT-LLM and StreamingLLM for Efficient Inference on Mistral¶
Welcome!
In this notebook, we will walk through using the StreamingLLM framework to run inference on Mistral. TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. StreamingLLM is a novel framework developed at the MIT-Han-Lab and is supported in TensorRT-LLM. See the Github repo for more examples and documentation!
Introduction to StreamingLLM¶
Handling infinite-length text with LLMs presents challenges. Notably, storing all previous Key and Value (KV) states demands significant memory, and models might struggle to generate text beyond their training sequence length. StreamingLLM addresses this by retaining only the most recent tokens and attention sinks, discarding intermediate tokens. This enables the model to generate coherent text from recent tokens without a cache reset — a capability not seen in earlier methods.
StreamingLLM is optimized for streaming applications, such as multi-round dialogues. It's ideal for scenarios where a model needs to operate continually without requiring extensive memory or dependency on past data. An example is a daily assistant based on LLMs. StreamingLLM would let the model function continuously, basing its responses on recent conversations without needing to refresh its cache. Earlier methods would either need a cache reset when the conversation length exceeded the training length (losing recent context) or recompute KV states from recent text history, which can be time-consuming.
Credits¶
Professor Song Han is an NVIDIA Distinguished Engineer and an associate professor in the MIT EECS department. He has been credited for numerous advances in the field of deep learning and has founded multiple AI companies.
Deployment powered by Brev.dev 🤙
!nvidia-smi
Install TensorRT-LLM¶
!pip install -q ipywidgets
!pip install tensorrt_llm -U -q --extra-index-url https://pypi.nvidia.com
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/tensorrt_llm/models/llama/convert.py
!mv convert.py /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/llama/convert_checkpoint.py -P .
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/run.py -P .
!wget https://raw.githubusercontent.com/NVIDIA/TensorRT-LLM/main/examples/utils.py -P .
Convert Mistral to the TensorRT format¶
For StreamingLLM to be enabled, we pass two additional flags to the checkpoint conversion
dense_context_fmha
- uses dense context fmha in the context phaseenable_pos_shift
- lets us use positions in KV cache for RoPE
# Build the model model with StreamingLLM feature using a single GPU and FP16.
!python convert_checkpoint.py --model_dir mistralai/Mistral-7B-v0.1 \
--output_dir ./tllm_checkpoint_1gpu_streamingllm \
--dtype float16 \
--dense_context_fmha \
--enable_pos_shift
# Build the model model with StreamingLLM feature using a single GPU and FP16.
!python convert_checkpoint.py --model_dir mistralai/Mistral-7B-v0.1 \
--output_dir ./tllm_checkpoint_1gpu_nostream \
--dtype float16
Build the TensorRT engine for the model¶
# Streaming
!trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_streamingllm \
--output_dir ./mistralengine_streaming \
--gemm_plugin float16
Run inference with a large input sequence¶
We use an open source Shakesphere dataset to demonstrate. We use 125,000 characters as our input
import requests
import re
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
response = requests.get(url)
if response.status_code == 200:
story = response.text
story = re.sub('\s+', ' ', story).strip()
else:
story = None
print("Failed to retrieve the document.")
%%time
# Use the streaming engine with a sliding window/cache size 2048 and sink token length 4
!python3 ./run.py --max_output_len=150 \
--tokenizer_dir mistralai/Mistral-7B-v0.1 \
--engine_dir=./mistralengine_streaming \
--max_attention_window_size=4096 \
--sink_token_length=4 \
--input_text f"{story[983152:]}"