Search…

Graph neural networks

In this series (25 parts)
  1. Neural networks: the basic building block
  2. Forward pass and backpropagation
  3. Training neural networks: a practical guide
  4. Convolutional neural networks
  5. Recurrent neural networks and LSTMs
  6. Attention mechanism and transformers
  7. Word embeddings: from one-hot to dense representations
  8. Transfer learning and fine-tuning
  9. Optimization techniques for deep networks
  10. Regularization for deep networks
  11. Encoder-decoder architectures
  12. Generative models: an overview
  13. Restricted Boltzmann Machines
  14. Deep Belief Networks
  15. Variational Autoencoders
  16. Generative Adversarial Networks: training and theory
  17. DCGAN, conditional GANs, and GAN variants
  18. Representation learning and self-supervised learning
  19. Domain adaptation and fine-tuning strategies
  20. Distributed representations and latent spaces
  21. AutoML and hyperparameter optimization
  22. Neural architecture search
  23. Network compression and efficient inference
  24. Graph neural networks
  25. Practical deep learning: debugging and tuning

Not all data lives on a grid. Social networks, molecules, citation graphs, road networks, protein structures: these are all graphs. Nodes have features, edges encode relationships, and the structure itself carries information. Graph neural networks (GNNs) extend deep learning to handle this irregular, non-Euclidean data.

Prerequisites

You should understand attention mechanisms (especially for GAT) and matrix operations. Familiarity with CNNs helps because GNNs generalize the same local-aggregation idea from grids to graphs.

Graphs: the basics

A graph G=(V,E)G = (V, E) has nodes VV and edges EE. For deep learning, we add:

  • Node features XRn×dX \in \mathbb{R}^{n \times d}: each node vv has a feature vector xvRdx_v \in \mathbb{R}^d
  • Adjacency matrix A{0,1}n×nA \in \{0, 1\}^{n \times n}: Aij=1A_{ij} = 1 if there’s an edge from ii to jj
  • Edge features (optional): attributes on edges, like bond type in a molecule

Why graphs need special treatment: standard neural networks expect fixed-size, ordered inputs. A fully connected layer on nn nodes would need n2n^2 parameters and wouldn’t generalize to graphs of different sizes. GNNs solve this by operating locally: each node looks at its neighbors, regardless of the total graph size.

The message passing framework

Almost all GNNs follow the same pattern: message passing. In each layer, every node:

  1. Collects messages from its neighbors
  2. Aggregates them (sum, mean, max, or attention-weighted)
  3. Updates its own representation using the aggregated message
graph TD
  subgraph "Step 1: Collect"
      N1["Neighbor u₁"] -->|"message m₁"| V["Node v"]
      N2["Neighbor u₂"] -->|"message m₂"| V
      N3["Neighbor u₃"] -->|"message m₃"| V
  end
  subgraph "Step 2: Aggregate"
      V --> AGG["AGG(m₁, m₂, m₃)
e.g., mean or sum"]
  end
  subgraph "Step 3: Update"
      AGG --> UPD["UPDATE(h_v, agg)
new node state"]
  end

Formally, one message passing layer computes:

hv(l+1)=UPDATE(hv(l),  AGG({MSG(hv(l),hu(l)):uN(v)}))h_v^{(l+1)} = \text{UPDATE}\left(h_v^{(l)},\; \text{AGG}\left(\left\{ \text{MSG}(h_v^{(l)}, h_u^{(l)}) : u \in \mathcal{N}(v) \right\}\right)\right)

where hv(l)h_v^{(l)} is node vv‘s representation at layer ll, N(v)\mathcal{N}(v) is the set of neighbors of vv, and MSG, AGG, UPDATE are learnable functions.

Stacking kk message passing layers lets each node “see” information from nodes up to kk hops away. This is analogous to how stacking CNN layers increases the receptive field.

CNN as message passing on a grid

Here’s an insight that connects GNNs to what you already know. A 2D convolution on an image is message passing on a grid graph. Each pixel is a node. Its neighbors are the adjacent pixels (defined by the kernel size). The convolution kernel defines the MSG function. Summation is the AGG function. The result is the UPDATE.

The difference: on a grid, every node has the same number of neighbors in the same arrangement. On a general graph, nodes can have any number of neighbors in any arrangement. GNNs handle this by using permutation-invariant aggregation functions (sum, mean, max) that don’t depend on neighbor ordering.

GCN: graph convolutional network

Kipf and Welling (2017) proposed one of the simplest and most influential GNN architectures. The GCN layer computes:

H(l+1)=σ(D^1/2A^D^1/2H(l)W(l))H^{(l+1)} = \sigma\left(\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2} H^{(l)} W^{(l)}\right)

where:

  • A^=A+I\hat{A} = A + I is the adjacency matrix with self-loops added
  • D^\hat{D} is the degree matrix of A^\hat{A} (diagonal, D^ii=jA^ij\hat{D}_{ii} = \sum_j \hat{A}_{ij})
  • W(l)W^{(l)} is the learnable weight matrix for layer ll
  • σ\sigma is an activation function (typically ReLU)

The normalization D^1/2A^D^1/2\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2} is called symmetric normalization. It prevents the scale of node features from growing with the number of neighbors. Without normalization, high-degree nodes would dominate.

Example 1: GCN forward pass

Consider a 4-node graph:

Edges: (0,1), (0,2), (1,2), (2,3). Undirected.

Adjacency matrix AA:

A=[0110101011010010]A = \begin{bmatrix} 0 & 1 & 1 & 0 \\ 1 & 0 & 1 & 0 \\ 1 & 1 & 0 & 1 \\ 0 & 0 & 1 & 0 \end{bmatrix}

Add self-loops: A^=A+I\hat{A} = A + I:

A^=[1110111011110011]\hat{A} = \begin{bmatrix} 1 & 1 & 1 & 0 \\ 1 & 1 & 1 & 0 \\ 1 & 1 & 1 & 1 \\ 0 & 0 & 1 & 1 \end{bmatrix}

Degree matrix D^\hat{D}: row sums are [3, 3, 4, 2].

D^1/2=diag(13,13,14,12)=diag(0.577,0.577,0.500,0.707)\hat{D}^{-1/2} = \text{diag}\left(\frac{1}{\sqrt{3}}, \frac{1}{\sqrt{3}}, \frac{1}{\sqrt{4}}, \frac{1}{\sqrt{2}}\right) = \text{diag}(0.577, 0.577, 0.500, 0.707)

Normalized adjacency A~=D^1/2A^D^1/2\tilde{A} = \hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2}:

For entry (i,j)(i, j): A~ij=A^ijD^iiD^jj\tilde{A}_{ij} = \frac{\hat{A}_{ij}}{\sqrt{\hat{D}_{ii}} \cdot \sqrt{\hat{D}_{jj}}}

A~=[0.3330.3330.28900.3330.3330.28900.2890.2890.2500.354000.3540.500]\tilde{A} = \begin{bmatrix} 0.333 & 0.333 & 0.289 & 0 \\ 0.333 & 0.333 & 0.289 & 0 \\ 0.289 & 0.289 & 0.250 & 0.354 \\ 0 & 0 & 0.354 & 0.500 \end{bmatrix}

Node features XX (4 nodes, 2 features each):

X=[1.00.00.01.01.01.00.50.5]X = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \\ 1.0 & 1.0 \\ 0.5 & 0.5 \end{bmatrix}

Weight matrix WW (2x2):

W=[1.00.50.51.0]W = \begin{bmatrix} 1.0 & 0.5 \\ 0.5 & 1.0 \end{bmatrix}

Step 1: Compute A~X\tilde{A} X. For node 0:

(A~X)0=0.333×[1,0]+0.333×[0,1]+0.289×[1,1]+0×[0.5,0.5](\tilde{A}X)_0 = 0.333 \times [1,0] + 0.333 \times [0,1] + 0.289 \times [1,1] + 0 \times [0.5,0.5] =[0.333,0]+[0,0.333]+[0.289,0.289]=[0.622,0.622]= [0.333, 0] + [0, 0.333] + [0.289, 0.289] = [0.622, 0.622]

Step 2: Multiply by WW:

(A~XW)0=[0.622,0.622]×[1.00.50.51.0]=[0.622+0.311,0.311+0.622]=[0.933,0.933](\tilde{A}XW)_0 = [0.622, 0.622] \times \begin{bmatrix} 1.0 & 0.5 \\ 0.5 & 1.0 \end{bmatrix} = [0.622 + 0.311, 0.311 + 0.622] = [0.933, 0.933]

Step 3: Apply ReLU (all positive, so no change): h0(1)=[0.933,0.933]h_0^{(1)} = [0.933, 0.933].

The output for node 0 is a smooth mix of its own features and its neighbors’ features, weighted by the normalized adjacency.

GraphSAGE: sampling for scalability

GCN requires the full adjacency matrix, which is impractical for graphs with millions of nodes. GraphSAGE (Hamilton et al., 2017) solves this by sampling a fixed number of neighbors per node and aggregating only from the sample.

For each node, GraphSAGE:

  1. Samples kk neighbors (e.g., k=10k = 10)
  2. Aggregates their features (mean, LSTM, or max pooling)
  3. Concatenates the aggregated features with the node’s own features
  4. Applies a linear transformation and activation

This makes the computational cost per node constant, regardless of the actual degree. Training uses mini-batches of nodes rather than the full graph.

GAT: attention over neighbors

Graph Attention Networks (Velickovic et al., 2018) apply the attention mechanism to graph neighborhoods. Instead of treating all neighbors equally (GCN) or sampling randomly (GraphSAGE), GAT learns to weight neighbors by importance.

For each edge (v,u)(v, u), GAT computes an attention coefficient:

evu=LeakyReLU(aT[WhvWhu])e_{vu} = \text{LeakyReLU}\left(a^T [W h_v \| W h_u]\right)

where \| denotes concatenation, WW is a shared linear transformation, and aa is a learnable attention vector. These raw scores are then normalized across all neighbors of vv:

αvu=exp(evu)kN(v)exp(evk)\alpha_{vu} = \frac{\exp(e_{vu})}{\sum_{k \in \mathcal{N}(v)} \exp(e_{vk})}

The updated node representation:

hv=σ(uN(v)αvuWhu)h_v' = \sigma\left(\sum_{u \in \mathcal{N}(v)} \alpha_{vu} \, W h_u\right)

Multi-head attention (just like in Transformers) is used to stabilize learning: run KK independent attention heads and concatenate (or average) their outputs.

graph TD
  subgraph "GCN"
      GCN_A["All neighbors
weighted equally
by degree normalization"]
  end
  subgraph "GraphSAGE"
      SAGE_A["Sample fixed k
neighbors, aggregate
with mean/max/LSTM"]
  end
  subgraph "GAT"
      GAT_A["All neighbors
weighted by learned
attention coefficients"]
  end

Example 2: Message passing step

Node vv with features hv=[1,2]h_v = [1, 2] has three neighbors:

  • u1u_1: hu1=[0.5,1.0]h_{u_1} = [0.5, 1.0]
  • u2u_2: hu2=[1.5,0.5]h_{u_2} = [1.5, 0.5]
  • u3u_3: hu3=[0.0,2.0]h_{u_3} = [0.0, 2.0]

Step 1: Mean aggregation of neighbor features:

AGG=13([0.5,1.0]+[1.5,0.5]+[0.0,2.0])=[2.0,3.5]3=[0.667,1.167]\text{AGG} = \frac{1}{3}([0.5, 1.0] + [1.5, 0.5] + [0.0, 2.0]) = \frac{[2.0, 3.5]}{3} = [0.667, 1.167]

Step 2: Concatenate with node’s own features:

[hvAGG]=[1,2,0.667,1.167][\text{h}_v \| \text{AGG}] = [1, 2, 0.667, 1.167]

Step 3: Linear transformation with WR4×2W \in \mathbb{R}^{4 \times 2}:

W=[0.50.30.20.80.70.10.10.6]W = \begin{bmatrix} 0.5 & 0.3 \\ -0.2 & 0.8 \\ 0.7 & -0.1 \\ 0.1 & 0.6 \end{bmatrix} z=WT[1,2,0.667,1.167]z = W^T [1, 2, 0.667, 1.167] z1=0.5(1)+(0.2)(2)+0.7(0.667)+0.1(1.167)=0.50.4+0.467+0.117=0.684z_1 = 0.5(1) + (-0.2)(2) + 0.7(0.667) + 0.1(1.167) = 0.5 - 0.4 + 0.467 + 0.117 = 0.684 z2=0.3(1)+0.8(2)+(0.1)(0.667)+0.6(1.167)=0.3+1.60.067+0.700=2.533z_2 = 0.3(1) + 0.8(2) + (-0.1)(0.667) + 0.6(1.167) = 0.3 + 1.6 - 0.067 + 0.700 = 2.533

Step 4: ReLU activation:

hv=ReLU([0.684,2.533])=[0.684,2.533]h_v' = \text{ReLU}([0.684, 2.533]) = [0.684, 2.533]

Node vv‘s new representation now encodes information from both itself and its neighbors.

Example 3: GAT attention computation

Node vv with hv=[1,1]h_v = [1, 1] and neighbor uu with hu=[2,0]h_u = [2, 0].

Attention vector: a=[0.2,0.1,0.3,0.1]a = [0.2, -0.1, 0.3, 0.1] (dimension 4 because we concatenate two 2D vectors).

Concatenation: [hvhu]=[1,1,2,0][h_v \| h_u] = [1, 1, 2, 0].

Raw attention score:

evu=LeakyReLU(aT[hvhu])=LeakyReLU(0.2(1)+(0.1)(1)+0.3(2)+0.1(0))e_{vu} = \text{LeakyReLU}(a^T [h_v \| h_u]) = \text{LeakyReLU}(0.2(1) + (-0.1)(1) + 0.3(2) + 0.1(0)) =LeakyReLU(0.20.1+0.6+0)=LeakyReLU(0.7)=0.7= \text{LeakyReLU}(0.2 - 0.1 + 0.6 + 0) = \text{LeakyReLU}(0.7) = 0.7

(Since 0.7 > 0, LeakyReLU has no effect.)

Now suppose we also have neighbor ww with evw=0.2e_{vw} = 0.2. Softmax normalization over neighbors:

Given evu=0.4e_{vu} = 0.4 and evw=0.2e_{vw} = 0.2 (using the values specified in the problem):

αvu=exp(0.4)exp(0.4)+exp(0.2)=1.4921.492+1.221=1.4922.713=0.550\alpha_{vu} = \frac{\exp(0.4)}{\exp(0.4) + \exp(0.2)} = \frac{1.492}{1.492 + 1.221} = \frac{1.492}{2.713} = 0.550 αvw=exp(0.2)exp(0.4)+exp(0.2)=1.2212.713=0.450\alpha_{vw} = \frac{\exp(0.2)}{\exp(0.4) + \exp(0.2)} = \frac{1.221}{2.713} = 0.450

Neighbor uu gets 55% of the attention weight and neighbor ww gets 45%. The attention mechanism learned that uu is slightly more relevant to vv than ww is. The final aggregation would be:

hv=σ(0.550Whu+0.450Whw)h_v' = \sigma(0.550 \cdot W h_u + 0.450 \cdot W h_w)

Readout: from node representations to predictions

Different tasks need different outputs:

Node classification: predict a label for each node (e.g., classify users in a social network). Use the final node representations directly with a classifier on top.

Link prediction: predict whether an edge exists between two nodes. Compute a score from pairs of node representations, e.g., dot product: score(u,v)=huThv\text{score}(u, v) = h_u^T h_v.

Graph classification: predict a label for the entire graph (e.g., predict whether a molecule is toxic). You need a readout function that aggregates all node representations into a single graph-level vector:

hG=READOUT({hv:vV})h_G = \text{READOUT}(\{h_v : v \in V\})

Common readout functions: mean pooling, sum pooling, or hierarchical pooling that coarsens the graph in stages.

GNN variants comparison

NameAggregationAttentionScalableEdge featuresBest for
GCNNormalized meanModerateSemi-supervised node classification
GraphSAGESampled mean/max/LSTMLarge-scale inductive tasks
GATAttention-weighted sumModerateWith modificationTasks needing neighbor importance
GINSum (injective)ModerateGraph classification (most expressive)
MPNNFlexible (learned)OptionalVariesMolecular property prediction
SchNetContinuous filterModerate✓ Distance3D molecular geometry

Applications

GNNs have found use across many domains:

  • Chemistry: predicting molecular properties, drug discovery, reaction prediction
  • Social networks: recommendation systems, community detection, influence modeling
  • Computer vision: scene graphs, point cloud processing, human pose estimation
  • NLP: knowledge graph reasoning, semantic parsing, dependency parsing
  • Physics: particle simulations, fluid dynamics, material science
  • Infrastructure: traffic prediction, power grid analysis

Node classification accuracy vs GNN depth

Known limitations

  1. Over-smoothing: as you stack more layers, node representations converge to the same vector. After too many message passing rounds, every node looks the same. This limits GNNs to relatively shallow architectures (2-4 layers typically).
  2. Over-squashing: information from distant nodes gets compressed into fixed-size vectors, losing information. Think of it as a bottleneck in long-range message passing.
  3. Expressiveness: the Weisfeiler-Lehman (WL) test sets an upper bound on what message passing GNNs can distinguish. Standard GNNs cannot distinguish certain non-isomorphic graphs that the 1-WL test also fails on.

What comes next

With architectures for grids (CNNs), sequences (RNNs, Transformers), and graphs (GNNs) under your belt, you have the building blocks for nearly any deep learning system. The final piece is knowing how to make it all work in practice. Practical deep learning: debugging and tuning covers the skills that turn theoretical knowledge into working models.

Start typing to search across all content
navigate Enter open Esc close