Back to Notes

Introduction to Graph Neural Networks

2025-12-15Deep Learning
Graph Neural NetworksDeep LearningGCNGAT

Introduction to Graph Neural Networks

Graph Neural Networks (GNNs) are a class of deep learning methods designed to perform inference on data structured as graphs.

Why Graphs?

Many real-world problems involve data with complex relationships:

  • Social networks
  • Molecular structures
  • Citation networks
  • Knowledge graphs

Message Passing Framework

Most GNNs follow the message passing paradigm:

hv(l+1)=UPDATE(hv(l),AGGREGATE({hu(l):uN(v)}))h_v^{(l+1)} = \text{UPDATE}\left(h_v^{(l)}, \text{AGGREGATE}\left(\{h_u^{(l)} : u \in \mathcal{N}(v)\}\right)\right)

Where:

  • hv(l)h_v^{(l)} is the feature of node vv at layer ll
  • N(v)\mathcal{N}(v) is the neighborhood of node vv

Graph Convolutional Network (GCN)

The GCN layer can be written as:

H(l+1)=σ(D~12A~D~12H(l)W(l))H^{(l+1)} = \sigma\left(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)}\right)

Code Example

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

Applications

  1. Node Classification: Predicting labels for nodes
  2. Link Prediction: Predicting missing edges
  3. Graph Classification: Classifying entire graphs