This guide shows how you should use CLIPSeg, a zero-shot image segmentation model, using 🤗 transformers. CLIPSeg creates rough segmentation masks that will be used for robot perception, image inpainting, and plenty of other tasks. For those who need more precise segmentation masks, we’ll show how you may refine the outcomes of CLIPSeg on Segments.ai.
Image segmentation is a widely known task throughout the field of computer vision. It allows a pc to not only know what’s in a picture (classification), where objects are within the image (detection), but additionally what the outlines of those objects are. Knowing the outlines of objects is crucial in fields reminiscent of robotics and autonomous driving. For instance, a robot has to know the form of an object to grab it appropriately. Segmentation may also be combined with image inpainting to permit users to explain which a part of the image they need to exchange.
One limitation of most image segmentation models is that they only work with a hard and fast list of categories. For instance, you can’t simply use a segmentation model trained on oranges to segment apples. To show the segmentation model a further category, you’ve gotten to label data of the brand new category and train a brand new model, which will be costly and time-consuming. But what if there was a model that may already segment almost any form of object, with none further training? That’s exactly what CLIPSeg, a zero-shot segmentation model, achieves.
Currently, CLIPSeg still has its limitations. For instance, the model uses images of 352 x 352 pixels, so the output is kind of low-resolution. This implies we cannot expect pixel-perfect results once we work with images from modern cameras. If we would like more precise segmentations, we will fine-tune a state-of-the-art segmentation model, as shown in our previous blog post. In that case, we will still use CLIPSeg to generate some rough labels, after which refine them in a labeling tool reminiscent of Segments.ai. Before we describe the right way to do this, let’s first take a have a look at how CLIPSeg works.
CLIP: the magic model behind CLIPSeg
CLIP, which stands for Contrastive Language–Image Pre-training, is a model developed by OpenAI in 2021. You’ll be able to give CLIP a picture or a bit of text, and CLIP will output an abstract representation of your input. This abstract representation, also called an embedding, is basically only a vector (an inventory of numbers). You’ll be able to consider this vector as a degree in high-dimensional space. CLIP is trained in order that the representations of comparable pictures and texts are similar as well. Which means that if we input a picture and a text description that matches that image, the representations of the image and the text will likely be similar (i.e., the high-dimensional points will likely be close together).
At first, this may not seem very useful, nevertheless it is definitely very powerful. For example, let’s take a fast have a look at how CLIP will be used to categorise images without ever having been trained on that task. To categorise a picture, we input the image and the various categories we would like to select from to CLIP (e.g. we input a picture and the words “apple”, “orange”, …). CLIP then gives us back an embedding of the image and of every category. Now, we simply have to examine which category embedding is closest to the embedding of the image, et voilà! Appears like magic, doesn’t it?
What’s more, CLIP isn’t only useful for classification, but it may even be used for image search (are you able to see how this is analogous to classification?), text-to-image models (DALL-E 2 is powered by CLIP), object detection (OWL-ViT), and most significantly for us: image segmentation. Now you see why CLIP was truly a breakthrough in machine learning.
The rationale why CLIP works so well is that the model was trained on an enormous dataset of images with text captions. The dataset contained a whopping 400 million image-text pairs taken from the web. These images contain a wide selection of objects and ideas, and CLIP is great at making a representation for every of them.
CLIPSeg: image segmentation with CLIP
CLIPSeg is a model that uses CLIP representations to create image segmentation masks. It was published by Timo Lüddecke and Alexander Ecker. They achieved zero-shot image segmentation by training a Transformer-based decoder on top of the CLIP model, which is kept frozen. The decoder takes within the CLIP representation of a picture, and the CLIP representation of the thing you would like to segment. Using these two inputs, the CLIPSeg decoder creates a binary segmentation mask. To be more precise, the decoder doesn’t only use the ultimate CLIP representation of the image we would like to segment, nevertheless it also uses the outputs of a number of the layers of CLIP.
The decoder is trained on the PhraseCut dataset, which comprises over 340,000 phrases with corresponding image segmentation masks. The authors also experimented with various augmentations to expand the scale of the dataset. The goal here isn’t only to give you the chance to segment the categories which are present within the dataset, but additionally to segment unseen categories. Experiments indeed show that the decoder can generalize to unseen categories.
One interesting feature of CLIPSeg is that each the query (the image we would like to segment) and the prompt (the thing we would like to segment within the image) are input as CLIP embeddings. The CLIP embedding for the prompt can either come from a bit of text (the category name), or from one other image. This implies you may segment oranges in a photograph by giving CLIPSeg an example image of an orange.
This system, which is named “visual prompting”, is basically helpful when the thing you would like to segment is difficult to explain. For instance, if you would like to segment a logo in an image of a t-shirt, it’s difficult to explain the form of the emblem, but CLIPSeg means that you can simply use the image of the emblem because the prompt.
The CLIPSeg paper comprises some recommendations on improving the effectiveness of visual prompting. They find that cropping the query image (in order that it only comprises the article you would like to segment) helps rather a lot. Blurring and darkening the background of the query image also helps a bit bit. In the following section, we’ll show how you may check out visual prompting yourself using 🤗 transformers.
Using CLIPSeg with Hugging Face Transformers
Using Hugging Face Transformers, you may easily download and run a
pre-trained CLIPSeg model in your images. Let’s start by installing
transformers.
!pip install -q transformers
To download the model, simply instantiate it.
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
Now we will load a picture to check out the segmentation. We’ll select a
picture of a delicious breakfast taken by Calum
Lewis.
from PIL import Image
import requests
url = "https://unsplash.com/photos/8Nc_oQsc2qQ/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjcxMjAwNzI0&force=true&w=640"
image = Image.open(requests.get(url, stream=True).raw)
image
Text prompting
Let’s start by defining some text categories we would like to segment.
prompts = ["cutlery", "pancakes", "blueberries", "orange juice"]
Now that we’ve got our inputs, we will process them and input them to the
model.
import torch
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
Finally, let’s visualize the output.
import matplotlib.pyplot as plt
_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
Visual prompting
As mentioned before, we also can use images because the input prompts (i.e.
instead of the category names). This will be especially useful if it’s
difficult to explain the thing you would like to segment. For this instance,
we’ll use an image of a coffee cup taken by Daniel
Hooper.
url = "https://unsplash.com/photos/Ki7sAc8gOGE/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTJ8fGNvZmZlJTIwdG8lMjBnb3xlbnwwfHx8fDE2NzExOTgzNDQ&force=true&w=640"
prompt = Image.open(requests.get(url, stream=True).raw)
prompt
We are able to now process the input image and prompt image and input them to
the model.
encoded_image = processor(images=[image], return_tensors="pt")
encoded_prompt = processor(images=[prompt], return_tensors="pt")
with torch.no_grad():
outputs = model(**encoded_image, conditional_pixel_values=encoded_prompt.pixel_values)
preds = outputs.logits.unsqueeze(1)
preds = torch.transpose(preds, 0, 1)
Then, we will visualize the outcomes as before.
_, ax = plt.subplots(1, 2, figsize=(6, 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
ax[1].imshow(torch.sigmoid(preds[0]))
Let’s try one last time through the use of the visual prompting suggestions described in
the paper, i.e. cropping the image and darkening the background.
url = "https://i.imgur.com/mRSORqz.jpg"
alternative_prompt = Image.open(requests.get(url, stream=True).raw)
alternative_prompt
encoded_alternative_prompt = processor(images=[alternative_prompt], return_tensors="pt")
with torch.no_grad():
outputs = model(**encoded_image, conditional_pixel_values=encoded_alternative_prompt.pixel_values)
preds = outputs.logits.unsqueeze(1)
preds = torch.transpose(preds, 0, 1)
_, ax = plt.subplots(1, 2, figsize=(6, 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
ax[1].imshow(torch.sigmoid(preds[0]))
On this case, the result’s just about the identical. This might be
since the coffee cup was already separated well from the background in
the unique image.
Using CLIPSeg to pre-label images on Segments.ai
As you may see, the outcomes from CLIPSeg are a bit fuzzy and really
low-res. If we would like to acquire higher results, you may fine-tune a
state-of-the-art segmentation model, as explained in our previous
blogpost. To finetune
the model, we’ll need labeled data. On this section, we’ll show you
how you should use CLIPSeg to create some rough segmentation masks after which
refine them on
Segments.ai,
a labeling platform with smart labeling tools for image segmentation.
First, create an account at
https://segments.ai/join
and install the Segments Python SDK. Then you definately can initialize the
Segments.ai Python client using an API key. This key will be found on
the account page.
!pip install -q segments-ai
from segments import SegmentsClient
from getpass import getpass
api_key = getpass('Enter your API key: ')
segments_client = SegmentsClient(api_key)
Next, let’s load a picture from a dataset using the Segments client.
We’ll use the a2d2 self-driving
dataset. You may also create your
own dataset by following these
instructions.
samples = segments_client.get_samples("admin-tobias/clipseg")
sample = samples[1]
image = Image.open(requests.get(sample.attributes.image.url, stream=True).raw)
image
We also must get the category names from the dataset attributes.
dataset = segments_client.get_dataset("admin-tobias/clipseg")
category_names = [category.name for category in dataset.task_attributes.categories]
Now we will use CLIPSeg on the image as before. This time, we’ll also
scale up the outputs in order that they match the input image’s size.
from torch import nn
inputs = processor(text=category_names, images=[image] * len(category_names), padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
preds = nn.functional.interpolate(
outputs.logits.unsqueeze(1),
size=(image.size[1], image.size[0]),
mode="bilinear"
)
And we will visualize the outcomes again.
len_cats = len(category_names)
_, ax = plt.subplots(1, len_cats + 1, figsize=(3*(len_cats + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len_cats)];
[ax[i+1].text(0, -15, category_name) for i, category_name in enumerate(category_names)];
Now we’ve got to mix the predictions to a single segmented image.
We’ll simply do that by taking the category with the best sigmoid
value for every patch. We’ll also make sure that that every one the values under a
certain threshold don’t count.
threshold = 0.1
flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))
flat_preds_with_treshold = torch.full((preds.shape[0] + 1, flat_preds.shape[-1]), threshold)
flat_preds_with_treshold[1:preds.shape[0]+1,:] = flat_preds
inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
Let’s quickly visualize the result.
plt.imshow(inds)
Lastly, we will upload the prediction to Segments.ai. To try this, we’ll
first convert the bitmap to a png file, then we’ll upload this file to
the Segments, and eventually we’ll add the label to the sample.
from segments.utils import bitmap2file
import numpy as np
inds_np = inds.numpy().astype(np.uint32)
unique_inds = np.unique(inds_np).tolist()
f = bitmap2file(inds_np, is_segmentation_bitmap=True)
asset = segments_client.upload_asset(f, "clipseg_prediction.png")
attributes = {
'format_version': '0.1',
'annotations': [{"id": i, "category_id": i} for i in unique_inds if i != 0],
'segmentation_bitmap': { 'url': asset.url },
}
segments_client.add_label(sample.uuid, 'ground-truth', attributes)
For those who take a have a look at the uploaded prediction on
Segments.ai,
you may see that it is not perfect. Nevertheless, you may manually correct
the most important mistakes, after which you should use the corrected dataset to
train a greater model than CLIPSeg.
Conclusion
CLIPSeg is a zero-shot segmentation model that works with each text and image prompts. The model adds a decoder to CLIP and might segment almost anything. Nevertheless, the output segmentation masks are still very low-res for now, so that you’ll probably still need to fine-tune a unique segmentation model if accuracy is vital.
Note that there is more research on zero-shot segmentation currently being conducted, so you may expect more models to be added within the near future. One example is GroupViT, which is already available in 🤗 Transformers. To stay awake thus far with the most recent news in segmentation research, you may follow us on Twitter: @TobiasCornille, @NielsRogge, and @huggingface.
For those who’re keen on learning the right way to fine-tune a state-of-the-art segmentation model, try our previous blog post: https://huggingface.co/blog/fine-tune-segformer.
