Welcome EmbeddingGemma, Google’s latest efficient embedding model

-


Today, Google releases EmbeddingGemma, a state-of-the-art multilingual embedding model perfect for on-device use cases. Designed for speed and efficiency, the model encompasses a compact size of 308M parameters and a 2K context window, unlocking latest possibilities for mobile RAG pipelines, agents, and more. EmbeddingGemma is trained to support over 100 languages and is the highest-ranking text-only multilingual embedding model under 500M on the Massive Text Embedding Benchmark (MTEB) on the time of writing.



Table of Contents



Introduction

Text embeddings have change into the backbone of recent natural‑language applications, turning words, sentences, and documents into dense vectors that capture meaning, sentiment, and intent. These vectors enable fast similarity search, clustering, classification, and retrieval across massive corpora, powering every little thing from advice engines and semantic search to retrieval-augmented generation and code‑search tools. Embedding models that calculate these embeddings are widely used, with well over 200 million monthly downloads on Hugging Face.

Constructing on this foundation, Google DeepMind’s EmbeddingGemma arrives as the latest, most capable small multilingual embedding model yet. With just 308M parameters, a 2k‑token context window, and support for over 100 languages, EmbeddingGemma delivers state‑of‑the‑art performance on the Massive Multilingual Text Embedding Benchmark (MMTEB) while staying under 200 MB of RAM when quantized.

The assorted design decisions lead to a really practical, open-source tool for computing high-quality multilingual embeddings on on a regular basis devices.

On this blogpost, we describe the EmbeddingGemma architecture and training, and show you easy methods to use the model with various frameworks like Sentence Transformers, LangChain, LlamaIndex, Haystack, txtai, Transformers.js, Text Embedding Inference, and ONNX.

Afterwards, we exhibit easy methods to finetune EmbeddingGemma in your domain for even stronger performance. In our example, we finetune EmbeddingGemma on the Medical Instruction and Retrieval Dataset (MIRIAD). The resulting model, sentence-transformers/embeddinggemma-300m-medical, achieves state-of-the-art performance on our task: retrieving passages of scientific medical papers in response to detailed medical questions. It even outperforms models twice as big on this task.



Architecture

EmbeddingGemma builds on the Gemma3 transformers backbone, but modified to make use of bi-directional attention as an alternative of causal (one-way) attention. Because of this earlier tokens within the sequence can attend to later tokens, effectively turning the architecture from a decoder into an encoder. Encoder models can outperform LLMs, that are decoders, on embedding tasks like retrieval (Weller et al., 2025). With this backbone, the model can process a large 2048 tokens without delay, sufficient for typical retrieval inputs, especially on condition that larger inputs often lead to information loss within the text embeddings.

Beyond the brand new Gemma3-based encoder backbone, which produces token embeddings, a mean pooling layer converts these token embeddings into text embeddings. Lastly, two dense layers transform the text embeddings into their final form, a 768-dimensional vector.

The EmbeddingGemma model has been trained with Matryoshka Representation Learning (MRL), allowing you to truncate the 768‑dimensional output to 512, 256, or 128 dimensions on demand. This leads to faster downstream processing and lower memory and disk space utilization. See the Sentence Transformers usage for a snippet showing easy methods to perform this truncation.

The model has been trained using a rigorously curated, multilingual corpus totalling roughly 320 billion tokens. The proprietary dataset is a mix of publicly available web text, code and technical documentation, and artificial task‑specific examples. It has been filtered to avoid Child Sexual Abuse Material (CSAM), sensitive data, and low-quality or unsafe content.



Evaluation

EmbeddingGemma was benchmarked on the MMTEB (Multilingual, v2) and MTEB (English, v2) suites, which span a wide selection of tasks, domains, and languages. Despite its modest 308M‑parameter size, the model consistently beats comparable baselines while keeping a really small memory footprint.

MTEB (Multilingual, v2) Performance MTEB (English, v2) Performance

The outcomes will likely be listed on the official MTEB Leaderboard. We exclude any model that has been trained on greater than 20% of the MTEB data, to mitigate potential over‑fitting.



Demo


The demo can
even be experienced in full screen.


Experience the demo yourself on a Desktop device.



Usage

EmbeddingGemma is integrated with many popular tools, making it easy to include into your existing workflows and applications. The model has been integrated in Sentence Transformers, and thus also in projects that use Sentence Transformers behind the scenes, corresponding to LangChain, LlamaIndex, Haystack, and txtai. See the examples below to start along with your preferred framework.

For production deployments, you should utilize Text Embeddings Inference (TEI) to serve the model efficiently on various hardware configurations, and you should utilize Transformers.js to be used in web applications.

No matter your framework selection, you need to be mindful of the prompts. For embedding models, prompts are prepended to the input text to permit the model to differentiate between different tasks. EmbeddingGemma was trained with these prompt names and prompts, so that they also needs to be included when using the model:

  • query: "task: search result | query: ",
  • document: "title: none | text: ",
  • BitextMining: "task: search result | query: ",
  • Clustering: "task: clustering | query: ",
  • Classification: "task: classification | query: ",
  • InstructionRetrieval: "task: code retrieval | query: ",
  • MultilabelClassification: "task: classification | query: ",
  • PairClassification: "task: sentence similarity | query: ",
  • Reranking: "task: search result | query: ",
  • Retrieval-query: "task: search result | query: ",
  • Retrieval-document: "title: none | text: ",
  • STS: "task: sentence similarity | query: ",
  • Summarization: "task: summarization | query: "

In Sentence Transformers, the query and document prompts are used robotically when calling model.encode_query and model.encode_document, but for other frameworks you would possibly need to: $

  1. specify prompt names (e.g. “Reranking”),
  2. specify prompt strings (e.g. “task: search result | query: “), or
  3. manually prepend the prompts to your input text.

The next example scripts will exhibit this with various frameworks.



Sentence Transformers

You have to to put in the next packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers>=5.0.0



Retrieval

Inference using Sentence Transformers is fairly easy, see this instance for semantic search:

from sentence_transformers import SentenceTransformer


model = SentenceTransformer("google/embeddinggemma-300m")


query = "Which planet is often called the Red Planet?"
documents = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(documents)
print(query_embeddings.shape, document_embeddings.shape)



similarities = model.similarity(query_embeddings, document_embeddings)
print(similarities)



rating = similarities.argsort(descending=True)[0]
print(rating)

Click to see non-retrieval code

Should you’re not trying to use this model for Information Retrieval, then you definately’re likely best off using probably the most general encode method along with the model prompt that best describes your downstream task out of those options:

  • BitextMining: Find translated sentence pairs in two languages.
  • Clustering: Find similar texts to group them together.
  • Classification: Assign predefined labels to texts.
  • InstructionRetrieval: Retrieve relevant code snippets based on natural language instructions.
  • MultilabelClassification: Assign multiple labels to texts.
  • PairClassification: Assign predefined labels to texts.
  • Reranking: Reorder search results based on relevance.
  • Retrieval-query: Retrieve documents based on a question.
  • Retrieval-document: Retrieve documents based on their content.
  • STS: Compute semantic textual similarity between texts.
  • Summarization: Generate concise summaries of texts.
from sentence_transformers import SentenceTransformer


model = SentenceTransformer("google/embeddinggemma-300m")


print(model.prompts)

















texts = [
    "The weather is beautiful today.",
    "It's a lovely day outside.",
    "The stock market crashed yesterday.",
    "I enjoy programming with Python."
]
embeddings = model.encode(texts, prompt_name="STS")
print(embeddings.shape)



similarities = model.similarity(embeddings, embeddings)
print(similarities)
"""
tensor([[1.0000, 0.9305, 0.4660, 0.4326],
        [0.9305, 1.0000, 0.4227, 0.4434],
        [0.4660, 0.4227, 1.0000, 0.2638],
        [0.4326, 0.4434, 0.2638, 1.0000]])
"""
Click to see easy methods to truncate embedding dimensionality for faster and cheaper search

Because google/embeddinggemma-300m was trained with MRL, the embeddings generated by this model may be truncated to lower dimensionalities without considerably hurting the evaluation performance. Embeddings with lower dimensionalities are each cheaper to store on disk and in memory, in addition to faster for downstream tasks like retrieval, clustering, or classification.

In Sentence Transformers, you possibly can set a lower dimensionality using the truncate_dim parameter on either the SentenceTransformer initialization or when calling model.encode/model.encode_query/model.encode_document:

from sentence_transformers import SentenceTransformer


model = SentenceTransformer("google/embeddinggemma-300m", truncate_dim=256)


query = "Which planet is often called the Red Planet?"
documents = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(documents)
print(query_embeddings.shape, document_embeddings.shape)



similarities = model.similarity(query_embeddings, document_embeddings)
print(similarities)



rating = similarities.argsort(descending=True)[0]
print(rating)

Note that the rating is preserved despite using 3x smaller embeddings in comparison with the full-sized embeddings.



LangChain

Should you prefer, you can too use the LangChain HuggingFaceEmbeddings, which uses Sentence Transformers behind the scenes. Note that you’re going to need to tell LangChain to make use of the prompts called “query” and “document” for queries and documents, respectively. This instance involves a straightforward information retrieval setup, but the identical embedding model may be used in additional complex scenarios.

You have to to put in the next packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers
pip install langchain
pip install langchain-community
pip install langchain-huggingface
pip install faiss-cpu
from langchain.docstore.document import Document
from langchain_community.vectorstores import FAISS
from langchain_huggingface.embeddings import HuggingFaceEmbeddings




embedder = HuggingFaceEmbeddings(
    model_name="google/embeddinggemma-300m",
    query_encode_kwargs={"prompt_name": "query"},
    encode_kwargs={"prompt_name": "document"}
)

data = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]


documents = [Document(page_content=text, metadata={"id": i}) for i, text in enumerate(data)]




vector_store = FAISS.from_documents(documents, embedder, distance_strategy="MAX_INNER_PRODUCT")


query = "Which planet is often called the Red Planet?"
results = vector_store.similarity_search_with_score(query, k=3)


for doc, rating in results:
    print(f"Text: {doc.page_content} (rating: {rating:.4f})")
"""
Text: Mars, known for its reddish appearance, is sometimes called the Red Planet. (rating: 0.6359)
Text: Jupiter, the most important planet in our solar system, has a outstanding red spot. (rating: 0.4930)
Text: Saturn, famous for its rings, is typically mistaken for the Red Planet. (rating: 0.4889)
"""



LlamaIndex

EmbeddingGemma can be supported in LlamaIndex because it uses Sentence Transformers under the hood. For the proper behaviour, you might want to specify the query and document prompts as defined within the model configuration. Otherwise, your performance will likely be suboptimal. This script shows a rudimentary example of using EmbeddingGemma with LlamaIndex, but you should utilize the HuggingFaceEmbedding class in harder settings also.

You have to to put in the next packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers
pip install llama-index
pip install llama-index-embeddings-huggingface
pip install llama-index-vector-stores-faiss
import faiss
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores import VectorStoreQuery
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore




embeddings = HuggingFaceEmbedding(
    model_name="google/embeddinggemma-300m",
    query_instruction="task: search result | query: ",
    text_instruction="title: none | text: ",
)

data = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]


store = FaissVectorStore(faiss_index=faiss.IndexFlatIP(768))
store.add([TextNode(id=i, text=text, embedding=embeddings.get_text_embedding(text)) for i, text in enumerate(data)])


query = "Which planet is often called the Red Planet?"
query_embedding = embeddings.get_query_embedding(query)
results = store.query(VectorStoreQuery(query_embedding=query_embedding, similarity_top_k=3))


for idx, rating in zip(results.ids, results.similarities):
    print(f"Text: {data[int(idx)]} (rating: {rating:.4f})")
"""
Text: Mars, known for its reddish appearance, is sometimes called the Red Planet. (rating: 0.6359)
Text: Jupiter, the most important planet in our solar system, has a outstanding red spot. (rating: 0.4930)
Text: Saturn, famous for its rings, is typically mistaken for the Red Planet. (rating: 0.4889)
"""



Haystack

EmbeddingGemma may also be used with Haystack, a framework for constructing production-ready search and language applications. Like LangChain and LlamaIndex, Haystack uses Sentence Transformers behind the scenes and requires you to specify the suitable prompts. The next example shows easy methods to arrange a basic retrieval pipeline using EmbeddingGemma with Haystack.

You have to to put in the next packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers
pip install haystack-ai
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.retrievers import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore


document_store = InMemoryDocumentStore()


document_embedder = SentenceTransformersDocumentEmbedder(
    model="google/embeddinggemma-300m", encode_kwargs={"prompt_name": "document"}
)
query_embedder = SentenceTransformersTextEmbedder(
    model="google/embeddinggemma-300m", encode_kwargs={"prompt_name": "query"}
)
document_embedder.warm_up()
query_embedder.warm_up()

data = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.",
]


documents = [Document(content=text, id=str(i)) for i, text in enumerate(data)]
documents_with_embeddings = document_embedder.run(documents=documents)["documents"]
document_store.write_documents(documents_with_embeddings)


query_pipeline = Pipeline()
query_pipeline.add_component("text_embedder", query_embedder)
query_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=document_store, top_k=3))
query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")


query = "Which planet is often called the Red Planet?"
results = query_pipeline.run({"text_embedder": {"text": query}})


for document in results["retriever"]["documents"]:
    print(f"Text: {document.content} (rating: {document.rating:.4f})")
"""
Text: Mars, known for its reddish appearance, is sometimes called the Red Planet. (rating: 0.6359)
Text: Jupiter, the most important planet in our solar system, has a outstanding red spot. (rating: 0.4930)
Text: Saturn, famous for its rings, is typically mistaken for the Red Planet. (rating: 0.4889)
"""



txtai

txtai can be compatible with EmbeddingGemma. Like other frameworks, txtai utilizes Sentence Transformers under the hood and wishes the suitable prompts for optimal performance with EmbeddingGemma. The next example demonstrates easy methods to arrange a basic retrieval system with txtai.

You have to to put in the next packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers
pip install txtai
from txtai import Embeddings




embeddings = Embeddings(
    path="google/embeddinggemma-300m",
    method="sentence-transformers",
    instructions= query: ",
        "data": "title: none 
)

data = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]


embeddings.index(data)


query = "Which planet is often called the Red Planet?"
results = embeddings.search(query, 3)


for idx, rating in results:
    print(f"Text: {data[int(idx)]} (rating: {rating:.4f})")
"""
Text: Mars, known for its reddish appearance, is sometimes called the Red Planet. (rating: 0.6359)
Text: Jupiter, the most important planet in our solar system, has a outstanding red spot. (rating: 0.4930)
Text: Saturn, famous for its rings, is typically mistaken for the Red Planet. (rating: 0.4889)
"""



Transformers.js

You possibly can even run EmbeddingGemma 100% locally in your browser with Transformers.js! Should you have not already, you possibly can install the library from NPM using:

npm i @huggingface/transformers

You possibly can then compute embeddings as follows:

import { AutoModel, AutoTokenizer, matmul } from "@huggingface/transformers";


const model_id = "onnx-community/embeddinggemma-300m-ONNX";
const tokenizer = await AutoTokenizer.from_pretrained(model_id);
const model = await AutoModel.from_pretrained(model_id, {
  dtype: "fp32", 
});


const prefixes =  query: ",
  document: "title: none ;
const query = prefixes.query + "Which planet is often called the Red Planet?";
const documents = [
  "Venus is often called Earth's twin because of its similar size and proximity.",
  "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
  "Jupiter, the largest planet in our solar system, has a prominent red spot.",
  "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.",
].map((x) => prefixes.document + x);

const inputs = await tokenizer([query, ...documents], { padding: true });
const { sentence_embedding } = await model(inputs);


const scores = await matmul(sentence_embedding, sentence_embedding.transpose(1, 0));
const similarities = scores.tolist()[0].slice(1);
console.log(similarities);



const rating = similarities.map((rating, index) => ({ index, rating })).sort((a, b) => b.rating - a.rating);
console.log(rating);








Text Embeddings Inference

You possibly can easily deploy EmbeddingGemma for each development and production using Text Embeddings Inference (TEI) version 1.8.1 or later.

docker run -p 8080:80 ghcr.io/huggingface/text-embeddings-inference:cpu-1.8.1 --model-id google/embeddinggemma-300m --dtype float32
docker run -p 8080:80 ghcr.io/huggingface/text-embeddings-inference:cpu-1.8.1 --model-id onnx-community/embeddinggemma-300m-ONNX --dtype float32 --pooling mean
docker run --gpus all --shm-size 1g -p 8080:80 ghcr.io/huggingface/text-embeddings-inference:cuda-1.8.1 --model-id google/embeddinggemma-300m --dtype float32

Should you run the Docker container with the cuda-1.8.1 tag, it includes support for multiple GPU architectures: Turing, Ampere, Ada Lovelace, and Hopper. For a lighter image tailored to only your GPU, you possibly can as an alternative use a particular tag corresponding to turing-1.8.1, 1.8.1 and 86-1.8.1 (Ampere), 89-1.8.1 (Ada Lovelace), or hopper-1.8.1.

Once deployed, whatever the device or runtime, you possibly can leverage the /v1/embeddings endpoint based on the OpenAI Embeddings API Specification to generate embeddings.

curl http://0.0.0.0:8080/v1/embeddings -H "Content-Type: application/json" -d ' query: Where did Amelia Earhart first fly?"]'

Alternatively, you can too leverage the /embed endpoint from the Text Embeddings Inference Embeddings API, which supports the prompt_name parameter, meaning there’s no have to manually prepend the prompt to the inputs but select it via prompt_name as an alternative.

curl http://0.0.0.0:8080/embed -H "Content-Type: application/json" -d '{"inputs":["Which planet is known as the Red Planet?","Where did Amelia Earthart first fly?"],"prompt_name":"query","normalize":true}'

Moreover, note that since google/embeddinggemma-300m was trained with Matryoshka Representation Learning (MRL), you can too leverage the dimensions parameter, on each /v1/embeddings and /embed, to truncate the embeddings to lower dimensionalities (512, 256, and 128) without hurting the evaluation performance.



ONNX Runtime

You may also run the model directly with ONNX Runtime, making it highly portable and cross-platform compatible. The instance below shows usage in Python, but the identical approach may be applied in other languages (Java, C#, C++, etc.) as well.

from huggingface_hub import hf_hub_download
import onnxruntime as ort
from transformers import AutoTokenizer


model_id = "onnx-community/embeddinggemma-300m-ONNX"
model_path = hf_hub_download(model_id, subfolder="onnx", filename="model.onnx") 
hf_hub_download(model_id, subfolder="onnx", filename="model.onnx_data") 
session = ort.InferenceSession(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_id)


prefixes =  text: ",

query = prefixes["query"] + "Which planet is often called the Red Planet?"
documents = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
documents = [prefixes["document"] + x for x in documents]

inputs = tokenizer([query] + documents, padding=True, return_tensors="np")

_, sentence_embedding = session.run(None, inputs.data)
print(sentence_embedding.shape)  


query_embeddings = sentence_embedding[0]
document_embeddings = sentence_embedding[1:]
similarities = query_embeddings @ document_embeddings.T
print(similarities)  


rating = similarities.argsort()[::-1]
print(rating)  



Finetuning

As with all models compatible with the Sentence Transformers library, EmbeddingGemma may be easily fine-tuned in your specific dataset. To showcase this, we’ll be finetuning google/embeddinggemma-300m on the Medical Instruction and RetrIeval Dataset (MIRIAD) dataset, such that our finetuned model becomes particularly adept at finding passages as much as 1000 tokens from scientific medical papers given detailed medical questions. These passages may be used as crucial context for a generative model to reply questions more effectively.

Below, you possibly can explore each key component of the finetuning process using expandable tabs. Each tab incorporates the relevant code and an in depth explanation.

Model
from sentence_transformers import SentenceTransformer, SentenceTransformerModelCardData

model = SentenceTransformer(
    "google/embeddinggemma-300m",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="EmbeddingGemma-300m trained on the Medical Instruction and RetrIeval Dataset (MIRIAD)",
    ),
)







This code loads the EmbeddingGemma model from Hugging Face, with optional model card metadata for documentation and sharing. The SentenceTransformer class loads the model weights and configuration, while the model_card_data argument attaches metadata useful for inclusion within the robotically generated model card.

Dataset
from datasets import load_dataset

train_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="train").select(range(100_000))
eval_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="eval").select(range(1_000))
test_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="test").select(range(1_000))












This code loads the MIRIAD dataset, or fairly, a copy that has been divided into train, eval, and test splits. Using a big, high-quality dataset ensures the model learns meaningful representations, while subsetting allows for faster experimentation. The load_dataset function fetches the dataset from Hugging Face Datasets, and the .select() method limits the variety of samples for every split.

Loss Function
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss

loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=8)

This code defines the loss function for training, using Cached Multiple Negatives Rating Loss (CMNRL). CMNRL is effective for retrieval tasks, because it uses in-batch negatives to efficiently train the model to differentiate between correct and incorrect pairs. The loss takes question-answer pairs and treats other answers within the batch as negatives, maximizing the space between unrelated pairs within the embedding space. The mini_batch_size parameter controls the memory usage, but doesn’t affect the training dynamics.

It’s advisable to make use of this loss with a big per_device_train_batch_size in SentenceTransformerTrainingArguments and a low mini_batch_size in CachedMultipleNegativesRankingLoss for a robust training signal with low memory usage. Moreover, the NO_DUPLICATES batch sampler is advisable to avoid accidental false negatives.

Training Arguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers import SentenceTransformerTrainingArguments

run_name = "embeddinggemma-300m-medical-100k"
args = SentenceTransformerTrainingArguments(
    output_dir=f"models/{run_name}",
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  
    bf16=False,  
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    prompts={
        "query": model.prompts["query"],
        "passage_text": model.prompts["document"],
    },
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=20,
    run_name=run_name,
)

This code sets up all hyperparameters and configuration for training, evaluation, and logging. Proper training arguments are crucial for efficient, stable, and reproducible training. The arguments control batch sizes, learning rate, mixed precision, evaluation and saving frequency, and more. Notably, the prompts dictionary maps dataset columns to prompts utilized by the model to differentiate queries from documents.

Evaluator
from sentence_transformers.evaluation import InformationRetrievalEvaluator

queries = dict(enumerate(eval_dataset["question"]))
corpus = dict(enumerate(eval_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
relevant_docs = {idx: [idx] for idx in queries}
dev_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="miriad-eval-1kq-31kd",
    show_progress_bar=True,
)
dev_evaluator(model)

This code sets up an evaluator for information retrieval, using queries and a corpus to measure model performance. Evaluation during training helps monitor progress and avoid overfitting. The evaluator computes retrieval metrics (NDCG, MRR, Recall, Precision, MAP, etc.) by checking if the model retrieves the proper passages for every query. It will possibly be run before, during, and after training, and the outcomes will likely be logged and incorporated within the robotically generated model card.

Note that this snippet particularly uses all (1k) evaluation questions against a corpus of all (1k) evaluation passages and 30k training passages, for a complete of 31k documents. Evaluating only against evaluation passages is simply too easy for the model.

Trainer
from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

This code initializes and runs the training loop, coordinating all components.



Full Finetuning Script

Below is the entire script, combining all components above:

import logging
import traceback

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers


logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


model = SentenceTransformer(
    "google/embeddinggemma-300m",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="EmbeddingGemma-300m trained on the Medical Instruction and RetrIeval Dataset (MIRIAD)",
    ),
)


train_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="train").select(range(100_000))
eval_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="eval").select(range(1_000))
test_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="test").select(range(1_000))









loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=8)


run_name = "embeddinggemma-300m-medical-100k"
args = SentenceTransformerTrainingArguments(
    
    output_dir=f"models/{run_name}",
    
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  
    bf16=False,  
    batch_sampler=BatchSamplers.NO_DUPLICATES,  
    prompts={  
        "query": model.prompts["query"],
        "passage_text": model.prompts["document"],
    },
    
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=20,
    run_name=run_name,  
)


queries = dict(enumerate(eval_dataset["question"]))
corpus = dict(enumerate(eval_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
relevant_docs = {idx: [idx] for idx in queries}
dev_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="miriad-eval-1kq-31kd",  
    show_progress_bar=True,
)
dev_evaluator(model)


trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()



dev_evaluator(model)

queries = dict(enumerate(test_dataset["question"]))
corpus = dict(enumerate(test_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
relevant_docs = {idx: [idx] for idx in queries}
test_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="miriad-test-1kq-31kd",  
    show_progress_bar=True,
)
test_evaluator(model)


final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)



try:
    model.push_to_hub(run_name)
except Exception:
    logging.error(
        f"Error uploading model to the Hugging Face Hub:n{traceback.format_exc()}To upload it manually, you possibly can run "
        f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
        f"and saving it using `model.push_to_hub('{run_name}')`."
    )



Training

We ran the total training script on an RTX 3090 with 24GB of VRAM, and the finished training and evaluating scripts took 5.5 hours. If desired, you possibly can further reduce the memory footprint by reducing mini_batch_size on the CachedMultipleNegativesRankingLoss and batch_size on the InformationRetrievalEvaluator instances. See here the logs from our training run:

Epoch Step Training Loss Validation Loss miriad-eval-1kq-31kd_cosine_ndcg@10 miriad-test-1kq-31kd_cosine_ndcg@10
-1 -1 0.8474 0.8340
0.0256 20 0.1019
0.0512 40 0.0444
0.0767 60 0.0408
0.1023 80 0.0462
0.1279 100 0.0542 0.0525 0.8616
0.1535 120 0.0454
0.1790 140 0.0403
0.2046 160 0.0463
0.2302 180 0.0508
0.2558 200 0.0497 0.0449 0.8643
0.2813 220 0.0451
0.3069 240 0.0445
0.3325 260 0.0489
0.3581 280 0.0452
0.3836 300 0.0461 0.0406 0.8832
0.4092 320 0.0415
0.4348 340 0.04
0.4604 360 0.0399
0.4859 380 0.0423
0.5115 400 0.0352 0.0316 0.8823
0.5371 420 0.0408
0.5627 440 0.0356
0.5882 460 0.0371
0.6138 480 0.0276
0.6394 500 0.028 0.0280 0.8807
0.6650 520 0.0302
0.6905 540 0.0345
0.7161 560 0.0325
0.7417 580 0.033
0.7673 600 0.0314 0.0264 0.8910
0.7928 620 0.033
0.8184 640 0.029
0.8440 660 0.0396
0.8696 680 0.0266
0.8951 700 0.0262 0.0240 0.8968
0.9207 720 0.0262
0.9463 740 0.0327
0.9719 760 0.0293
0.9974 780 0.0304
-1 -1 0.9026 0.8862



Finetuned Evaluation

The performance of the bottom model was already excellent, with a robust 0.8340 NDCG@10 on our MIRIAD test set. Despite that, we were capable of increase it considerably on this domain-specific dataset.

Our fine-tuning process achieved a major improvement of +0.0522 NDCG@10 on the test set, leading to a model that comfortably outperforms any existing general-purpose embedding model on our specific task, at this model size. Additional time and compute investment would allow for even stronger results, corresponding to hard negatives mining or training with greater than 100k data pairs.



Further Reading



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x