Sentence Transformers is a Python library for using and training embedding and reranker models for a big selection of applications, corresponding to retrieval augmented generation, semantic search, semantic textual similarity, paraphrase mining, and more. Its v4.0 update introduces a brand new training approach for rerankers, also referred to as cross-encoder models, just like what the v3.0 update introduced for embedding models. On this blogpost, I’ll show you how one can use it to finetune a reranker model that beats all existing options on exactly your data. This method may also train extremely strong recent reranker models from scratch.
Finetuning reranker models involves several components: datasets, loss functions, training arguments, evaluators, and the trainer class itself. I’ll have a take a look at each of those components, accompanied by practical examples of how they may be used for finetuning strong reranker models.
Lastly, within the Evaluation section, I’ll show you that my small finetuned tomaarsen/reranker-ModernBERT-base-gooaq-bce reranker model that I trained alongside this blogpost easily outperforms the 13 mostly used public reranker models on my evaluation dataset. It even beats models which are 4x greater.
Repeating the recipe with a much bigger base model leads to tomaarsen/reranker-ModernBERT-large-gooaq-bce, a reranker model that blows all existing general-purpose reranker models out of the water on my data.
When you’re interested by finetuning embedding models as a substitute, then consider reading through my prior Training and Finetuning Embedding Models with Sentence Transformers v3 blogpost as well.
Table of Contents
What are Reranker models?
Reranker models, often implemented using Cross Encoder architectures, are designed to guage the relevance between pairs of texts (e.g., a question and a document, or two sentences). Unlike Sentence Transformers (a.k.a. bi-encoders, embedding models), which independently embed each text into vectors and compute similarity via a distance metric, Cross Encoder process the paired texts together through a shared neural network, leading to one output rating. By letting the 2 texts attend to one another, Cross Encoder models can outperform embedding models.
Nevertheless, this strength comes with a trade-off: Cross Encoder models are slower as they process every possible pair of texts (e.g., 10 queries with 500 candidate documents requires 5,000 computations as a substitute of 510 for embedding models). This makes them less efficient for large-scale initial retrieval but ideal for reranking: refining the top-k results first identified by faster Sentence Transformer models. The strongest search systems commonly use this 2-stage “retrieve and rerank” approach.
Throughout this blogpost, I’ll use “reranker model” and “Cross Encoder model” interchangeably.
Why Finetune?
Reranker models are sometimes tasked with a difficult problem:
Which of those highly-related documents answers the query the perfect?
General-purpose reranker models are trained to perform adequately on this exact query in a big selection of domains and topics, stopping them from reaching their maximum potential in your specific domain. Through finetuning, the model can learn to focus exclusively on the domain and/or language that matters to you.
Within the Evaluation section end of this blogpost, I’ll show that training a model in your domain can outperform any general-purpose reranker model, even when those baselines are much greater. Don’t underestimate the ability of finetuning in your domain!
Training Components
Training reranker models involves the next components:
- Dataset: The information used for training and/or evaluation.
- Loss Function: A function that measures the model’s performance and guides the optimization process.
- Training Arguments (optional): Parameters that impact training performance, tracking, and debugging.
- Evaluator (optional): A category for evaluating the model before, during, or after training.
- Trainer: Brings together all training components.
Let’s take a more in-depth take a look at each component.
Dataset
The CrossEncoderTrainer uses datasets.Dataset or datasets.DatasetDict instances for training and evaluation. You’ll be able to load data from the Hugging Face Datasets Hub or use your local data in whatever format you favor (e.g. CSV, JSON, Parquet, Arrow, or SQL).
Note: A lot of public datasets that work out of the box with Sentence Transformers have been tagged with sentence-transformers on the Hugging Face Hub, so you’ll be able to easily find them on https://huggingface.co/datasets?other=sentence-transformers. Consider browsing through these to seek out ready-to-go datasets that could be useful in your tasks, domains, or languages.
Data on the Hugging Face Hub
You need to use the load_dataset function to load data from datasets within the Hugging Face Hub
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
print(train_dataset)
"""
Dataset({
features: ['query', 'answer'],
num_rows: 100231
})
"""
Some datasets, like nthakur/swim-ir-monolingual, have multiple subsets with different data formats. That you must specify the subset name together with the dataset name, e.g. dataset = load_dataset("nthakur/swim-ir-monolingual", "de", split="train").
Local Data (CSV, JSON, Parquet, Arrow, SQL)
You can too use load_dataset for loading local data in certain file formats:
from datasets import load_dataset
dataset = load_dataset("csv", data_files="my_file.csv")
dataset = load_dataset("json", data_files="my_file.json")
Local Data that requires pre-processing
You need to use datasets.Dataset.from_dict in case your local data requires pre-processing. This permits you to initialize your dataset with a dictionary of lists:
from datasets import Dataset
queries = []
documents = []
dataset = Dataset.from_dict({
"query": queries,
"document": documents,
})
Each key within the dictionary becomes a column within the resulting dataset.
Dataset Format
It can be crucial that your dataset format matches your loss function (or that you simply select a loss function that matches your dataset format and model). Verifying whether a dataset format and model work with a loss function involves three steps:
- All columns not named “label”, “labels”, “rating”, or “scores” are considered Inputs in line with the Loss Overview table. The variety of remaining columns must match the variety of valid inputs in your chosen loss.
- In case your loss function requires a Label in line with the Loss Overview table, then your dataset should have a column named “label”, “labels”, “rating”, or “scores”. This column is robotically taken because the label.
- The variety of model output labels matches what’s required for the loss in line with Loss Overview table.
For instance, given a dataset with columns ["text1", "text2", "label"] where the “label” column has float similarity rating starting from 0 to 1 and a model outputting 1 label, we are able to use it with BinaryCrossEntropyLoss because:
- the dataset has a “label” column as is required for this loss function.
- the dataset has 2 non-label columns, precisely the amount required by this loss functions.
- the model has 1 output label, exactly as required by this loss function.
You should definitely re-order your dataset columns with Dataset.select_columns in case your columns usually are not ordered accurately. For instance, in case your dataset has ["good_answer", "bad_answer", "question"] as columns, then this dataset can technically be used with a loss that requires (anchor, positive, negative) triplets, however the good_answer column will likely be taken because the anchor, bad_answer because the positive, and query because the negative.
Moreover, in case your dataset has extraneous columns (e.g. sample_id, metadata, source, type), it’s best to remove these with Dataset.remove_columns as they will likely be used as inputs otherwise. You can too use Dataset.select_columns to maintain only the specified columns.
Hard Negatives Mining
The success of coaching reranker models often depends upon the standard of the negatives, i.e. the passages for which the query-negative rating ought to be low. Negatives may be divided into two types:
- Soft negatives: passages which are completely unrelated. Also called easy negatives.
- Hard negatives: passages that look like they could be relevant for the query, but usually are not.
A concise example is:
- Query: Where was Apple founded?
- Soft Negative: The Cache River Bridge is a Parker pony truss that spans the Cache River between Walnut Ridge and Paragould, Arkansas.
- Hard Negative: The Fuji apple is an apple cultivar developed within the late Thirties, and delivered to market in 1962.
The strongest CrossEncoder models are generally trained to acknowledge hard negatives, and so it’s priceless to have the option to “mine” hard negatives to coach with. Sentence Transformers supports a powerful mine_hard_negatives function that may assist, given a dataset of query-answer pairs:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import mine_hard_negatives
train_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
print(train_dataset)
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=5,
range_min=10,
range_max=100,
max_score=0.8,
margin=0.1,
sampling_strategy="top",
batch_size=4096,
output_format="labeled-pair",
use_faiss=True,
)
print(hard_train_dataset)
print(hard_train_dataset[1])
Click to see the outputs of this script.
Dataset({
features: ['question', 'answer'],
num_rows: 100000
})
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 13.74it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 36.49it/s]
Querying FAISS index: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:19<00:00, 2.80s/it]
Metric Positive Negative Difference
Count 100,000 436,925
Mean 0.5882 0.4040 0.2157
Median 0.5989 0.4024 0.1836
Std 0.1425 0.0905 0.1013
Min -0.0514 0.1405 0.1014
25% 0.4993 0.3377 0.1352
50% 0.5989 0.4024 0.1836
75% 0.6888 0.4681 0.2699
Max 0.9748 0.7486 0.7545
Skipped 2420871 potential negatives (23.97%) as a consequence of the margin of 0.1.
Skipped 43 potential negatives (0.00%) as a consequence of the utmost rating of 0.8.
Couldn't find enough negatives for 63075 samples (12.62%). Consider adjusting the range_max, range_min, margin and max_score parameters should you'd like to seek out more valid negatives.
Dataset({
features: ['question', 'answer', 'label'],
num_rows: 536925
})
{
'query': 'how one can transfer bookmarks from one laptop to a different?',
'answer': 'Using an External Drive Nearly any external drive, including a USB thumb drive, or an SD card may be used to transfer your files from one laptop to a different. Connect the drive to your old laptop; drag your files to the drive, then disconnect it and transfer the drive contents onto your recent laptop.',
'label': 0
}
Loss Function
Loss functions help evaluate a model’s performance on a set of information and direct the training process. The suitable loss function in your task depends upon the info you have got and what you are trying to attain. You will discover a full list of accessible loss functions within the Loss Overview.
Most loss functions are easy to establish – you simply need to supply the CrossEncoder model you are training:
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.losses import CachedMultipleNegativesRankingLoss
model = CrossEncoder("xlm-roberta-base", num_labels=1)
loss = CachedMultipleNegativesRankingLoss(model)
train_dataset = load_dataset("sentence-transformers/gooaq", split="train")
...
Training Arguments
You’ll be able to customize the training process using the CrossEncoderTrainingArguments class. This class allows you to adjust parameters that may impact training speed and aid you understand what’s happening during training.
For more information on essentially the most useful training arguments, try the Cross Encoder > Training Overview > Training Arguments. It’s price reading to get essentially the most out of your training.
Here’s an example of how one can arrange CrossEncoderTrainingArguments:
from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments
args = CrossEncoderTrainingArguments(
output_dir="models/reranker-MiniLM-msmarco-v1",
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True,
bf16=False,
batch_sampler=BatchSamplers.NO_DUPLICATES,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="reranker-MiniLM-msmarco-v1",
)
Evaluator
To trace your model’s performance during training, you’ll be able to pass an eval_dataset to the CrossEncoderTrainer. Nevertheless, you may want more detailed metrics beyond just the evaluation loss. That is where evaluators can aid you assess your model’s performance using specific metrics at various stages of coaching. You need to use an evaluation dataset, an evaluator, each, or neither, depending in your needs. The evaluation strategy and frequency are controlled by the eval_strategy and eval_steps Training Arguments.
Sentence Transformers includes the next built-in evaluators:
You can too use SequentialEvaluator to hitch multiple evaluators into one, which may then be passed to the CrossEncoderTrainer. You can too just pass a listing of evaluators to the trainer.
Sometimes you haven’t got the required evaluation data to arrange one among these evaluators on your individual, but you continue to need to track how well the model performs on some common benchmarks. In that case, you should utilize these evaluators with data from Hugging Face.
CrossEncoderCorrelationEvaluator with STSb
The STS Benchmark (a.k.a. STSb) is a commonly used benchmarking dataset to measure the model’s understanding of semantic textual similarity of short texts like “A person is feeding a mouse to a snake.”.
Be happy to browse the sentence-transformers/stsb dataset on Hugging Face.
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderCorrelationEvaluator
model = CrossEncoder("cross-encoder/stsb-TinyBERT-L4")
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
pairs = list(zip(eval_dataset["sentence1"], eval_dataset["sentence2"]))
dev_evaluator = CrossEncoderCorrelationEvaluator(
sentence_pairs=pairs,
scores=eval_dataset["score"],
name="sts_dev",
)
CrossEncoderRerankingEvaluator with GooAQ mined negatives
Preparing data for CrossEncoderRerankingEvaluator may be difficult as you wish negatives along with your query-positive data.
The mine_hard_negatives function has a convenient include_positives parameter, which may be set to True to also mine for the positive texts. When supplied as documents (which should be 1. ranked and a couple of. contain positives) to CrossEncoderRerankingEvaluator, the evaluator is not going to just evaluate the reranking performance of the CrossEncoder, but additionally the unique rankings by the embedding model used for mining.
For instance:
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 67.28
MRR@10: 52.40 -> 66.65
NDCG@10: 59.12 -> 71.35
Note that by default, should you are using CrossEncoderRerankingEvaluator with documents, the evaluator will rerank with all positives, even in the event that they usually are not within the documents. This is beneficial for getting a stronger signal out of your evaluator, but does give a rather unrealistic performance. In spite of everything, the utmost performance is now 100, whereas normally its bounded by whether the first-stage retriever actually retrieved the positives.
You’ll be able to enable the realistic behaviour by setting always_rerank_positives=False when initializing CrossEncoderRerankingEvaluator. Repeating the identical script with this realistic two-stage performance leads to::
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 66.12
MRR@10: 52.40 -> 65.61
NDCG@10: 59.12 -> 70.10
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
from sentence_transformers.util import mine_hard_negatives
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
full_dataset = load_dataset("sentence-transformers/gooaq", split=f"train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(eval_dataset)
"""
Dataset({
features: ['question', 'answer'],
num_rows: 1000
})
"""
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_eval_dataset = mine_hard_negatives(
eval_dataset,
embedding_model,
corpus=full_dataset["answer"],
num_negatives=50,
batch_size=4096,
output_format="n-tuple",
include_positives=True,
use_faiss=True,
)
print(hard_eval_dataset)
"""
Dataset({
features: ['question', 'answer', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'negative_9', 'negative_10', 'negative_11', 'negative_12', 'negative_13', 'negative_14', 'negative_15', 'negative_16', 'negative_17', 'negative_18', 'negative_19', 'negative_20', 'negative_21', 'negative_22', 'negative_23', 'negative_24', 'negative_25', 'negative_26', 'negative_27', 'negative_28', 'negative_29', 'negative_30', 'negative_31', 'negative_32', 'negative_33', 'negative_34', 'negative_35', 'negative_36', 'negative_37', 'negative_38', 'negative_39', 'negative_40', 'negative_41', 'negative_42', 'negative_43', 'negative_44', 'negative_45', 'negative_46', 'negative_47', 'negative_48', 'negative_49', 'negative_50'],
num_rows: 1000
})
"""
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=[
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
],
batch_size=32,
name="gooaq-dev",
)
results = reranking_evaluator(model)
"""
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 67.28
MRR@10: 52.40 -> 66.65
NDCG@10: 59.12 -> 71.35
"""
Trainer
The CrossEncoderTrainer is where all previous components come together. We only should specify the trainer with the model, training arguments (optional), training dataset, evaluation dataset (optional), loss function, evaluator (optional) and we are able to start training. Let’s have a take a look at a script where all of those components come together:
import logging
import traceback
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import (
CrossEncoder,
CrossEncoderModelCardData,
CrossEncoderTrainer,
CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import (
CrossEncoderNanoBEIREvaluator,
CrossEncoderRerankingEvaluator,
)
from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
from sentence_transformers.util import mine_hard_negatives
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def primary():
model_name = "answerdotai/ModernBERT-base"
train_batch_size = 16
num_epochs = 1
num_hard_negatives = 5
model = CrossEncoder(
model_name,
model_card_data=CrossEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="ModernBERT-base trained on GooAQ",
),
)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
logging.info("Read the gooaq training dataset")
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=num_hard_negatives,
margin=0,
range_min=0,
range_max=100,
sampling_strategy="top",
batch_size=4096,
output_format="labeled-pair",
use_faiss=True,
)
logging.info(hard_train_dataset)
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=train_batch_size,
)
hard_eval_dataset = mine_hard_negatives(
eval_dataset,
embedding_model,
corpus=full_dataset["answer"],
num_negatives=30,
batch_size=4096,
include_positives=True,
output_format="n-tuple",
use_faiss=True,
)
logging.info(hard_eval_dataset)
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=[
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
],
batch_size=train_batch_size,
name="gooaq-dev",
always_rerank_positives=False,
)
evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
evaluator(model)
short_model_name = model_name if "https://huggingface.co/" not in model_name else model_name.split("https://huggingface.co/")[-1]
run_name = f"reranker-{short_model_name}-gooaq-bce"
args = CrossEncoderTrainingArguments(
output_dir=f"models/{run_name}",
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False,
bf16=True,
dataloader_num_workers=4,
load_best_model_at_end=True,
metric_for_best_model="eval_gooaq-dev_ndcg@10",
eval_strategy="steps",
eval_steps=4000,
save_strategy="steps",
save_steps=4000,
save_total_limit=2,
logging_steps=1000,
logging_first_step=True,
run_name=run_name,
seed=12,
)
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=hard_train_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
evaluator(model)
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:n{traceback.format_exc()}To upload it manually, you'll be able to run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
primary()
In this instance I’m finetuning from answerdotai/ModernBERT-base, a base model that just isn’t yet a Cross Encoder model. This generally requires more training data than finetuning an existing reranker model like Alibaba-NLP/gte-multilingual-reranker-base. I’m using 99k query-answer pairs from the GooAQ dataset, after which I mine hard negatives using the sentence-transformers/static-retrieval-mrl-en-v1 embedding model. This leads to 578k labeled pairs: 99k positive pairs (i.e. label=1) and 479k negative pairs (i.e. label=0).
I take advantage of the BinaryCrossEntropyLoss, which is well fitted to these labeled pairs. I also arrange 2 types of evaluation: CrossEncoderNanoBEIREvaluator which evaluates against the NanoBEIR benchmark and CrossEncoderRerankingEvaluator which evaluates the performance of reranking the highest 30 results from the aforementioned static embedding model. Afterwards, I define a reasonably standard set of hyperparameters, including learning rates, warmup ratios, bf16, loading the perfect model at the tip, and a few debugging parameters. Lastly, I run the trainer, perform post-training evaluation, and save the model each locally and on the Hugging Face Hub.
After running this script, the tomaarsen/reranker-ModernBERT-base-gooaq-bce model was uploaded for me. See the upcoming Evaluation section with evidence that this model outperformed 13 commonly used open-source alternatives, including much greater models. I also ran the model with answerdotai/ModernBERT-large as the bottom model, leading to tomaarsen/reranker-ModernBERT-large-gooaq-bce.
Evaluation results are robotically stored within the generated model card upon saving a model, alongside the bottom model, language, license, evaluation results, training & evaluation dataset info, hyperparameters, training logs, and more. With none effort, your uploaded models should contain all the knowledge that your potential users would wish to find out whether your model is suitable for them.
Callbacks
The CrossEncoder trainer supports various transformers.TrainerCallback subclasses, including:
WandbCallbackfor logging training metrics to W&B ifwandbis installedTensorBoardCallbackfor logging training metrics to TensorBoard iftensorboardis accessibleCodeCarbonCallbackfor tracking carbon emissions during training ifcodecarbonis installed
These are robotically used without you having to specify anything, so long as the required dependency is installed.
Consult with the Transformers Callbacks documentation for more information on these callbacks and how one can create your individual.
Multi-Dataset Training
Typically, top-performing general-purpose models are trained on multiple datasets concurrently. Nevertheless, this approach may be difficult as a consequence of the various formats of every dataset. Fortunately, the CrossEncoderTrainer permits you to train on multiple datasets without requiring a uniform format. Moreover, it provides the pliability to use different loss functions to every dataset. Listed here are the steps to coach with multiple datasets directly:
- Use a dictionary of
datasets.Datasetinstances (or adatasets.DatasetDict) because thetrain_dataset(and optionally alsoeval_dataset). - (Optional) Use a dictionary of loss functions mapping dataset names to losses. Only required should you wish to make use of different loss function for various datasets.
Each training/evaluation batch will only contain samples from one among the datasets. The order wherein batches are sampled from the multiple datasets is defined by the MultiDatasetBatchSamplers enum, which may be passed to the CrossEncoderTrainingArguments via multi_dataset_batch_sampler. Valid options are:
MultiDatasetBatchSamplers.ROUND_ROBIN: Round-robin sampling from each dataset until one is exhausted. With this strategy, it’s likely that not all samples from each dataset are used, but each dataset is sampled from equally.MultiDatasetBatchSamplers.PROPORTIONAL(default): Sample from each dataset in proportion to its size. With this strategy, all samples from each dataset are used and bigger datasets are sampled from more regularly.
Training Suggestions
Cross Encoder models have their very own unique quirks, so here’s some tricks to aid you out:
-
CrossEncoder models overfit fairly quickly, so it’s really useful to make use of an evaluator like
CrossEncoderNanoBEIREvaluatororCrossEncoderRerankingEvaluatoralong with theload_best_model_at_endandmetric_for_best_modeltraining arguments to load the model with the perfect evaluation performance after training. -
CrossEncoders are particularly receptive to strong hard negatives (
mine_hard_negatives). They teach the model to be very strict, useful e.g. when distinguishing between passages that answer an issue or passages that relate to an issue.- Note that should you only use hard negatives, your model may unexpectedly perform worse for easier tasks. This could mean that reranking the highest 200 results from a first-stage retrieval system (e.g. with a SentenceTransformer model) can actually give worse top-10 results than reranking the highest 100. Training using random negatives alongside hard negatives can mitigate this.
-
Don’t underestimate
BinaryCrossEntropyLoss, it stays a really strong option despite being simpler than learning-to-rank (LambdaLoss, ListNetLoss) or in-batch negatives (CachedMultipleNegativesRankingLoss, MultipleNegativesRankingLoss) losses, and its data is simple to arrange, especially usingmine_hard_negatives.
Evaluation
I ran a reranking evaluation of my model from the Trainer section against several baselines on the GooAQ development set with each always_rerank_positives=False and with always_rerank_positives=True within the reranking evaluator. These represent the realistic (only rerank what the retriever found) and evaluation (rerank all positives, even when the retriever didn’t find it) formats, respectively.
As a reminder, I used the extremely efficient sentence-transformers/static-retrieval-mrl-en-v1 static embedding model to retrieve the highest 30 for reranking.
Click to see the Evaluation Script & datasets
Here is the evaluation script:
import logging
from pprint import pprint
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def primary():
model_name = "tomaarsen/reranker-ModernBERT-base-gooaq-bce"
eval_batch_size = 64
model = CrossEncoder(model_name)
logging.info("Read the gooaq reranking dataset")
hard_eval_dataset = load_dataset("tomaarsen/gooaq-reranker-blogpost-datasets", "rerank", split="eval")
samples = [
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
]
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=samples,
batch_size=eval_batch_size,
name="gooaq-dev-realistic",
always_rerank_positives=False,
)
realistic_results = reranking_evaluator(model)
pprint(realistic_results)
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=samples,
batch_size=eval_batch_size,
name="gooaq-dev-evaluation",
always_rerank_positives=True,
)
evaluation_results = reranking_evaluator(model)
pprint(evaluation_results)
if __name__ == "__main__":
primary()
Which uses the rerank subset from my tomaarsen/gooaq-reranker-blogpost-datasets dataset. This dataset accommodates:
pairsubset,trainsplit: 99k training samples taken directly from GooAQ. This just isn’t directly used for training, but for preparing thehard-labeled-pairsubset, which is utilized in training.pairsubset,evalsplit: 1k training samples taken directly from GooAQ, no overlap with the previous 99k. This just isn’t directly used for evaluation, but used to arrange thereranksubset, which is utilized in evaluation.hard-labeled-pairsubset,trainsplit: 578k labeled pairs used for training, by mining with sentence-transformers/static-retrieval-mrl-en-v1 using the 99k samples from thepairsubset &trainsplit. This dataset is utilized in training.reranksubset,evalsplit: 1k samples with query, answer, and exactly 30 documents as retrieved by sentence-transformers/static-retrieval-mrl-en-v1 using the total 100k train and evaluation answers from my subset of GooAQ. This rating already has an NDCG@10 of 59.12.
With just 99k out of three million training pairs from the gooaq dataset and just half-hour of coaching on my RTX 3090, my small 150M tomaarsen/reranker-ModernBERT-base-gooaq-bce was in a position to handily outperform each <1B parameter general-purpose reranker. The larger tomaarsen/reranker-ModernBERT-large-gooaq-bce took lower than an hour to coach, and is in a league of its own with a large 79.42 NDCG@10 within the realistic setting. The GooAQ training and evaluation dataset aligns thoroughly with what these baselines were trained for, so the difference ought to be even larger when training for a more area of interest domain.
Note that this doesn’t mean that tomaarsen/reranker-ModernBERT-large-gooaq-bce is the strongest model here on all domains: It’s simply the strongest in our domain. This is completely effective, as we just need this reranker to work well on our data.
Don’t underestimate the ability of finetuning reranker models in your domain. You’ll be able to improve each the search performance and the latency of your search stack by finetuning a (small) reranker!
Additional Resources
Training Examples
These pages have training examples with explanations in addition to links to training scripts code. You need to use them to get conversant in the reranker training loop:
Documentation
For further learning, it’s possible you’ll also need to explore the next resources on Sentence Transformers:
And here is a sophisticated page which may interest you:


