Search…

Network compression and efficient inference

In this series (25 parts)
  1. Neural networks: the basic building block
  2. Forward pass and backpropagation
  3. Training neural networks: a practical guide
  4. Convolutional neural networks
  5. Recurrent neural networks and LSTMs
  6. Attention mechanism and transformers
  7. Word embeddings: from one-hot to dense representations
  8. Transfer learning and fine-tuning
  9. Optimization techniques for deep networks
  10. Regularization for deep networks
  11. Encoder-decoder architectures
  12. Generative models: an overview
  13. Restricted Boltzmann Machines
  14. Deep Belief Networks
  15. Variational Autoencoders
  16. Generative Adversarial Networks: training and theory
  17. DCGAN, conditional GANs, and GAN variants
  18. Representation learning and self-supervised learning
  19. Domain adaptation and fine-tuning strategies
  20. Distributed representations and latent spaces
  21. AutoML and hyperparameter optimization
  22. Neural architecture search
  23. Network compression and efficient inference
  24. Graph neural networks
  25. Practical deep learning: debugging and tuning

A ResNet-50 has 25 million parameters and needs 4 billion FLOPs for one forward pass. That’s fine on a server with a GPU. It’s not fine on a phone, a drone, or an IoT sensor. Network compression makes large models small and fast enough to run where they actually need to run: at the edge, under latency constraints, with limited memory and power.

Prerequisites

You should be comfortable with convolutional neural networks, transfer learning, and matrix decompositions (especially SVD). Understanding how training works will help with the fine-tuning steps that most compression methods require.

Why compression matters

Three practical reasons drive the need for smaller models:

  1. Latency: a self-driving car can’t wait 200ms for a prediction. Real-time applications need inference in single-digit milliseconds.
  2. Memory: mobile devices have limited RAM. A 400MB model won’t fit alongside the rest of an app.
  3. Energy: every FLOP costs energy. On battery-powered devices, fewer FLOPs means longer battery life. In data centers, less compute means lower electricity bills.

The good news: most neural networks are heavily overparameterized. They contain far more capacity than they need for the task. Compression exploits this redundancy.

Pruning: removing unnecessary weights

Pruning removes weights (or entire neurons/filters) that contribute little to the output. The simplest approach: remove weights with the smallest magnitude.

Unstructured vs structured pruning

Unstructured pruning zeroes out individual weights anywhere in the network. You get a sparse weight matrix. The compression ratio can be very high (90%+ of weights removed), but sparse matrix operations are not well supported on most hardware. You need special sparse libraries or hardware to see speedups.

Structured pruning removes entire filters, channels, or layers. The resulting network is a regular, smaller dense network that runs faster on standard hardware without special support. The compression ratio is usually lower, but the speedup is real and immediate.

graph LR
  A[Train full
model] --> B[Rank weights
by magnitude]
  B --> C[Remove smallest
weights/filters]
  C --> D[Fine-tune to
recover accuracy]
  D --> E{Accuracy
acceptable?}
  E -->|No| B
  E -->|Yes| F[Deploy compressed
model]

Example 1: Magnitude pruning

Weight matrix before pruning:

W=[0.800.100.500.020.700.300.150.600.04]W = \begin{bmatrix} 0.80 & -0.10 & 0.50 \\ 0.02 & -0.70 & 0.30 \\ 0.15 & 0.60 & -0.04 \end{bmatrix}

Threshold: prune all weights with w<0.2|w| < 0.2.

Check each weight:

  • 0.80=0.800.2|0.80| = 0.80 \geq 0.2 ✓ keep
  • 0.10=0.10<0.2|-0.10| = 0.10 < 0.2 ✗ prune
  • 0.50=0.500.2|0.50| = 0.50 \geq 0.2 ✓ keep
  • 0.02=0.02<0.2|0.02| = 0.02 < 0.2 ✗ prune
  • 0.70=0.700.2|-0.70| = 0.70 \geq 0.2 ✓ keep
  • 0.30=0.300.2|0.30| = 0.30 \geq 0.2 ✓ keep
  • 0.15=0.15<0.2|0.15| = 0.15 < 0.2 ✗ prune
  • 0.60=0.600.2|0.60| = 0.60 \geq 0.2 ✓ keep
  • 0.04=0.04<0.2|-0.04| = 0.04 < 0.2 ✗ prune

Sparse matrix after pruning:

Wpruned=[0.8000.5000.700.3000.600]W_{\text{pruned}} = \begin{bmatrix} 0.80 & 0 & 0.50 \\ 0 & -0.70 & 0.30 \\ 0 & 0.60 & 0 \end{bmatrix}

We removed 4 out of 9 weights. Compression ratio: 9/5=1.8×9/5 = 1.8\times. If we stored only non-zero values plus indices, we’d need 5×(value+index)5 \times (\text{value} + \text{index}) instead of 9×value9 \times \text{value}.

In practice, you’d fine-tune the remaining weights for a few epochs to recover accuracy lost from pruning.

The lottery ticket hypothesis

Frankle and Carlin (2019) proposed a striking idea: within a randomly initialized dense network, there exists a sparse subnetwork (the “winning ticket”) that, when trained in isolation from the same initialization, reaches the same accuracy as the full network.

The practical implication: you can find small networks that work just as well as large ones. The catch is that finding the winning ticket currently requires training the full network first, then pruning, then rewinding to the original initialization and retraining. This is expensive, but it tells us something deep about overparameterization.

Quantization: fewer bits per weight

Standard neural networks use 32-bit floating point (float32) for weights and activations. Quantization reduces this to 16-bit, 8-bit, or even lower. Fewer bits means less memory, faster computation, and lower energy.

Post-training quantization (PTQ): take a trained float32 model and convert weights to int8. No retraining needed. Simple but can lose accuracy, especially at very low bit widths.

The mapping from float32 to int8:

q=round(xxminxmaxxmin×255)q = \text{round}\left(\frac{x - x_{\min}}{x_{\max} - x_{\min}} \times 255\right) x^=q255×(xmaxxmin)+xmin\hat{x} = \frac{q}{255} \times (x_{\max} - x_{\min}) + x_{\min}

Quantization-aware training (QAT): simulate quantization during training. Forward passes use quantized values; backward passes use the straight-through estimator (gradients flow through the rounding operation as if it were the identity). This gives the network a chance to adapt to the quantization noise.

Mixed-precision: use lower precision (float16 or int8) where it doesn’t hurt and full precision where it does. Typically, the first and last layers are kept at higher precision because they handle raw inputs and final logits.

Key numbers to remember:

  • float32 to float16: 2x memory reduction, minimal accuracy loss
  • float32 to int8: 4x memory reduction, usually < 1% accuracy loss with QAT
  • float32 to int4: 8x memory reduction, requires careful calibration

Knowledge distillation: teacher-student learning

Knowledge distillation trains a small “student” network to mimic a large “teacher” network. The key insight from Hinton et al. (2015): the teacher’s soft probability outputs contain more information than hard labels.

When a teacher classifies an image of a cat, its softmax output might be [0.7, 0.2, 0.1] for [cat, dog, car]. The hard label is just “cat.” But the soft output tells you that this image looks somewhat like a dog and not at all like a car. This “dark knowledge” helps the student learn better representations.

graph TD
  Input["Input x"] --> Teacher["Teacher
(large model)"]
  Input --> Student["Student
(small model)"]
  Teacher --> SoftT["Soft targets
(temperature T)"]
  Student --> SoftS["Soft predictions
(temperature T)"]
  SoftT --> KL["KL divergence
loss"]
  SoftS --> KL
  Student --> HardS["Hard predictions"]
  Labels["True labels"] --> CE["Cross-entropy
loss"]
  HardS --> CE
  KL --> Total["Total loss =
α·KL + (1-α)·CE"]
  CE --> Total

The temperature parameter TT controls how soft the distributions are. At T=1T = 1, you get the standard softmax. At higher TT, the distribution becomes smoother, revealing more information about relative similarities:

pi=exp(zi/T)jexp(zj/T)p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}

Example 2: Knowledge distillation with temperature

Teacher logits: zT=[2.8,0.8,0.4]z^T = [2.8, 0.8, 0.4]

At T=1T = 1 (standard softmax):

pT=[exp(2.8),exp(0.8),exp(0.4)]exp(2.8)+exp(0.8)+exp(0.4)=[16.44,2.23,1.49]20.16=[0.816,0.111,0.074]p^T = \frac{[\exp(2.8), \exp(0.8), \exp(0.4)]}{\exp(2.8) + \exp(0.8) + \exp(0.4)} = \frac{[16.44, 2.23, 1.49]}{20.16} = [0.816, 0.111, 0.074]

At T=4T = 4 (soft targets):

zT/4=[0.70,0.20,0.10]z^T / 4 = [0.70, 0.20, 0.10] pT=4T=[exp(0.70),exp(0.20),exp(0.10)]exp(0.70)+exp(0.20)+exp(0.10)=[2.014,1.221,1.105]4.340=[0.464,0.281,0.255]p^T_{T=4} = \frac{[\exp(0.70), \exp(0.20), \exp(0.10)]}{\exp(0.70) + \exp(0.20) + \exp(0.10)} = \frac{[2.014, 1.221, 1.105]}{4.340} = [0.464, 0.281, 0.255]

Student logits: zS=[2.1,0.8,0.3]z^S = [2.1, 0.8, 0.3]

At T=4T = 4:

zS/4=[0.525,0.200,0.075]z^S / 4 = [0.525, 0.200, 0.075] pT=4S=[exp(0.525),exp(0.200),exp(0.075)]sum=[1.691,1.221,1.078]3.990=[0.424,0.306,0.270]p^S_{T=4} = \frac{[\exp(0.525), \exp(0.200), \exp(0.075)]}{\text{sum}} = \frac{[1.691, 1.221, 1.078]}{3.990} = [0.424, 0.306, 0.270]

KL divergence from student to teacher (at T=4T = 4):

DKL(pTpS)=ipiTlogpiTpiSD_{KL}(p^T \| p^S) = \sum_i p^T_i \log\frac{p^T_i}{p^S_i} =0.464log0.4640.424+0.281log0.2810.306+0.255log0.2550.270= 0.464 \log\frac{0.464}{0.424} + 0.281 \log\frac{0.281}{0.306} + 0.255 \log\frac{0.255}{0.270} =0.464×0.090+0.281×(0.085)+0.255×(0.057)= 0.464 \times 0.090 + 0.281 \times (-0.085) + 0.255 \times (-0.057) =0.04180.02390.0145=0.0034= 0.0418 - 0.0239 - 0.0145 = 0.0034

The KL divergence is small (0.0034), meaning the student’s soft predictions are close to the teacher’s. Notice how T=4T = 4 spreads the probability mass, making the teacher’s “dark knowledge” visible. The student can learn that class 2 (0.281) is more similar to the input than class 3 (0.255).

Low-rank factorization

A weight matrix WRm×nW \in \mathbb{R}^{m \times n} can be approximated by two smaller matrices using SVD:

WUrΣrVrTW \approx U_r \Sigma_r V_r^T

where UrRm×rU_r \in \mathbb{R}^{m \times r}, ΣrRr×r\Sigma_r \in \mathbb{R}^{r \times r}, VrRn×rV_r \in \mathbb{R}^{n \times r}, and rmin(m,n)r \ll \min(m, n).

This replaces one layer with two smaller layers. The original layer computes WxWx at cost mnmn. The factorized version computes Ur(Σr(VrTx))U_r(\Sigma_r(V_r^T x)) at cost nr+r+mr=(m+n)r+rnr + r + mr = (m + n)r + r.

Example 3: Low-rank approximation savings

Consider a 4×44 \times 4 weight matrix with singular values [5.0,3.0,0.1,0.05][5.0, 3.0, 0.1, 0.05].

The first two singular values (5.0 and 3.0) capture most of the energy. The last two (0.1 and 0.05) are tiny. Keeping rank r=2r = 2:

  • Original parameters: 4×4=164 \times 4 = 16
  • Factorized: UrU_r is 4×24 \times 2, Σr\Sigma_r is 2×22 \times 2 (diagonal, so 2 values), VrV_r is 4×24 \times 2. Total: 8+2+8=188 + 2 + 8 = 18.

For this tiny matrix, factorization actually uses more parameters. The savings come with larger matrices.

Scaling to a realistic layer: WW is 100×100100 \times 100, rank r=10r = 10.

  • Original: 100×100=10,000100 \times 100 = 10{,}000 parameters
  • Factorized: 100×10+10+10×100=1,000+10+1,000=2,010100 \times 10 + 10 + 10 \times 100 = 1{,}000 + 10 + 1{,}000 = 2{,}010 parameters
  • Compression ratio: 10,000/2,0104.98×10{,}000 / 2{,}010 \approx 4.98\times
  • Energy captured: (5.02+3.02)/(5.02+3.02+0.12+0.052)=34.0/34.0125=99.96%(5.0^2 + 3.0^2) / (5.0^2 + 3.0^2 + 0.1^2 + 0.05^2) = 34.0 / 34.0125 = 99.96\%

So we keep 99.96% of the information with 5x fewer parameters. In practice, you’d merge Σr\Sigma_r into UrU_r or VrV_r to avoid the extra diagonal matrix.

Mobile architectures: depthwise separable convolutions

Instead of compressing an existing model, you can design efficient architectures from scratch. MobileNet uses depthwise separable convolutions, which factor a standard convolution into two steps:

  1. Depthwise convolution: apply one filter per input channel (no cross-channel mixing)
  2. Pointwise convolution: 1x1 convolution to mix channels

A standard convolution with kernel k×kk \times k, CinC_{in} input channels, and CoutC_{out} output channels costs:

k2CinCoutHW FLOPsk^2 \cdot C_{in} \cdot C_{out} \cdot H \cdot W \text{ FLOPs}

Depthwise separable convolution costs:

k2CinHW+CinCoutHWk^2 \cdot C_{in} \cdot H \cdot W + C_{in} \cdot C_{out} \cdot H \cdot W

The ratio:

k2Cin+CinCoutk2CinCout=1Cout+1k2\frac{k^2 \cdot C_{in} + C_{in} \cdot C_{out}}{k^2 \cdot C_{in} \cdot C_{out}} = \frac{1}{C_{out}} + \frac{1}{k^2}

For k=3k = 3 and Cout=256C_{out} = 256: reduction factor 1/256+1/90.115\approx 1/256 + 1/9 \approx 0.115. That’s roughly 8-9x fewer FLOPs.

Model size vs accuracy after compression

Compression methods comparison

MethodCompression ratioAccuracy dropHardware friendlyTraining needed
Unstructured pruning5-20x0.5-2%✗ Needs sparse supportFine-tuning
Structured pruning2-5x1-3%✓ Standard dense opsFine-tuning
PTQ (int8)4x0.5-1%✓ Widely supportedNone
QAT (int8)4x< 0.5%✓ Widely supportedFull retraining
Knowledge distillation3-10x (model dependent)1-3%✓ Student is standardFull training
Low-rank factorization2-5x1-2%✓ Standard dense opsFine-tuning
Depthwise separable8-9x FLOPsArchitecture dependent✓ Optimized on mobileFull training

Combining methods

These methods are not mutually exclusive. A common pipeline:

  1. Start with a large, accurate teacher model
  2. Use NAS or manual design for a small student architecture (perhaps with depthwise separable convolutions)
  3. Train the student with knowledge distillation
  4. Apply quantization-aware training
  5. Optionally prune and fine-tune

Each step gives an independent compression factor. A 4x from distillation, 4x from quantization, and 2x from pruning gives 32x total compression.

What comes next

With efficient models ready for deployment, we can tackle a different class of problems: data that lives on graphs rather than grids. Graph neural networks extend the ideas from CNNs and attention mechanisms to irregular, non-Euclidean structures like social networks, molecules, and knowledge graphs.

Start typing to search across all content
navigate Enter open Esc close