is a to Optimizing Data Transfer in AI/ML Workloads where we demonstrated using NVIDIA Nsight™ Systems (nsys) in studying and solving the common data-loading bottleneck — occurrences where the GPU idles while it waits for input data from the CPU. On this post we focus our attention on data travelling in the wrong way, from the GPU device to the CPU host. More specifically, we address AI/ML inference workloads where the dimensions of the output being returned by the model is comparatively high. Common examples include: 1) running a scene segmentation (per-pixel labeling) model on batches of high-resolution images and a couple of) capturing high dimensional feature embeddings of input sequences using an encoder model (e.g., to create a vector database). Each examples involve executing a model on an input batch after which copying the output tensor from the GPU to the CPU for extra processing, storage, and/or over-the-network communication.
GPU-to-CPU memory copies of the model output typically receive much less attention in optimization tutorials than the CPU-to-GPU copies that feed the model (e.g., see here). But their potential impact on model efficiency and execution costs might be just as detrimental. Furthermore, while optimizations to CPU-to-GPU data-loading are well documented and straightforward to implement, optimizing data copy in the wrong way requires a bit more manual labor.
On this post we’ll apply the identical strategy we utilized in our previous post: We’ll define a toy model and use nsys profiler to discover and solve performance bottlenecks. We’ll run our experiments on an Amazon EC2 g6e.2xlarge instance (with an NVIDIA L40S GPU) running an AWS Deep Learning (Ubuntu 24.04) AMI with PyTorch (2.8), nsys-cli profiler (version 2025.6.1), and the NVIDIA Tools Extension (NVTX) library.
Disclaimers
The code we’ll share is meant for demonstrative purposes; please don’t depend on its correctness or optimality. Please don’t interpret our use of any library, tool, or platform, as an endorsement of its use. The impact of the optimizations we’ll cover can vary greatly based on the small print of the model and the runtime environment. Please make sure you assess their effect on your personal use case before integrating their use.
Many because of Yitzhak Levi and Gilad Wasserman for his or her contributions to this post.
A Toy PyTorch Model
We introduce a batched inference script that performs image segmentation on an artificial dataset using a DeepLabV3 model with a ResNet-50 backbone. The model outputs are copied to the CPU for post processing and storage. We wrap the various portions of the inference step with color-coded nvtx annotations:
import time, torch, nvtx
from torch.utils.data import Dataset, DataLoader
from torch.cuda import profiler
from torchvision.models.segmentation import deeplabv3_resnet50
DEVICE = "cuda"
WARMUP_STEPS = 10
PROFILE_STEPS = 3
COOLDOWN_STEPS = 1
TOTAL_STEPS = WARMUP_STEPS + PROFILE_STEPS + COOLDOWN_STEPS
BATCH_SIZE = 64
TOTAL_SAMPLES = TOTAL_STEPS * BATCH_SIZE
IMG_SIZE = 512
N_CLASSES = 21
NUM_WORKERS = 8
ASYNC_DATALOAD = True
# An artificial Dataset with random images
class FakeDataset(Dataset):
def __len__(self):
return TOTAL_SAMPLES
def __getitem__(self, index):
img = torch.randn((3, IMG_SIZE, IMG_SIZE))
return img
# utility class for prefetching data to GPU
class DataPrefetcher:
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.next_batch = None
self.preload()
def preload(self):
try:
data = next(self.loader)
with torch.cuda.stream(self.stream):
next_data = data.to(DEVICE, non_blocking=ASYNC_DATALOAD)
self.next_batch = next_data
except:
self.next_batch = None
def __iter__(self):
return self
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
data = self.next_batch
self.preload()
return data
model = deeplabv3_resnet50(weights_backbone=None).to(DEVICE).eval()
data_loader = DataLoader(
FakeDataset(),
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=ASYNC_DATALOAD
)
data_iter = DataPrefetcher(data_loader)
def synchronize_all():
torch.cuda.synchronize()
def to_cpu(output):
return output.cpu()
def process_output(batch_id, logits):
# do some post processing on output
with open('/dev/null', 'wb') as f:
f.write(logits.numpy().tobytes())
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu = to_cpu(output['out'])
with nvtx.annotate("process output", color="cyan"):
process_output(i, output_cpu)
total_time = end_time - start_time
throughput = PROFILE_STEPS / total_time
print(f"Throughput: {throughput:.2f} steps/sec")
Note the inclusion of the entire CPU-to-GPU data-loading optimizations discussed in our previous post.
We run the next command to capture an nsys profile trace:
nsys profile
--capture-range=cudaProfilerApi
--trace=cuda,nvtx,osrt
--output=baseline
python batch_infer.py
This ends in a trace file that we copy over to our development machine for evaluation.
To measure the inference throughput, we increase the variety of steps to 100. The common throughput of our baseline experiment is 0.45 steps-per-second. In the next sections we’ll use the nsys profile traces to incrementally improve this result.
Baseline Performance Evaluation
The image below shows the nsys profile trace of our baseline experiment:
Within the GPU section we see the next recurring pattern:
- A block of kernel compute (in light blue) that runs for ~520 milliseconds.
- A small block of host-to-device memory copy (in green) that runs in parallel to the kernel compute. This concurrency was achieved using the optimizations discussed in our previous post.
- A block of device-to-host memory copy (in red) that runs for ~750 milliseconds.
- An extended period (~940 milliseconds) of GPU idle time (white space) between every two steps.
the NVTX bar of the CPU section, we will see that the whitespace aligns perfectly with the “process output” block (in cyan). In our initial implementation, each the model execution and the output storage function run in the identical single process in a sequential manner. This results in significant idle time on the GPU because the CPU waits for the storage function to return before feeding the GPU the subsequent batch.
Optimization 1: Multi-Employee Output Processing
Step one we take is to run the output storage function in parallel employee processes. We took an identical step in our previous post after we moved the input batch preparation sequence to dedicated employees. Nonetheless, whereas there we were capable of automate multi-process data loading by simply setting the argument of the DataLoader class to a non-zero value, applying multi-worker output-processing requires a manual implementation. Here we decide an easy solution for demonstrative purposes. This must be customized per your needs and design preferences.
PyTorch Multiprocessing
We implement a producer-consumer strategy using PyTorch’s built-in multiprocessing package, torch.multiprocessing. We define a queue for storing output batches and multiple consumer employees that process the batches on the queue. We modify our inference loop to place the output buffers within the output queue. We also update the utility to empty the queue and append a cleanup sequence at the tip of the script.
The next block of code comprises our initial implementation. As we’ll see in the subsequent sections, it will require some tuning with a view to reach maximum performance.
import torch.multiprocessing as mp
POSTPROC_WORKERS = 8 # tune for optimal throughput
output_queue = mp.JoinableQueue(maxsize=POSTPROC_WORKERS)
def output_worker(in_q):
while True:
item = in_q.get()
if item is None: break # signal to shut down
batch_id, batch_preds = item
process_output(batch_id, batch_preds)
in_q.task_done()
processes = []
for _ in range(POSTPROC_WORKERS):
p = mp.Process(goal=output_worker, args=(output_queue,))
p.start()
processes.append(p)
def synchronize_all():
torch.cuda.synchronize()
output_queue.join() # drain queue
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu = to_cpu(output['out'])
with nvtx.annotate("queue output", color="cyan"):
output_queue.put((i, output_cpu))
total_time = end_time - start_time
throughput = PROFILE_STEPS / total_time
print(f"Throughput: {throughput:.2f} steps/sec")
# cleanup
for _ in range(POSTPROC_WORKERS):
output_queue.put(None)
The multi-worker output processing optimization ends in a throughput of 0.71 steps-per-second — a 58% increase over our baseline results.
Rerunning the nsys command ends in the next profile trace:

We are able to see that the dimensions of the block of whitespace has dropped considerably (from ~940 milliseconds to ~50). Were we to zoom in on the remaining whitespace, we might find it aligned to an “munmap” operation. In our previous post, the identical finding informed our asynchronous data copy optimization. But this time we take an intermediate memory-optimization step in the shape of a pre-allocated pool of buffers.
Optimization 2: Buffer Pool Pre-allocation
With the intention to reduce the overhead of allocating and managing a brand new CPU tensor on every iteration, we initialize a pool of tensors pre-allocated in shared memory and define a second queue to administer their use.
Our updated code appears below:
shape = (BATCH_SIZE, N_CLASSES, IMG_SIZE, IMG_SIZE)
buffer_pool = [torch.empty(shape).share_memory_()
for _ in range(POSTPROC_WORKERS)]
buf_queue = mp.Queue()
for i in range(POSTPROC_WORKERS):
buf_queue.put(i)
def output_worker(buffer_pool, in_q, buf_q):
while True:
item = in_q.get()
if item is None: break # signal to shut down
batch_id, buf_id = item
process_output(batch_id, buffer_pool[buf_id])
buf_q.put(buf_id)
in_q.task_done()
processes = []
for _ in range(POSTPROC_WORKERS):
p = mp.Process(goal=output_worker,
args=(buffer_pool,output_queue,buf_queue))
p.start()
processes.append(p)
def to_cpu(output):
buf_id = buf_queue.get()
output_cpu = buffer_pool[buf_id]
output_cpu.copy_(output)
return output_cpu, buf_id
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu, buf_id = to_cpu(output['out'])
with nvtx.annotate("queue output", color="cyan"):
output_queue.put((i, buf_id))
Following these changes, the inference throughput jumps to 1.51 — a greater than 2speed-up over our previous result.
The brand new profile trace appears below:

Not only has the whitespace all but disappeared, however the CUDA DtoH memory operation (in red) has dropped from ~750 milliseconds to ~110. Presumably, the big GPU-to-CPU data copy involved quite a little bit of memory-management overhead that we have now removed by implementing a dedicated buffer pool.
Despite the considerable improvement, if we zoom in we’ll find that there stays around ~0.5 milliseconds of whitespace that’s brought on by the synchronicity of the GPU-to-CPU copy command — as long as the copy has not accomplished the CPU doesn’t trigger the kernel computation of the subsequent batch.
Optimization 3: Asynchronous Data Copy
Our third optimization is to vary the device-to-host copy to be asynchronous. As before, we’ll find that implementing this transformation is tougher than within the CPU-to-GPU direction.
Step one is to pass non_blocking=True to the GPU-to-CPU copy command.
def to_cpu(output):
buf_id = buf_queue.get()
output_cpu = buffer_pool[buf_id]
output_cpu.copy_(output, non_blocking=True)
return output_cpu, buf_id
Nonetheless, as we saw in our previous post, this transformation is not going to have a meaningful impact unless we modify our tensors to make use of pinned memory:
shape = (BATCH_SIZE, N_CLASSES, IMG_SIZE, IMG_SIZE)
buffer_pool = [torch.empty(shape, pin_memory=True).share_memory_()
for _ in range(POSTPROC_WORKERS)]
Crucially, if we apply only these two changes to our script, the throughput would increase however the output could also be corrupted (e.g., see here). We want an event-based mechanism for identifying every time a GPU-to-CPU copy has been accomplished in order that we will proceed with the output data processing. (Note, that this was not required when making the CPU-to-GPU copy asynchronous. Because a single GPU stream processes commands sequentially, the kernel computation only starts when the copy has accomplished. Synchronization was only required when introducing a second stream.)
To implement the notification mechanism, we define a pool of CUDA events and an extra queue for managing their use. We further define a listener thread for monitoring the state of events on the queue and populating the output queue once the copies are complete.
import threading, queue
event_pool = [torch.cuda.Event() for _ in range(POSTPROC_WORKERS)]
event_queue = queue.Queue()
def event_monitor(event_pool, event_queue, output_queue):
while True:
item = event_queue.get()
if item is None: break
batch_id, buf_idx = item
event_pool[buf_idx].synchronize()
output_queue.put((batch_id, buf_idx))
event_queue.task_done()
monitor = threading.Thread(goal=event_monitor,
args=(event_pool, event_queue, output_queue))
monitor.start()
The updated inference sequence consists of the next steps:
- Get an input batch that was prefetched to the GPU.
- Execute the model on the input batch to get an output tensor on the GPU.
- Request a vacant CPU buffer from the buffer queue and use it to trigger an asynchronous data copy. Configure an event to trigger when the copy is complete and push the event to the event-queue.
- The monitor thread waits for the event to trigger after which pushes the output tensor to the output queue for processing.
- A employee thread pulls the output tensor from the queue and saves it to disk. It then releases the buffer back to the buffer queue.
The updated code appears below.
def synchronize_all():
torch.cuda.synchronize()
event_queue.join()
output_queue.join()
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu, buf_id = to_cpu(output['out'])
with nvtx.annotate("queue CUDA event", color="cyan"):
event_pool[buf_id].record()
event_queue.put((i, buf_id))
total_time = end_time - start_time
throughput = PROFILE_STEPS / total_time
print(f"Throughput: {throughput:.2f} steps/sec")
# cleanup
event_queue.put(None)
for _ in range(POSTPROC_WORKERS):
output_queue.put(None)
The resultant throughput is 1.55 steps-per-second.
The brand new profile trace appears below:

Within the NVTX row of the CPU section we will see the entire operations within the inference loop bunched together on left side — implying that all of them ran immediately and asynchronously. We also see the event synchronization calls (in light green) running on the dedicated monitor thread. Within the GPU section we see that the kernel computation begins immediately after the device-to-host copy has accomplished.
Our final optimization will deal with improving the parallelization of the kernel and memory operations on the GPU.
Optimization 4: Pipelining Using CUDA Streams
As in our previous post, we want to benefit from the independent engines for memory copying (the DMA) and kernel compute (the SMs). We do that by assigning the memory copy to a dedicated CUDA stream:
egress_stream = torch.cuda.Stream()
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
# on separate stream
with torch.cuda.stream(egress_stream):
# wait for default stream to finish compute
egress_stream.wait_stream(torch.cuda.default_stream())
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu, buf_id = to_cpu(output['out'])
with nvtx.annotate("queue CUDA event", color="cyan"):
event_pool[buf_id].record(egress_stream)
event_queue.put((i, buf_id))
This ends in a throughput of 1.85 steps per second — an extra 19.3% improvement over our previous experiment.
The ultimate profile trace appears below:

Within the GPU section we see a continuous block of kernel compute (in light blue) with each the host-to-device (in light green) and device-to-host (in purple) running in parallel. Our inference loop is now compute-bound, implying that we have now exhausted all practical opportunities for data-transfer optimization.
Results
We summarize our ends in the next table:

Through using nsys profiler we were capable of increase efficiency by over 4. Naturally, the impact of the optimizations we discussed will vary based on the small print of the model and runtime environment.
Summary
This concludes the second a part of our series of posts on the subject of optimizing data-transfer in AI/ML workloads. Part one focused on host-to-device copies and part two on device-to-host copies. When implemented naively, data-transfer in either direction can result in significant performance bottlenecks leading to GPU starvation and increased runtime costs. Using Nsight Systems profiler, we demonstrated find out how to discover and resolve these bottlenecks and increase runtime efficiency.
Although the optimization of each directions involved similar steps, the implementation details were very different. While optimizing CPU-to-GPU data-transfer is well-supported by PyTorch’s data-loading APIs and required relatively small changes to the execution loop, optimizing the the GPU-to-CPU direction required a bit more software engineering. Importantly, the solutions we put forth on this post were chosen for demonstrative purposes. Your individual solution may differ considerably based in your project needs and design preferences.
Having covered each CPU-to-GPU and GPU-to-CPU data copies, we turn our attention to GPU-to-GPU transactions: Stay tuned for a future post on the subject of optimizing data transfer between GPUs in distributed training workloads.
