Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure

-

parts of this series, we checked out Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs). Each architectures work effective, but additionally they have some limitations! A giant one is that for big graphs, calculating the node representations with GCNs and GATs will turn into v-e-r-y slow. One other limitation is that if the graph structure changes, GCNs and GATs won’t have the ability to generalize. So if nodes are added to the graph, a GCN or GAT cannot make predictions for it. Luckily, these issues will be solved!

On this post, I’ll explain Graphsage and the way it solves common problems of GCNs and GATs. We’ll train GraphSAGE and use it for graph predictions to match performance with GCNs and GATs.

Recent to GNNs? You may start with post 1 about GCNs (also containing the initial setup for running the code samples), and post 2 about GATs. 


Two Key Problems with GCNs and GATs

I shortly touched upon it within the introduction, but let’s dive a bit deeper. What are the issues with the previous GNN models?

Problem 1. They don’t generalize

GCNs and GATs struggle with generalizing to unseen graphs. The graph structure must be the identical because the training data. That is generally known as , where the model trains and makes predictions on the identical fixed graph. It is definitely overfitting to specific graph topologies. In point of fact, graphs will change: Nodes and edges will be added or removed, and this happens often in real world scenarios. We would like our GNNs to be able to learning patterns that generalize to unseen nodes, or to thoroughly recent graphs (this is known as  ).

Problem 2. They’ve scalability issues

Training GCNs and GATs on large-scale graphs is computationally expensive. GCNs require repeated neighbor aggregation, which grows exponentially with graph size, while GATs involve (multihead) attention mechanisms that scale poorly with increasing nodes.
In big production advice systems which have large graphs with hundreds of thousands of users and products, GCNs and GATs are impractical and slow.

Let’s take a have a look at GraphSAGE to repair these issues.

GraphSAGE (SAmple and aggreGatE)

GraphSAGE makes training much faster and scalable. It does this by . For super large graphs it’s computationally unimaginable to process all neighbors of a node (except if you could have limitless time, which all of us don’t…), like with traditional GCNs. One other necessary step of GraphSAGE is . 
We’ll walk through all of the steps of GraphSAGE below.

1. Sampling Neighbors

With tabular data, sampling is straightforward. It’s something you do in every common machine learning project when creating train, test, and validation sets. With graphs, you can not select random nodes. This can lead to disconnected graphs, nodes without neighbors, etcetera:

Randomly choosing nodes, but some are disconnected. Image by creator.

What you  do with graphs, is choosing a random fixed-size subset of neighbors. For instance in a social network, you possibly can sample 3 friends for every user (as a substitute of all friends):

Randomly choosing three rows within the table, all neighbors chosen within the GCN, three neighbors chosen in GraphSAGE. Image by creator.

2. Aggregate Information

After the neighbor selection from the previous part, GraphSAGE combines their features into one single representation. There are multiple ways to do that (multiple ). Probably the most common types and those explained within the paper are , , and . 

With mean aggregation, the typical is computed over all sampled neighbors’ features (quite simple and infrequently effective). In a formula:

LSTM aggregation uses an LSTM (variety of neural network) to process neighbor features sequentially. It could capture more complex relationships, and is more powerful than mean aggregation. 

The third type, pool aggregation, applies a non-linear function to extract key features (take into consideration max-pooling in a neural network, where you furthermore may take the utmost value of some values).

3. Update Node Representation

After sampling and aggregation, the node . Nodes will learn from their neighbors but in addition keep their very own identity, similar to we saw before with GCNs and GATs. Information can flow across the graph effectively. 

That is the formula for this step:

The aggregation of step 2 is completed over all neighbors, after which the feature representation of the node is concatenated. This vector is multiplied by the burden matrix, and passed through non-linearity (for instance ReLU). As a final step, normalization will be applied.

4. Repeat for Multiple Layers

The primary three steps will be repeated multiple times, when this happens, information can flow from distant neighbors. Within the image below you see a node with three neighbors chosen in the primary layer (direct neighbors), and two neighbors chosen within the second layer (neighbors of neighbors). 

Chosen node with chosen neighbors, three in the primary layer, two within the second layer. Interesting to notice is that one in every of the neighbors of the nodes in step one is the chosen node, in order that one may also be chosen when two neighbors are chosen within the second step (only a bit harder to visualise). Image by creator.

To summarize, the important thing strengths of GraphSAGE are its scalability (sampling makes it efficient for enormous graphs); flexibility, you should use it for Inductive learning (works well when used for predicting on unseen nodes and graphs); aggregation helps with generalization since it smooths out noisy features; and the multi-layers allow the model to learn from far-away nodes.

Cool! And the perfect thing, GraphSAGE is implemented in PyG, so we will use it easily in PyTorch.

Predicting with GraphSAGE

Within the previous posts, we implemented an MLP, GCN, and GAT on the Cora dataset (CC BY-SA). To refresh your mind a bit, Cora is a dataset with scientific publications where you could have to predict the topic of every paper, with seven classes in total. This dataset is comparatively small, so it is perhaps not the perfect set for testing GraphSAGE. We’ll do that anyway, simply to have the ability to match. Let’s see how well GraphSAGE performs.

Interesting parts of the code I like to focus on related to GraphSAGE:

  • The NeighborLoader that performs choosing the neighbors for every layer:
from torch_geometric.loader import NeighborLoader

# 10 neighbors sampled in the primary layer, 10 within the second layer
num_neighbors = [10, 10]

# sample data from the train set
train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    input_nodes=data.train_mask,
)
  • The aggregation type is implemented within the SAGEConv layer. The default is mean, you possibly can change this to max or lstm:
from torch_geometric.nn import SAGEConv

SAGEConv(in_c, out_c, aggr='mean')
  • One other necessary difference is that GraphSAGE is trained in mini batches, and GCN and GAT on the total dataset. This touches the essence of GraphSAGE, since the neighbor sampling of GraphSAGE makes it possible to coach in mini batches, we don’t need the total graph anymore. GCNs and GATs do need the entire graph for proper feature propagation and calculation of attention scores, in order that’s why we train GCNs and GATs on the total graph.
  • The remaining of the code is analogous as before, except that we have now one class where all different models are instantiated based on the model_type (GCN, GAT, or SAGE). This makes it easy to match or make small changes.

That is the entire script, we train 100 epochs and repeat the experiment 10 times to calculate average accuracy and standard deviation for every model:

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GCNConv, GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader

# dataset_name will be 'Cora', 'CiteSeer', 'PubMed'
dataset_name = 'Cora'
hidden_dim = 64
num_layers = 2
num_neighbors = [10, 10]
batch_size = 128
num_epochs = 100
model_types = ['GCN', 'GAT', 'SAGE']

dataset = Planetoid(root='data', name=dataset_name)
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, model_type='SAGE', gat_heads=8):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.model_type = model_type
        self.gat_heads = gat_heads

        def get_conv(in_c, out_c, is_final=False):
            if model_type == 'GCN':
                return GCNConv(in_c, out_c)
            elif model_type == 'GAT':
                heads = 1 if is_final else gat_heads
                concat = False if is_final else True
                return GATConv(in_c, out_c, heads=heads, concat=concat)
            else:
                return SAGEConv(in_c, out_c, aggr='mean')

        if model_type == 'GAT':
            self.convs.append(get_conv(in_channels, hidden_channels))
            in_dim = hidden_channels * gat_heads
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(in_dim, hidden_channels))
                in_dim = hidden_channels * gat_heads
            self.convs.append(get_conv(in_dim, out_channels, is_final=True))
        else:
            self.convs.append(get_conv(in_channels, hidden_channels))
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(hidden_channels, hidden_channels))
            self.convs.append(get_conv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        x = self.convs[-1](x, edge_index)
        return x

@torch.no_grad()
def test(model):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs

results = {}

for model_type in model_types:
    print(f'Training {model_type}')
    results[model_type] = []

    for i in range(10):
        model = GNN(dataset.num_features, hidden_dim, dataset.num_classes, num_layers, model_type, gat_heads=8).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

        if model_type == 'SAGE':
            train_loader = NeighborLoader(
                data,
                num_neighbors=num_neighbors,
                batch_size=batch_size,
                input_nodes=data.train_mask,
            )

            def train():
                model.train()
                total_loss = 0
                for batch in train_loader:
                    batch = batch.to(device)
                    optimizer.zero_grad()
                    out = model(batch.x, batch.edge_index)
                    loss = F.cross_entropy(out, batch.y[:out.size(0)])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                return total_loss / len(train_loader)

        else:
            def train():
                model.train()
                optimizer.zero_grad()
                out = model(data.x, data.edge_index)
                loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()
                return loss.item()

        best_val_acc = 0
        best_test_acc = 0
        for epoch in range(1, num_epochs + 1):
            loss = train()
            train_acc, val_acc, test_acc = test(model)
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
            if epoch % 10 == 0:
                print(f'Epoch {epoch:02d} | Loss: {loss:.4f} | Train: {train_acc:.4f} | Val: {val_acc:.4f} | Test: {test_acc:.4f}')

        results[model_type].append([best_val_acc, best_test_acc])

for model_name, model_results in results.items():
    model_results = torch.tensor(model_results)
    print(f'{model_name} Val Accuracy: {model_results[:, 0].mean():.3f} ± {model_results[:, 0].std():.3f}')
    print(f'{model_name} Test Accuracy: {model_results[:, 1].mean():.3f} ± {model_results[:, 1].std():.3f}')

And listed below are the outcomes:

GCN Val Accuracy: 0.791 ± 0.007
GCN Test Accuracy: 0.806 ± 0.006
GAT Val Accuracy: 0.790 ± 0.007
GAT Test Accuracy: 0.800 ± 0.004
SAGE Val Accuracy: 0.899 ± 0.005
SAGE Test Accuracy: 0.907 ± 0.004

Impressive improvement! Even on this small dataset, GraphSAGE outperforms GAT and GCN easily! I repeated this test for CiteSeer and PubMed datasets, and at all times GraphSAGE got here out best. 

What I wish to note here is that GCN remains to be very useful, it’s one of the crucial effective baselines (if the graph structure allows it). Also, I didn’t do much hyperparameter tuning, but just went with some standard values (like 8 heads for the GAT multi-head attention). In larger, more complex and noisier graphs, the benefits of GraphSAGE turn into more clear than in this instance. We didn’t do any performance testing, because for these small graphs GraphSAGE isn’t faster than GCN.


Conclusion

GraphSAGE brings us very nice improvements and advantages in comparison with GATs and GCNs. Inductive learning is feasible, GraphSAGE can handle changing graph structures quite well. And we didn’t test it on this post, but neighbor sampling makes it possible to create feature representations for larger graphs with good performance. 

Related

Optimizing Connections: Mathematical Optimization inside Graphs

Graph Neural Networks Part 1. Graph Convolutional Networks Explained

Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x