Graph neural networks
In this series (25 parts)
- Neural networks: the basic building block
- Forward pass and backpropagation
- Training neural networks: a practical guide
- Convolutional neural networks
- Recurrent neural networks and LSTMs
- Attention mechanism and transformers
- Word embeddings: from one-hot to dense representations
- Transfer learning and fine-tuning
- Optimization techniques for deep networks
- Regularization for deep networks
- Encoder-decoder architectures
- Generative models: an overview
- Restricted Boltzmann Machines
- Deep Belief Networks
- Variational Autoencoders
- Generative Adversarial Networks: training and theory
- DCGAN, conditional GANs, and GAN variants
- Representation learning and self-supervised learning
- Domain adaptation and fine-tuning strategies
- Distributed representations and latent spaces
- AutoML and hyperparameter optimization
- Neural architecture search
- Network compression and efficient inference
- Graph neural networks
- Practical deep learning: debugging and tuning
Not all data lives on a grid. Social networks, molecules, citation graphs, road networks, protein structures: these are all graphs. Nodes have features, edges encode relationships, and the structure itself carries information. Graph neural networks (GNNs) extend deep learning to handle this irregular, non-Euclidean data.
Prerequisites
You should understand attention mechanisms (especially for GAT) and matrix operations. Familiarity with CNNs helps because GNNs generalize the same local-aggregation idea from grids to graphs.
Graphs: the basics
A graph has nodes and edges . For deep learning, we add:
- Node features : each node has a feature vector
- Adjacency matrix : if there’s an edge from to
- Edge features (optional): attributes on edges, like bond type in a molecule
Why graphs need special treatment: standard neural networks expect fixed-size, ordered inputs. A fully connected layer on nodes would need parameters and wouldn’t generalize to graphs of different sizes. GNNs solve this by operating locally: each node looks at its neighbors, regardless of the total graph size.
The message passing framework
Almost all GNNs follow the same pattern: message passing. In each layer, every node:
- Collects messages from its neighbors
- Aggregates them (sum, mean, max, or attention-weighted)
- Updates its own representation using the aggregated message
graph TD
subgraph "Step 1: Collect"
N1["Neighbor u₁"] -->|"message m₁"| V["Node v"]
N2["Neighbor u₂"] -->|"message m₂"| V
N3["Neighbor u₃"] -->|"message m₃"| V
end
subgraph "Step 2: Aggregate"
V --> AGG["AGG(m₁, m₂, m₃)
e.g., mean or sum"]
end
subgraph "Step 3: Update"
AGG --> UPD["UPDATE(h_v, agg)
new node state"]
end
Formally, one message passing layer computes:
where is node ‘s representation at layer , is the set of neighbors of , and MSG, AGG, UPDATE are learnable functions.
Stacking message passing layers lets each node “see” information from nodes up to hops away. This is analogous to how stacking CNN layers increases the receptive field.
CNN as message passing on a grid
Here’s an insight that connects GNNs to what you already know. A 2D convolution on an image is message passing on a grid graph. Each pixel is a node. Its neighbors are the adjacent pixels (defined by the kernel size). The convolution kernel defines the MSG function. Summation is the AGG function. The result is the UPDATE.
The difference: on a grid, every node has the same number of neighbors in the same arrangement. On a general graph, nodes can have any number of neighbors in any arrangement. GNNs handle this by using permutation-invariant aggregation functions (sum, mean, max) that don’t depend on neighbor ordering.
GCN: graph convolutional network
Kipf and Welling (2017) proposed one of the simplest and most influential GNN architectures. The GCN layer computes:
where:
- is the adjacency matrix with self-loops added
- is the degree matrix of (diagonal, )
- is the learnable weight matrix for layer
- is an activation function (typically ReLU)
The normalization is called symmetric normalization. It prevents the scale of node features from growing with the number of neighbors. Without normalization, high-degree nodes would dominate.
Example 1: GCN forward pass
Consider a 4-node graph:
Edges: (0,1), (0,2), (1,2), (2,3). Undirected.
Adjacency matrix :
Add self-loops: :
Degree matrix : row sums are [3, 3, 4, 2].
Normalized adjacency :
For entry :
Node features (4 nodes, 2 features each):
Weight matrix (2x2):
Step 1: Compute . For node 0:
Step 2: Multiply by :
Step 3: Apply ReLU (all positive, so no change): .
The output for node 0 is a smooth mix of its own features and its neighbors’ features, weighted by the normalized adjacency.
GraphSAGE: sampling for scalability
GCN requires the full adjacency matrix, which is impractical for graphs with millions of nodes. GraphSAGE (Hamilton et al., 2017) solves this by sampling a fixed number of neighbors per node and aggregating only from the sample.
For each node, GraphSAGE:
- Samples neighbors (e.g., )
- Aggregates their features (mean, LSTM, or max pooling)
- Concatenates the aggregated features with the node’s own features
- Applies a linear transformation and activation
This makes the computational cost per node constant, regardless of the actual degree. Training uses mini-batches of nodes rather than the full graph.
GAT: attention over neighbors
Graph Attention Networks (Velickovic et al., 2018) apply the attention mechanism to graph neighborhoods. Instead of treating all neighbors equally (GCN) or sampling randomly (GraphSAGE), GAT learns to weight neighbors by importance.
For each edge , GAT computes an attention coefficient:
where denotes concatenation, is a shared linear transformation, and is a learnable attention vector. These raw scores are then normalized across all neighbors of :
The updated node representation:
Multi-head attention (just like in Transformers) is used to stabilize learning: run independent attention heads and concatenate (or average) their outputs.
graph TD
subgraph "GCN"
GCN_A["All neighbors
weighted equally
by degree normalization"]
end
subgraph "GraphSAGE"
SAGE_A["Sample fixed k
neighbors, aggregate
with mean/max/LSTM"]
end
subgraph "GAT"
GAT_A["All neighbors
weighted by learned
attention coefficients"]
end
Example 2: Message passing step
Node with features has three neighbors:
- :
- :
- :
Step 1: Mean aggregation of neighbor features:
Step 2: Concatenate with node’s own features:
Step 3: Linear transformation with :
Step 4: ReLU activation:
Node ‘s new representation now encodes information from both itself and its neighbors.
Example 3: GAT attention computation
Node with and neighbor with .
Attention vector: (dimension 4 because we concatenate two 2D vectors).
Concatenation: .
Raw attention score:
(Since 0.7 > 0, LeakyReLU has no effect.)
Now suppose we also have neighbor with . Softmax normalization over neighbors:
Given and (using the values specified in the problem):
Neighbor gets 55% of the attention weight and neighbor gets 45%. The attention mechanism learned that is slightly more relevant to than is. The final aggregation would be:
Readout: from node representations to predictions
Different tasks need different outputs:
Node classification: predict a label for each node (e.g., classify users in a social network). Use the final node representations directly with a classifier on top.
Link prediction: predict whether an edge exists between two nodes. Compute a score from pairs of node representations, e.g., dot product: .
Graph classification: predict a label for the entire graph (e.g., predict whether a molecule is toxic). You need a readout function that aggregates all node representations into a single graph-level vector:
Common readout functions: mean pooling, sum pooling, or hierarchical pooling that coarsens the graph in stages.
GNN variants comparison
| Name | Aggregation | Attention | Scalable | Edge features | Best for |
|---|---|---|---|---|---|
| GCN | Normalized mean | ✗ | Moderate | ✗ | Semi-supervised node classification |
| GraphSAGE | Sampled mean/max/LSTM | ✗ | ✓ | ✗ | Large-scale inductive tasks |
| GAT | Attention-weighted sum | ✓ | Moderate | With modification | Tasks needing neighbor importance |
| GIN | Sum (injective) | ✗ | Moderate | ✗ | Graph classification (most expressive) |
| MPNN | Flexible (learned) | Optional | Varies | ✓ | Molecular property prediction |
| SchNet | Continuous filter | ✗ | Moderate | ✓ Distance | 3D molecular geometry |
Applications
GNNs have found use across many domains:
- Chemistry: predicting molecular properties, drug discovery, reaction prediction
- Social networks: recommendation systems, community detection, influence modeling
- Computer vision: scene graphs, point cloud processing, human pose estimation
- NLP: knowledge graph reasoning, semantic parsing, dependency parsing
- Physics: particle simulations, fluid dynamics, material science
- Infrastructure: traffic prediction, power grid analysis
Node classification accuracy vs GNN depth
Known limitations
- Over-smoothing: as you stack more layers, node representations converge to the same vector. After too many message passing rounds, every node looks the same. This limits GNNs to relatively shallow architectures (2-4 layers typically).
- Over-squashing: information from distant nodes gets compressed into fixed-size vectors, losing information. Think of it as a bottleneck in long-range message passing.
- Expressiveness: the Weisfeiler-Lehman (WL) test sets an upper bound on what message passing GNNs can distinguish. Standard GNNs cannot distinguish certain non-isomorphic graphs that the 1-WL test also fails on.
What comes next
With architectures for grids (CNNs), sequences (RNNs, Transformers), and graphs (GNNs) under your belt, you have the building blocks for nearly any deep learning system. The final piece is knowing how to make it all work in practice. Practical deep learning: debugging and tuning covers the skills that turn theoretical knowledge into working models.