Search…

Case study: matrix multiplication from naive to cuBLAS speed

In this series (28 parts)
  1. GPUs: from pixels to parallel supercomputers
  2. Your first CUDA program: kernels, threads, and grids
  3. Thread hierarchy in CUDA: threads, blocks, warps, and grids
  4. Warps and warp divergence: the hidden performance trap
  5. CUDA memory hierarchy: where your data lives matters
  6. Memory coalescing: the most important optimization you will learn
  7. Shared memory and tiling: the key to fast matrix operations
  8. Debugging and profiling CUDA programs
  9. Device functions, host functions, and CUDA function qualifiers
  10. Synchronization and atomic operations in CUDA
  11. Parallel prefix sum and reduction: the core parallel primitives
  12. Concurrent data structures on the GPU
  13. CUDA streams and asynchronous execution
  14. CUDA events and fine-grained synchronization
  15. Dynamic parallelism: kernels launching kernels
  16. Unified virtual memory: one pointer for CPU and GPU
  17. Multi-GPU programming and peer access
  18. Memory allocation patterns and multi-dimensional arrays in CUDA
  19. Texture and constant memory: specialized caches
  20. Occupancy, register pressure, and performance tuning
  21. Case study: matrix multiplication from naive to cuBLAS speed
  22. Case study: implementing a convolution layer in CUDA
  23. Case study: reduction and histogram at scale
  24. Heterogeneous computing: CPU and GPU working together
  25. Advanced memory patterns: pinned memory, zero-copy, and more
  26. Advanced stream patterns and concurrent kernel execution
  27. Performance case studies and optimization patterns
  28. Where to go from here: CUDA ecosystem and next steps

Prerequisites

This article assumes you have read the following:

You should be comfortable writing tiled CUDA kernels and reasoning about arithmetic intensity. This article puts all those concepts together in a single, progressively optimized kernel.

Why matrix multiplication is the GPU benchmark

Matrix multiplication (GEMM) is the single most important kernel in GPU computing. Every dense layer in a neural network, every attention head in a transformer, every linear solve in scientific simulation reduces to matrix multiplication. If your GEMM is slow, everything built on top of it is slow.

GEMM is also uniquely suited for benchmarking because it has predictable arithmetic: multiplying two N x N matrices requires exactly 2N³ floating-point operations (N³ multiply-add pairs, each contributing one multiply and one add). This makes it easy to compute achieved GFLOPS and compare against hardware peak.

The gap between a naive implementation and an expert-tuned one is enormous. On an A100, a naive kernel achieves roughly 300 GFLOPS. cuBLAS achieves over 18,000 GFLOPS. That is a 60x difference from the same hardware running the same mathematical operation. This article walks through each optimization step that closes that gap.

The optimization progression

graph TD
  A["Naive Global Memory
approx 300 GFLOPS
1x baseline"] --> B["Tiled Shared Memory
approx 2,500 GFLOPS
approx 8x improvement"]
  B --> C["Vectorized float4 Loads
approx 5,000 GFLOPS
approx 17x improvement"]
  C --> D["Double Buffering
approx 7,500 GFLOPS
approx 25x improvement"]
  D --> E["Tensor Cores (WMMA)
approx 14,000 GFLOPS
approx 47x improvement"]
  E --> F["cuBLAS SGEMM
approx 18,000 GFLOPS
approx 60x improvement"]
  style A fill:#ff6b6b,color:#fff
  style B fill:#ffa94d,color:#fff
  style C fill:#ffd43b,color:#000
  style D fill:#69db7c,color:#000
  style E fill:#4dabf7,color:#fff
  style F fill:#845ef7,color:#fff

Each step targets a different bottleneck. The naive kernel is bandwidth-limited by redundant global memory reads. Tiling reduces traffic. Vectorized loads use the memory bus more efficiently. Double buffering overlaps memory access with computation. Tensor cores use specialized hardware for matrix math. cuBLAS combines all of these with architecture-specific tuning.

Step 1: naive global memory kernel

The simplest correct implementation. Each thread computes one element of the output matrix by reading an entire row of A and an entire column of B from global memory.

__global__ void matmul_naive(const float* A, const float* B, float* C, int N) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < N && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < N; k++) {
            sum += A[row * N + k] * B[k * N + col];
        }
        C[row * N + col] = sum;
    }
}

// Launch: 32x32 thread blocks
// dim3 block(32, 32);
// dim3 grid((N + 31) / 32, (N + 31) / 32);
// matmul_naive<<<grid, block>>>(d_A, d_B, d_C, N);

Why it is slow. For an N x N matmul, each element of A is read N times across all threads computing a row of C. Each element of B is read N times across all threads computing a column of C. Total global memory reads: 2N³ elements. The arithmetic intensity is 2N³ FLOPs / (2N³ * 4 bytes) = 0.25 FLOPs/byte. On an A100 with 2,039 GB/s bandwidth, this caps throughput at roughly 500 GFLOPS. In practice, cache effects and imperfect coalescing on B bring this down to around 300 GFLOPS.

Example: operation count for N = 1024

Total FLOPs: 2 * 1024³ = 2,147,483,648 (roughly 2.15 GFLOPS worth of work).

At 300 GFLOPS (naive kernel): 2.15 / 300 = 7.2 ms.

At 18,000 GFLOPS (cuBLAS): 2.15 / 18,000 = 0.12 ms.

That is a 60x speedup from algorithmic and hardware optimization alone.

Step 2: tiled shared memory kernel

The core insight: threads in the same block need overlapping data. Load it once into shared memory and reuse it.

#define TILE 32

__global__ void matmul_tiled(const float* A, const float* B, float* C, int N) {
    __shared__ float sA[TILE][TILE];
    __shared__ float sB[TILE][TILE];

    int row = blockIdx.y * TILE + threadIdx.y;
    int col = blockIdx.x * TILE + threadIdx.x;
    float sum = 0.0f;

    for (int t = 0; t < N / TILE; t++) {
        sA[threadIdx.y][threadIdx.x] = A[row * N + t * TILE + threadIdx.x];
        sB[threadIdx.y][threadIdx.x] = B[(t * TILE + threadIdx.y) * N + col];
        __syncthreads();

        for (int k = 0; k < TILE; k++) {
            sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];
        }
        __syncthreads();
    }

    C[row * N + col] = sum;
}

Example: memory reuse with tiling

In the naive kernel, each element of A is read from global memory N times total (once for each column of C computed by threads in the same row). With tiling, each element is loaded into shared memory once per tile step, and each tile covers T columns of output. So each element is read N/T times from global memory.

For N = 1024 and T = 32: each element is read 1024 / 32 = 32 times from global memory, compared to 1024 times in the naive version. That is a 32x reduction in global memory traffic.

The arithmetic intensity increases from 0.25 FLOPs/byte to T/4 = 8 FLOPs/byte. On an A100, the compute can now reach roughly 2,500 GFLOPS before hitting other bottlenecks.

Step 3: vectorized loads with float4

The tiled kernel loads one float (4 bytes) per thread per memory transaction. The memory bus can transfer 128 bits (16 bytes) at once. Using float4 loads, each thread fetches four floats in a single transaction, cutting the number of memory instructions by 4x.

#define TILE 32
#define FLOAT4_TILE (TILE / 4)

__global__ void matmul_vectorized(const float* A, const float* B, float* C, int N) {
    __shared__ float sA[TILE][TILE];
    __shared__ float sB[TILE][TILE];

    int row = blockIdx.y * TILE + threadIdx.y;
    int col = blockIdx.x * TILE + threadIdx.x;

    // Thread coordinates for vectorized loading
    int ty = threadIdx.y;
    int tx = threadIdx.x;

    float sum = 0.0f;

    for (int t = 0; t < N / TILE; t++) {
        // Vectorized load: each thread loads 4 consecutive floats into shared memory
        // We use a linear thread ID to cover the entire tile
        int linearIdx = ty * blockDim.x + tx;
        int loadRow = linearIdx / (TILE / 4);
        int loadCol = (linearIdx % (TILE / 4)) * 4;

        // Load A tile: 4 floats at a time
        if (loadRow < TILE && (t * TILE + loadCol + 3) < N) {
            float4 a4 = *reinterpret_cast<const float4*>(
                &A[(blockIdx.y * TILE + loadRow) * N + t * TILE + loadCol]);
            sA[loadRow][loadCol]     = a4.x;
            sA[loadRow][loadCol + 1] = a4.y;
            sA[loadRow][loadCol + 2] = a4.z;
            sA[loadRow][loadCol + 3] = a4.w;
        }

        // Load B tile: 4 floats at a time
        if (loadRow < TILE && (blockIdx.x * TILE + loadCol + 3) < N) {
            float4 b4 = *reinterpret_cast<const float4*>(
                &B[(t * TILE + loadRow) * N + blockIdx.x * TILE + loadCol]);
            sB[loadRow][loadCol]     = b4.x;
            sB[loadRow][loadCol + 1] = b4.y;
            sB[loadRow][loadCol + 2] = b4.z;
            sB[loadRow][loadCol + 3] = b4.w;
        }

        __syncthreads();

        for (int k = 0; k < TILE; k++) {
            sum += sA[ty][k] * sB[k][tx];
        }
        __syncthreads();
    }

    C[row * N + col] = sum;
}

The float4 load issues a single 128-bit memory instruction instead of four 32-bit instructions. This reduces instruction overhead and makes better use of the memory controller. The kernel also benefits from improved instruction-level parallelism: while the hardware processes one load, independent arithmetic instructions can execute.

On an A100, the vectorized tiled kernel reaches roughly 5,000 GFLOPS for large matrices. The remaining gap to cuBLAS comes from compute-side inefficiencies: the kernel still stalls while waiting for shared memory loads to complete before computing.

Step 4: double buffering

Double buffering overlaps data loading with computation. While threads compute on the current tile from one shared memory buffer, they simultaneously load the next tile into a second buffer. This hides the latency of global memory loads behind useful arithmetic.

#define TILE 32

__global__ void matmul_double_buffered(const float* A, const float* B, float* C, int N) {
    __shared__ float sA[2][TILE][TILE];
    __shared__ float sB[2][TILE][TILE];

    int row = blockIdx.y * TILE + threadIdx.y;
    int col = blockIdx.x * TILE + threadIdx.x;
    int ty = threadIdx.y;
    int tx = threadIdx.x;
    float sum = 0.0f;

    int numTiles = N / TILE;
    int buf = 0;

    // Prefetch first tile into buffer 0
    sA[0][ty][tx] = A[row * N + tx];
    sB[0][ty][tx] = B[ty * N + col];
    __syncthreads();

    for (int t = 0; t < numTiles; t++) {
        int next = 1 - buf;

        // Load next tile into alternate buffer (if not last tile)
        if (t + 1 < numTiles) {
            int nextTile = t + 1;
            sA[next][ty][tx] = A[row * N + nextTile * TILE + tx];
            sB[next][ty][tx] = B[(nextTile * TILE + ty) * N + col];
        }

        // Compute on current buffer
        for (int k = 0; k < TILE; k++) {
            sum += sA[buf][ty][k] * sB[buf][k][tx];
        }

        __syncthreads();
        buf = next;
    }

    C[row * N + col] = sum;
}

The key insight is that __syncthreads() only needs to ensure the next buffer is fully loaded before we switch to it. While threads execute the multiply-accumulate loop on the current buffer, the hardware pipelines the global memory loads for the next tile. This overlap is critical for hiding the 200-400 cycle global memory latency.

Double buffering costs 2x shared memory (two copies of each tile). For 32 x 32 tiles with FP32, that is 2 * 2 * 32 * 32 * 4 = 16 KB, well within the shared memory budget of modern GPUs. The performance improvement is substantial: roughly 7,500 GFLOPS on an A100, because the kernel now keeps the compute units busy during what was previously idle time.

Step 5: tensor cores

Starting with Volta (SM 7.0), NVIDIA GPUs include tensor cores: specialized hardware units that compute small matrix multiply-accumulate operations (D = A * B + C) in a single instruction. On an A100, each tensor core computes a 4x4 FP16 matrix multiply per cycle, and the WMMA (Warp Matrix Multiply-Accumulate) API exposes 16x16x16 tile operations at the warp level.

#include <mma.h>
using namespace nvcuda;

// Each warp computes a 16x16 output tile
// Accumulation in FP32 for numerical stability
__global__ void matmul_tensor_core(const half* A, const half* B, float* C, int N) {
    // Warp-level tile dimensions for WMMA
    const int WMMA_M = 16;
    const int WMMA_N = 16;
    const int WMMA_K = 16;

    int warpM = (blockIdx.y * blockDim.y + threadIdx.y) / 32 * WMMA_M;
    int warpN = (blockIdx.x * blockDim.x + threadIdx.x) / 32 * WMMA_N;

    // Declare fragments
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> b_frag;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

    // Initialize accumulator to zero
    wmma::fill_fragment(c_frag, 0.0f);

    // Loop over K dimension in steps of WMMA_K
    for (int k = 0; k < N; k += WMMA_K) {
        // Load matrix fragments from global memory
        wmma::load_matrix_sync(a_frag, A + warpM * N + k, N);
        wmma::load_matrix_sync(b_frag, B + k * N + warpN, N);

        // Perform the matrix multiply-accumulate
        wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
    }

    // Store the result
    wmma::store_matrix_sync(C + warpM * N + warpN, c_frag, N, wmma::mem_row_major);
}

Tensor cores achieve their throughput by operating on matrix fragments rather than individual elements. A single wmma::mma_sync call performs 16 * 16 * 16 * 2 = 8,192 floating-point operations. On an A100 with 432 tensor cores running at 1.41 GHz, the theoretical FP16 peak with FP32 accumulation is 312 TFLOPS.

The WMMA kernel above is simplified. A production kernel would tile shared memory loads, use double buffering, and carefully manage fragment layout to minimize register pressure. Even so, this basic tensor core kernel reaches roughly 14,000 GFLOPS because the tensor core hardware is fundamentally more efficient at matrix math than standard FP32 CUDA cores.

⚠ Tensor cores require FP16 inputs (or BF16/TF32/INT8 depending on architecture). If your application requires FP32 precision throughout, you can use TF32 on Ampere, which uses tensor cores with 19-bit mantissa precision and no code changes to cuBLAS calls.

Step 6: cuBLAS

cuBLAS is NVIDIA’s hand-tuned BLAS library. It selects the optimal kernel for your specific GPU, matrix dimensions, and data types at runtime. Writing code that matches cuBLAS performance by hand is extremely difficult and rarely worth the effort.

#include <cublas_v2.h>

void matmul_cublas(const float* d_A, const float* d_B, float* d_C,
                   int M, int N, int K) {
    cublasHandle_t handle;
    cublasCreate(&handle);

    float alpha = 1.0f;
    float beta = 0.0f;

    // cuBLAS uses column-major order
    // To compute C = A * B in row-major:
    // call cublasSgemm with transposed arguments
    cublasSgemm(handle,
                CUBLAS_OP_N,    // op(B)
                CUBLAS_OP_N,    // op(A)
                N, M, K,       // dimensions (column-major convention)
                &alpha,
                d_B, N,        // B, leading dimension
                d_A, K,        // A, leading dimension
                &beta,
                d_C, N);       // C, leading dimension

    cublasDestroy(handle);
}

A common pitfall: cuBLAS uses column-major layout (Fortran convention), while C/C++ arrays are row-major. The call above works correctly for row-major matrices by swapping the argument order, effectively computing Cᵀ = Bᵀ * Aᵀ.

cuBLAS achieves roughly 18,000 GFLOPS on an A100 for large FP32 matrices. Internally, it uses all the techniques described above plus architecture-specific tricks: software pipelining, register-level tiling, warp-specialization (separate warps for loading and computing), and autotuning across hundreds of kernel variants.

Performance comparison

Key observations from the chart:

  • Small matrices underperform everywhere. At N = 256, even cuBLAS struggles because there is not enough parallelism to saturate the GPU. Kernel launch overhead becomes significant.
  • Naive performance is flat. Once the matrix is large enough to fill the GPU, the naive kernel is purely bandwidth-bound and cannot improve with larger problem sizes.
  • cuBLAS scales best. The library’s autotuner selects progressively better kernels as the matrix grows. At N = 4096 and above, it reaches over 90% of FP32 peak.
  • The gap between each optimization step is multiplicative. Going from naive to tiled is roughly 8x. Vectorization adds another 2x. Double buffering adds 1.5x. Tensor cores (not shown separately here as FP32) add further gains through cuBLAS’s internal use of TF32.

Roofline positioning

The roofline model explains why each kernel performs the way it does:

KernelArithmetic Intensity (FLOPs/byte)BottleneckAchieved GFLOPS (N=4096)
Naive0.25Memory bandwidth~305
Tiled (T=32)8.0Instruction overhead~2,500
Vectorized8.0Instruction throughput~5,000
Double buffered8.0Compute/memory overlap~7,500
cuBLAS8.0+Approaches compute peak~18,000

All tiled variants have the same arithmetic intensity (the data reuse is identical), but they differ in how efficiently they utilize the compute units. The naive kernel sits on the memory-bound slope of the roofline. All tiled variants sit in the compute-bound region, but with varying amounts of pipeline overhead.

Python: NumPy vs CuPy benchmark

For many workloads, you do not need to write CUDA C++ at all. CuPy provides a NumPy-compatible interface that calls cuBLAS under the hood.

import numpy as np
import cupy as cp
import time

def benchmark_matmul(N, num_runs=100):
    # NumPy (CPU)
    a_cpu = np.random.randn(N, N).astype(np.float32)
    b_cpu = np.random.randn(N, N).astype(np.float32)

    # Warmup
    _ = a_cpu @ b_cpu

    start = time.perf_counter()
    for _ in range(num_runs):
        _ = a_cpu @ b_cpu
    cpu_time = (time.perf_counter() - start) / num_runs

    # CuPy (GPU, calls cuBLAS internally)
    a_gpu = cp.asarray(a_cpu)
    b_gpu = cp.asarray(b_cpu)

    # Warmup and sync
    _ = a_gpu @ b_gpu
    cp.cuda.Stream.null.synchronize()

    start = time.perf_counter()
    for _ in range(num_runs):
        _ = a_gpu @ b_gpu
    cp.cuda.Stream.null.synchronize()
    gpu_time = (time.perf_counter() - start) / num_runs

    flops = 2 * N**3
    cpu_gflops = flops / cpu_time / 1e9
    gpu_gflops = flops / gpu_time / 1e9

    print(f"N={N:5d} | CPU: {cpu_gflops:8.1f} GFLOPS ({cpu_time*1000:7.2f} ms) "
          f"| GPU: {gpu_gflops:8.1f} GFLOPS ({gpu_time*1000:7.2f} ms) "
          f"| Speedup: {cpu_time/gpu_time:6.1f}x")

for size in [512, 1024, 2048, 4096]:
    benchmark_matmul(size)

Typical output on an A100 with a 32-core Xeon:

N=  512 | CPU:    120.5 GFLOPS (   2.23 ms) | GPU:   5200.0 GFLOPS (  0.052 ms) | Speedup:   42.9x
N= 1024 | CPU:    135.2 GFLOPS (  15.88 ms) | GPU:  15800.0 GFLOPS (  0.136 ms) | Speedup:  116.8x
N= 2048 | CPU:    140.1 GFLOPS ( 122.54 ms) | GPU:  17500.0 GFLOPS (  0.981 ms) | Speedup:  124.9x
N= 4096 | CPU:    138.7 GFLOPS ( 990.40 ms) | GPU:  18100.0 GFLOPS (  7.590 ms) | Speedup:  130.5x

The GPU advantage grows with matrix size because larger matrices expose more parallelism and amortize kernel launch overhead. For N = 4096, the GPU is over 130x faster than a multi-threaded CPU BLAS.

In practice

Use cuBLAS for production GEMM. Writing custom matrix multiply kernels is an educational exercise. For production code, cuBLAS (or cuBLASLt for more control) is almost always the right choice. It handles edge cases, non-square matrices, batched operations, and mixed precision automatically.

Profile before optimizing. Use nsys profile or ncu to determine whether your kernel is compute-bound or memory-bound before applying optimizations. The roofline model tells you which optimizations will help.

Consider TF32 on Ampere and later. cuBLAS uses TF32 by default for cublasSgemm on Ampere GPUs. This uses tensor cores with reduced mantissa precision (10 bits instead of 23). For deep learning training, this is usually acceptable. For scientific computing, set CUBLAS_MATH_MODE to CUBLAS_DEFAULT_MATH to force full FP32.

Batch small matrix multiplies. If you need to multiply many small matrices (common in batched attention), use cublasSgemmBatched or cublasSgemmStridedBatched. These launch a single kernel that processes all matrices in parallel, avoiding the overhead of many small kernel launches.

Shared memory bank conflicts still matter. Even with vectorized loads and double buffering, bank conflicts in shared memory can reduce throughput by 2-8x. Pad shared memory arrays (__shared__ float sA[TILE][TILE + 1]) to avoid conflicts when threads access the same bank.

Matrix alignment affects vectorized loads. float4 loads require 16-byte alignment. If your matrix dimensions are not multiples of 4, you need boundary checks or padding. cuBLAS handles this internally.

Do not assume hand-written kernels will beat cuBLAS. NVIDIA employs teams of engineers who optimize cuBLAS for each GPU architecture. Matching their performance requires months of work and architecture-specific tuning. Your time is better spent on higher-level algorithmic improvements.

What comes next

The next article applies the optimization techniques from this case study to a real workload: Case study: CNN forward pass. That article implements convolution as implicit GEMM, builds a complete forward pass for a small convolutional neural network, and profiles it end-to-end with Nsight Systems.

Start typing to search across all content
navigate Enter open Esc close