A guest blog post by Amog Kamsetty from the Anyscale team
Huggingface Transformers recently added the Retrieval Augmented Generation (RAG) model, a brand new NLP architecture that leverages external documents (like Wikipedia) to reinforce its knowledge and achieve state-of-the-art results on knowledge-intensive tasks. On this blog post, we introduce the mixing of Ray, a library for constructing scalable applications, into the RAG contextual document retrieval mechanism. This hurries up retrieval calls by 2x and improves the scalability of RAG distributed fine-tuning.
What’s Retrieval Augmented Generation (RAG)?
An outline of RAG. The model retrieves contextual documents from an external dataset as a part of its execution. These contextual documents are used at the side of the unique input to provide an output. The GIF is taken from Facebook’s original blog post.
Recently, Huggingface partnered with Facebook AI to introduce the RAG model as a part of its Transformers library.
RAG acts just like every other seq2seq model. Nonetheless, RAG has an intermediate component that retrieves contextual documents from an external knowledge base (like a Wikipedia text corpus). These documents are then used at the side of the input sequence and passed into the underlying seq2seq generator.
This information retrieval step allows RAG to utilize multiple sources of data — those which can be baked into the model parameters and the data that’s contained within the contextual passages, allowing it to outperform other state-of-the-art models in tasks like query answering. You possibly can try it for yourself using this demo provided by Huggingface!
Scaling up fine-tuning
This retrieval of contextual documents is crucial for RAG’s state-of-the-art results but introduces an additional layer of complexity. When scaling up the training process via a data-parallel training routine, a naive implementation of the document lookup can develop into a bottleneck for training. Further, the document index utilized in the retrieval component is usually quite large, making it infeasible for every training employee to load its own replicated copy of the index.
The previous implementation of RAG fine-tuning leveraged the torch.distributed communication package for the document retrieval portion. Nonetheless, this implementation sometimes proved to be inflexible and limited in scalability.
As a substitute, a framework-agnostic and a more flexible implementation for ad-hoc concurrent programming is required. Ray matches the bill perfectly. Ray is a straightforward, yet powerful Python library for general-purpose distributed and parallel programming. Using Ray for distributed document retrieval, we achieved a 2x speedup per retrieval call in comparison with torch.distributed, and overall higher fine-tuning scalability.
Ray for Document Retrieval

Document retrieval with the torch.distributed implementation
The principal drawback of the torch.distributed implementation for document retrieval was that it latched onto the identical process group used for training and only the rank 0 training employee loaded the index into memory.
In consequence, this implementation had some limitations:
- Synchronization bottleneck: The rank 0 employee needed to receive the inputs from all staff, perform the index query, after which send the outcomes back to the opposite staff. This limited performance with multiple training staff.
- PyTorch specific: The document retrieval process group needed to latch onto the prevailing process group used for training, meaning that PyTorch needed to be used for training as well.

Document retrieval with the Ray implementation
To beat these limitations, we introduced a novel implementation of distributed retrieval based on Ray. With Ray’s stateful actor abstractions, multiple processes which can be separate from the training processes are used to load the index and handle the retrieval queries. With multiple Ray actors, retrieval is not any longer a bottleneck and PyTorch is not any longer a requirement for RAG.
And as you’ll be able to see below, using the Ray based implementation leads to higher retrieval performance for multi-GPU fine-tuning. The next results show the seconds per retrieval call and we are able to see that as we increase the variety of GPUs that we train on, using Ray has comparatively higher performance than torch.distributed. Also, if we increase the variety of Ray processes that perform retrieval, we also get well performance with more training staff since a single retrieval process is not any longer a bottleneck.
| 2 GPU | 3 GPU | 4 GPU | |
| torch.distributed | 2.12 sec/retrieval | 2.62 sec/retrieve | 3.438 sec/retrieve |
| Ray 2 retrieval processes | 1.49 sec/retrieve | 1.539 sec/retrieve | 2.029 sec/retrieve |
| Ray 4 retrieval processes | 1.145 sec/retrieve | 1.484 sec/retrieve | 1.66 sec/retrieve |
A performance comparison of various retrieval implementations. For every document retrieval implementation, we run 500 training steps with a per-GPU batch size of 8, and measure the time it takes to retrieve the contextual documents for every batch on the rank 0 training employee. As the outcomes show, using multiple retrieval processes improves performance, especially as we scale training to multiple GPUs.
How do I exploit it?
Huggingface provides a PyTorch Lightning based high quality tuning script, and we prolonged it so as to add the Ray retrieval implementation as an option.
To try it out, first install the crucial requirements
pip install ray
pip install transformers
pip install -r transformers/examples/research_projects/rag/requirements.txt
Then, you’ll be able to specify your data paths and other configurations and run finetune-rag-ray.sh!
export PYTHONPATH="../":"${PYTHONPATH}"
ray start --head
python examples/rag/finetune_rag.py
--data_dir $DATA_DIR
--output_dir $OUTPUT_DIR
--model_name_or_path $MODEL_NAME_OR_PATH
--model_type rag_sequence
--fp16
--gpus 8
--profile
--do_train
--do_predict
--n_val -1
--train_batch_size 8
--eval_batch_size 1
--max_source_length 128
--max_target_length 25
--val_max_target_length 25
--test_max_target_length 25
--label_smoothing 0.1
--dropout 0.1
--attention_dropout 0.1
--weight_decay 0.001
--adam_epsilon 1e-08
--max_grad_norm 0.1
--lr_scheduler polynomial
--learning_rate 3e-05
--num_train_epochs 100
--warmup_steps 500
--gradient_accumulation_steps 1
--distributed_retriever ray
--num_retrieval_workers 4
ray stop
What’s next?
Using RAG with Huggingface transformers and the Ray retrieval implementation for faster distributed fine-tuning, you’ll be able to leverage RAG for retrieval-based generation on your personal knowledge-intensive tasks.
Also, hyperparameter tuning is one other aspect of transformer high quality tuning and might have huge impacts on accuracy. For scalable and straightforward hyperparameter tuning, try the Ray Tune library. Through the use of Ray Tune’s integration with PyTorch Lightning, or the built-in integration with Huggingface transformers, you’ll be able to run experiments to search out the right hyperparameters to your RAG model.
And lastly, stay tuned for a possible Tensorflow implementation of RAG on Huggingface!
When you plan to try RAG+Ray integration out, please be happy to share your experiences on the Ray Discourse or join the Ray community Slack for further discussion — we’d love to listen to from you!
Also published at https://medium.com/distributed-computing-with-ray/retrieval-augmented-generation-with-huggingface-transformers-and-ray-b09b56161b1e

