High quality-Tune ViT for Image Classification with 🤗 Transformers

-


Nate Raw's avatar


Open In Colab

Just as transformers-based models have revolutionized NLP, we’re now seeing an explosion of papers applying them to all types of other domains. One of the crucial revolutionary of those was the Vision Transformer (ViT), which was introduced in June 2021 by a team of researchers at Google Brain.

This paper explored how you’ll be able to tokenize images, just as you’d tokenize sentences, in order that they could be passed to transformer models for training. It’s quite an easy concept, really…

  1. Split a picture right into a grid of sub-image patches
  2. Embed each patch with a linear projection
  3. Each embedded patch becomes a token, and the resulting sequence of embedded patches is the sequence you pass to the model.

It seems that when you have done the above, you’ll be able to pre-train and fine-tune transformers just as you are used to with NLP tasks. Pretty sweet 😎.


On this blog post, we’ll walk through leverage 🤗 datasets to download and process image classification datasets, after which use them to fine-tune a pre-trained ViT with 🤗 transformers.

To start, let’s first install each those packages.

pip install datasets transformers



Load a dataset

Let’s start by loading a small image classification dataset and taking a take a look at its structure.

We’ll use the beans dataset, which is a set of images of healthy and unhealthy bean leaves. 🍃

from datasets import load_dataset

ds = load_dataset('beans')
ds

Let’s take a take a look at the four-hundredth example from the 'train' split from the beans dataset. You may notice each example from the dataset has 3 features:

  1. image: A PIL Image
  2. image_file_path: The str path to the image file that was loaded as image
  3. labels: A datasets.ClassLabel feature, which is an integer representation of the label. (Later you will see get the string class names, don’t fret!)
ex = ds['train'][400]
ex
{
  'image': ,
  'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
  'labels': 1
}

Let’s take a take a look at the image 👀

image = ex['image']
image

That is definitely a leaf! But what kind? 😅

For the reason that 'labels' feature of this dataset is a datasets.features.ClassLabel, we will use it to look up the corresponding name for this instance’s label ID.

First, let’s access the feature definition for the 'labels'.

labels = ds['train'].features['labels']
labels
ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)

Now, let’s print out the category label for our example. You may try this through the use of the int2str function of ClassLabel, which, because the name implies, allows to pass the integer representation of the category to look up the string label.

labels.int2str(ex['labels'])
'bean_rust'

Seems the leaf shown above is infected with Bean Rust, a serious disease in bean plants. 😢

Let’s write a function that’ll display a grid of examples from each class to get a greater idea of what you are working with.

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.recent('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Daring.ttf", 24)

    for label_id, label in enumerate(labels):

        
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
A grid of just a few examples from each class within the dataset

From what I’m seeing,

  • Angular Leaf Spot: Has irregular brown patches
  • Bean Rust: Has circular brown spots surrounded with a white-ish yellow ring
  • Healthy: …looks healthy. 🤷‍♂️



Loading ViT Image Processor

Now we all know what our images appear to be and higher understand the issue we’re trying to unravel. Let’s examine how we will prepare these images for our model!

When ViT models are trained, specific transformations are applied to pictures fed into them. Use the fallacious transformations in your image, and the model won’t understand what it’s seeing! 🖼 ➡️ 🔢

To be sure we apply the proper transformations, we are going to use a ViTImageProcessor initialized with a configuration that was saved together with the pretrained model we plan to make use of. In our case, we’ll be using the google/vit-base-patch16-224-in21k model, so let’s load its image processor from the Hugging Face Hub.

from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

You may see the image processor configuration by printing it.

ViTImageProcessor {
  "do_normalize": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

To process a picture, simply pass it to the image processor’s call function. It will return a dict containing pixel values, which is the numeric representation to be passed to the model.

You get a NumPy array by default, but should you add the return_tensors="pt" argument, you will get back torch tensors as an alternative.

processor(image, return_tensors='pt')

Should provide you with something like…

{
  'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
}

…where the form of the tensor is (1, 3, 224, 224).



Processing the Dataset

Now that you already know read images and transform them into inputs, let’s write a function that may put those two things together to process a single example from the dataset.

def process_example(example):
    inputs = processor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs
process_example(ds['train'][0])
{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': 0
}

While you might call ds.map and apply this to each example directly, this could be very slow, especially should you use a bigger dataset. As a substitute, you’ll be able to apply a transform to the dataset. Transforms are only applied to examples as you index them.

First, though, you will need to update the last function to just accept a batch of knowledge, as that is what ds.with_transform expects.

ds = load_dataset('beans')

def transform(example_batch):
    
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    
    inputs['labels'] = example_batch['labels']
    return inputs

You may directly apply this to the dataset using ds.with_transform(transform).

prepared_ds = ds.with_transform(transform)

Now, every time you get an example from the dataset, the transform will likely be
applied in real time (on each samples and slices, as shown below)

prepared_ds['train'][0:2]

This time, the resulting pixel_values tensor could have shape (2, 3, 224, 224).

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': [0, 0]
}



Training and Evaluation

The information is processed and you’re ready to begin establishing the training pipeline. This blog post uses 🤗’s Trainer, but that’ll require us to do just a few things first:

  • Define a collate function.

  • Define an evaluation metric. During training, the model needs to be evaluated on its prediction accuracy. It’s best to define a compute_metrics function accordingly.

  • Load a pretrained checkpoint. It’s essential load a pretrained checkpoint and configure it accurately for training.

  • Define the training configuration.

After fine-tuning the model, you’ll accurately evaluate it on the evaluation data and confirm that it has indeed learned to accurately classify the photographs.



Define our data collator

Batches are coming in as lists of dicts, so you’ll be able to just unpack + stack those into batch tensors.

For the reason that collate_fn will return a batch dict, you’ll be able to **unpack the inputs to the model later. ✨

import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }



Define an evaluation metric

The accuracy metric from evaluate can easily be used to check the predictions with the labels. Below, you’ll be able to see use it inside a compute_metrics function that will likely be utilized by the Trainer.

import numpy as np
from evaluate import load

metric = load("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

Let’s load the pretrained model. We’ll add num_labels on init so the model creates a classification head with the suitable variety of units. We’ll also include the id2label and label2id mappings to have human-readable labels within the Hub widget (should you decide to push_to_hub).

from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Almost able to train! The final thing needed before that’s to establish the training configuration by defining TrainingArguments.

Most of those are pretty self-explanatory, but one which is kind of necessary here is remove_unused_columns=False. This one will drop any features not utilized by the model’s call function. By default it’s True because often it’s ideal to drop unused feature columns, making it easier to unpack inputs into the model’s call function. But, in our case, we want the unused features (‘image’ specifically) so as to create ‘pixel_values’.

What I’m attempting to say is that you’re going to have a foul time should you forget to set remove_unused_columns=False.

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

Now, all instances could be passed to Trainer and we’re ready to begin training!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=processor,
)



Train 🚀

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()



Evaluate 📊

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

Here were my evaluation results – Cool beans! Sorry, needed to say it.

***** eval metrics *****
  epoch                   =        4.0
  eval_accuracy           =      0.985
  eval_loss               =     0.0637
  eval_runtime            = 0:00:02.13
  eval_samples_per_second =     62.356
  eval_steps_per_second   =       7.97

Finally, should you want, you’ll be able to push your model as much as the hub. Here, we’ll push it up should you specified push_to_hub=True within the training configuration. Note that so as to push to hub, you will have to have git-lfs installed and be logged into your Hugging Face account (which could be done via huggingface-cli login).

kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

The resulting model has been shared to nateraw/vit-base-beans. I’m assuming you do not have pictures of bean leaves laying around, so I added some examples for you to provide it a try! 🚀



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x