TorchPP: The PyTorch Performance Plus Toolkit
torchpp is a powerful extension library for PyTorch, designed to supercharge your deep learning workflows. It provides a suite of high-performance CUDA kernels and a powerful distributed training framework to accelerate model performance and simplify scaling.
Whether you are working with Large Language Models (LLMs), Diffusion Models, Text-to-Speech (TTS), or Time-Series Models, torchpp aims to be your go-to library for performance and scalability.
Core Pillars
1. Accelerate Your Models
Boost your model's speed with our collection of high-performance, custom-written CUDA kernels. We leverage libraries like CUTLASS to build fused and optimized components that replace standard PyTorch modules, resulting in significant performance gains by reducing memory bandwidth and leveraging hardware-specific features like Tensor Cores.
- Fused Kernels:
- Linear + Activation: Fused
Linearlayers withGeLUandSiLUactivations. - Optimized Normalization: High-performance
LayerNormandRMSNorm.
- Linear + Activation: Fused
- Custom Implementations:
- Rotary Position Embeddings (RoPE): An optimized implementation of RoPE.
- Attention Variants: Efficient implementations of various attention mechanisms, including
Multi-Head,Multi-Query,Grouped-Query,Sliding-Window, andCross-Attention, all usingflash-attnfor maximum performance.
2. Simplify Distributed Training
Move beyond the boilerplate of distributed training. torchpp provides a high-level, easy-to-use DistributedTrainer that handles the complexities of different parallelization strategies, so you can focus on your model.
- Effortless Scaling: Easily switch between strategies like Data Parallel (DDP), Fully Sharded Data Parallel (FSDP), and hybrid approaches with simple configuration changes.
- Out-of-the-Box Functionality: The trainer includes built-in support for mixed-precision training, gradient accumulation, checkpointing, and more.
How it Works: The torchpp Architecture
torchpp achieves its performance by combining a low-level CUDA backend with a high-level Python frontend.
- CUDA Kernels (
csrc/): The core operations are written in C++ and CUDA. These kernels are optimized for specific hardware (e.g., NVIDIA GPUs with Tensor Cores) and data types (currentlyFP16). - Pybind11 Bindings: We use
pybind11to create Python bindings for the C++ functions. This is handled in thebinding.cufiles within thecsrcdirectory (e.g.,csrc/activation/binding.cu). These bindings compile the CUDA code into Python modules (e.g.,linearActvationFp16). - Python API (
torchpp/): The user-facing API is written in Python. The modules intorchpp/(e.g.,torchpp.dlops.linear) import the compiled CUDA modules and wrap them in familiartorch.nn.Moduleclasses, making them easy to integrate into existing PyTorch models.
This architecture allows you to get the performance of low-level CUDA programming with the ease of use of a high-level Python library.
Installation
Prerequisites:
- A CUDA-enabled GPU.
- The CUTLASS library. Ensure the
CUTLASS_PATHenvironment variable is set.
export CUTLASS_PATH=/path/to/cutlass/include
Installation:
git clone https://github.com/AmanSwar/TorchPlusPlus.git
cd torchpp
pip install .
API Reference
Deep Learning Operations (torchpp.dlops)
This module provides optimized implementations of common deep learning operations. All modules in torchpp.dlops currently expect FP16 tensors.
Normalization (torchpp.dlops.normalization)
RmsNormFused(nn.Module)
A high-performance implementation of Root Mean Square Normalization.
- Arguments:
normalize_dim_shape(int): The size of the dimension to normalize.eps(float, optional): A small value to avoid division by zero. Defaults to1e-6.dtype(torch.dtype, optional): The data type of the weight. Defaults totorch.float16.device(torch.device, optional): The device of the weight. Defaults totorch.device("cuda").
- Usage:
import torch from torchpp.dlops.normalization import RmsNormFused norm = RmsNormFused(normalize_dim_shape=256) x = torch.randn(16, 128, 256, dtype=torch.float16, device="cuda") output = norm(x)
LayerNormFused(nn.Module)
A high-performance implementation of Layer Normalization.
- Arguments:
normalize_dim_shape(int): The size of the dimension to normalize.eps(float, optional): A small value to avoid division by zero. Defaults to1e-6.dtype(torch.dtype, optional): The data type of the weight. Defaults totorch.float16.device(torch.device, optional): The device of the weight. Defaults totorch.device("cuda").
- Usage:
import torch from torchpp.dlops.normalization import LayerNormFused norm = LayerNormFused(normalize_dim_shape=256) x = torch.randn(16, 128, 256, dtype=torch.float16, device="cuda") output = norm(x)
Fused Linear Layers (torchpp.dlops.linear)
LinearGELU(nn.Module)
A Linear layer fused with a GELU activation.
- Arguments:
in_features(int): Size of each input sample.out_features(int): Size of each output sample.
- Usage:
import torch from torchpp.dlops.linear import LinearGELU layer = LinearGELU(in_features=512, out_features=1024) x = torch.randn(32, 512, dtype=torch.float16, device="cuda") output = layer(x)
LinearSILU(nn.Module)
A Linear layer fused with a SiLU (Swish) activation.
- Arguments:
in_features(int): Size of each input sample.out_features(int): Size of each output sample.
- Usage:
import torch from torchpp.dlops.linear import LinearSILU layer = LinearSILU(in_features=512, out_features=1024) x = torch.randn(32, 512, dtype=torch.float16, device="cuda") output = layer(x)
Rotary Position Embeddings (torchpp.dlops.rope)
rope_apply(x, cos, sin)
A functional implementation of Rotary Position Embeddings.
- Arguments:
x(torch.Tensor): The input tensor of shape[batch, heads, seq_len, head_dim].cos(torch.Tensor): The cosine cache of shape[seq_len, head_dim].sin(torch.Tensor): The sine cache of shape[seq_len, head_dim].
- Returns: A
torch.Tensorwith RoPE applied. - Usage:
import torch from torchpp.dlops.rope import rope_apply # Assume x, cos_cache, and sin_cache are precomputed # x: [bs, n_heads, seq_len, head_dim] # cos_cache, sin_cache: [seq_len, head_dim] output = rope_apply(x, cos_cache, sin_cache)
Attention Mechanisms (torchpp.attention)
This module provides several efficient attention implementations that leverage flash-attn.
MultiHeadAttention(embed_dim, n_heads, ...): Standard Multi-Head Attention.GroupedQueryAttention(d_in, num_heads, n_kv_heads, ...): Grouped-Query Attention.MQA_FA(num_q_heads, embed_dim, ...): Multi-Query Attention.SlidingWindowAttention(window_size, embed_dim, n_heads, ...): Sliding Window Attention.CrossAttention(embed_dim, cross_dim, n_heads, ...): Cross-Attention.
All attention modules follow a similar pattern and are initialized with model dimensions and optional arguments like qknorm and dtype.
Distributed Training (torchpp.train.dist_train)
DistributedTrainer
A high-level trainer class to handle the complexities of distributed training.
- Key Arguments:
model(nn.Module): The model to be trained.config(TrainingConfig): A dataclass containing all training configurations (strategy, learning rate, checkpoint paths, etc.).optimizer(torch.optim.Optimizer, optional): The optimizer to use. If not provided, anAdamWoptimizer is created by default.lr_sched(optional): A learning rate scheduler.loss_function(Callable, optional): The loss function. Defaults to cross-entropy loss.
- Key Methods:
train(train_dataloader, eval_dataloader, num_epochs): Starts the training loop.evaluate(eval_dataloader): Runs the evaluation loop._save_checkpoint(): Saves a model checkpoint.load_checkpoint(path): Loads a model checkpoint.
Future Vision & Roadmap
We have an ambitious roadmap to make torchpp an indispensable tool for PyTorch developers:
- Quantization Support: Integration of popular quantization techniques like AWQ, GPTQ, and others to further boost inference performance.
- Faster Training with Custom Backward Kernels: Implementation of custom backward passes for all our fused kernels to accelerate the training process.
- Expanded Kernel Library: Introduction of new fused kernels for Diffusion, Convolution-based models, and RNN-based models.