Introduction to Graph Neural Networks
Graph Neural Networks (GNNs) are a class of deep learning methods designed to perform inference on data structured as graphs.
Why Graphs?
Many real-world problems involve data with complex relationships:
- Social networks
- Molecular structures
- Citation networks
- Knowledge graphs
Message Passing Framework
Most GNNs follow the message passing paradigm:
Where:
- is the feature of node at layer
- is the neighborhood of node
Graph Convolutional Network (GCN)
The GCN layer can be written as:
Code Example
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
Applications
- Node Classification: Predicting labels for nodes
- Link Prediction: Predicting missing edges
- Graph Classification: Classifying entire graphs