Search…

Shared memory and tiling: the key to fast matrix operations

In this series (28 parts)
  1. GPUs: from pixels to parallel supercomputers
  2. Your first CUDA program: kernels, threads, and grids
  3. Thread hierarchy in CUDA: threads, blocks, warps, and grids
  4. Warps and warp divergence: the hidden performance trap
  5. CUDA memory hierarchy: where your data lives matters
  6. Memory coalescing: the most important optimization you will learn
  7. Shared memory and tiling: the key to fast matrix operations
  8. Debugging and profiling CUDA programs
  9. Device functions, host functions, and CUDA function qualifiers
  10. Synchronization and atomic operations in CUDA
  11. Parallel prefix sum and reduction: the core parallel primitives
  12. Concurrent data structures on the GPU
  13. CUDA streams and asynchronous execution
  14. CUDA events and fine-grained synchronization
  15. Dynamic parallelism: kernels launching kernels
  16. Unified virtual memory: one pointer for CPU and GPU
  17. Multi-GPU programming and peer access
  18. Memory allocation patterns and multi-dimensional arrays in CUDA
  19. Texture and constant memory: specialized caches
  20. Occupancy, register pressure, and performance tuning
  21. Case study: matrix multiplication from naive to cuBLAS speed
  22. Case study: implementing a convolution layer in CUDA
  23. Case study: reduction and histogram at scale
  24. Heterogeneous computing: CPU and GPU working together
  25. Advanced memory patterns: pinned memory, zero-copy, and more
  26. Advanced stream patterns and concurrent kernel execution
  27. Performance case studies and optimization patterns
  28. Where to go from here: CUDA ecosystem and next steps

Prerequisites

This article assumes you have read the following:

You should be comfortable writing basic CUDA kernels and launching them with a grid/block configuration.

The global memory bottleneck

Matrix multiplication is the most important operation in scientific computing and deep learning. Multiplying two N x N matrices requires 2N³ floating-point operations (N³ multiply-adds). A naive implementation reads each element from global memory far more often than necessary.

Consider C = A * B where all matrices are N x N. To compute one element C[row][col], the kernel reads an entire row of A and an entire column of B. That is 2N reads from global memory for 2N FLOPs. The arithmetic intensity is 2N / (2N * 4 bytes) = 0.25 FLOPs/byte for FP32. On an A100 with 2039 GB/s bandwidth and 19.5 TFLOPS peak, the memory system can only feed about 500 GFLOPS of useful compute. That is 2.5% utilization.

The problem is not the arithmetic. The problem is that every thread re-reads data that other threads in the same block have already loaded. Shared memory fixes this by letting threads cooperate: load a chunk once, reuse it many times.

What shared memory gives you

Shared memory is an on-chip SRAM that sits physically next to the compute units. It provides roughly 19 TB/s of bandwidth on an A100 (compared to 2 TB/s for global memory). The tradeoff: it is small (up to 164 KB per SM on Ampere, configurable) and scoped to a single thread block.

Key properties:

  • Low latency: roughly 20-30 cycles versus 200-400 cycles for a global memory miss.
  • High bandwidth: each SM has 32 banks that can each serve one 4-byte word per cycle.
  • Block-scoped lifetime: data in shared memory is visible to all threads in the same block and is destroyed when the block finishes.
  • Programmer-managed: you explicitly load data into shared memory and synchronize threads.

The programming model is straightforward. Declare shared memory with __shared__, load from global memory, call __syncthreads(), then read from shared memory.

Tiling: the core idea

Tiling splits the computation into small tiles that fit in shared memory. Instead of each thread independently reading a full row and column from global memory, threads in a block cooperatively load a tile of A and a tile of B into shared memory, compute a partial result, then move to the next tile.

graph LR
  subgraph "Global Memory"
      A["Matrix A (N x N)"]
      B["Matrix B (N x N)"]
  end

  subgraph "Step 1: Load tile 0"
      SA1["Shared A tile
(TILE x TILE)"]
      SB1["Shared B tile
(TILE x TILE)"]
  end

  subgraph "Step 2: Load tile 1"
      SA2["Shared A tile
(TILE x TILE)"]
      SB2["Shared B tile
(TILE x TILE)"]
  end

  subgraph "Registers"
      C["Partial sum for C
accumulates across tiles"]
  end

  A -->|"block cooperatively loads"| SA1
  B -->|"block cooperatively loads"| SB1
  SA1 -->|"compute partial product"| C
  SB1 -->|"compute partial product"| C
  A -->|"next tile"| SA2
  B -->|"next tile"| SB2
  SA2 -->|"accumulate"| C
  SB2 -->|"accumulate"| C

For an N x N matrix with tile size T, each element of A and B is loaded from global memory N/T times instead of N times. That is a T-fold reduction in global memory traffic. With T = 32, the arithmetic intensity jumps from 0.25 to 8 FLOPs/byte, enough to reach several TFLOPS on modern hardware.

The algorithm for one thread block computing a TILE x TILE sub-matrix of C:

  1. For each tile index t from 0 to N/TILE - 1:
    • Each thread loads one element of A[row][t * TILE + tx] into shared memory.
    • Each thread loads one element of B[t * TILE + ty][col] into shared memory.
    • Call __syncthreads() so all loads complete.
    • Each thread loops over k from 0 to TILE - 1, accumulating sharedA[ty][k] * sharedB[k][tx].
    • Call __syncthreads() before the next tile overwrites shared memory.
  2. Write the accumulated result to C[row][col] in global memory.

Example: tiled matmul by hand

Consider A (4 x 4) multiplied by B (4 x 4) with tile size 2 x 2. The output C is also 4 x 4. Focus on the thread block responsible for computing C[0:2][0:2] (the top-left 2 x 2 sub-matrix).

A = | a00  a01  a02  a03 |    B = | b00  b01  b02  b03 |
    | a10  a11  a12  a13 |        | b10  b11  b12  b13 |
    | a20  a21  a22  a23 |        | b20  b21  b22  b23 |
    | a30  a31  a32  a33 |        | b30  b31  b32  b33 |

Step 1 (tile t=0):

Load into shared memory:

sharedA = | a00  a01 |    sharedB = | b00  b01 |
          | a10  a11 |              | b10  b11 |

Partial sum for C[0][0] after step 1:

C[0][0] += a00 * b00 + a01 * b10

This uses k=0 and k=1 from the inner loop over the tile.

Step 2 (tile t=1):

Load next tile into shared memory:

sharedA = | a02  a03 |    sharedB = | b20  b21 |
          | a12  a13 |              | b30  b31 |

Partial sum for C[0][0] after step 2:

C[0][0] += a02 * b20 + a03 * b30

Final result: C[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30, which matches the standard matrix multiply definition.

Each element of A and B was loaded from global memory exactly once per tile, shared across all threads in the block. Without tiling, thread (0,0) and thread (0,1) would both independently read the same row of A from global memory.

CUDA C++: naive matmul

The naive kernel has each thread compute one element of C by reading a full row and column from global memory:

__global__ void matmulNaive(const float* A, const float* B, float* C, int N) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < N && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < N; k++) {
            sum += A[row * N + k] * B[k * N + col];
        }
        C[row * N + col] = sum;
    }
}

// Launch
dim3 block(16, 16);
dim3 grid((N + 15) / 16, (N + 15) / 16);
matmulNaive<<<grid, block>>>(d_A, d_B, d_C, N);

This kernel achieves roughly 200-500 GFLOPS on an A100 for large N. The bottleneck is global memory bandwidth, not compute.

CUDA C++: tiled matmul with shared memory

The tiled version loads TILE x TILE blocks into shared memory before computing:

#define TILE 32

__global__ void matmulTiled(const float* A, const float* B, float* C, int N) {
    __shared__ float sA[TILE][TILE];
    __shared__ float sB[TILE][TILE];

    int row = blockIdx.y * TILE + threadIdx.y;
    int col = blockIdx.x * TILE + threadIdx.x;
    float sum = 0.0f;

    for (int t = 0; t < (N + TILE - 1) / TILE; t++) {
        // Cooperative load into shared memory
        int aCol = t * TILE + threadIdx.x;
        int bRow = t * TILE + threadIdx.y;

        sA[threadIdx.y][threadIdx.x] = (row < N && aCol < N)
            ? A[row * N + aCol] : 0.0f;
        sB[threadIdx.y][threadIdx.x] = (bRow < N && col < N)
            ? B[bRow * N + col] : 0.0f;

        __syncthreads();

        // Compute partial product from this tile
        for (int k = 0; k < TILE; k++) {
            sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];
        }

        __syncthreads();
    }

    if (row < N && col < N) {
        C[row * N + col] = sum;
    }
}

// Launch with 32x32 threads per block
dim3 block(TILE, TILE);
dim3 grid((N + TILE - 1) / TILE, (N + TILE - 1) / TILE);
matmulTiled<<<grid, block>>>(d_A, d_B, d_C, N);

Two __syncthreads() calls are critical. The first ensures all threads have finished writing to shared memory before any thread reads from it. The second ensures all threads have finished reading before the next iteration overwrites the tiles.

Benchmarking: measuring GFLOPS

To compute GFLOPS for matrix multiply:

// FLOP count for C = A * B where all are N x N
double flops = 2.0 * N * N * N;

cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);

cudaEventRecord(start);
matmulTiled<<<grid, block>>>(d_A, d_B, d_C, N);
cudaEventRecord(stop);
cudaEventSynchronize(stop);

float ms = 0;
cudaEventElapsedTime(&ms, start, stop);
double gflops = (flops / (ms / 1000.0)) / 1e9;
printf("N=%d: %.1f GFLOPS (%.2f ms)\n", N, gflops, ms);

Typical results on an A100 (FP32):

Matrix size (N)Naive GFLOPSTiled GFLOPScuBLAS GFLOPS
2561806201,200
5122501,8004,500
10243203,2009,800
20483804,10014,200
40964104,50017,800

The tiled kernel is 8-12x faster than naive. cuBLAS is another 3-4x faster because it uses register tiling, vectorized loads, double buffering, and warp-level primitives that go beyond the scope of this article.

Python (CuPy): tiled kernel vs cuBLAS

CuPy lets you write custom CUDA kernels in Python and compare directly against cuBLAS:

import cupy as cp
import numpy as np
import time

tiled_kernel = cp.RawKernel(r'''
#define TILE 32
extern "C" __global__
void matmulTiled(const float* A, const float* B, float* C, int N) {
    __shared__ float sA[TILE][TILE];
    __shared__ float sB[TILE][TILE];

    int row = blockIdx.y * TILE + threadIdx.y;
    int col = blockIdx.x * TILE + threadIdx.x;
    float sum = 0.0f;

    for (int t = 0; t < (N + TILE - 1) / TILE; t++) {
        int aCol = t * TILE + threadIdx.x;
        int bRow = t * TILE + threadIdx.y;
        sA[threadIdx.y][threadIdx.x] = (row < N && aCol < N) ? A[row * N + aCol] : 0.0f;
        sB[threadIdx.y][threadIdx.x] = (bRow < N && col < N) ? B[bRow * N + col] : 0.0f;
        __syncthreads();
        for (int k = 0; k < TILE; k++)
            sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];
        __syncthreads();
    }
    if (row < N && col < N) C[row * N + col] = sum;
}
''', 'matmulTiled')

def bench_tiled(N, repeats=10):
    A = cp.random.randn(N, N, dtype=cp.float32)
    B = cp.random.randn(N, N, dtype=cp.float32)
    C = cp.zeros((N, N), dtype=cp.float32)
    block = (32, 32)
    grid = ((N + 31) // 32, (N + 31) // 32)
    # Warmup
    tiled_kernel(grid, block, (A, B, C, np.int32(N)))
    cp.cuda.Device().synchronize()
    start = time.perf_counter()
    for _ in range(repeats):
        tiled_kernel(grid, block, (A, B, C, np.int32(N)))
    cp.cuda.Device().synchronize()
    elapsed = (time.perf_counter() - start) / repeats
    gflops = (2.0 * N**3) / elapsed / 1e9
    return gflops

def bench_cublas(N, repeats=10):
    A = cp.random.randn(N, N, dtype=cp.float32)
    B = cp.random.randn(N, N, dtype=cp.float32)
    # Warmup
    cp.matmul(A, B)
    cp.cuda.Device().synchronize()
    start = time.perf_counter()
    for _ in range(repeats):
        cp.matmul(A, B)
    cp.cuda.Device().synchronize()
    elapsed = (time.perf_counter() - start) / repeats
    gflops = (2.0 * N**3) / elapsed / 1e9
    return gflops

sizes = [256, 512, 1024, 2048, 4096]
for N in sizes:
    tg = bench_tiled(N)
    cg = bench_cublas(N)
    print(f"N={N:5d}  tiled={tg:8.1f} GFLOPS  cuBLAS={cg:8.1f} GFLOPS")

CuPy’s cp.matmul calls cuBLAS internally, giving you the highly optimized vendor implementation as a baseline.

Performance comparison

The curve tells a clear story. The naive kernel flatlines because it is memory-bound. The tiled kernel scales further because it has reduced global memory traffic by a factor of TILE. cuBLAS scales to near peak because it applies many more optimizations on top of tiling.

Shared memory bank conflicts

Shared memory on NVIDIA GPUs is organized into 32 banks, each 4 bytes wide. In each clock cycle, each bank can serve one address. When two or more threads in the same warp access different addresses that map to the same bank, the accesses are serialized. This is a bank conflict.

The bank for address addr is:

bank = (addr / 4) % 32

For a float array s[i], element s[i] lives in bank i % 32.

graph TD
  subgraph "32 Shared Memory Banks"
      B0["Bank 0
s[0], s[32], s[64]..."]
      B1["Bank 1
s[1], s[33], s[65]..."]
      B2["Bank 2
s[2], s[34], s[66]..."]
      B3["Bank 3
..."]
      BD["..."]
      B30["Bank 30
s[30], s[62], s[94]..."]
      B31["Bank 31
s[31], s[63], s[95]..."]
  end

  subgraph "Warp (32 threads)"
      T0["Thread 0"]
      T1["Thread 1"]
      T2["Thread 2"]
      TD2["..."]
      T31["Thread 31"]
  end

  T0 -->|"s[0]"| B0
  T1 -->|"s[1]"| B1
  T2 -->|"s[2]"| B2
  T31 -->|"s[31]"| B31

  style B0 fill:#00CC96,stroke:#333
  style B1 fill:#00CC96,stroke:#333
  style B2 fill:#00CC96,stroke:#333
  style B30 fill:#00CC96,stroke:#333
  style B31 fill:#00CC96,stroke:#333

When stride-1 access maps each thread to a distinct bank: zero conflicts, one transaction. When the stride equals a multiple of 32, every thread hits the same bank: 32-way conflict, 32 serialized transactions.

Bank conflict patterns

Access patternStrideBanks hitConflict degreeTransactionsFix
s[threadIdx.x]132 distinctNone1✓ Already optimal
s[threadIdx.x * 2]216 distinct2-way2Rearrange data layout
s[threadIdx.x * 3]332 distinctNone1✓ Odd strides avoid conflicts
s[threadIdx.x * 4]48 distinct4-way4Add 1-element padding per row
s[threadIdx.x * 8]84 distinct8-way8Restructure to stride-1
s[threadIdx.x * 16]162 distinct16-way16Restructure to stride-1
s[threadIdx.x * 32]321 bank32-way32⚠ Worst case. Pad or swizzle

Example: detecting bank conflicts

Pattern 1: 32 threads access s[threadIdx.x * 2].

Thread 0 reads s[0] (bank 0). Thread 1 reads s[2] (bank 2). Thread 2 reads s[4] (bank 4). Thread 16 reads s[32] (bank 0). Thread 17 reads s[34] (bank 2).

Threads 0 and 16 both hit bank 0. Threads 1 and 17 both hit bank 2. Every bank is hit by exactly two threads. This is a 2-way bank conflict: 2 transactions instead of 1.

Pattern 2: 32 threads access s[threadIdx.x * 32].

Thread 0 reads s[0] (bank 0). Thread 1 reads s[32] (bank 0). Thread 2 reads s[64] (bank 0). Every thread hits bank 0 because (threadIdx.x * 32) % 32 == 0 for all threadIdx.x. This is a 32-way bank conflict: 32 serialized transactions. Performance drops to 1/32 of peak shared memory bandwidth.

The padding trick

Bank conflicts in 2D shared memory arrays are common. When you declare __shared__ float s[32][32] and threads in a warp read a column (varying row index, fixed column), each consecutive row element is 32 floats apart in linear memory. That is a stride of 32, which means every access hits the same bank.

The fix is simple: add one padding element per row.

// Without padding: 32-way bank conflicts on column access
__shared__ float s[32][32];
// Column access: s[0][col], s[1][col], s[2][col], ...
// Bank for s[row][col] = (row * 32 + col) % 32 = col  (same bank!)

// With padding: zero bank conflicts on column access
__shared__ float s[32][32 + 1];
// Bank for s[row][col] = (row * 33 + col) % 32
// row=0: bank = col
// row=1: bank = (33 + col) % 32 = (1 + col) % 32
// row=2: bank = (66 + col) % 32 = (2 + col) % 32
// Each row shifts by 1 bank. All 32 rows hit distinct banks.

The cost is 32 extra floats (128 bytes) of shared memory per tile. The benefit is eliminating serialization that would otherwise divide bandwidth by 32. In the tiled matmul kernel, applying padding to both sA and sB typically yields a 10-15% speedup for large matrices.

Updated declaration:

__shared__ float sA[TILE][TILE + 1];  // +1 padding
__shared__ float sB[TILE][TILE + 1];  // +1 padding

The rest of the kernel stays identical. Indexing does not change because padding only affects the physical layout, not the logical indices.

Putting it all together: optimized tiled matmul

#define TILE 32

__global__ void matmulOptimized(const float* A, const float* B, float* C, int N) {
    // Padded shared memory to avoid bank conflicts on column access
    __shared__ float sA[TILE][TILE + 1];
    __shared__ float sB[TILE][TILE + 1];

    int row = blockIdx.y * TILE + threadIdx.y;
    int col = blockIdx.x * TILE + threadIdx.x;
    float sum = 0.0f;

    int numTiles = (N + TILE - 1) / TILE;

    for (int t = 0; t < numTiles; t++) {
        int aCol = t * TILE + threadIdx.x;
        int bRow = t * TILE + threadIdx.y;

        sA[threadIdx.y][threadIdx.x] = (row < N && aCol < N)
            ? A[row * N + aCol] : 0.0f;
        sB[threadIdx.y][threadIdx.x] = (bRow < N && col < N)
            ? B[bRow * N + col] : 0.0f;

        __syncthreads();

        #pragma unroll
        for (int k = 0; k < TILE; k++) {
            sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];
        }

        __syncthreads();
    }

    if (row < N && col < N) {
        C[row * N + col] = sum;
    }
}

The three additions over the basic tiled version:

  1. Padding (TILE + 1): eliminates bank conflicts when threads read columns of sB.
  2. #pragma unroll: hints the compiler to fully unroll the inner loop, reducing loop overhead and enabling instruction-level parallelism.
  3. Boundary checks with zero-fill: handles matrices whose dimensions are not multiples of TILE by loading zeros for out-of-bounds elements.

In practice

Shared memory tiling is the foundation of every high-performance GPU kernel, not just matrix multiply. Convolutions, reductions, scans, and stencil operations all use the same pattern: cooperatively load, synchronize, compute, synchronize, repeat.

Production considerations:

  • Tile size selection. 16 x 16 or 32 x 32 are standard choices. Larger tiles improve reuse but consume more shared memory, reducing occupancy. Profile both and pick the one with higher throughput.
  • Double buffering. Overlap global memory loads for the next tile with computation on the current tile. This hides memory latency and typically adds another 15-25% throughput.
  • Vectorized loads. Use float4 to load 16 bytes at once instead of 4. This reduces the number of load instructions and improves coalescing.
  • Register tiling. Each thread computes a small sub-tile (e.g., 4 x 4 or 8 x 8) instead of a single element. This increases arithmetic intensity per thread and reduces shared memory traffic. cuBLAS uses this extensively.
  • Do not rewrite cuBLAS. For standard GEMM, cuBLAS and cuBLASLt will always be faster than a hand-written kernel. Write custom tiled kernels for fused operations (e.g., matmul + bias + ReLU) where calling separate cuBLAS and cuDNN kernels would require extra global memory round-trips.
  • Profile bank conflicts. Use ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum to measure actual bank conflicts. Zero conflicts does not always mean optimal, but non-zero conflicts almost always means something is wrong.

Common mistakes

MistakeSymptomFix
Missing __syncthreads() after loadRace condition: some threads read stale dataAdd barrier between load and compute phases
Missing __syncthreads() before next tileRace condition: next tile overwrites data still being readAdd barrier between compute and next load
TILE too large for shared memoryKernel fails to launch (too many resources)Reduce TILE or adjust shared memory config
Not checking boundariesGarbage values for non-power-of-2 matricesLoad 0.0f for out-of-bounds indices
Forgetting padding on column access32-way bank conflicts, 30% performance lossAdd +1 to inner dimension

What comes next

Tiled shared memory is the single biggest optimization you can apply to memory-bound GPU kernels. But reaching cuBLAS-level performance requires additional techniques: register tiling, warp-level primitives (tensor cores on Ampere+), and double buffering. Before optimizing further, you need tools to measure where time is actually spent.

The next article, CUDA debugging and profiling, covers Nsight Compute, Nsight Systems, and systematic approaches to finding and fixing performance bottlenecks in CUDA kernels.

Start typing to search across all content
navigate Enter open Esc close