Get lightning-fast inference, quick quantization, torch.compile boosts, and effortless fine-tuning
for any timm model—all inside the friendly 🤗 transformers ecosystem.
Enter TimmWrapper—a straightforward,
yet powerful tool that unlocks this potential.
On this post, we’ll cover:
- How the timm integration works and why it’s a game-changer.
- Methods to integrate
timmmodels with 🤗transformers. - Practical examples: pipelines, quantization, fine-tuning, and more.
To follow together with this blog post, install the most recent version of
transformersandtimmby running:pip install -Uq transformers timm
Try the total repository for all code examples and notebooks:
🔗 TimmWrapper Examples
What’s timm?
The PyTorch Image Models (timm) library
offers a wealthy collection of state-of-the-art computer vision models,
together with useful layers, utilities, optimizers, and data augmentations.
With greater than 32K GitHub stars and greater than 200K day by day downloads on the time of writing,
it is a go-to resource for image classification and have extraction for object detection,
segmentation, image search, and other downstream tasks.
With pre-trained models covering a big selection of architectures, timm simplifies the workflow for
computer vision practitioners.
Why Use the timm integration?
While 🤗 transformers supports several vision models, timm offers a good broader collection,
including many mobile-friendly and efficient models not available in transformers.
The timm integration bridges this gap, bringing the perfect of each worlds:
- ✅ Pipeline API Support: Easily plug any
timmmodel into the high-leveltransformerspipeline for streamlined inference. - 🧩 Compatibility with Auto Classes: While
timmmodels aren’t natively compatible withtransformers, the combination makes them work seamlessly with theAutoclasses API. - ⚡ Quick Quantization: With just ~5 lines of code, you’ll be able to quantize any
timmmodel for efficient inference. - 🎯 High-quality-Tuning with Trainer API: High-quality-tune
timmmodels using theTrainerAPI and even integrate with adapters like low rank adaptation (LoRA). - 🔁 Round trip to timm: Use fine-tuned models back in
timm. - 🚀 Torch Compile for Speed: Leverage
torch.compileto optimize inference time.
Pipeline API: Using timm Models for Image Classification
One in all the standout features of the timm integration is that it means that you can leverage the 🤗 pipeline API.
The pipeline API abstracts away lots of complexity, making it easy to load a pre-trained model,
perform inference, and examine results with a couple of lines of code.
Let’s have a look at easy methods to use a transformers pipeline with the MobileNetV4. This architecture doesn’t have a native transformers implementation, but may be easily used from timm:
from transformers import pipeline
import requests
image_classifier = pipeline(model="timm/mobilenetv4_conv_medium.e500_r256_in1k")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/fundamental/timm/cat.jpg"
outputs = image_classifier(url)
for output in outputs:
print(f"Label: {output['label'] :20} Rating: {output['score'] :0.2f}")
Outputs:
Device set to make use of cpu
Label: tabby, tabby cat Rating: 0.69
Label: tiger cat Rating: 0.21
Label: Egyptian cat Rating: 0.02
Label: bee Rating: 0.00
Label: marmoset Rating: 0.00
Gradio Integration: Constructing a Food Classifier Demo 🍣
Need to quickly create an interactive web app for image classification? Gradio makes it easy
to construct a user-friendly interface with minimal code. Let’s mix Gradio with the pipeline API
to categorise food images using a fine-tuned timm ViT model (we’ll cover nice tuning in a later section).
Here’s how you’ll be able to arrange a fast demo with a timm model:
import gradio as gr
from transformers import pipeline
pipe = pipeline(
"image-classification",
model="ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
)
def classify(image):
return pipe(image)[0]["label"]
demo = gr.Interface(
fn=classify,
inputs=gr.Image(type="pil"),
outputs="text",
examples=[["./sushi.png", "sushi"]]
)
demo.launch()
Here’s a live example hosted on Hugging Face Spaces. You may test it directly in your browser!
Auto Classes: Simplifying Model Loading
The 🤗 transformers library provides Auto Classes to abstract away the complexity of loading
models and processors. With the TimmWrapper, you should use AutoModelForImageClassification
and AutoImageProcessor to load any timm model effortlessly.
Here’s a fast example:
from transformers import (
AutoModelForImageClassification,
AutoImageProcessor,
)
from transformers.image_utils import load_image
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/fundamental/timm/cat.jpg"
image = load_image(image_url)
checkpoint = "timm/mobilenetv4_conv_medium.e500_r256_in1k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
print(type(image_processor))
print(type(model))
Running quantized timm models
Quantization is a robust technique to reduce model size and speed up inference,
especially for deployment on resource-constrained devices. With the timm integration,
you’ll be able to quantize any timm model on the fly with just a couple of lines of code using
BitsAndBytesConfig from bitsandbytes.
Here’s how easy it’s to quantize a timm model:
from transformers import TimmWrapperForImageClassification, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
checkpoint = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to("cuda")
model_8bit = TimmWrapperForImageClassification.from_pretrained(
checkpoint,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
)
original_footprint = model.get_memory_footprint()
quantized_footprint = model_8bit.get_memory_footprint()
print(f"Original model size: {original_footprint / 1e6:.2f} MB")
print(f"Quantized model size: {quantized_footprint / 1e6:.2f} MB")
print(f"Reduction: {(original_footprint - quantized_footprint) / original_footprint * 100:.2f}%")
Output:
Original model size: 346.27 MB
Quantized model size: 88.20 MB
Reduction: 74.53%
Quantized models perform almost identically to full-precision models during inference:
| Model | Label | Accuracy |
|---|---|---|
| Original Model | handheld remote control, distant | 0.35% |
| Quantized Model | handheld remote control, distant | 0.33% |
Supervised High-quality-Tuning of timm models
High-quality-tuning a timm model with the Trainer API from 🤗 transformers is straightforward and highly flexible.
You may fine-tune your model on custom datasets using the Trainer class, which handles the training loop,
logging, and evaluation. Moreover, you’ll be able to fine-tune using LoRA (Low-Rank Adaptation) to coach efficiently with fewer parameters.
This section gives a quick overview of each standard fine-tuning and LoRA fine-tuning, with links to the whole code.
Standard High-quality-Tuning with the Trainer API
The Trainer API makes it easy to establish training with minimal code. Here’s a top level view of what a fine-tuning setup looks like:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="my_model_output",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
num_train_epochs=3,
load_best_model_at_end=True,
push_to_hub=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
What’s remarkable about this approach is that it mirrors the precise workflow used for native transformers models,
maintaining consistency across different model types.
This implies you should use the familiar Trainer API to fine-tune not only Transformers models, but
also any timm model—bringing powerful models from the timm library into the Hugging Face
ecosystem with minimal adjustments. This significantly broadens the scope of models you’ll be able to fine-tune
using the identical trusted tools and workflows.
Model Example:
High-quality-tuned ViT on Food-101: vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101
LoRA High-quality-Tuning for Efficient Training
LoRA (Low-Rank Adaptation) means that you can fine-tune large models efficiently by training only a
few additional parameters, slightly than the total set of model weights. This makes fine-tuning faster,
and allows using consumer hardware. You may fine-tune a timm model using
LoRA with the PEFT library.
Here’s how you’ll be able to set it up:
from peft import LoraConfig, get_peft_model
model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=num_labels)
lora_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["qkv"],
lora_dropout=0.1,
bias="none",
modules_to_save=["head"],
)
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()
Trainable Parameters with LoRA:
trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.77%
Model Example:
LoRA High-quality-Tuned ViT on Food-101: vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101
LoRA is only one example of efficient adapter-based fine-tuning methods you’ll be able to apply to timm models.
The mixing of timm with the 🤗 ecosystem opens up a wide selection of parameter-efficient fine-tuning (PEFT) techniques,
allowing you to decide on the tactic that most closely fits your use case.
Inference with LoRA High-quality-Tuned Model
Once the model is LoRA fine-tuned, we only push the adapter weights to the Hugging Face Hub. This section helps
you to download the adapter weights, merge the adapter weights with the bottom model, after which perform inference.
from peft import PeftModel, PeftConfig
repo_name = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101"
config = PeftConfig.from_pretrained(repo_name)
model = AutoModelForImageClassification.from_pretrained(
config.base_model_name_or_path,
label2id=label2id,
num_labels=num_labels,
id2label=id2label,
ignore_mismatched_sizes=True,
)
inference_model = PeftModel.from_pretrained(model, repo_name)
Round trip integration
One in all Ross’ (creator of timm) favourite features is that this integration maintains
full ’round-trip’ compatibility. Namely, using the wrapper one can fine-tune a timm model on a brand new dataset using transformer‘s Trainer, publish the resulting model to the Hugging Face hub, after which load the fine-tuned model in timm again using timm.create_model('hf-hub:my_org/my_fine_tuned_model', pretrained=True).
Allow us to see how we will load our nice tuned model ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101 with timm
checkpoint = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
config = AutoConfig.from_pretrained(checkpoint)
model = timm.create_model(f"hf_hub:{checkpoint}", pretrained=True)
model = model.eval()
image = load_image("https://cdn.britannica.com/52/128652-050-14AD19CA/Maki-zushi.jpg")
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(image).unsqueeze(0))
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
for prob, idx in zip(top5_probabilities[0], top5_class_indices[0]):
print(f"Label: {config.id2label[idx.item()] :20} Rating: {prob/100 :0.2f}%")
Outputs
Label: sushi Rating: 0.98%
Label: spring_rolls Rating: 0.01%
Label: sashimi Rating: 0.00%
Label: club_sandwich Rating: 0.00%
Label: cannoli Rating: 0.00%
Torch Compile: Quick Speedup
With torch.compile in PyTorch 2.0, you’ll be able to achieve faster inference by compiling your model
with only one line of code. The timm integration is fully compatible with torch.compile.
Here’s a fast benchmark to match inference time with and without torch.compile using the TimmWrapper.
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to(device)
processed_input = image_processor(image, return_tensors="pt").to(device)
def run_benchmark(model, input_data, warmup_runs=5, benchmark_runs=300):
model.eval()
with torch.no_grad():
for _ in range(warmup_runs):
_ = model(**input_data)
times = []
with torch.no_grad():
for _ in range(benchmark_runs):
start_time = time.perf_counter()
_ = model(**input_data)
if device.type == "cuda":
torch.cuda.synchronize(device=device)
times.append(time.perf_counter() - start_time)
avg_time = sum(times) / benchmark_runs
return avg_time
time_no_compile = run_benchmark(model, processed_input)
compiled_model = torch.compile(model).to(device)
time_compile = run_benchmark(compiled_model, processed_input)
print(f"Without torch.compile: {time_no_compile:.4f} s")
print(f"With torch.compile: {time_compile:.4f} s")
Wrapping Up
timm‘s integration with transformers opens latest doors for leveraging state-of-the-art vision models
with minimal effort. Whether you are seeking to fine-tune, quantize, or just run inference, this
integration provides a unified API to streamline your workflow.
Start exploring today and unlock latest possibilities in computer vision!
Acknowledgments
We wish to present an enormous shout-out to the parents who made this integration occur in
Transformers PR #34564.
In no particular order, a giant due to Pavel Iakubovskii, Ross Wightman, Lysandre Debut,
Pablo Montalvo, Arthur Zucker, and Amy Roberts for all of your incredible work. Your combined efforts took
this feature from an idea to something everyone can now enjoy!


