Scale Biology Transformer Models with PyTorch and NVIDIA BioNeMo Recipes

-


Training models with billions or trillions of parameters demands advanced parallel computing. Researchers must resolve the right way to mix parallelism strategies, select probably the most efficient accelerated libraries, and integrate low-precision formats reminiscent of FP8 and FP4—all without sacrificing speed or memory. 

There are accelerated frameworks that help, but adapting to those specific methodologies can significantly slow R&D, as users typically must learn a completely latest codebase. 

NVIDIA BioNeMo Recipes can simplify and speed up this process by lowering the barrier to entry for large-scale model training. Using step-by-step guides built on familiar frameworks like PyTorch and Hugging Face (HF), we show how integrating accelerated libraries reminiscent of NVIDIA Transformer Engine (TE) unlocks speed and memory efficiency, scaling performance through techniques like Fully Sharded Data Parallel (FSDP) and Context Parallelism.

On this blog post, we exhibit the right way to speed up transformer-style AI models for biology by taking the Hugging Face ESM-2 protein language model with a native PyTorch training loop and:

  1. Accelerating it with TE. 
  2. Integrating with FSDP2 for auto-parallelism. 
  3. Showin sequence packing to attain even greater performance.

All you might want to start is PyTorch, NVIDIA CUDA 12.8, and the next resources: 

Integrating Transformer Engine into ESM-2

TE enables significant performance gains by optimizing transformer computations, particularly on NVIDIA GPUs. It may possibly be integrated into existing training pipelines without requiring an entire overhaul of your datasets, data loaders, or trainers. This section shows the right way to incorporate TE right into a model like ESM-2, drawing inspiration from the BioNeMo recipes.

In most use cases, using the ready-made TransformerLayer module from TE is easy. This encapsulates all fused TE operations and best practices right into a single drop-in module, reducing boilerplate code and setup. The next snippet shows how we integrated TE in ESM-2. The total implementation could be present in the NVEsmEncoder class definition in bionemo-recipes.

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

class MyEsmEncoder(torch.nn.Module):
    def __init__(self, num_layers, hidden_size, ffn_hidden_size, num_heads):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            te.TransformerLayer(
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_attention_heads=num_heads,
                layer_type="encoder",
                self_attn_mask_type="padding",
                attn_input_format="bshd", # or 'thd', read below.
                window_size=(-1, -1), # disable windowed attention
            ) for _ in range(num_layers)
        ])
        # Optionally add embedding, head, etc.

    def forward(self, x, attention_mask=None):
        for layer in self.layers:
            x = layer(x, attention_mask=attention_mask)
        return x

# Layer configuration
layer_num = 8
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.bfloat16

# Synthetic data (batch, seq, hidden) for bshd format
x = torch.rand(batch_size, sequence_length, hidden_size).cuda().to(dtype=dtype)
attention_mask = torch.ones(batch_size, 1, 1, sequence_length, dtype=torch.bool).cuda()
myEsm = MyEsmEncoder(layer_num, hidden_size, ffn_hidden_size, num_attention_heads)
myEsm.to(dtype=dtype).cuda()

fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    y = myEsm(x, attention_mask=attention_mask)

In case your architecture deviates from a regular Transformer block, TE can still be integrated on the layer level. The core idea is to interchange standard PyTorch modules (e.g., nn.Linear, nn.LayerNorm) with their TE counterparts and use FP8 autocasting to attain maximum performance gains. TE provides several alternative implementations to common layers, reminiscent of Linear, fused LayerNormLinear, and a focus modules like DotProductAttention and MultiheadAttention. For an entire list of supported modules, check the TE documentation.

Efficient sequence packing

Standard input data formats could be inefficient when samples have various sequence lengths. For instance, ESM-2 pretraining with a context length of 1,024 can consist of around 60% padding tokens, wasting compute on tokens that don’t take part in the model’s attention mechanism. Internally, networks typically represent the hidden state of input sequences in a tensor with 4 dimensions: [batch size (B), max sequence length (S), number of attention heads (H), and head hidden dimension (D)], or BSHD.

In its place, modern attention kernels enable users to offer packed inputs without padding tokens, using index vectors to indicate the boundaries between input sequences. Here, hidden states are represented by a flattened tensor of size [flattened input tokens (T), number of attention heads (H), head hidden dimension (D)], or THD. Figure 1 shows this format change, which ends up in less memory usage and faster token throughput by removing padding tokens (grey).

The diagram compares the BSHD format, in which sequences of varying lengths in a batch include padding tokens shown as gray squares, with the THD “sequence-packed” format, where sequences are concatenated into a single flattened tensor, separated by cumulative sequence length markers (cu_seq_lens), eliminating padding for more efficient computation.The diagram compares the BSHD format, in which sequences of varying lengths in a batch include padding tokens shown as gray squares, with the THD “sequence-packed” format, where sequences are concatenated into a single flattened tensor, separated by cumulative sequence length markers (cu_seq_lens), eliminating padding for more efficient computation.
Figure 1. BSHD vs. THD “sequence‑packed” input: converting padded BSHD tensors to THD using cumulative sequence lengths (cu_seq_lens)

TE makes this optimization relatively easy by adding an attn_input_format parameter to relevant layers, which then accepts standard flash-attention-style cumulative sequence length keyword arguments (cu_seq_lens_q). These could be generated using THD-aware collators, reminiscent of Hugging Face’s DataCollatorWithFlattening, or the masking version implemented in BioNeMo Recipes.

def sequence_pack(input_ids, labels):
    # input_ids is an inventory of sequences: [(S1,), (S2,), ..., (SN,)] of shape (B,S)
    # Flatten and track sequence boundaries

    # Determine the length of every sequence    
    sample_lengths = [len(sample) for sample in input_ids]

    # Flatten the input_ids and labels
    flat_input_ids = [token for sample in input_ids for token in sample]
    flat_labels = [label for sample in labels for label in sample]

    # Create an inventory of cumulative sums showing where the sequences start/stop
    # Note: for self attention cu_seqlens_q and cu_seqlens_kv will probably be the identical
    cu_seqlens = torch.cumsum(torch.tensor([0] + sample_lengths), dim=0, dtype=torch.int32)

    max_length = max(sample_lengths)
    
    return {
        "input_ids": torch.tensor(flat_input_ids, dtype=torch.int64),
        "labels": torch.tensor(flat_labels, dtype=torch.int64),
        # These are the identical kwargs utilized by `flash_attn_varlen_func`, etc.
        "cu_seqlens_q": cu_seqlens,
        "cu_seqlens_kv": cu_seqlens,
        "max_length_q": max_length,
        "max_length_kv": max_length,
    }

TE and sequence packing on/off performance 

The plot shows that using TE boosts token throughput, demonstrating its ability to maximize NVIDIA GPU efficiency.The plot shows that using TE boosts token throughput, demonstrating its ability to maximize NVIDIA GPU efficiency.
Figure 2. TE and sequence packing on/off performance

Figure 2 shows the performance comparison, with a big uplift in token throughput when TE is employed. This demonstrates TE’s ability to maximise the computational efficiency of your NVIDIA GPUs.

EvolutionaryScale integrated Transformer Engine across their next-generation models as well:

“ESM3 is the most important foundation model trained on biological data. Integrating the NVIDIA Transformer Engine was crucial to training it at this 98B parameter scale with high throughput and GPU utilization,” said Tom Sercu, co-founder and VP of Engineering at EvolutionaryScale. “The precision and speed of FP8 acceleration, combined with optimized kernels for fused layers, allow us to push the boundaries of compute and model scale across NVIDIA GPUs. This results in emergent understanding of biology in our frontier models for the scientific community.”

Hugging Face interoperability

Considered one of the important thing benefits of TE is its interoperability with existing machine learning ecosystems, including popular libraries like Hugging Face. This implies you should use TE’s performance advantages even when working with models loaded from the Hugging Face Transformers library.

TE layers could be embedded directly inside a Hugging Face Transformers PreTrainedModel, and are fully compatible with AutoModel.from_pretrained. See the NVIDIA BioNeMo Collection on the Hugging Face Hub for pre-optimized models.

The method typically involves loading your Hugging Face model, then fastidiously identifying and replacing its standard PyTorch layers (reminiscent of nn.Linear, nn.LayerNorm, and nn.MultiheadAttention) with their TE-optimized counterparts. This often requires renaming some layers or a custom model wrapper to make sure the TE layers are accurately integrated into the model’s forward pass.

Start

Our mission with BioNeMo Recipes is to make acceleration and scaling accessible for all foundation model builders. To assist us construct a more powerful and practical toolkit, we wish to listen to from you. We encourage you to check out the recipes and contribute by submitting a pull request or opening a difficulty on our GitHub. 



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x