Debugging the Dreaded NaN

-

You’re training your latest AI model, anxiously watching because the loss steadily decreases when suddenly — boom! Your logs are flooded with NaNs (Not a Number) — your model is irreparably corrupted and also you’re left gazing your screen in despair. To make matters worse, the NaNs don’t appear consistently. Sometimes your model trains just nice; other times, it fails inexplicably. Sometimes it’s going to crash immediately, sometimes after many days of coaching.

NaNs in Deep Learning workloads are amongst essentially the most frustrating issues to come across. And since they often appear sporadically — triggered by a selected combination of model state, input data, and stochastic aspects — they will be incredibly difficult to breed and debug.

Given the considerable cost of coaching AI models and the potential waste brought on by NaN failures, it is strongly recommended to have dedicated tools for capturing and analyzing NaN occurrences. In a previous post, we discussed the challenge of debugging NaNs in a TensorFlow training workload. We proposed an efficient scheme for capturing and reproducing NaNs and shared a sample TensorFlow implementation. On this post, we adopt and show an analogous mechanism for debugging NaNs in PyTorch workloads. The final scheme is as follows:

On each training step:

  1. Save a duplicate of the training input batch.
  2. Check the gradients for NaN values. If any appear, save a checkpoint with the present model weights before the model is corrupted. Also, save the input batch and, if crucial, the stochastic state. Discontinue the training job.
  3. Reproduce and debug the NaN occurrence by loading the saved experiment state.

Although this scheme will be easily implemented in native PyTorch, we are going to take the chance to show a number of the conveniences of PyTorch Lightning — a strong open-source framework designed to streamline the event of machine learning (ML) models. Built on PyTorch, Lightning abstracts away lots of the boiler-plate components of an ML experiment, corresponding to training loops, data distribution, logging, and more, enabling developers to concentrate on the core logic of their models.

To implement our NaN capturing scheme, we are going to use Lightning’s callback interface — a dedicated structure that permits inserting custom logic at specific points in the course of the flow of execution.

Importantly, please don’t view our alternative of Lightning or another tool or technique that we mention as an endorsement of its use. The code that we’ll share is meant for demonstrative purposes — please don’t depend on its correctness or optimality.

Many due to Rom Maltser for his contributions to this post.

NaNCapture Callback

To implement our NaN capturing solution, we create a NaNCapture Lightning callback. The constructor receives a directory path for storing/loading checkpoints and sets up the NaNCapture state. We also define utilities for checking for NaNs, storing checkpoints, and halting the training job.

 import os
import torch
from copy import deepcopy
import lightning.pytorch as pl

class NaNCapture(pl.Callback):

    def __init__(self, dirpath: str):
        # path to checkpoint
        self.dirpath = dirpath
        
        # update to True when Nan is identified
        self.nan_captured = False
        
        # stores a duplicate of the last batch
        self.last_batch = None
        self.batch_idx = None

    @staticmethod
    def contains_nan(tensor):
        return torch.isnan(tensor).any().item()
        # alternatively check for finite
        # return not torch.isfinite(tensor).item()

    @staticmethod
    def halt_training(trainer):
        trainer.should_stop = True
        # communicate stop command to all other ranks
        trainer.strategy.reduce_boolean_decision(trainer.should_stop,
                                                 all=False)

    def save_ckpt(self, trainer):
        os.makedirs(self.dirpath, exist_ok=True)
        # include trainer.global_rank to avoid conflict
        filename = f"nan_checkpoint_rank_{trainer.global_rank}.ckpt"
        full_path = os.path.join(self.dirpath, filename)
        print(f"saving ckpt to {full_path}")
        trainer.save_checkpoint(full_path, False)

Callback Function: on_train_batch_start

We start by implementing the on_train_batch_start hook to store a duplicate of every input batch. In case of a NaN event, this batch can be stored within the checkpoint.

Callback Function: on_before_optimizer_step

Next we implement the on_before_optimizer_step hook. Here, we check for NaN entries in all the gradient tensors. If found, we store a checkpoint with the uncorrupted model weights and halt the training.

Python">    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        if not self.nan_captured:
            # Check if gradients contain NaN
            grads = [p.grad.view(-1) for p in pl_module.parameters()
                     if p.grad is not None]
            all_grads = torch.cat(grads)
            if self.contains_nan(all_grads):
                print("nan found")
                self.save_ckpt(trainer)
                self.halt_training(trainer)

Capturing the Training State

To enable reproducibility, we include the NaNCapture state within the checkpoint by appending it to the training state dictionary. Lightning provides dedicated utilities for saving and loading a callback state:

def state_dict(self):
        d = {"nan_captured": self.nan_captured}
        if self.nan_captured:
            d["last_batch"] = self.last_batch
        return d


    def load_state_dict(self, state_dict):
        self.nan_captured = state_dict.get("nan_captured", False)
        if self.nan_captured:
            self.last_batch = state_dict["last_batch"]

Reproducing the NaN Occurrence

We’ve described how our NaNCapture callback will be used to store the training state that resulted in a NaN, but how will we reload this state with the intention to reproduce the problem and debug it? To perform this, we leverage Lightning’s dedicated data loading class, LightningDataModule.

DataModule Function: on_before_batch_transfer

Within the code block below, we extend the LightningDataModule class to permit injecting a hard and fast training input batch. That is achieved by overriding the on_before_batch_transfer hook, as shown below:

from lightning.pytorch import LightningDataModule

class InjectableDataModule(LightningDataModule):

    def __init__(self):
        super().__init__()
        self.cached_batch = None

    def set_custom_batch(self, batch):
        self.cached_batch = batch

    def on_before_batch_transfer(self, batch, dataloader_idx):
        if self.cached_batch:
            return self.cached_batch
        return batch

Callback Function: on_train_start

The ultimate step is modifying the on_train_start hook of our NaNCapture callback to inject the stored training batch into the LightningDataModule.

    def on_train_start(self, trainer, pl_module):
        if self.nan_captured:
            datamodule = trainer.datamodule
            datamodule.set_custom_batch(self.last_batch)

In the following section we are going to show the end-to-end solution using a toy example.

Toy Example

To check our latest callback, we create a resnet50-based image classification model with a loss function deliberately designed to trigger NaN occurrences.

As an alternative of using the usual CrossEntropy loss, we compute binary_cross_entropy_with_logits for every class independently and divide the result by the variety of samples belonging to that class. Inevitably, we are going to encounter a batch by which a number of classes are missing, resulting in a divide-by-zero operation, leading to NaN values and corrupting the model.

The implementation below follows Lightning’s introductory tutorial.

import lightning.pytorch as pl
import torch
import torchvision
import torch.nn.functional as F

num_classes = 20


# define a lightning module
class ResnetModel(pl.LightningModule):
    def __init__(self):
        """Initializes a brand new instance of the MNISTModel class."""
        super().__init__()
        self.model = torchvision.models.resnet50(num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        outputs = self(x)
        # uncomment for default loss
        # return F.cross_entropy(outputs, y)
        
        # calculate binary_cross_entropy for every class individually
        losses = []
        for c in range(num_classes):
            count = torch.count_nonzero(y==c)
            masked = torch.where(y==c, 1., 0.)
            loss = F.binary_cross_entropy_with_logits(
                outputs[..., c],
                masked,
                reduction='sum'
            )
            mean_loss = loss/count # could lead to NaN
            losses.append(mean_loss)
        total_loss = torch.stack(losses).mean()
        return total_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

We define an artificial dataset and encapsulate it in our InjectableDataModule class:

import os
import random
from torch.utils.data import Dataset, DataLoader

batch_size = 128
num_steps = 800

# A dataset with random images and labels
class FakeDataset(Dataset):
    def __len__(self):
        return batch_size*num_steps

    def __getitem__(self, index):
        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
        label = torch.tensor(random.randint(0, num_classes-1),
                             dtype=torch.int64)
        return rand_image, label



# define a lightning datamodule
class FakeDataModule(InjectableDataModule):

    def train_dataloader(self):
        dataset = FakeDataset()
        return DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=os.cpu_count(),
            pin_memory=True
        )

Finally, we initialize a Lightning Trainer with our NaNCapture callback and call trainer.fit with our Lightning module and Lightning DataModule.

import time

if __name__ == "__main__":

    # Initialize a lightning module
    lit_module = ResnetModel()

    # Initialize a DataModule
    mnist_data = FakeDataModule()

    # Train the model
    ckpt_dir = "./ckpt_dir"
    trainer = pl.Trainer(
        max_epochs=1,
        callbacks=[NaNCapture(ckpt_dir)]
    )

    ckpt_path = None
    
    # check is nan ckpt exists
    if os.path.isdir(ckpt_dir):

    # check if nan ckpt exists
    if os.path.isdir(ckpt_dir):
        dir_contents = [os.path.join(ckpt_dir, f)
                        for f in os.listdir(ckpt_dir)]
        ckpts = [f for f in dir_contents
                 if os.path.isfile(f) and f.endswith('.ckpt')]
        if ckpts:
            ckpt_path = ckpts[0]

    t0 = time.perf_counter()
    trainer.fit(lit_module, mnist_data, ckpt_path=ckpt_path)
    print(f"total runtime: {time.perf_counter() - t0}")

After quite a few training steps, a NaN event will occur. At this point a checkpoint is saved with the complete training state and the training is halted.

When the script is run again the precise state that caused the NaN can be reloaded allowing us to simply reproduce the problem and debug its root cause.

Performance Overhead

To evaluate the impact of our NaNCapture callback on runtime performance, we modified our experiment to make use of CrossEntropyLoss (to avoid NaNs) and measured the common throughput when running with and without NaNCapture callback. The experiments were conducted on an NVIDIA L40S GPU, with a PyTorch 2.5.1 Docker image.

Overhead of NaNCapture Callback (by Creator)

For our toy model, the NaNCapture callback adds a minimal 1.5% overhead to the runtime performance — a small price to pay for the useful debugging capabilities it provides.

Naturally, the actual overhead will depend upon the specifics of the model and runtime environment.

Methods to Handle Stochasticity

The answer we now have described henceforth will achieve reproducing the training state provided that the model doesn’t include any randomness. Nevertheless, introducing stochasticity into the model definition is usually critical for convergence. A typical example of a stochastic layer is torch.nn.Dropout.

You might find that your NaN event depends upon the precise state of randomness when the failure occurred. Consequently, we would love to boost our NaNCapture callback to capture and restore the random state at the purpose of failure. The random state is decided by quite a few libraries. Within the code block below, we try and capture the complete state of randomness:

import os
import torch
import random
import numpy as np
from copy import deepcopy
import lightning.pytorch as pl

class NaNCapture(pl.Callback):

    def __init__(self, dirpath: str):
        # path to checkpoint
        self.dirpath = dirpath
        
        # update to True when Nan is identified
        self.nan_captured = False
        
        # stores a duplicate of the last batch
        self.last_batch = None
        self.batch_idx = None

        # rng state
        self.rng_state = {
            "torch": None,
            "torch_cuda": None,
            "numpy": None,
            "random": None
        }

    @staticmethod
    def contains_nan(tensor):
        return torch.isnan(tensor).any().item()
        # alternatively check for finite
        # return not torch.isfinite(tensor).item()

    @staticmethod
    def halt_training(trainer):
        trainer.should_stop = True
        trainer.strategy.reduce_boolean_decision(trainer.should_stop,
                                                 all=False)

    def save_ckpt(self, trainer):
        os.makedirs(self.dirpath, exist_ok=True)
        # include trainer.global_rank to avoid conflict
        filename = f"nan_checkpoint_rank_{trainer.global_rank}.ckpt"
        full_path = os.path.join(self.dirpath, filename)
        print(f"saving ckpt to {full_path}")
        trainer.save_checkpoint(full_path, False)

    def on_train_start(self, trainer, pl_module):
        if self.nan_captured:
            # inject batch
            datamodule = trainer.datamodule
            datamodule.set_custom_batch(self.last_batch)

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
       if self.nan_captured:
            # restore random state
            torch.random.set_rng_state(self.rng_state["torch"])
            torch.cuda.set_rng_state_all(self.rng_state["torch_cuda"])
            np.random.set_state(self.rng_state["numpy"])
            random.setstate(self.rng_state["random"])
        else:
            # capture current batch
            self.last_batch= deepcopy(batch)
            self.batch_idx = batch_idx
    
            # capture current random state
            self.rng_state["torch"] = torch.random.get_rng_state()
            self.rng_state["torch_cuda"] = torch.cuda.get_rng_state_all()
            self.rng_state["numpy"] = np.random.get_state()
            self.rng_state["random"] = random.getstate()
    
    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        if not self.nan_captured:
            # Check if gradients contain NaN
            grads = [p.grad.view(-1) for p in pl_module.parameters()
                     if p.grad is not None]
            all_grads = torch.cat(grads)
            if self.contains_nan(all_grads):
                print("nan found")
                self.save_ckpt(trainer)
                self.halt_training(trainer)

    def state_dict(self):
        d = {"nan_captured": self.nan_captured}
        if self.nan_captured:
            d["last_batch"] = self.last_batch
            d["rng_state"] = self.rng_state
        return d

    def load_state_dict(self, state_dict):
        self.nan_captured = state_dict.get("nan_captured", False)
        if self.nan_captured:
            self.last_batch = state_dict["last_batch"]
            self.rng_state = state_dict["rng_state"]

Importantly, setting the random state may not guarantee full reproducibility. The GPU owes its power to its massive parallelism. In some GPU operations, multiple threads may read or write concurrently to the identical memory locations leading to nondeterminism. PyTorch allows for some control over this via its use_deterministic_algorithms, but this may occasionally impact the runtime performance. Moreover, there may be a possibility that the NaN event won’t reproduced once this configuration setting is modified. Please see the PyTorch documentation on reproducibility for more details.

Summary

Encountering NaN failures is one of the vital discouraging events that may occur in machine learning development. These errors not only waste useful computation and development resources, but often indicate fundamental issues within the model architecture or experiment design. Attributable to their sporadic, sometimes elusive nature, debugging NaN failures is usually a nightmare.

This post introduced a proactive approach for capturing and reproducing NaN errors using a dedicated Lightning callback. The answer we shared is a proposal which will be modified and prolonged to your specific use case.

While this solution may not address every possible NaN scenario, it significantly reduces debugging time when applicable, potentially saving developers countless hours of frustration and wasted effort.

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x