Search…

Domain adaptation and fine-tuning strategies

In this series (25 parts)
  1. Neural networks: the basic building block
  2. Forward pass and backpropagation
  3. Training neural networks: a practical guide
  4. Convolutional neural networks
  5. Recurrent neural networks and LSTMs
  6. Attention mechanism and transformers
  7. Word embeddings: from one-hot to dense representations
  8. Transfer learning and fine-tuning
  9. Optimization techniques for deep networks
  10. Regularization for deep networks
  11. Encoder-decoder architectures
  12. Generative models: an overview
  13. Restricted Boltzmann Machines
  14. Deep Belief Networks
  15. Variational Autoencoders
  16. Generative Adversarial Networks: training and theory
  17. DCGAN, conditional GANs, and GAN variants
  18. Representation learning and self-supervised learning
  19. Domain adaptation and fine-tuning strategies
  20. Distributed representations and latent spaces
  21. AutoML and hyperparameter optimization
  22. Neural architecture search
  23. Network compression and efficient inference
  24. Graph neural networks
  25. Practical deep learning: debugging and tuning

Prerequisites

Before reading this article, make sure you are comfortable with:

The domain shift problem

You trained on product photos from a studio. Clean backgrounds, consistent lighting, high resolution. Now you deploy on blurry phone photos taken by customers. Performance drops from 95% to 72%.

This is like moving to a new city where the rules are slightly different. You know how to drive, but the road signs look different, traffic flows in unexpected ways, and your GPS keeps recalculating. Your core skills still apply, but the surface details have changed.

MetricSource domain (studio photos)Target domain (phone photos)
Accuracy95%72%
ConfidenceHighDrops on blurry inputs
Error typeRare misclassificationsSystematic failures on new backgrounds

Domain shift: source vs target distributions

graph LR
  A["Source distribution
Studio photos
Clean, consistent"] --> C["Same task
Same labels
Different features"]
  B["Target distribution
Phone photos
Noisy, varied"] --> C
  C --> D["Goal: learn features
that work for both"]

The core issue is that the training distribution (source domain) differs from the deployment distribution (target domain). The labels have not changed: a shoe is still a shoe. But the pixel patterns are different enough that the model struggles. We need strategies to bridge this gap.

Covariate shift vs dataset shift

Covariate shift: The input distribution changes (pS(x)pT(x)p_S(x) \neq p_T(x)), but the labeling function stays the same (p(yx)p(y|x) is identical). A cat is still a cat whether photographed with a DSLR or a phone camera. The relationship between features and labels hasn’t changed; the features themselves just look different.

Dataset shift (more general): Both p(x)p(x) and p(yx)p(y|x) may change. Reviews from different time periods might use the same words differently, so the mapping from text to sentiment actually changes.

Label shift: p(y)p(y) changes but p(xy)p(x|y) stays the same. A disease classifier trained where prevalence is 10% deployed where it’s 1%.

Most practical domain adaptation focuses on covariate shift, where the goal is to learn features that are invariant to the domain while remaining discriminative for the task.

Types of distribution shift

graph TD
  A["Distribution Shift"] --> B["Covariate Shift
P(x) changes
P(y|x) stays same"]
  A --> C["Label Shift
P(y) changes
P(x|y) stays same"]
  A --> D["Concept Drift
P(y|x) changes
The meaning of
features shifts"]
  B --> E["Example: DSLR vs phone
Same cat, different pixels"]
  C --> F["Example: disease prevalence
10% in training, 1% in deployment"]
  D --> G["Example: word meaning change
over time"]

Covariate shift is the most common and the most studied. The solutions below primarily target it, though some generalize to other shift types.

Feature extraction vs fine-tuning

Source vs target domain accuracy by adaptation method

The simplest approach to domain adaptation is transfer learning: take a pretrained model and adapt it to your target domain. Two strategies:

Feature extraction: Freeze the pretrained model. Use its output as a fixed feature vector. Train only a new classifier head on target data. This is safe when you have very little target data (say, under 100 examples per class), because you’re only training a small number of parameters and can’t overfit badly.

Full fine-tuning: Unfreeze the entire model and train on target data with a small learning rate. This is more powerful but riskier. If target data is scarce, the model can forget useful pretrained features (catastrophic forgetting) or overfit to the target training set.

The right choice depends on two factors: how much target data you have, and how different the source and target domains are.

flowchart TD
  START["Need to adapt model to new domain"] --> Q1{"How much target labeled data?"}
  Q1 -->|"< 100 per class"| Q2{"How different are domains?"}
  Q1 -->|"100-1000 per class"| FT_PARTIAL["Fine-tune top layers
Freeze bottom layers"]
  Q1 -->|"> 1000 per class"| FT_FULL["Full fine-tuning
with small lr"]

  Q2 -->|"Similar"| FE["Feature extraction
(freeze all, train head)"]
  Q2 -->|"Very different"| DA["Domain adaptation
(DANN, CORAL)"]

  FT_PARTIAL --> LR["Use layer-wise
learning rate decay"]

  style START fill:#9775fa,color:#fff
  style FE fill:#51cf66,color:#fff
  style FT_PARTIAL fill:#74c0fc,color:#fff
  style FT_FULL fill:#ffa94d,color:#fff
  style DA fill:#ff6b6b,color:#fff

Layer-wise learning rate decay

When fine-tuning, not all layers should learn at the same rate. Earlier layers capture general features (edges, textures) that transfer well. Later layers capture task-specific features that need more adaptation. Layer-wise learning rate decay assigns a smaller learning rate to earlier layers.

Given a base learning rate η\eta and a decay factor γ(0,1)\gamma \in (0, 1), the learning rate for layer ll (counting from the top) is:

ηl=ηγLl\eta_l = \eta \cdot \gamma^{L - l}

where LL is the total number of layers. The top layer (closest to the output) gets the full learning rate η\eta. Each layer below it gets γ\gamma times the rate of the layer above.

This prevents catastrophic forgetting of low-level features while allowing high-level features to adapt. The decay factor γ\gamma is typically between 0.1 and 0.5. Smaller γ\gamma means more aggressive freezing of early layers.

DANN: domain adversarial neural network

Domain Adversarial Neural Network (DANN, 2016) takes a more principled approach. The idea: train a feature extractor that produces representations which are (1) useful for the task and (2) indistinguishable between source and target domains.

The architecture has three components:

  1. Feature extractor GfG_f: maps input xx to features f=Gf(x)f = G_f(x)
  2. Label predictor GyG_y: maps features to task labels y^=Gy(f)\hat{y} = G_y(f)
  3. Domain classifier GdG_d: maps features to domain labels d^=Gd(f)\hat{d} = G_d(f) (source vs target)

The training objective:

L=Ly(Gy(Gf(xs)),ys)λLd(Gd(Gf(x)),d)\mathcal{L} = \mathcal{L}_y(G_y(G_f(x_s)), y_s) - \lambda \, \mathcal{L}_d(G_d(G_f(x)), d)

The feature extractor minimizes the label prediction loss while maximizing the domain classification loss. In other words, it tries to create features that help predict labels but confuse the domain classifier.

The minus sign is implemented through a gradient reversal layer (GRL). During forward pass, the GRL is an identity function. During backward pass, it multiplies the gradient by λ-\lambda. This elegantly handles the adversarial objective with standard gradient descent.

flowchart LR
  X["Input x"] --> GF["Feature extractor
Gf"]
  GF --> F["Features f"]
  F --> GY["Label predictor
Gy"]
  GY --> YHAT["ŷ (task label)"]

  F --> GRL["Gradient reversal
layer (×−λ)"]
  GRL --> GD["Domain classifier
Gd"]
  GD --> DHAT["d̂ (source/target)"]

  YHAT --> LY["Label loss ℒy
(minimize)"]
  DHAT --> LD["Domain loss ℒd
(maximize via GRL)"]

  style GRL fill:#ff6b6b,color:#fff
  style LY fill:#51cf66,color:#fff
  style LD fill:#ffa94d,color:#fff

The theory behind DANN connects to Ben-David’s domain adaptation bound. The target error is bounded by the source error plus the domain divergence (measured by the ability of a classifier to distinguish domains). By minimizing both terms, DANN minimizes the bound on target error.

CORAL: aligning statistics

CORrelation ALignment (CORAL) takes a simpler approach: align the second-order statistics (covariance matrices) of source and target features.

Given source features with covariance CSC_S and target features with covariance CTC_T, the CORAL loss is:

LCORAL=14d2CSCTF2\mathcal{L}_{\text{CORAL}} = \frac{1}{4d^2} \|C_S - C_T\|_F^2

where dd is the feature dimension and F\|\cdot\|_F is the Frobenius norm.

The 14d2\frac{1}{4d^2} normalization makes the loss scale-invariant with respect to feature dimension. You can add this loss to any standard training objective as a regularization term:

Ltotal=Ltask+αLCORAL\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \alpha \, \mathcal{L}_{\text{CORAL}}

CORAL is simple to implement and adds minimal computational overhead. It doesn’t require a separate domain classifier or gradient reversal. The limitation is that it only aligns second-order statistics. If the domain shift involves higher-order structure (e.g., different cluster shapes), CORAL may not be enough.

Deep CORAL applies this idea to features at intermediate layers of a deep network, not just the final features.

Comparing adaptation strategies

graph TD
  A["Fine-tuning
Retrain top layers
on target data"] --> D["Needs labeled
target data"]
  B["Domain Adversarial
DANN
Fool domain classifier"] --> E["Works with
unlabeled target data"]
  C["Distribution Matching
CORAL
Align feature statistics"] --> F["Simple, minimal
overhead"]
  D --> G["Best when you have
100+ labeled examples"]
  E --> H["Best for significant
domain gap"]
  F --> I["Best as lightweight
regularizer"]

Fine-tuning is the default when you have labeled target data. DANN is more powerful when labels are scarce but unlabeled target data is plentiful. CORAL is the simplest to add and works well as a regularization term alongside any other objective.

Few-shot learning

Few-shot learning handles the extreme case: you have only a handful of labeled examples per class in the target domain. The standard setup is N-way K-shot: classify among NN classes, with only KK examples per class.

Prototypical networks: Compute a prototype (mean embedding) for each class from the KK examples. Classify a query by finding the nearest prototype in embedding space.

For class cc with support examples ScS_c:

pc=1ScxScf(x)\mathbf{p}_c = \frac{1}{|S_c|} \sum_{x \in S_c} f(x)

Classification uses softmax over negative distances:

p(y=cx)=exp(d(f(x),pc))cexp(d(f(x),pc))p(y = c | x) = \frac{\exp(-d(f(x), \mathbf{p}_c))}{\sum_{c'} \exp(-d(f(x), \mathbf{p}_{c'}))}

where dd is typically squared Euclidean distance.

Meta-learning (learning to learn): Train the model across many few-shot tasks (episodes) sampled from a large dataset. Each episode simulates the few-shot scenario. The model learns to quickly adapt to new classes from few examples. MAML (Model-Agnostic Meta-Learning) is the canonical example: it finds model parameters that can be adapted to any new task with just a few gradient steps.

Transfer scenarios comparison

Source domainTarget domainTarget data sizeRecommended strategyExpected gain
ImageNet (photos)Product photos> 5000 labeledFull fine-tuningHigh (similar domains)
ImageNet (photos)Medical X-rays100-500 labeledFeature extraction + linear headModerate (different domains, limited data)
ImageNet (photos)Satellite imagery1000-5000 labeledFine-tune top layers, freeze bottomModerate to high
English NLPFrench NLP> 10000 labeledFull fine-tuning multilingual modelHigh with multilingual pretrain
Hospital A scansHospital B scans50-200 labeledDANN or CORAL + fine-tuneModerate (covariate shift)
Natural imagesArtistic paintingsNo labels in targetUnsupervised DA (DANN, CycleGAN features)Variable

Example 1: layer-wise learning rate decay

A 5-layer network with base learning rate η=0.001\eta = 0.001 and decay factor γ=0.3\gamma = 0.3.

Layers are numbered 1 (bottom/earliest) to 5 (top/closest to output).

ηl=ηγLl\eta_l = \eta \cdot \gamma^{L - l}
LayerComputationLearning rate
5 (top)0.001×0.355=0.001×0.300.001 \times 0.3^{5-5} = 0.001 \times 0.3^00.0010.001
40.001×0.354=0.001×0.310.001 \times 0.3^{5-4} = 0.001 \times 0.3^10.00030.0003
30.001×0.353=0.001×0.320.001 \times 0.3^{5-3} = 0.001 \times 0.3^20.000090.00009
20.001×0.352=0.001×0.330.001 \times 0.3^{5-2} = 0.001 \times 0.3^30.0000270.000027
1 (bottom)0.001×0.351=0.001×0.340.001 \times 0.3^{5-1} = 0.001 \times 0.3^40.00000810.0000081

Layer 1 learns ~123x slower than layer 5. The bottom layers (which capture general features like edges and textures) are nearly frozen, while the top layers (task-specific features) adapt freely.

In practice, you’d implement this in your optimizer by creating parameter groups. In PyTorch:

params = []
for i, layer in enumerate(model.layers):
    lr = base_lr * (decay ** (num_layers - 1 - i))
    params.append({"params": layer.parameters(), "lr": lr})
optimizer = torch.optim.Adam(params)

If the decay were γ=0.1\gamma = 0.1 instead, layer 1 would get η1=0.001×0.14=0.0000001\eta_1 = 0.001 \times 0.1^4 = 0.0000001, essentially frozen. With γ=0.5\gamma = 0.5, it would get 0.001×0.54=0.00006250.001 \times 0.5^4 = 0.0000625, still learning but slowly.

Example 2: CORAL loss

Source features have covariance matrix CSC_S and target features have covariance matrix CTC_T.

CS=[2113],CT=[1002]C_S = \begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix}, \quad C_T = \begin{bmatrix} 1 & 0 \\ 0 & 2 \end{bmatrix}

Feature dimension d=2d = 2.

Step 1: Compute the difference

CSCT=[21101032]=[1111]C_S - C_T = \begin{bmatrix} 2-1 & 1-0 \\ 1-0 & 3-2 \end{bmatrix} = \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix}

Step 2: Frobenius norm squared

The Frobenius norm squared is the sum of squares of all elements:

CSCTF2=12+12+12+12=4\|C_S - C_T\|_F^2 = 1^2 + 1^2 + 1^2 + 1^2 = 4

Step 3: CORAL loss

LCORAL=14d2CSCTF2=14×4×4=416=0.25\mathcal{L}_{\text{CORAL}} = \frac{1}{4d^2} \|C_S - C_T\|_F^2 = \frac{1}{4 \times 4} \times 4 = \frac{4}{16} = 0.25

What does this number mean? The source features have positive correlation between dimensions (off-diagonal = 1), while target features are uncorrelated (off-diagonal = 0). The source also has higher variance in both dimensions. The CORAL loss penalizes these statistical differences.

If we add this to a classification loss with weight α=1.0\alpha = 1.0:

Ltotal=Ltask+1.0×0.25=Ltask+0.25\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + 1.0 \times 0.25 = \mathcal{L}_{\text{task}} + 0.25

During training, the gradient of LCORAL\mathcal{L}_{\text{CORAL}} pushes the feature extractor to produce features whose source and target covariances align. Over time, CSC_S and CTC_T should converge, and the CORAL loss should approach zero.

Example 3: few-shot 5-way 1-shot classification

We have 5 classes, 1 support example per class. A query image arrives and we need to classify it.

The encoder ff maps each image to an embedding. The 5 class prototypes (just the single support embedding for each class) are already computed. We compute the Euclidean distance from the query embedding to each prototype.

Given distances from the query to each class prototype:

d=[0.3,0.8,0.6,0.9,0.4]d = [0.3, 0.8, 0.6, 0.9, 0.4]

These are distances, so smaller means more similar. Prototypical networks use negative distances in a softmax:

p(y=cx)=exp(dc)c=15exp(dc)p(y = c | x) = \frac{\exp(-d_c)}{\sum_{c'=1}^{5} \exp(-d_{c'})}

Step 1: Compute exp(dc)\exp(-d_c) for each class

exp(0.3)=0.7408\exp(-0.3) = 0.7408 exp(0.8)=0.4493\exp(-0.8) = 0.4493 exp(0.6)=0.5488\exp(-0.6) = 0.5488 exp(0.9)=0.4066\exp(-0.9) = 0.4066 exp(0.4)=0.6703\exp(-0.4) = 0.6703

Step 2: Sum

Z=0.7408+0.4493+0.5488+0.4066+0.6703=2.8158Z = 0.7408 + 0.4493 + 0.5488 + 0.4066 + 0.6703 = 2.8158

Step 3: Softmax probabilities

p(y=1)=0.74082.8158=0.2631p(y=1) = \frac{0.7408}{2.8158} = 0.2631 p(y=2)=0.44932.8158=0.1596p(y=2) = \frac{0.4493}{2.8158} = 0.1596 p(y=3)=0.54882.8158=0.1949p(y=3) = \frac{0.5488}{2.8158} = 0.1949 p(y=4)=0.40662.8158=0.1444p(y=4) = \frac{0.4066}{2.8158} = 0.1444 p(y=5)=0.67032.8158=0.2380p(y=5) = \frac{0.6703}{2.8158} = 0.2380

Prediction: Class 1 with probability 0.2631.

The model is not very confident. The margin between class 1 (0.263) and class 5 (0.238) is small. This is typical in 1-shot settings. With 5-shot (5 examples per class), prototypes would be more robust (averaging 5 embeddings reduces noise), and confidence would increase. The distances would also typically separate more because the prototypes better represent the class centers.

Practical fine-tuning guide

  1. Always start with the simplest approach. Try feature extraction (frozen backbone + linear head) first. If that works well enough, stop. If not, try fine-tuning the top few layers. Only resort to full fine-tuning or domain adaptation methods if simpler approaches fail.

  2. Use a smaller learning rate for fine-tuning than for training from scratch. A common starting point: 110\frac{1}{10} of the original pretraining learning rate.

  3. Watch for overfitting. With small target datasets, validation loss will start increasing quickly. Use early stopping, regularization (dropout, weight decay), and data augmentation.

  4. Gradual unfreezing is an alternative to layer-wise lr decay. Start with all layers frozen except the head. Train for a few epochs. Unfreeze the top layer. Train more. Unfreeze the next layer. This gives you fine-grained control and is less sensitive to learning rate choices.

  5. Domain adaptation methods (DANN, CORAL) are most valuable when you have unlabeled target data. If you have labeled target data, fine-tuning is usually simpler and works just as well.

Summary

Domain shift is inevitable in real applications. The right adaptation strategy depends on how much labeled target data you have and how different the domains are. Feature extraction is safe with little data. Fine-tuning with layer-wise learning rate decay works well with moderate data. DANN and CORAL are useful when you have unlabeled target data and significant domain shift. Few-shot methods handle the extreme case of just a handful of examples.

The common thread: good representations transfer. The better your pretrained features, the less adaptation you need.

What comes next

The next article on distributed representations and latent spaces dives into the geometry of learned representations. You’ll see how word embedding analogies work, why latent spaces have linear structure, how to measure disentanglement, and how latent arithmetic in GANs and VAEs enables controlled generation.

Start typing to search across all content
navigate Enter open Esc close