Pruning in action (Source
Inference optimization techniques and solutions
Inference optimization techniques and solutions
Deep learning models and LLMs work well in development but often slow down on prod due to scale and unpredictability of input. This article explores more than ten different approaches you can use to optimize your model inference.
After training and validation, you deploy your model to production so your end users can benefit from it. However, moving from controlled experimental conditions to the real world brings its own challenges. You may find model performance dropping significantly as usage scales. The model takes longer to generate a prediction than initially anticipated. Model training data may become outdated, leading to decreased output accuracy. Running costs may spike unexpectedly, impacting the project budget.
That’s where inference optimization comes in. It includes different techniques that ML engineers use to improve the efficiency and speed of making predictions with a trained model. You can use inference optimization to reduce the computational cost and latency of the inference process and make your models faster and more scalable.
This article explores several different inference optimization strategies and ways to implement them.
What is inference optimization?
Inference optimization is the process of reducing the time gap between model processing during development and in production. Deep learning models (with neural network architecture) typically work faster in development environments because of controlled configurations and predictable input data. However, in production, other factors come into play. For example:
-
Application dependencies
-
Different hardware/operating system configurations
-
Unexpected data inputs from users
-
Large number of simultaneous data inputs
As model performance drops, the model’s running cost increases, and user experience also gets impacted. Inference optimization techniques attempt to restore model performance as close to the initial baseline as possible.
Performance metrics
Here are some metrics commonly used to measure model performance:
Metric | Explanation | Unit of measurement |
---|---|---|
Latency (Response Time) | The time the model takes to return a result after receiving input | Milliseconds (ms) |
Throughput | The number of inferences the system can handle in a given time frame | Inferences/second (inf/s) |
Memory usage | The memory the model consumes during an inference | MB/GB |
CPU/GPU utilization | The percentage of CPU/GPU resources used during inference | Percentage |
Cost per Inference | The financial cost associated with each inference operation | Dollars/inference ($/inf) |
Error rate | The percentage of inferences that return incorrect or no results | Percentage |
Using the metrics
The first step in inference optimization is to document the baseline performance metrics. They act as benchmarks for the optimization process and allow you to quantify your progress.
For example, let’s say your model’s response time is 200 ms in development. That is your baseline. In production, your initial model performance is 500 ms. After optimization, you improve it to 350 ms. You reduced your time by 150 ms, but the maximum possible improvement you can do is 500-200 = 300 ms. Hence, you can say you have achieved a 50% improvement.
Inference optimization techniques
The next question that arises is how to achieve inference optimization. We have summarized several techniques below, and the rest of the article explains them techniques in detail.
Technique | Description |
---|---|
Pruning | Removes unnecessary parts of a model to reduce its size and computational complexity. |
Quantization | Reduces the precision of model weights and activations to smaller data types. |
Knowledge distillation | Transfers knowledge from a large teacher model to a smaller student model. |
Weight sharing | Shares weights across multiple neurons or layers. |
Low-rank factorization | Decomposes large matrices into smaller ones. |
Early exit mechanisms | Allows models to produce predictions before processing all layers. |
Deployment strategy | Optimizes model deployment location and infrastructure. |
Caching | Stores intermediate computations or inference results for faster retrieval. |
Memorization | Stores results of expensive function calls for reuse. |
Model simplification techniques
These techniques reduce the computational complexity of your model, so you get lower memory usage, faster inference times, and reduced costs without significant accuracy drops.
Pruning
Pruning methods eliminate model parts that contribute little to the overall output. There are three approaches to pruning.
- Weight pruning sets certain weights to zero. Magnitude-based pruning ranks weights by size and removes the smallest based on a threshold.
- Neuron pruning eliminates entire neurons based on different criteria. For example, activation-based pruning removes neurons with low activation values. Importance-based pruning removes neurons that have a minimal impact on the loss function.
- Channel pruning reduces the number of channels in convolutional layers of a specific neural network called convolutional neural network (CNN). In a CNN, a channel refers to a feature map representing different input data aspects. For example, image data typically has three color channels (red, green, and blue). Your CNN may have channels representing learned features such as edges, textures, or patterns for image classification. You adaptively prune the least important channels during training.
Quantization
Models built in popular libraries like PyTorch have a high floating point precision. Quantization reduces the precision by a factor of 2. For example, instead of representing the model’s weights and activations using 32-bit floating-point numbers, you switch to lower-precision formats such as 16-bit or 8-bit integers.
There are two primary quantization methods:
-
Post-training quantization (PTQ) applies quantization after training.
-
Quantization Aware Training (QAT) adjusts the floating point during training to ensure the model adapts to lower precision.
PTQ is relatively straightforward to implement but may result in some accuracy loss, especially when applied to complex models. QAT manages accuracy retention from the start but requires more computational resources and time.
Quantization reduces memory use and computational demands. It is often used in IoT applications to deploy models on resource-constrained devices.
Quantization reduces floating point precision (Source
Knowledge distillation
Knowledge distillation involves training a smaller, simpler student model that mimics the predictions of a larger, more complex teacher model. The term knowledge refers to how the neural network arrives at its prediction. The goal is to transfer the knowledge from the large to the smaller model, which then performs more efficient inference.
There are different knowledge types, such as:
-
Response-based knowledge is the teacher model’s output layer (logits) and final predictions.
-
Feature-based knowledge is information from intermediate layers, such as activations or weights.
-
Relation-based knowledge captures relationships between neurons or activations, providing insights into how different parts of the model interact.
You can transfer this knowledge between the teacher and the student in different ways.
Offline knowledge distillation
This is predefined training. The teacher model generates probability distributions (soft targets) over classes for a given dataset. The student model is trained on the soft targets rather than the original data labels.
Online knowledge distillation
You train the teacher’s parameters alongside the students. The teacher can adapt and refine its predictions based on the student’s progress. It allows for more adaptive learning since the teacher model evolves alongside the student.
Self-distillation
You use the same network as both the teacher and the student. Initially, you train the network on the labeled dataset and then re-train it on its soft targets to improve performance.
Other techniques
While the above three are more common, here are some more techniques you can use for model simplification.
Weight sharing
Instead of assigning individual weights to each neuron, you force the model to use a shared set of weights across multiple neurons or layers. This minimizes the parameters the model needs to store and compute. It is often used in CNNs for image processing, where filters are shared across input image regions to reduce the total number of weights.
Low-rank factorization
Large-weight matrices are decomposed into smaller matrices with lower ranks, which approximate the original matrix’s values. This leads to faster matrix operations and reduced memory requirements.
Early exit mechanisms
Early exit mechanisms allow models to produce predictions before processing all layers. The model has multiple exit points, and if a high-confidence prediction is achieved at an intermediate layer, the model halts further computation and outputs the result early.
Inference optimization without model simplification
You don’t have to rearchitect the model every time. Consider the following.
Deployment strategy
You can improve inference time by deploying your models strategically. You have to choose the optimal platform and location for your model to run. You can make the best decision considering cost, reliability, scale, security, and other project requirements.
Deploying the model closer to its end users and data sources reduces data transfer latency. For example, if your model processes inputs from an IoT system, you can deploy it on the edge device instead of in the cloud. However, if your model processes inputs from users around the globe, you are better off running it in the cloud on servers geographically closer to your end users.
Infrastructure optimization
Even reducing a few milli-seconds in real-time inference can be a big win. Sometimes, it is easier to optimize your network architecture than model architecture. Consider removing any redundant network equipment that is adding to the data transfer time. For example, if you use two different load balancers or ingress controllers, it adds more network hops to your user and model interactions, increasing latency at scale.
Caching vs. memoization
You can implement both caching and memoization to optimize inference. Caching stores intermediate computations or entire inference results during input data processing. Subsequent requests with identical or similar inputs are served directly from the cache.
Memoization stores the results of expensive function calls likely to be encountered multiple times. Once memoized, the model avoids repeated recalculations and speeds up processing.
Parallelism and batching for inference optimization
You can also leverage the power of multiple servers, GPUs, and cores in your ML infrastructure for inference optimization.
Parallelism
Parallelism means running multiple model instances simultaneously. It typically requires GPU/TPU hardware (or cloud instances) optimized for parallel processing. Graphic processing units (GPUs) are designed to handle thousands of matrix operations concurrently. Tensor Processing Units (TPUs) are specifically designed to accelerate ML workloads. Accelerated hardware reduces inference time for real-time model performance.
Batching
Batching involves grouping multiple inputs (or outputs) to process them in a single pass through your model. There are different approaches to batching.
-
Mini-batch processing is used when the input data is static and can be pre-aggregated. You group several inputs in advance before feeding them into the model.
-
Dynamic batching collects and processes incoming requests in real time. It handles fluctuating workloads and maximizes resource utilization.
-
Batching with padding involves adding padding to inputs to ensure that all inputs in a batch have the same size. This is often used in sequence-based tasks where inputs might vary in length.
Model serving frameworks
ML engineers use model serving frameworks for efficiently deploying and managing machine learning models in production environments. These frameworks help with parallelism, batching, and other inference optimization tasks. You can use them to implement the model simplification techniques along with model versioning, scaling, and other MLOps features.
We describe some common frameworks below, but there are many solutions available. All major cloud providers also have their serving frameworks.
ONNX Runtime
ONNX Runtime
You can use ONNX with models trained in various ML frameworks (e.g., PyTorch, TensorFlow). The ONNX exporter libraries convert models to the ONNX format and run them on Runtime.
TensorFlow Serving
TensorFlow Serving
It offers several features like:
-
Dynamic model loading/unloading to reduce startup time and memory overhead
-
Asynchronous request handling
-
Custom logic and extensions, such as pre-processing and post-processing
TensorFlow Serving also supports batch processing of requests to improve throughput.
Kubeflow
Kubeflow
How Nebius can help
Nebius is an AI-centric cloud platform that offers all the essential services you need for inference optimization without model re-design. You can deploy your models built with any framework (PyTorch, ONNX, etc.) on the NVIDIA TritonTM Inference Server for optimum results. Our data centers offer 400 GB/s network for maximum data transfer speed from your models to applications.
Nebius provides:
-
Manage infrastructure as code to implement best deployment practices
-
Built-in hardware monitoring, network balancing, and Managed Kubernetes
-
Intuitive cloud console UI to visualize your data and manage more efficiently
An on-demand payment model allows you to select optimal hardware based on model requirements and current workload.
FAQ
What is an inference in deep learning?
What is an inference in deep learning?
Inference in deep learning is the process of using a trained model to make predictions or decisions on new, unseen data. It involves feeding the model input data and getting the corresponding output.
What is inference computing?
What is inference computing?
What is an inference algorithm?
What is an inference algorithm?
What is an inference cost?
What is an inference cost?
How to reduce inference cost?
How to reduce inference cost?
How to optimize model inference?
How to optimize model inference?
Is inference more expensive than training?
Is inference more expensive than training?