Jagged Flash Attention Optimization

Meta researchers have introduced Jagged Flash Attention, a novel technique that significantly enhances the performance and scalability of large-scale recommendation systems. By combining jagged tensors with flash attention, this innovation achieves up to 9× speedup and 22× memory reduction compared to dense attention, outperforming even dense flash attention with 3× speedup and 53% better memory efficiency.

A write-up on the RecSys '24: Proceedings of the 18th ACM Conference on Recommender Systems paper, “Enhancing Performance and Scalability of Large-Scale Recommendation Systems with Jagged Flash Attention”, by Meta Platforms, CA, USA.

The Problem: Why Traditional Methods Fall Short

Traditional recommendation systems face challenges with variable-length categorical features, such as user interaction history. Unlike fixed-size numerical features, these require special handling. The conventional approach of padding to standardize lengths introduces significant overhead, especially in GPU-intensive operations.

Consider this scenario: If you're tracking a user's last 100 interactions, but they only have 20, you'd need to pad the remaining 80 slots with zeros. This padding creates:

  • Unnecessary memory usage
  • Increased computational load
  • Higher communication overhead between system components

TorchRec: Scalable Recommender Systems 

TorchRec is a powerful PyTorch domain library designed to address the unique challenges of building and deploying large-scale recommendation systems. It offers several key features and optimizations:

Embedding Operations

  • Fused embedding tables and bulk lookups for improved performance
  • Efficient single kernel lookups across multiple embedding tables

Sparse Data Handling

  • Specialized containers and operations for sparse data
  • Optimized permutation and all-to-all communication

Advanced Sharding Capabilities

  • Supports various techniques: data parallel, table-wise, row-wise, column-wise
  • Hierarchical sharding for scaling to many GPUs
  • Automated sharding planner for optimal strategies

Performance Optimizations

  • Quantization support for embeddings (int8/int4)
  • High-performance GPU inference with TorchDeploy integration
  • Caching between GPU and system memory

Production Impact at Meta

  • Enables training of 3+ trillion parameter models
  • Up to 10x performance improvements
  • Facilitates transition to accelerator-based full-sync training

TorchRec excels at handling models combining deep neural networks with wide embedding tables, addressing PyTorch's previous limitations with sparse data and wide models. This enables researchers and engineers to build and efficiently deploy state-of-the-art personalization models in production environments.

The Game-Changer: Jagged Feature Interaction Kernels

Jagged Feature Interaction Kernels represent a significant advancement in handling variable-length categorical features in recommendation systems. This innovative approach efficiently extracts fine-grained insights from long categorical features by utilizing dynamically sized tensors. The kernel operates on jagged tensors , which store variable-length features from multiple samples contiguously in memory without padding.

Image Source: Research paper

The key components of Jagged Feature Interaction Kernels include:

  • Values tensor: A contiguous array storing all feature values collectively
  • Offset tensor: Determines sample boundaries for each feature segment
  • Triton kernels: Custom-built for both forward and backward computations, optimizing data locality and parallelism

These kernels enable efficient operations such as jagged tensor multiplication, softmax computations, and element-wise operations specifically tailored for sparse data structures. By prioritizing the most relevant feature values and assigning them higher weights, Jagged Feature Interaction Kernels significantly improve the performance and memory efficiency of large-scale recommendation models.

Performance Gains

Image Source: Research paper

Speedup

  • Jagged attention: Up to 2× faster than dense attention
  • Jagged Flash Attention: 9× speedup compared to dense attention
  • 3× speedup over dense flash attention

Memory Efficiency

  • Jagged attention: Up to 3.5× reduction vs. dense attention
  • Jagged Flash Attention: Impressive 22× memory reduction

Real-World Impact (Production)

  • 10% improvement in Queries Per Second (QPS)
  • 18% reduction in memory usage
  • Enhanced ability to handle longer feature sequences
  • Support for more complex model architectures

These optimizations significantly enhance the efficiency and scalability of large-scale recommendation systems, enabling more complex model architectures and longer feature sequences.

Flash Attention Tiling Optimization

Flash Attention's  tiling optimization is a key innovation that significantly improves the efficiency of attention computations in large language models. By leveraging the GPU memory hierarchy, FlashAttention reduces the number of memory accesses to high-bandwidth memory (HBM) and maximizes the use of fast on-chip SRAM. The tiling strategy involves dividing the input matrices into smaller blocks that fit into SRAM, allowing for efficient processing without excessive data movement.

The core algorithm employs two main techniques:

  • Tiling: Input matrices Q, K, and V are divided into blocks of size B×d, where B is the block size and d is the embedding dimension.
  • Incremental softmax: A modified online softmax algorithm is used to process attention scores block-wise, maintaining running statistics to ensure numerical stability.

This approach reduces the complexity of attention from O(N2) to approximately O(N) in terms of memory accesses, where N is the sequence length. The practical benefits include up to 3x speedup over dense attention implementations and significant memory savings, enabling the processing of longer sequences with limited GPU resources

A New Era for Recommendation Systems

Jagged Flash Attention and the open source TorchRec implementation are a fantastic contribution to the recommendation system community. It addresses key challenges in handling variable-length categorical features and attention mechanisms, significantly improving performance in production systems and making further advancements in the field.

Key implementation considerations for leveraging Jagged Flash Attention include:

  • Memory efficiency: Prioritize jagged tensor implementations over dense padded approaches to reduce memory overhead.
  • Computational optimization: Utilize custom Triton kernels for jagged tensor operations, achieving up to 2.52× speedup for matrix multiplications and 3.06× for softmax operations.
  • Scalability: Implement block-wise processing for large-scale operations, allowing for efficient handling of longer sequences and more complex model architectures.
  • GPU utilization: Leverage shared memory effectively and implement fused operations to maximize computational efficiency.

The practical impact of these optimizations is substantial, with production models demonstrating a 10% improvement in Queries Per Second (QPS) and an 18% reduction in memory usage. Experiments were performed for recommendation system use-cases but we could see this being useful for any use-case that requires sparse variable length batch sizes and attention models.

At Shaped we use Jagged Tensors and the TorchRec library to power many of our PyTorch models. We're excited to start integrating the Flash Attention model and see what improvements we get across our customer base! 

Get up and running with one engineer in one sprint

Guaranteed lift within your first 30 days or your money back

100M+
Users and items
1000+
Queries per second
1B+
Requests

Related Posts

Nina Shenker-Tauris
 | 
February 21, 2023

Do Large Language Models (LLMs) reason?

Tullie Murrell
 | 
February 27, 2025

Beyond Dot Products: Retrieval with Learned Similarities

Heorhii Skovorodnikov
 | 

Breaking Down Toolformer