Case study: implementing a convolution layer in CUDA
In this series (28 parts)
- GPUs: from pixels to parallel supercomputers
- Your first CUDA program: kernels, threads, and grids
- Thread hierarchy in CUDA: threads, blocks, warps, and grids
- Warps and warp divergence: the hidden performance trap
- CUDA memory hierarchy: where your data lives matters
- Memory coalescing: the most important optimization you will learn
- Shared memory and tiling: the key to fast matrix operations
- Debugging and profiling CUDA programs
- Device functions, host functions, and CUDA function qualifiers
- Synchronization and atomic operations in CUDA
- Parallel prefix sum and reduction: the core parallel primitives
- Concurrent data structures on the GPU
- CUDA streams and asynchronous execution
- CUDA events and fine-grained synchronization
- Dynamic parallelism: kernels launching kernels
- Unified virtual memory: one pointer for CPU and GPU
- Multi-GPU programming and peer access
- Memory allocation patterns and multi-dimensional arrays in CUDA
- Texture and constant memory: specialized caches
- Occupancy, register pressure, and performance tuning
- Case study: matrix multiplication from naive to cuBLAS speed
- Case study: implementing a convolution layer in CUDA
- Case study: reduction and histogram at scale
- Heterogeneous computing: CPU and GPU working together
- Advanced memory patterns: pinned memory, zero-copy, and more
- Advanced stream patterns and concurrent kernel execution
- Performance case studies and optimization patterns
- Where to go from here: CUDA ecosystem and next steps
Prerequisites
This article assumes you have read the following:
- Case study: matrix multiply for tiled GEMM kernels, register tiling, and cuBLAS performance characteristics.
- Shared memory and tiling for the tiling pattern, bank conflicts, and synchronization.
You should be comfortable with convolutional neural networks at an algorithmic level (filters, stride, padding, feature maps) and with matrix operations including how matrix dimensions propagate through multiplication. This article focuses on implementing those operations efficiently on a GPU.
Why convolution is a GPU problem
A single convolution layer in ResNet-50 can require over 100 million multiply-add operations. The entire forward pass of the network performs roughly 4 billion FLOPs. CPUs process these sequentially across output pixels. GPUs process them in parallel, turning minutes into milliseconds.
The challenge is not raw arithmetic throughput. Modern GPUs have more than enough compute for convolution. The challenge is keeping that compute fed with data. Convolution has a low arithmetic intensity when implemented naively: each input element is loaded from global memory many times across overlapping filter windows. Every optimization in this article is fundamentally about reducing redundant memory traffic.
Direct 2D convolution kernel
The most straightforward approach maps one CUDA thread to one output pixel. Each thread slides the filter over its corresponding input patch, accumulates the dot product across all input channels, and writes one output value.
__global__ void conv2d_direct(
const float* input, // N x C x H x W
const float* filter, // K x C x R x S
float* output, // N x K x OH x OW
int C, int H, int W,
int K, int R, int S,
int OH, int OW,
int stride, int pad)
{
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
int k = blockIdx.z;
if (ow >= OW || oh >= OH) return;
float sum = 0.0f;
for (int c = 0; c < C; c++) {
for (int r = 0; r < R; r++) {
for (int s = 0; s < S; s++) {
int ih = oh * stride - pad + r;
int iw = ow * stride - pad + s;
if (ih >= 0 && ih < H && iw >= 0 && iw < W) {
sum += input[c * H * W + ih * W + iw]
* filter[k * C * R * S + c * R * S + r * S + s];
}
}
}
}
output[k * OH * OW + oh * OW + ow] = sum;
}
This kernel works correctly but has serious performance problems:
- Redundant global loads. Adjacent output pixels share most of their input data (overlapping receptive fields), but each thread loads independently.
- No data reuse. The filter weights are the same for every output pixel in a given output channel, yet every thread re-reads them from global memory.
- Branch divergence. The boundary checks (
ih >= 0 && ih < H) cause warp divergence for threads near the edges.
On an A100, this direct kernel achieves roughly 500 GFLOPS for a 3x3 convolution on 32x32 input. The card is capable of 19,500 GFLOPS (FP32). That is 2.5% utilization.
Convolution as matrix multiply: the im2col transformation
The key insight behind every high-performance convolution implementation is that convolution can be reformulated as a single large matrix multiplication. The technique is called im2col (image to column): extract every input patch that a filter will touch, lay those patches out as columns of a matrix, then multiply by the filter matrix.
graph TD
subgraph "Input (C=3, H=32, W=32)"
A["3 x 32 x 32 tensor"]
end
subgraph "im2col: extract patches"
B["Patch at (0,0): 3x3x3 = 27 values"]
C["Patch at (0,1): 3x3x3 = 27 values"]
D["..."]
E["Patch at (31,31): 3x3x3 = 27 values"]
end
subgraph "Column Matrix (27 x 1024)"
F["Each column = one flattened patch"]
end
subgraph "Filter Matrix (64 x 27)"
G["Each row = one filter flattened"]
end
subgraph "Output Matrix (64 x 1024)"
H["Reshape to 64 x 32 x 32"]
end
A --> B
A --> C
A --> E
B --> F
C --> F
E --> F
G --- I["GEMM: (64x27) x (27x1024)"]
F --- I
I --> H
Worked example: shape arithmetic
Consider a concrete layer: input 32 x 32 with 3 channels, 64 filters of size 3 x 3, stride 1, padding 1.
Output spatial dimensions:
OH = (H + 2 * pad - R) / stride + 1 = (32 + 2 - 3) / 1 + 1 = 32
OW = (W + 2 * pad - S) / stride + 1 = (32 + 2 - 3) / 1 + 1 = 32
im2col matrix dimensions:
Each patch covers C * R * S = 3 * 3 * 3 = 27 values. There are OH * OW = 32 * 32 = 1024 patches. The im2col matrix is 27 x 1024.
Filter matrix dimensions:
K filters, each of size C * R * S = 27. The filter matrix is 64 x 27.
GEMM:
(64 x 27) * (27 x 1024) = 64 x 1024
Reshape to 64 x 32 x 32. That is the output feature map.
Convolution FLOPs
Total multiply-adds = K * C * R * S * OH * OW = 64 * 3 * 3 * 3 * 32 * 32 = 1,769,472.
Each multiply-add is 2 FLOPs (one multiply, one add), giving 3,538,944 FLOPs total.
On a GPU achieving 10 TFLOPS effective throughput, this single layer takes 3,538,944 / 10⁹⁸ = 0.35 microseconds of pure compute. In practice, memory latency and kernel launch overhead dominate at this scale. Larger layers (e.g., 56 x 56 x 64 input with 128 filters) produce FLOPs in the hundreds of millions, where compute time becomes measurable and optimization matters.
im2col + cuBLAS GEMM implementation
The im2col approach separates the data layout transformation from the computation. The im2col kernel rearranges data, then cuBLAS handles the optimized GEMM.
__global__ void im2col_kernel(
const float* input, // C x H x W
float* col, // (C*R*S) x (OH*OW)
int C, int H, int W,
int R, int S,
int OH, int OW,
int stride, int pad)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = C * R * S * OH * OW;
if (idx >= total) return;
int ow = idx % OW;
int oh = (idx / OW) % OH;
int s = (idx / (OW * OH)) % S;
int r = (idx / (OW * OH * S)) % R;
int c = idx / (OW * OH * S * R);
int ih = oh * stride - pad + r;
int iw = ow * stride - pad + s;
float val = 0.0f;
if (ih >= 0 && ih < H && iw >= 0 && iw < W) {
val = input[c * H * W + ih * W + iw];
}
int col_row = c * R * S + r * S + s;
int col_col = oh * OW + ow;
col[col_row * (OH * OW) + col_col] = val;
}
void conv2d_im2col(
const float* d_input, const float* d_filter, float* d_output,
int C, int H, int W, int K, int R, int S,
int stride, int pad, cublasHandle_t handle)
{
int OH = (H + 2 * pad - R) / stride + 1;
int OW = (W + 2 * pad - S) / stride + 1;
int col_rows = C * R * S;
int col_cols = OH * OW;
float* d_col;
cudaMalloc(&d_col, col_rows * col_cols * sizeof(float));
int total = col_rows * col_cols;
int threads = 256;
int blocks = (total + threads - 1) / threads;
im2col_kernel<<<blocks, threads>>>(
d_input, d_col, C, H, W, R, S, OH, OW, stride, pad);
float alpha = 1.0f, beta = 0.0f;
// filter: K x col_rows, col: col_rows x col_cols, output: K x col_cols
cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N,
col_cols, K, col_rows,
&alpha, d_col, col_cols,
d_filter, col_rows,
&beta, d_output, col_cols);
cudaFree(d_col);
}
The im2col approach has a clear trade-off: it allocates extra memory (the column matrix can be C * R * S times larger than the original input) but converts convolution into a standard GEMM where cuBLAS can apply all its tiling, register blocking, and tensor core optimizations. For 3 x 3 filters, this 9x memory expansion is acceptable. For larger filters it becomes expensive.
Shared memory tiling for direct convolution
An alternative to im2col is to keep the direct convolution structure but use shared memory to eliminate redundant global memory loads. The idea: load a tile of input data (including the halo region needed by border threads) into shared memory once, then let all threads in the block read from shared memory for their filter computations.
#define TILE 16
#define PAD 1 // for 3x3 filter with padding=1
__global__ void conv2d_tiled(
const float* input,
const float* filter,
float* output,
int C, int H, int W,
int K, int OH, int OW)
{
__shared__ float tile[TILE + 2][TILE + 2]; // halo of 1 on each side
int tx = threadIdx.x;
int ty = threadIdx.y;
int ow = blockIdx.x * TILE + tx;
int oh = blockIdx.y * TILE + ty;
int k = blockIdx.z;
float sum = 0.0f;
for (int c = 0; c < C; c++) {
// Load tile with halo into shared memory
int ih = oh - PAD;
int iw = ow - PAD;
// Each thread loads its core element
if (ih >= 0 && ih < H && iw >= 0 && iw < W)
tile[ty][tx] = input[c * H * W + ih * W + iw];
else
tile[ty][tx] = 0.0f;
// Border threads load halo elements
if (tx < 2 && (iw + TILE) < W && ih >= 0 && ih < H)
tile[ty][tx + TILE] = input[c * H * W + ih * W + (iw + TILE)];
if (ty < 2 && (ih + TILE) < H && iw >= 0 && iw < W)
tile[ty + TILE][tx] = input[c * H * W + (ih + TILE) * W + iw];
if (tx < 2 && ty < 2)
tile[ty + TILE][tx + TILE] =
((ih + TILE) < H && (iw + TILE) < W) ?
input[c * H * W + (ih + TILE) * W + (iw + TILE)] : 0.0f;
__syncthreads();
// Compute 3x3 convolution from shared memory
for (int r = 0; r < 3; r++) {
for (int s = 0; s < 3; s++) {
sum += tile[ty + r][tx + s]
* filter[k * C * 9 + c * 9 + r * 3 + s];
}
}
__syncthreads();
}
if (oh < OH && ow < OW)
output[k * OH * OW + oh * OW + ow] = sum;
}
The tiled kernel loads each input element from global memory once per channel instead of up to 9 times (for a 3 x 3 filter). The shared memory tile includes a halo region so that border threads have access to neighboring data without extra global loads. This gives roughly a 3x to 5x speedup over the naive direct kernel, depending on the filter size and input dimensions.
The filter weights are still loaded from global memory in the inner loop. For small filters (3 x 3), moving them to constant memory via __constant__ provides a further boost because the constant cache broadcasts the same value to all threads in a warp.
cuDNN: production convolution
Nobody ships hand-written convolution kernels in production. NVIDIA’s cuDNN library provides highly tuned implementations that select the best algorithm at runtime based on input dimensions, data type, and available hardware.
cuDNN supports multiple convolution algorithms:
| Algorithm | Idea | Best for |
|---|---|---|
| IMPLICIT_GEMM | im2col fused into GEMM, no extra workspace | Memory-constrained cases |
| IMPLICIT_PRECOMP_GEMM | Precomputed im2col indices | Large spatial dimensions |
| GEMM | Explicit im2col + cuBLAS | General case, good baseline |
| WINOGRAD | Reduced multiplications via transform | 3x3 filters, stride 1 |
| WINOGRAD_NONFUSED | Winograd with separate transform steps | Very large batches |
| FFT | Frequency-domain multiplication | Large filters (5x5+) |
Winograd convolution
Winograd is the dominant algorithm for 3 x 3 convolutions in modern networks. A standard 3 x 3 convolution on a 4 x 4 input tile requires 36 multiplications. The Winograd F(2x2, 3x3) algorithm computes the same result with only 16 multiplications by transforming both input and filter into a different basis, performing element-wise multiplication, and transforming back.
The trade-off is more additions and the need for transform operations. On GPUs where multiply throughput is the bottleneck (especially with FP16 and tensor cores), Winograd’s 2.25x reduction in multiplications translates to significant speedups. cuDNN’s Winograd implementation handles the tile management, boundary conditions, and numerical stability that make hand-written Winograd impractical.
Python benchmark: custom conv vs PyTorch
import torch
import torch.nn.functional as F
import time
def benchmark(fn, warmup=50, repeat=200):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
return (elapsed / repeat) * 1000 # ms
device = torch.device("cuda")
# ResNet-50 first conv: 3 -> 64 channels, 7x7 filter, stride 2, 224x224 input
x = torch.randn(1, 3, 224, 224, device=device)
w = torch.randn(64, 3, 7, 7, device=device)
# Naive im2col via unfold + matmul
def custom_conv():
patches = F.unfold(x, kernel_size=7, stride=2, padding=3) # (1, 147, 12544)
out = (w.view(64, -1) @ patches).view(1, 64, 112, 112)
return out
# cuDNN via PyTorch
def cudnn_conv():
return F.conv2d(x, w, stride=2, padding=3)
t_custom = benchmark(custom_conv)
t_cudnn = benchmark(cudnn_conv)
print(f"Custom im2col + matmul: {t_custom:.3f} ms")
print(f"cuDNN conv2d: {t_cudnn:.3f} ms")
print(f"Speedup: {t_custom / t_cudnn:.1f}x")
Typical results on an A100 (batch size 1, FP32):
| Method | Time (ms) | Relative |
|---|---|---|
| Naive direct (custom CUDA kernel) | 2.41 | 8.0x slower |
| im2col + matmul (unfold + @) | 0.82 | 2.7x slower |
| cuDNN conv2d (auto-tuned) | 0.30 | 1.0x (baseline) |
cuDNN’s advantage comes from algorithm selection (it picks Winograd or FFT when beneficial), fused operations, and hardware-specific tuning that accounts for cache sizes, tensor core availability, and memory bandwidth.
Profiling a ResNet-50 forward pass
Real networks are not a single convolution. ResNet-50 has 50 weighted layers organized into bottleneck blocks. Each block contains convolutions, batch normalization, ReLU activations, and residual additions. Understanding where time is actually spent requires profiling the full forward pass.
graph LR
subgraph "ResNet-50 Forward Pass Time Breakdown"
A["Conv layers
68%"] --> B["BatchNorm
14%"]
B --> C["ReLU + Add
8%"]
C --> D["Pooling
4%"]
D --> E["FC + Softmax
2%"]
E --> F["Memory ops
4%"]
end
style A fill:#e74c3c,color:#fff
style B fill:#f39c12,color:#fff
style C fill:#27ae60,color:#fff
style D fill:#2980b9,color:#fff
style E fill:#8e44ad,color:#fff
style F fill:#7f8c8d,color:#fff
The breakdown above comes from Nsight Systems profiling on an A100 with batch size 32, FP32. Convolution dominates at 68% of total time, but the remaining 32% in batch norm, activations, pooling, and memory operations is not negligible. This motivates operator fusion: combining conv + batch norm + ReLU into a single kernel to eliminate intermediate global memory writes.
Layer-by-layer breakdown within a bottleneck block
A single ResNet-50 bottleneck block has three convolutions: 1x1 (reduce channels), 3x3 (spatial processing), and 1x1 (expand channels). The 3x3 convolution dominates the block’s compute time despite having fewer output channels, because its spatial footprint is 9x larger per output element.
The 3x3 convolution at 57% of block time is the primary optimization target. This is where cuDNN’s Winograd algorithm has the largest impact. Switching from IMPLICIT_GEMM to WINOGRAD for 3x3 layers typically saves 20-40% on the 3x3 convolution time alone.
Worked example: FLOPs and roofline for a ResNet-50 bottleneck
Consider the third stage of ResNet-50, where the input is 14 x 14 x 256.
1x1 conv (256 -> 128): FLOPs = 2 * 128 * 256 * 1 * 1 * 14 * 14 = 12,845,056
3x3 conv (128 -> 128): FLOPs = 2 * 128 * 128 * 3 * 3 * 14 * 14 = 57,802,752
1x1 conv (128 -> 256): FLOPs = 2 * 256 * 128 * 1 * 1 * 14 * 14 = 12,845,056
Total block FLOPs: 83,492,864 (about 83.5 MFLOP).
On an A100 with 19.5 TFLOPS FP32 peak, the compute lower bound is 83.5M / 19.5T = 4.3 microseconds. The actual measured time is about 323 microseconds (summing from the chart above). That is a 75x gap between compute bound and actual execution.
This gap exists because:
- Memory bandwidth limits. The 1x1 convolutions are heavily memory-bound (low arithmetic intensity: only 256 FLOPs per element loaded).
- Kernel launch overhead. Six separate kernels (three convolutions, three batch norms) each incur 5-10 microseconds of launch latency.
- Unfused operations. Each intermediate tensor is written to and read back from global memory.
This is exactly why frameworks like TensorRT fuse conv + batch norm + ReLU into single kernels: it eliminates intermediate memory traffic and amortizes kernel launch costs.
Algorithm selection heuristics
cuDNN’s cudnnFindConvolutionForwardAlgorithm benchmarks all applicable algorithms and returns them ranked by execution time. The general patterns:
- 3x3 filters, stride 1: Winograd wins for moderate to large spatial dimensions (14x14 and above). Falls back to IMPLICIT_GEMM for very small spatial dimensions where transform overhead dominates.
- 1x1 filters: This is pure GEMM. IMPLICIT_GEMM or direct GEMM, depending on whether the im2col step can be fused.
- 5x5 and 7x7 filters: FFT-based algorithms become competitive. The larger the filter, the more FFT’s O(N log N) cost is amortized.
- Depthwise convolutions: Neither im2col nor Winograd applies well. Specialized direct kernels that exploit the per-channel independence are faster.
- FP16 with tensor cores: IMPLICIT_GEMM using tensor core HMMA instructions. Winograd has numerical issues with FP16 due to the transform’s sensitivity to precision.
Always benchmark with cudnnFindConvolutionForwardAlgorithm or PyTorch’s torch.backends.cudnn.benchmark = True rather than assuming which algorithm is best. Hardware generation, batch size, and channel counts all shift the optimal choice.
In practice
- Use cuDNN. Writing custom convolution kernels is an educational exercise. In production, cuDNN (via PyTorch, TensorFlow, or TensorRT) will outperform hand-written kernels by 3x to 10x. The library has years of hardware-specific tuning that is impractical to replicate.
- Enable cuDNN auto-tuning. Set
torch.backends.cudnn.benchmark = Truein PyTorch. The first iteration will be slow (cuDNN benchmarks all algorithms), but subsequent iterations use the fastest one. This is essential for training loops where the same shapes repeat. - Fuse operations. Conv + batch norm + ReLU fusion eliminates two intermediate global memory round-trips. TensorRT does this automatically. In training, frameworks are increasingly supporting fusion through torch.compile and similar JIT systems.
- Watch memory. im2col can expand memory usage by 9x for 3x3 filters. For large inputs or tight GPU memory, IMPLICIT_GEMM avoids the explicit column matrix. cuDNN selects this automatically when workspace memory is limited.
- Profile before optimizing. Use Nsight Systems to identify whether your bottleneck is compute (high SM utilization, long kernel times) or overhead (many short kernels, low GPU utilization between launches). The fix is different: compute-bound problems need better algorithms; overhead-bound problems need fusion and CUDA graphs.
- Consider data layout. NCHW is the default in PyTorch. NHWC is faster for tensor core operations because it aligns channel data contiguously for the matrix units. cuDNN supports both, and TensorRT often converts to NHWC internally.
Common mistakes
| Mistake | Symptom | Fix |
|---|---|---|
| Not padding im2col for alignment | cuBLAS GEMM runs 30% slower | Pad column count to multiples of 8 (FP32) or 16 (FP16) |
| Forgetting to set cuDNN benchmark mode | Always using IMPLICIT_GEMM | Add torch.backends.cudnn.benchmark = True before training |
| Allocating im2col workspace per call | Excessive cudaMalloc overhead | Pre-allocate workspace once, reuse across layers |
| Using FP16 Winograd without testing | Numerical divergence in training | Stick to IMPLICIT_GEMM for FP16 or validate loss curves |
| Profiling with batch size 1 | Misleading: kernel launch overhead dominates | Profile with realistic batch sizes (16, 32, 64) |
What comes next
Convolution layers account for the majority of compute in modern vision networks, but the techniques here apply broadly. The im2col transformation shows how restructuring data for a well-optimized primitive (GEMM) often beats writing a specialized kernel. Shared memory tiling reduces redundant global loads for any stencil-like computation pattern. cuDNN’s algorithm selection demonstrates that the best implementation depends on the specific problem dimensions.
The next article, Case study: reduction and histogram, applies similar principles to two more fundamental GPU patterns: reducing large arrays to scalar summaries (loss computation, gradient aggregation) and building histograms from unstructured data (batch statistics, quantization calibration).