Graph Classification with Transformers

-


Clémentine Fourrier's avatar

Within the previous blog, we explored among the theoretical features of machine learning on graphs. This one will explore how you may do graph classification using the Transformers library. (You too can follow along by downloading the demo notebook here!)

For the time being, the one graph transformer model available in Transformers is Microsoft’s Graphormer, so that is the one we’ll use here. We’re looking forward to seeing what other models people will use and integrate 🤗



Requirements

To follow this tutorial, it’s essential have installed datasets and transformers (version >= 4.27.2), which you’ll be able to do with pip install -U datasets transformers.



Data

To make use of graph data, you may either start from your personal datasets, or use those available on the Hub. We’ll give attention to using already available ones, but be happy to add your datasets!



Loading

Loading a graph dataset from the Hub may be very easy. Let’s load the ogbg-mohiv dataset (a baseline from the Open Graph Benchmark by Stanford), stored within the OGB repository:

from datasets import load_dataset


dataset = load_dataset("OGB/ogbg-molhiv")

dataset = dataset.shuffle(seed=0)

This dataset already has three splits, train, validation, and test, and all these splits contain our 5 columns of interest (edge_index, edge_attr, y, num_nodes, node_feat), which you’ll be able to see by doing print(dataset).

If you’ve got other graph libraries, you should use them to plot your graphs and further inspect the dataset. For instance, using PyGeometric and matplotlib:

import networkx as nx
import matplotlib.pyplot as plt


graph = dataset["train"][0]

edges = graph["edge_index"]
num_edges = len(edges[0])
num_nodes = graph["num_nodes"]


G = nx.Graph()
G.add_nodes_from(range(num_nodes))
G.add_edges_from([(edges[0][i], edges[1][i]) for i in range(num_edges)])


nx.draw(G)



Format

On the Hub, graph datasets are mostly stored as lists of graphs (using the jsonl format).

A single graph is a dictionary, and here is the expected format for our graph classification datasets:

  • edge_index accommodates the indices of nodes in edges, stored as an inventory containing two parallel lists of edge indices.
    • Type: list of two lists of integers.
    • Example: a graph containing 4 nodes (0, 1, 2 and three) and where connections are 1->2, 1->3 and 3->1 may have edge_index = [[1, 1, 3], [2, 3, 1]]. You may notice here that node 0 is just not present here, because it is just not a part of an edge per se. That is why the following attribute is vital.
  • num_nodes indicates the overall variety of nodes available within the graph (by default, it’s assumed that nodes are numbered sequentially).
    • Type: integer
    • Example: In our above example, num_nodes = 4.
  • y maps each graph to what we wish to predict from it (be it a category, a property value, or several binary label for various tasks).
    • Type: list of either integers (for multi-class classification), floats (for regression), or lists of ones and zeroes (for binary multi-task classification)
    • Example: We could predict the graph size (small = 0, medium = 1, big = 2). Here, y = [0].
  • node_feat accommodates the available features (if present) for every node of the graph, ordered by node index.
    • Type: list of lists of integer (Optional)
    • Example: Our above nodes could have, for instance, types (like different atoms in a molecule). This might give node_feat = [[1], [0], [1], [1]].
  • edge_attr accommodates the available attributes (if present) for every fringe of the graph, following the edge_index ordering.
    • Type: list of lists of integers (Optional)
    • Example: Our above edges could have, for instance, types (like molecular bonds). This might give edge_attr = [[0], [1], [1]].



Preprocessing

Graph transformer frameworks normally apply specific preprocessing to their datasets to generate added features and properties which help the underlying learning task (classification in our case).
Here, we use Graphormer’s default preprocessing, which generates in/out degree information, the shortest path between node matrices, and other properties of interest for the model.

from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator

dataset_processed = dataset.map(preprocess_item, batched=False)

It is usually possible to use this preprocessing on the fly, within the DataCollator’s parameters (by setting on_the_fly_processing to True): not all datasets are as small as ogbg-molhiv, and for giant graphs, it is likely to be too costly to store all of the preprocessed data beforehand.



Model



Loading

Here, we load an existing pretrained model/checkpoint and fine-tune it on our downstream task, which is a binary classification task (hence num_classes = 2). We could also fine-tune our model on regression tasks (num_classes = 1) or on multi-task classification.

from transformers import GraphormerForGraphClassification

model = GraphormerForGraphClassification.from_pretrained(
    "clefourrier/pcqm4mv2_graphormer_base",
    num_classes=2, 
    ignore_mismatched_sizes=True,
)

Let’s take a look at this in additional detail.

Calling the from_pretrained method on our model downloads and caches the weights for us. Because the variety of classes (for prediction) is dataset dependent, we pass the brand new num_classes in addition to ignore_mismatched_sizes alongside the model_checkpoint. This makes sure a custom classification head is created, specific to our task, hence likely different from the unique decoder head.

It is usually possible to create a brand new randomly initialized model to coach from scratch, either following the known parameters of a given checkpoint or by manually selecting them.



Training or fine-tuning

To coach our model simply, we’ll use a Trainer. To instantiate it, we’ll must define the training configuration and the evaluation metric. An important is the TrainingArguments, which is a category that accommodates all of the attributes to customize the training. It requires a folder name, which might be used to avoid wasting the checkpoints of the model.

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    "graph-classification",
    logging_dir="graph-classification",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    auto_find_batch_size=True, 
    gradient_accumulation_steps=10,
    dataloader_num_workers=4, 
    num_train_epochs=20,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    push_to_hub=False,
)

For graph datasets, it is especially essential to mess around with batch sizes and gradient accumulation steps to coach on enough samples while avoiding out-of-memory errors.

The last argument push_to_hub allows the Trainer to push the model to the Hub often during training, as each saving step.

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_processed["train"],
    eval_dataset=dataset_processed["validation"],
    data_collator=GraphormerDataCollator(),
)

Within the Trainer for graph classification, it is crucial to pass the particular data collator for the given graph dataset, which is able to convert individual graphs to batches for training.

train_results = trainer.train()
trainer.push_to_hub()

When the model is trained, it may possibly be saved to the hub with all of the associated training artefacts using push_to_hub.

As this model is sort of big, it takes a few day to coach/fine-tune for 20 epochs on CPU (IntelCore i7). To go faster, you may use powerful GPUs and parallelization as a substitute, by launching the code either in a Colab notebook or directly on the cluster of your selection.



Ending note

Now that you realize find out how to use transformers to coach a graph classification model, we hope you’ll attempt to share your favorite graph transformer checkpoints, models, and datasets on the Hub for the remaining of the community to make use of!



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x