Training and Finetuning Embedding Models with Sentence Transformers v3

-


Tom Aarsen's avatar


Sentence Transformers is a Python library for using and training embedding models for a big selection of applications, corresponding to retrieval augmented generation, semantic search, semantic textual similarity, paraphrase mining, and more. Its v3.0 update is the most important for the reason that project’s inception, introducing a brand new training approach. On this blogpost, I’ll show you find out how to use it to finetune Sentence Transformer models to enhance their performance on specific tasks. You too can use this method to coach latest Sentence Transformer models from scratch.

Finetuning Sentence Transformers now involves several components, including datasets, loss functions, training arguments, evaluators, and the brand new trainer itself. I’ll undergo each of those components intimately and supply examples of find out how to use them to coach effective models.



Table of Contents



Why Finetune?

Finetuning Sentence Transformer models can significantly enhance their performance on specific tasks. It’s because each task requires a novel notion of similarity. Let’s consider a pair of reports article headlines for example:

  • “Apple launches the brand new iPad”
  • “NVIDIA is gearing up for the subsequent GPU generation”

Depending on the use case, we would want similar or dissimilar embeddings for these texts. As an illustration, a classification model for news articles could treat these texts as similar since they each belong to the Technology category. Then again, a semantic textual similarity or retrieval model should consider them dissimilar as a result of their distinct meanings.



Training Components

Training Sentence Transformer models involves the next components:

  1. Dataset: The information used for training and evaluation.
  2. Loss Function: A function that quantifies the model’s performance and guides the optimization process.
  3. Training Arguments (optional): Parameters that influence training performance and tracking/debugging.
  4. Evaluator (optional): A tool for evaluating the model before, during, or after training.
  5. Trainer: Brings together the model, dataset, loss function, and other components for training.

Now, let’s dive into each of those components in additional detail.



Dataset

The SentenceTransformerTrainer uses datasets.Dataset or datasets.DatasetDict instances for training and evaluation. You’ll be able to load data from the Hugging Face Datasets Hub or use local data in various formats corresponding to CSV, JSON, Parquet, Arrow, or SQL.

Note: Many Hugging Face datasets that work out of the box with Sentence Transformers have been tagged with sentence-transformers, allowing you to simply find them by browsing to https://huggingface.co/datasets?other=sentence-transformers. We strongly recommend that you just browse these datasets to seek out training datasets that is likely to be useful in your tasks.



Data on Hugging Face Hub

To load data from datasets within the Hugging Face Hub, use the load_dataset function:

from datasets import load_dataset

train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev")

print(train_dataset)
"""
Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942069
})
"""

Some datasets, like sentence-transformers/all-nli, have multiple subsets with different data formats. You could specify the subset name together with the dataset name.



Local Data (CSV, JSON, Parquet, Arrow, SQL)

If you could have local data in common file formats, you may easily load it using load_dataset too:

from datasets import load_dataset

dataset = load_dataset("csv", data_files="my_file.csv")

dataset = load_dataset("json", data_files="my_file.json")



Local Data that requires pre-processing

In case your local data requires pre-processing, you need to use datasets.Dataset.from_dict to initialize your dataset with a dictionary of lists:

from datasets import Dataset

anchors = []
positives = []



dataset = Dataset.from_dict({
    "anchor": anchors,
    "positive": positives,
})

Each key within the dictionary becomes a column within the resulting dataset.



Dataset Format

It’s crucial to be sure that your dataset format matches your chosen loss function. This involves checking two things:

  1. In case your loss function requires a Label (as indicated within the Loss Overview table), your dataset will need to have a column named “label” or “rating”.
  2. All columns aside from “label” or “rating” are considered Inputs (as indicated within the Loss Overview table). The variety of these columns must match the variety of valid inputs in your chosen loss function. The names of the columns don’t matter, only their order matters.

For instance, in case your loss function accepts (anchor, positive, negative) triplets, then your first, second, and third dataset columns correspond with anchor, positive, and negative, respectively. Which means that your first and second column must contain texts that ought to embed closely, and that your first and third column must contain texts that ought to embed far apart. That’s the reason depending in your loss function, your dataset column order matters.

Consider a dataset with columns ["text1", "text2", "label"], where the "label" column incorporates floating point similarity scores. This dataset might be used with CoSENTLoss, AnglELoss, and CosineSimilarityLoss because:

  1. The dataset has a “label” column, which is required by these loss functions.
  2. The dataset has 2 non-label columns, matching the variety of inputs required by these loss functions.

If the columns in your dataset will not be ordered accurately, use Dataset.select_columns to reorder them. Moreover, remove any extraneous columns (e.g., sample_id, metadata, source, type) using Dataset.remove_columns, as they can be treated as inputs otherwise.



Loss Function

Loss functions measure how well a model performs on a given batch of knowledge and guide the optimization process. The selection of loss function will depend on your available data and goal task. Check with the Loss Overview for a comprehensive list of options.

Most loss functions might be initialized with just the SentenceTransformer model that you just’re training:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss


model = SentenceTransformer("FacebookAI/xlm-roberta-base")



loss = CoSENTLoss(model)


train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train")
"""
Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 942069
})
"""



Training Arguments

The SentenceTransformersTrainingArguments class lets you specify parameters that influence training performance and tracking/debugging. While optional, experimenting with these arguments can assist improve training efficiency and supply insights into the training process.

Within the Sentence Transformers documentation, I’ve outlined among the most useful training arguments. I might recommend reading it in Training Overview > Training Arguments.

Here’s an example of find out how to initialize SentenceTransformersTrainingArguments:

from sentence_transformers.training_args import SentenceTransformerTrainingArguments

args = SentenceTransformerTrainingArguments(
    
    output_dir="models/mpnet-base-all-nli-triplet",
    
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    fp16=True,  
    bf16=False,  
    batch_sampler=BatchSamplers.NO_DUPLICATES,  
    
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet",  
)

Note that eval_strategy was introduced in transformers version 4.41.0. Prior versions should use evaluation_strategy as an alternative.



Evaluator

You’ll be able to provide the SentenceTransformerTrainer with an eval_dataset to get the evaluation loss during training, however it could also be useful to get more concrete metrics during training, too. For this, you need to use evaluators to evaluate the model’s performance with useful metrics before, during, or after training. You’ll be able to each an eval_dataset and an evaluator, one or the opposite, or neither. They evaluate based on the eval_strategy and eval_steps Training Arguments.

Listed below are the implemented Evaluators that include Sentence Transformers:

Moreover, you need to use SequentialEvaluator to mix multiple evaluators into one, which might then be passed to the SentenceTransformerTrainer.

In case you do not have the needed evaluation data but still need to track the model’s performance on common benchmarks, you need to use these evaluators with data from Hugging Face:



EmbeddingSimilarityEvaluator with STSb

The STS Benchmark (a.k.a. STSb) is a commonly used benchmarking dataset to measure the model’s understanding of semantic textual similarity of short texts like “A person is feeding a mouse to a snake.”.

Be at liberty to browse the sentence-transformers/stsb dataset on Hugging Face.

from datasets import load_dataset
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction


eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")


dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    scores=eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)






TripletEvaluator with AllNLI

AllNLI is a concatenation of the SNLI and MultiNLI datasets, each of that are datasets for Natural Language Inference. This task is traditionally for determining whether two texts are an entailment, contradiction, or neither. It has since been adopted for training embedding models, because the entailing and contradictory sentences make for useful (anchor, positive, negative) triplets: a typical format for training embedding models.

On this snippet, it’s used to guage how continuously the model considers the anchor text and the entailing text to be more similar than the anchor text and the contradictory text. An example text is “An older man is drinking orange juice at a restaurant.”.

Be at liberty to browse the sentence-transformers/all-nli dataset on Hugging Face.

from datasets import load_dataset
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction


max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split=f"dev[:{max_samples}]")


dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    main_distance_function=SimilarityFunction.COSINE,
    name=f"all-nli-{max_samples}-dev",
)






Trainer

The SentenceTransformerTrainer brings together the model, dataset, loss function, and other components for training:

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator


model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on AllNLI triplets",
    )
)


dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(100_000))
eval_dataset = dataset["dev"]
test_dataset = dataset["test"]


loss = MultipleNegativesRankingLoss(model)


args = SentenceTransformerTrainingArguments(
    
    output_dir="models/mpnet-base-all-nli-triplet",
    
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    fp16=True,  
    bf16=False,  
    batch_sampler=BatchSamplers.NO_DUPLICATES,  
    
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet",  
)


dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    name="all-nli-dev",
)
dev_evaluator(model)


trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()


test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)
test_evaluator(model)


model.save_pretrained("models/mpnet-base-all-nli-triplet/final")


model.push_to_hub("mpnet-base-all-nli-triplet")

In this instance I’m finetuning from microsoft/mpnet-base, a base model that shouldn’t be yet a Sentence Transformer model. This requires more training data than finetuning an existing Sentence Transformer model, like all-mpnet-base-v2.

After running this script, the tomaarsen/mpnet-base-all-nli-triplet model was uploaded for me. The triplet accuracy using cosine similarity, i.e. what percentage of the time cosine_similarity(anchor, positive) > cosine_similarity(anchor, negative) is 90.04% for the event set and 91.5% for the testing set! For reference, the microsoft/mpnet-base model scored only 68.32% on the dev set before training.

All of this information is stored within the robotically generated model card, including the bottom model, language, license, evaluation results, training & evaluation dataset info, hyperparameters, training logs, and more. With none effort, your uploaded models should contain all the knowledge that your potential users would wish to find out whether your model is suitable for them.



Callbacks

The Sentence Transformers trainer supports various transformers.TrainerCallback subclasses, including:

  • WandbCallback for logging training metrics to W&B if wandb is installed
  • TensorBoardCallback for logging training metrics to TensorBoard if tensorboard is accessible
  • CodeCarbonCallback for tracking carbon emissions during training if codecarbon is installed

These are robotically used without you having to specify anything, so long as the required dependency is installed.

Check with the Transformers Callbacks documentation for more information on these callbacks and find out how to create your individual.



Multi-Dataset Training

Top-performing models are sometimes trained using multiple datasets concurrently. The SentenceTransformerTrainer simplifies this process by allowing you to coach with multiple datasets without converting them to the identical format. You’ll be able to even apply different loss functions to every dataset. Listed below are the steps for multi-dataset training:

  1. Use a dictionary of datasets.Dataset instances (or a datasets.DatasetDict) because the train_dataset and eval_dataset.
  2. (Optional) Use a dictionary of loss functions mapping dataset names to losses if you would like to use different losses for various datasets.

Each training/evaluation batch will contain samples from only one among the datasets. The order during which batches are sampled from the multiple datasets is set by the MultiDatasetBatchSamplers enum, which might be passed to the SentenceTransformersTrainingArguments via multi_dataset_batch_sampler. The valid options are:

  • MultiDatasetBatchSamplers.ROUND_ROBIN: Samples from each dataset in a round-robin fashion until one is exhausted. This strategy may not use all samples from each dataset, however it ensures equal sampling from each dataset.
  • MultiDatasetBatchSamplers.PROPORTIONAL (default): Samples from each dataset proportionally to its size. This strategy ensures that every one samples from each dataset are used, and bigger datasets are sampled from more continuously.

Multi-task training has proven to be highly effective. As an illustration, Huang et al. 2024 employed MultipleNegativesRankingLoss, CoSENTLoss, and a variation of MultipleNegativesRankingLoss without in-batch negatives and only hard negatives to realize state-of-the-art performance on Chinese. In addition they applied MatryoshkaLoss to enable the model to supply Matryoshka Embeddings.

Here’s an example of multi-dataset training:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss, SoftmaxLoss


model = SentenceTransformer("bert-base-uncased")



all_nli_pair_train = load_dataset("sentence-transformers/all-nli", "pair", split="train[:10000]")

all_nli_pair_class_train = load_dataset("sentence-transformers/all-nli", "pair-class", split="train[:10000]")

all_nli_pair_score_train = load_dataset("sentence-transformers/all-nli", "pair-score", split="train[:10000]")

all_nli_triplet_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:10000]")

stsb_pair_score_train = load_dataset("sentence-transformers/stsb", split="train[:10000]")

quora_pair_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:10000]")

natural_questions_train = load_dataset("sentence-transformers/natural-questions", split="train[:10000]")


train_dataset = {
    "all-nli-pair": all_nli_pair_train,
    "all-nli-pair-class": all_nli_pair_class_train,
    "all-nli-pair-score": all_nli_pair_score_train,
    "all-nli-triplet": all_nli_triplet_train,
    "stsb": stsb_pair_score_train,
    "quora": quora_pair_train,
    "natural-questions": natural_questions_train,
}



all_nli_triplet_dev = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")

stsb_pair_score_dev = load_dataset("sentence-transformers/stsb", split="validation")

quora_pair_dev = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[10000:11000]")

natural_questions_dev = load_dataset("sentence-transformers/natural-questions", split="train[10000:11000]")


eval_dataset = {
    "all-nli-triplet": all_nli_triplet_dev,
    "stsb": stsb_pair_score_dev,
    "quora": quora_pair_dev,
    "natural-questions": natural_questions_dev,
}



mnrl_loss = MultipleNegativesRankingLoss(model)

softmax_loss = SoftmaxLoss(model)

cosent_loss = CoSENTLoss(model)



losses = {
    "all-nli-pair": mnrl_loss,
    "all-nli-pair-class": softmax_loss,
    "all-nli-pair-score": cosent_loss,
    "all-nli-triplet": mnrl_loss,
    "stsb": cosent_loss,
    "quora": mnrl_loss,
    "natural-questions": mnrl_loss,
}


trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=losses,
)
trainer.train()


model.save_pretrained("bert-base-all-nli-stsb-quora-nq")
model.push_to_hub("bert-base-all-nli-stsb-quora-nq")



Deprecation

Prior to the Sentence Transformer v3 release, all models could be trained using the SentenceTransformer.fit method. Fairly than deprecating this method, ranging from v3.0, this method will use the SentenceTransformerTrainer behind the scenes. Which means that your old training code should still work, and may even be upgraded with the brand new features corresponding to multi-gpu training, loss logging, etc. That said, the brand new training approach is way more powerful, so it’s beneficial to write down latest training scripts using the brand new approach.



Additional Resources



Training Examples

The next pages contain training examples with explanations in addition to links to code. We recommend that you just flick through these to familiarize yourself with the training loop:



Documentation

Moreover, the next pages could also be useful to learn more about Sentence Transformers:

And lastly, listed here are some advanced pages that may interest you:



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x