Note 001: GEMM Optimization

Iterating on General Matrix Multiplication in CUDA for optimal performance on NVIDIA GPUs
CUDA
GEMM
Linear Algebra
Published

February 6, 2026

Introduction

General Matrix Multiply, or GEMM, is a linear algebra operation that comprises the majority of computing done by modern deep learning models. In this note, I will explain how we can iteratively optimize GEMM implementations in CUDA until we have almost saturated the capability of modern NVIDIA GPUs.

Mathematical definition

Formally, GEMM is defined as an operation on two input matrices \(A\) and \(B\), and an accumulation matrix \(C\), scaled by scalars \(\alpha\) and \(\beta\):

\[ C = \alpha \cdot (A \times B) + \beta \cdot C \]

Where:

  • \(A\) is an \(M \times K\) matrix.
  • \(B\) is a \(K \times N\) matrix.
  • \(C\) is an \(M \times N\) matrix.

In deep learning contexts, \(\beta\) is often 0 (overwriting the output) or 1 (accumulating gradients), and \(\alpha\) is typically 1.

Why GEMM?

In modern Transformer architectures, GEMM operations account for the vast majority of total Floating Point Operations (FLOPs). This is due to the structure of the Attention operation: \(\text{softmax}(\frac{Q \times K^T}{\sqrt{d}}) \times V\). Aside from the softmax operation, everything else can be represented as GEMM:

  1. Calculating the scaled attention scores (\(\frac{Q \times K^T}{\sqrt{d}}\)).
  2. Calculating the weighted sum of values (\(\text{scores} \times V\)).

Since GEMM dominates the runtime, even a small percentage improvement in kernel efficiency can realize massive savings in training and inference costs at scale.

Problem setup

As I iterate on GEMM kernels, I will test them on the General Matrix Multiplication test suite and infrastructure on LeetGPU (“LeetGPU: Competitive GPU Programming” 2026). As per the problem setup there, I will only be using native capabilities of the GPUs, so no libraries like CuTe or cuBLAS. The test suite is hidden, but the known constraints are that each of the matrix dimensions \(M\), \(N\), and \(K\) are between 16 and 4096. So the input matrices range from very small (a few hundred elements) to fairly large (16 million elements). The platform tells us the runtime of the kernel on a particular large test case that is unknown to us. The input matrices A and B are given as type half (half-precision floating point number). Lower than usual precision floats are common in AI workloads as they take up less space and allow for higher throughput. For improved accuracy, the computation of the GEMM output will be done using full-precision floats, but the final storage will also be as a half-precision float.

For each kernel, I will explain the algorithm, how it interacts with the GPU architecture and memory hierarchy, and show the full code in CUDA C++. Finally, I will discuss the arithmetic intensity of the kernel and benchmark its performance on the following NVIDIA GPUs: Tesla T4 (2017), Ampere A100-80GB (2020), Hopper H100 (2022), Hopper H200 (2023), and Blackwell B200 (2024).

Assumed background

I will assume the reader understands the basics of the CUDA programming model. If not, I recommend reading the first 6 chapters of Programming Massively Parallel Processors (Kirk and Hwu 2022), an excellent resource and probably the canonical text on this topic.

1. Naive Matrix Multiplication

In a naive parallel computing model, we can have every thread be solely responsible for computing exactly one output element in the final matrix. Each thread would load the row from A and column from B that it needs for the dot product for that output element.

Hover over the numbered annotations for explanations of key parts.

Annotated Code

#include <cuda_fp16.h>
#include <cuda_runtime.h>

__global__ void gemm_naive_kernel(const half* A, const half* B, half* C,
                                  int M, int N, int K, 
                                  float alpha, float beta) { 
    
    // Calculate global row and column indices for this thread
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    
    // Boundary check: ensure we don't access memory outside the matrix
    if (row < M && col < N) {
        float val = 0;
        
        // The K-loop: Perform the dot product
        for (int i = 0; i < K; i++) {
            val += __half2float(A[row * K + i]) * __half2float(B[i * N + col]);
        }
        
        // Write result back to C
        val = alpha * val + beta * __half2float(C[row * N + col]);
        C[row * N + col] = __float2half(val);
    }
}

// Wrapper function to be called from Host
extern "C" void solve(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) { 

    dim3 block(16, 16);
    // Grid calculation: ensures we cover the entire matrix (ceiling division)
    dim3 grid(
        (N + 15) / 16,
        (M + 15) / 16
    );

    gemm_naive_kernel<<<grid, block>>>(A, B, C, M, N, K, alpha, beta);
    cudaDeviceSynchronize();
}
1
half vs. float: We use half precision (FP16) for storage but perform accumulation in float (FP32). This is so that we can move data faster from global memory (only 2 bytes per element rather than 4), but during the accumulation computation, we don’t lose small updates due to the smaller mantissa in FP16. (For example, imagine adding 0.01 to a running sum of 1000: if our mantissa is small enough, we may significantly alter or even omit some updates.)
2
The Bottleneck: This line is the performance killer. For every single pixel in C, we are fetching the entire row of A and column of B from Global Memory (DRAM).

Arithmetic Intensity

For each output element of C, we load K elements of A and K elements of B in order to compute a dot product. For each pair of elements in the dot product, we multiply them together and then add the result to the running sum. Therefore, for every 2 halves we load from global memory (a total of 4 bytes), we perform 2 floating point operations. So our computational intensity is 2 FLOPs divided by 4 bytes, or 0.5 FLOP/B.

Benchmarks

Below, we can see the runtime of our kernel on the same test suite for each GPU. We can also compare the arithmetic intensity of the kernel to the ridge point of each GPU (the arithmetic intensity at which kernels switch from memory-bound to compute-bound). This kernel is highly memory-bound on every GPU. Our first course of action to improve the performance of our kernel should be to rethink our memory access pattern.

If our arithmetic intensity is below the Ridge Point, kernels are memory bound. Above the Ridge Point, kernels are compute bound.
GPU Model Memory Bandwidth Peak FP16 Compute Ridge Point (FLOP/Byte) Runtime (ms)
NVIDIA T4 320 GB/s 65 TFLOPS 203 8.49
NVIDIA A100 (80GB) 2,039 GB/s 312 TFLOPS 153 1.03
NVIDIA H100 (SXM) 3,350 GB/s 989 TFLOPS 295 0.54
NVIDIA H200 (SXM) 4,800 GB/s 989 TFLOPS 206 0.53
NVIDIA B200 8,000 GB/s 2,500 TFLOPS 312 0.50

2. Tiled Matrix Multiplication

The main issue with our memory access pattern above was that we are redundantly accessing each row N times and each column M times. Why? Recall that the output C is an M x N matrix. Therefore for \(C_{1,1}\), we need to compute the dot product of row 1 of A with column 1 of B; then for \(C_{2,1}\), we need to compute the dot product of row 2 of A with column 1 of B again. So we retrieve column 1 of B from global memory a total of M times. Similarly, row 1 of A is retrieved from global memory a total of N times, since we access it once for each element in row 1 of the output.

Memory hierarchy of an A100-40GB (“Memory Hierarchy of GPUs” 2025)

When we execute our kernel, we pass it a grid configuration that defines a total number of blocks and how we can index them, and a total number of threads per block and how we can index them. Multiple blocks will be assigned to a single Streaming Multiprocessor (SM) of the GPU at any given time. So all threads in an individual block have access to the same Shared Memory and L1 Cache on their resident Streaming Multiprocessor during execution. We can take advantage of this local memory to reduce our global memory accesses. This pattern is known as locality.

Visualization of tiled matrix multiplication (Matthes et al. 2017)

In tiled matrix multiplication, we choose a tile size which will comprise the total threads in a single block. We will choose 16 x 16 as our tile size so that we have a nice total of 256 threads per block. (32 x 32 would also work, but beyond that we need to be cognizant of hardware restrictions on the maximum number of threads per block). We then loop over a wide row in A and a wide column in B, one tile at a time, as shown above. During each loop iteration, we have a single tile in A and tile in B to process. Each thread is responsible for loading in one element each from A and B to the block’s shared memory. Then in an inner loop, we compute the product of those tiles and add it to the running sum for the output tile. By the end of the outer loop, we have loaded in and processed all elements required for the final value of elements in the 16 x 16 output tile, and so we can write to global memory.

One additional optimization we introduce here is thread coarsening. This means that each thread is tasked with doing more work independently. The advantage of this approach is that if our grid ends up launching more total blocks than the hardware can assign to its SMs, then the blocks will inevitably be queued for assignment and execution. In that case, the blocks will be executed serially anyway, so we may as well have threads do more work in the first place and reduce some redundant data loading and synchronization overhead. However, we must be careful not to coarsen so much that we are no longer taking full advantage of the hardware. For our tiled matrix multiplication kernel, it can make sense for large matrices to have some coarsening. This is because although we have reduced redundancy in global memory accesses, we still will access the same “wide row” in A in two different blocks for two side-by-side output tiles in C. We can experiment with having a thread coarsening factor of 2, which means each block will process two output tiles in C rather than one.

Annotated Code

#include <cuda_fp16.h>
#include <cuda_runtime.h>

#define TILE_WIDTH 16
#define COARSE_FACTOR 2

__global__ void gemm_tiled_kernel(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {
    
    __shared__ float As[TILE_WIDTH][TILE_WIDTH];
    __shared__ float Bs[TILE_WIDTH][TILE_WIDTH];

    int row = TILE_WIDTH * blockIdx.y + threadIdx.y;
    int colStart = COARSE_FACTOR * TILE_WIDTH * blockIdx.x + threadIdx.x;

    float sum[COARSE_FACTOR]; 
    #pragma unroll
    for (int c = 0; c < COARSE_FACTOR; c++) {
        sum[c] = 0.0f;
    }

    // Loop over the K-dimension (shared dimension)
    for (int phase = 0; phase < (K + TILE_WIDTH - 1) / TILE_WIDTH; phase++) {
        
        // --- Load A ---
        // A is (M x K). 
        // Row comes from global 'row'. 
        // Col comes from 'phase' and 'threadIdx.x'.
        int a_col = phase * TILE_WIDTH + threadIdx.x;
        As[threadIdx.y][threadIdx.x] = 
            (row < M && a_col < K) ? 
            __half2float(A[row * K + a_col]) : 0.0f;

        #pragma unroll
        for (int c = 0; c < COARSE_FACTOR; c++) {
            int col = colStart + c * TILE_WIDTH;

            // --- Load B ---
            // B is (K x N). 
            // Row comes from 'phase' and 'threadIdx.y'. 
            // Col comes from global 'col'.
            int b_row = phase * TILE_WIDTH + threadIdx.y;
            
            Bs[threadIdx.y][threadIdx.x] = 
                (b_row < K && col < N) ?
                __half2float(B[b_row * N + col]) : 0.0f; 
            
            __syncthreads();

            for (int j = 0; j < TILE_WIDTH; j++) {
                sum[c] += As[threadIdx.y][j] * Bs[j][threadIdx.x];
            }
            __syncthreads();
        }
    }

    #pragma unroll
    for (int c = 0; c < COARSE_FACTOR; c++) {
        int col = colStart + c * TILE_WIDTH;
        if (row < M && col < N) {
            int idx = row * N + col; // C is (M x N), stride is N
            float initial_val = __half2float(C[idx]);
            C[idx] = __float2half(alpha * sum[c] + beta * initial_val);
        }
    }
}

extern "C" void solve(const half* A, const half* B, half* C, int M, int N, int K, float alpha,
                      float beta) {

    dim3 block(TILE_WIDTH, TILE_WIDTH);
    
    dim3 grid(
        (N + (TILE_WIDTH * COARSE_FACTOR) - 1) / (TILE_WIDTH * COARSE_FACTOR),
        (M + TILE_WIDTH - 1) / TILE_WIDTH
    ); 

    gemm_tiled_kernel<<<grid, block>>>(A, B, C, M, N, K, alpha, beta);
    cudaDeviceSynchronize();
}
1
Shared memory: We declare our block shared memory. One can also dynamically pass the total size of block shared memory to the kernel at runtime if desired. In our case, we have a predetermined tile width. Note that we need to be cognizant of the total shared memory available on an SM. Our oldest GPU, the T4, has 64 KB of shared memory per SM. Here, we have two arrays of 16 x 16 floats each, so 512 total floats, so 4 KB. We’re well within the limits. I went ahead and converted the halves to floats at this stage since we’re so far within shared memory limits, but to save on half of the shared memory allocation, we could declare the shared memory arrays as type half and convert them at compute time.
2
Coarsening: We set COARSE_FACTOR to 2, so each thread is going to load in 2 elements each from A and B, and compute 2 output elements in C. We are loading in two horizontal tiles at a time per block, so we need to apply our coarsening factor to our column computation.
3
Loop unrolling: #pragma unroll is a directive that asks the compiler to try to unroll the loop fully, especially if the total number of iterations is known at compile time. To unroll a loop means to duplicate the code in the loop body rather than perform a condition check and a jump back to the start of the loop body. This allows us to avoid the execution speed cost of checking the loop condition, with the tradeoff of increasing code size. From here on out, we will typically unroll any loop with a constant number of iterations.
4
Boundary checks: Our tiles are a fixed size. So if our matrix dimensions are not all multiples of 16, we will have some tiles that aren’t fully contained within the input matrices and try to access out-of-bound indices. We can simply set these values to 0 in shared memory so that they accumulate to 0 and don’t impact the result.
5
__syncthreads(): This instruction forces each thread in the block to halt here and wait until every other thread in the block reaches this point. This first syncthreads command is known as a Read-After-Write hazard, and the one after it is known as a Write-After-Read hazard. In the first case, individual threads rely on reading shared memory that other threads in their block are writing to. In the second case, if we don’t have a barrier, then some threads risk proceeding to the next loop iteration and modifying shared memory before other threads have read it for their computation on the previous iteration.
6
Another boundary check: When we write to C, we again need to check that we are within bounds, since some tiles may not be fully contained at the end of the grid.
7
Grid calculation with coarsening: We adjust our grid calculation to account for the coarsening in the horizontal dimension; this impacts the total number of blocks we need horizontally.

Arithmetic Intensity

Now that we are reusing some global memory, our arithmetic intensity is higher. The coarsening factor doesn’t impact the arithmetic intensity, so let’s ignore it for the calculation. A single thread is computing a single output element in C, but it doesn’t have to load every element in the vectors of A and B that are used for that dot product. It only has to load one element of A and one element of B per tile, and then it benefits from the other 15 elements it needs from each matrix for each tile that were loaded by other threads. Therefore we reduced the number of global memory accesses by a factor of 16. But we are performing the same number of floating point operations, so our arithmetic intensity is simply 16 times higher than that of the naive kernel. Hence the arithmetic intensity of this kernel is 8 FLOPs/B.

Benchmarks

The runtime improved from our increase in arithmetic intensity. The kernel is still memory-bound though on every GPU. In the next section, we will address this by taking advantage of a fundamental hardware capability that happens to available in every GPU in our test set.

GPU Model Memory Bandwidth Peak FP16 Compute Ridge Point (FLOP/Byte) Runtime (ms)
NVIDIA T4 320 GB/s 65 TFLOPS 203 6.73
NVIDIA A100 (80GB) 2,039 GB/s 312 TFLOPS 153 0.72
NVIDIA H100 (SXM) 3,350 GB/s 989 TFLOPS 295 0.37
NVIDIA H200 (SXM) 4,800 GB/s 989 TFLOPS 206 0.36
NVIDIA B200 8,000 GB/s 2,500 TFLOPS 312 0.33

3. Warp Matrix Multiply Accumulate

Every GPU in our test suite is modern enough to be equipped with Tensor Cores: programmable matrix-multiply-and-accumulate units that deliver massively higher throughput. Each SM has many of these Tensor Cores. An individual Tensor Core performs the operation \(D = A \times B + C\), where every matrix in the operation has size 4x4. We call the shape of this operation 4x4x4. Additionally, Tensor Cores natively handle mixed-precision: the input matrices A and B are expected to be half-precision (FP16), while the accumulators C and D can be either FP16 or FP32.

Tensor Core performing a 4x4x4 matrix multiply and accumulate operation (2024)

This capability is exposed to us as the Warp Matrix Multiply Accumulate API (WMMA). During program execution, a full warp of execution will use multiple Tensor Cores at a time in order to process a 16x16x16 MMA operation.

There are several advantages of using WMMA rather than manually programming the matrix multiply and accumulate operation like we did in previous kernels.

  1. Single instruction: As opposed to issuing separate multiplication and addition instructions manually, the warp scheduler issues a single instruction to the Tensor Core hardware, which proceeds to take over the rest of the operation. GPUs have a limited rate at which they can feed instructions to the execution units, so this allows us to issue memory requests much faster and get closer to saturating the memory bus.
  2. Matrix loading: The load_matrix_sync instruction in WMMA is optimized to use 128-bit global loads. So it retrieves 16 bytes (8 halves) in a single transaction. Meanwhile, when we manually load half data, we are loading 2 bytes at a time unless we specify otherwise (discussed in a subsequent section, when we explicitly issue vectorized loads).
  3. Dedicated registers: Tensor Cores have dedicated register file data paths and accumulation buffers, laid out to maximize efficiency. We don’t have to deal with register pressure (when we risk allocating too many local variables that live in registers, which can spill over to slower memory stores in we exceed the register capacity) or bank conflicts (discussed in a subsequent section). We don’t have to manage all of this ourselves as it’s already fully optimized when we use the Tensor Cores.

One disadvantage of WMMA is that we are locked into the 16x16x16 operation shape. Later on, we’ll adapt our kernel to handle any arbitrary matrix sizes. For now, we’ll have our host code decide whether to use our WMMA kernel based on the input matrix sizes.

Annotated Code

#include <cuda_runtime.h>
#include <cuda_fp16.h>

#include <mma.h>

using namespace nvcuda;

#define WARP_SIZE 32

__global__ void gemm_wmma(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {
   
   // Leading dimensions for Row-Major matrices
   int lead_dim_A = K; // A: M x K. Stride between rows is K
   int lead_dim_B = N; // B: K x N. Stride between rows is N
   int lead_dim_C = N; // C: M x N. Stride between rows is N

   // 2D grid tiling. We will have multiple warps worth of threads in the x dimension.
   // Hence warp_col is divided by warp size. 
   int warp_row = blockDim.y * blockIdx.y + threadIdx.y;
   int warp_col = (blockDim.x * blockIdx.x + threadIdx.x) / WARP_SIZE;

    // Declare fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> A_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> B_frag;
    wmma::fragment<wmma::accumulator, 16, 16, 16, float> accum_frag;
    wmma::fragment<wmma::accumulator, 16, 16, 16, half> C_frag;

    // Initialize the accumulator fragment for A * B with zeroes.
    wmma::fill_fragment(accum_frag, 0.0f);

    for (int i = 0; i < K; i += 16) {
        // Get the starting row and column of our 16 x 16 tiles in both A and B.
        int row_A = warp_row * 16;
        int col_A = i;
        int row_B = i;
        int col_B = warp_col * 16;

        // Check bounds
        if (row_A < M && col_A < K && row_B < K && col_B < N) {

            // Load matrices. 
            wmma::load_matrix_sync(A_frag, A + row_A * lead_dim_A + col_A, lead_dim_A);
            wmma::load_matrix_sync(B_frag, B + row_B * lead_dim_B + col_B, lead_dim_B);

            // Perform MMA. 
            wmma::mma_sync(accum_frag, A_frag, B_frag, accum_frag);
        }
    }

    int row_C = warp_row * 16;
    int col_C = warp_col * 16;

    // Complete the GEMM operation: scale and add result fragments, then write to global memory
    if (row_C < M && col_C < N) {
        wmma::load_matrix_sync(C_frag, C + row_C * lead_dim_C + col_C, lead_dim_C, wmma::mem_row_major);

        for (int i = 0; i < C_frag.num_elements; i++) {
            C_frag.x[i] = __float2half(alpha * accum_frag.x[i] + beta * __half2float(C_frag.x[i]));
        }

        // Store the result in global memory
        wmma::store_matrix_sync(C + row_C * lead_dim_C + col_C, C_frag, lead_dim_C, wmma::mem_row_major);
    }
}

// Same as in Tiled Matrix Multiplication
__global__ void gemm_tiled_kernel(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) { ... }


extern "C" void solve(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {

    if (M % 16 == 0 && N % 16 == 0 && K % 16 == 0) {
        const int WARPS_X = 4, WARPS_Y = 4;
        dim3 blockDim(WARPS_X * WARP_SIZE, WARPS_Y);
        
        int num_col_tiles = N / 16;
        int num_row_tiles = M / 16;
        
        dim3 gridDim(
            (num_col_tiles + WARPS_X - 1) / WARPS_X,
            (num_row_tiles + WARPS_Y - 1) / WARPS_Y
        );
        
        gemm_wmma<<<gridDim, blockDim>>>(A, B, C, M, N, K, alpha, beta);
    } else {
        dim3 block(TILE_WIDTH, TILE_WIDTH);
        dim3 grid(
            (N + (TILE_WIDTH * COARSE_FACTOR) - 1) / (TILE_WIDTH * COARSE_FACTOR), 
            (M + TILE_WIDTH - 1) / TILE_WIDTH
        );

        gemm_tiled_kernel<<<grid, block>>>(A, B, C, M, N, K, alpha, beta);
    }
}
1
Fragments: The operand matrices must be represented in the registers of Tensor Cores before MMA is performed. Since MMA is a warp-wide operation, these registers are distributed between the threads of a warp. Each thread holds a fragment of the overall matrix. A fragment is a templated type that accepts parameters for: the matrix the fragment holds, the shape of the overall operation, the data type, and whether the data is row or column major for the operand matrices. We pass in 16 three times for the shape of the overall operation to represent that the number of rows the fragment stores, the number of columns the fragment stores, and the dot product length are all 16.
2
The K-loop: Each warp computes one 16 x 16 tile of A * B. We loop over rows of A and columns of B. Each row of A and column of B has K elements. Overall, we are computing a 16 x 16 output tile in C: C (16 x 16) = A (16 x K) * B(K x 16). However, we can only store and use 16 x 16 chunks of A and B at once for the MMA operation. Therefore we need to split K into chunks of 16. On each loop iteration, we accumulate C (16 x 16) += A (16 x 16) * B (16 x 16).
3
Loading into a fragment: To load data into a fragment, we need to specify the fragment to load into, the pointer to the memory we are loading from, and the leading dimension of the matrix (so that the operation knows the stride length between rows for a row-major matrix, or between columns for a column-major matrix).
4
Matrix Multiply Accumulate: Computes Arg1 = Arg2 * Arg3 + Arg4.
5
Modifying data within fragments: There are 16 x 16 = 256 elements in C_frag and 32 threads per warp. Each thread therefore holds 256 / 32 = 8 elements. So the loop will have 8 iterations. The fragment’s internal storage is opaque - we don’t know which thread holds each element. Luckily, this doesn’t matter for element-wise operations like scaling. What about for accum_frag and C_frag? As they are declared with identical template parameters, they are guaranteed to have the same internal layout. Hence we can be sure we are adding the correct corresponding elements.
6
Storing back to global memory: Here we need to pass the pointer to memory that we are storing into, the fragment we are loading from, the leading dimension of the matrix, and whether the matrix is row or column major.
7
Restrictions on WMMA: WMMA strictly handles 16x16x16 operations only, so we need to check that our matrix dimensions are multiples of 16. If not, we’ll launch our tiled GEMM kernel. In a later section, we will adjust our WMMA kernel to handle arbitrary matrix dimensions.
8
Grid Dimensions: This works out to be (128, 4), so we have 512 total threads per block. Each row in our block has 128 threads, so a total of 4 warps, and then we have 4 rows, so we essentially have a 4x4 grid of warps in each block. Since each warp computes a 16x16 output tile, each warp is handling the same output as each block did in our tiled GEMM kernel. Since each block has a 4x4 grid of warps, we are then computing a 64x64 output tile of C for each block. We know that our matrix dimensions are divisible by 16, but they may not be divisible by 64. So at the blocks at the edge of our grid, we may have some warps that fall out of bounds of C. Luckily we have the necessary boundary checks in our kernel, so we just need to do our ceiling division here to ensure our blocks fully cover C, without worrying about if some of them go beyond the edges of C.

Arithmetic Intensity

To calculate the arithmetic intensity of this kernel, we will focus on the main loop where the loading from global memory and MMA operations happen. On each loop iteration, a warp collectively loads one 16x16 tile from each of A and B. So we retrieve 512 half-precision floats for a total of 1024 bytes. Then we are modifying the running sums for a 16x16 output tile in C. For each pixel in this output tile, we are taking a dot product of two 16-element vectors, so we perform 16 multiplications and 16 additions. Therefore we perform 32 FLOPs for each pixel in the 16x16 output tile, for a total of 8192 FLOPs. Therefore, our arithmetic intensity is approximately 8192 / 1024 = 8 FLOPs/B.

Notice that this is exactly the same as the arithmetic intensity of our previous tiled matrix multiplication kernel. In this kernel, I avoided using shared memory so that I could have a very simple and clear WMMA implementation. However, in reality, we can make use of the same collaborative shared memory loading technique from our prior kernel to improve the arithmetic intensity of our WMMA kernel even further. I will do exactly this (among other improvements) in subsequent sections. The other aspect that I observed with this kernel is that despite having the same arithmetic intensity as our tiled matrix multiplication, it is significantly faster. This is because WMMA is a hardware-native operation. In the section introduction, we discussed the anatomy of an WMMA operation and why it is so fast, but I’ll call out a few ways the arithmetic intensity here is misleading. First, although it is standard to count multiplication and addition as separate FLOPs, they are fused into a single operation on the hardware when using tensor cores. Second, we discussed that WMMA fragments live on registers instead of shared memory. This is not reflected in our arithmetic intensity (which only takes into account global memory accesses). After accessing global memory in our tiled GEMM kernel, we have just transferred it to shared memory, so we still have to pull our data again from shared memory to our compute cores. Here, we load from global memory directly to the registers of the Tensor Core.

Benchmarks

GPU Model Memory Bandwidth Peak FP16 Compute Ridge Point (FLOP/Byte) Runtime (ms)
NVIDIA T4 320 GB/s 65 TFLOPS 203 1.68
NVIDIA A100 (80GB) 2,039 GB/s 312 TFLOPS 153 0.17
NVIDIA H100 (SXM) 3,350 GB/s 989 TFLOPS 295 0.10
NVIDIA H200 (SXM) 4,800 GB/s 989 TFLOPS 206 0.10
NVIDIA B200 8,000 GB/s 2,500 TFLOPS 312 0.10

4. Double Buffer

The next improvement we can make to our kernel is the use of a double buffer. The goal of a double buffer is to hide the latency of fetching data from global memory. In our current implementation, when threads request data from global memory, the compute cores have to pause while we wait for the data to arrive. Then we start computing, but our memory units are now sitting idle. When we’re done, we request data again and repeat the cycle. At any given time, either our compute cores or memory units are sitting idle.

Instead, before we compute the current tile, we can issue an asynchronous request to load data for the next tile. Then our memory bus will load data in for the next tile while we compute the current tile. There is a dedicated hardware unit in the GPU that handles this asynchronous loading, the Async Copy Engine.

The double buffer is so named because we declare shared memory that is double the size of what we need to compute on. That way, we can use half of the buffer to load the next tiles of A and B from global memory to shared memory, and the other half of the buffer holds the currently loaded data that we feed to our Tensor Cores. We can track which half of the buffer is ready and which is being loaded. So our process is as follows within each loop iteration:

  1. Asynchronously request data for the next tile to the half of the buffer we are not about to use.
  2. WMMA compute on the current tile, using the half of the buffer that is ready.
  3. Barrier wait until the asynchronous request is complete. Then swap the stage index that tells us which half of the buffer is ready, and proceed to the next loop iteration.

There are a few other optimizations related to the data loading and grid configuration that we’ll pack into this kernel that warrant some explanation ahead of time. First, we will have each block be composed of 4 warps in a 2 x 2 grid (so 128 total threads). Each warp will be responsible for computing a 32 x 32 output tile of C, so in total one block will compute a 64 x 64 output tile.

To accomplish this, we will still loop over the K-dimension in a wide row in A and wide column in B, just as pictured in the image from tiled matrix multiplication. However, we will specify the wide row in A to have 64 rows, and the wide column in B to have 64 columns. We still loop over K via increments of 16 at a time. So in each loop iteration over K, we will use a 64 x 16 chunk of A and a 16 x 64 chunk of B. This is the same process as tiled matrix multiplication, but we are now using a non-square tile.

Because we have a 2 x 2 grid of warps, each warp will use a 32 x 16 chunk of A and a 16 x 32 chunk of B, and perform 4 WMMA operations (since they only take matrices of size 16 x 16). We then add their output to our accumulator fragments (4 for each warp, since each WMMA operation accumulates to a different 16 x 16 output tile) in each loop iteration. By the time our K loop is complete, our block has fully computed the value of \(A \times B\) for a 64 x 64 tile of C.

The reason we do this is similar to why we loaded to shared memory in our tiled GEMM: we want to avoid redundant data loading and load as much data from global memory at once as we can usefully share across our block. By arranging our warps in 2 x 2 grid, we also are able to reuse more memory than if they were arranged in a straight line. For the collaborative data loading, we will use the thread ID in the block to determine what part of the current A (64 x 16) and B (16 x 64) chunks this thread will load. Each of these chunks can be treated as 128 8-half vectors, so each thread should load 8 elements. To reduce the number of instructions to load from global memory, we will employ vectorized loads to load 8 halves at once. Therefore, our A chunk can be viewed as 64 rows of 2 vectors, and our B chunk can be viewed as 8 rows of 8 vectors. We will use a vectorized store to global memory in the final section too, when possible.

Annotated Code

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda_pipeline_primitives.h>

using namespace nvcuda;

// ------------- CONFIGURATION -------------
constexpr int BLOCK_M = 64;
constexpr int BLOCK_N = 64; // One block computes a 64 x 64 tile of the output matrix
constexpr int BLOCK_K = 16; // Accumulation step
constexpr int WARP_SIZE = 32;
constexpr int THREAD_COUNT = 128;
constexpr int WMMA = 16;

__global__ void gemm_buffer_kernel(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {

    // ------------- INDEX CALCULATIONS -------------
    // Linear view for data loading: which worker out of 128 threads am I?
    int tid = threadIdx.x;

    // Global position: what tile of the output matrix am I calculating?
    int block_row_start = blockIdx.y * BLOCK_M;
    int block_col_start = blockIdx.x * BLOCK_N;

    // What warp am I in the 2x2 grid?
    int warp_id = tid / WARP_SIZE;
    int warp_row = (warp_id / 2) * 32;
    int warp_col = (warp_id % 2) * 32;


    // A tile: 64 x 16. Each row has 2 8-element vectors. 
    int row_A = tid / 2;       // 0 to 63
    int col_A = (tid % 2) * 8; // 0 or 8
    // B tile: 16 x 64. Each row has 8 8-element vectors. 
    int row_B = tid / 8;       // 0 to 7
    int col_B = (tid % 8) * 8; // 0, 8, 16, 24, 32, 40, 48, or 56
    // ----------------------------------------------


    // ------------- MEMORY INITIALIZATION ----------
    // Double Buffer: Shared Memory
    __shared__ half sA[2][BLOCK_M * BLOCK_K]; // 64 rows, 16 cols (K)
    __shared__ half sB[2][BLOCK_K * BLOCK_N]; // 16 rows (K), 64 cols

    // Declare fragments and initialize accumulator. 
    wmma::fragment<wmma::matrix_a, WMMA, WMMA, WMMA, half, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, WMMA, WMMA, WMMA, half, wmma::row_major> b_frag;
    wmma::fragment<wmma::accumulator, WMMA, WMMA, WMMA, float> accum_frag[2][2];

    #pragma unroll
    for (int i = 0; i < 2; i++) {
        #pragma unroll
        for (int j = 0; j < 2; j++) {
            wmma::fill_fragment(accum_frag[i][j], 0.0f);
        }
    }

    // Pipeline setup
    int stage = 0; // Alternates between 0 and 1
    // ----------------------------------------------


    // ------------- PROLOGUE -------------
    // Load the first tile. 
    {
        const half* src_A = A + (block_row_start + row_A) * K + (0 + col_A);
        half* dst_A = &sA[stage][row_A * BLOCK_K + col_A];

        const half* src_B = B + (0 + row_B) * N + (block_col_start + col_B);
        half* dst_B = &sB[stage][row_B * BLOCK_N + col_B];

        // Async copy. int4 is the size of 8 half elements
        __pipeline_memcpy_async(dst_A, src_A, sizeof(int4));
        __pipeline_memcpy_async(dst_B, src_B, sizeof(int4)); 

        __pipeline_commit();
        __pipeline_wait_prior(0);
        __syncthreads();
    }
    // ------------------------------------

    // ------------- MAIN LOOP -------------
    #pragma unroll
    for (int k = 0; k < K; k += BLOCK_K) {

        int k_next = k + BLOCK_K;

        // 1. LOAD the next tile asynchronously
        if (k_next < K) {
            // Turns 1 into 0 or 0 into 1
            int next_stage = 1 - stage;

            const half* src_A = A + (block_row_start + row_A) * K + (k_next + col_A);
            half* dst_A = &sA[next_stage][row_A * BLOCK_K + col_A];
            
            const half* src_B = B + (k_next + row_B) * N + (block_col_start + col_B);
            half* dst_B = &sB[next_stage][row_B * BLOCK_N + col_B];

            __pipeline_memcpy_async(dst_A, src_A, sizeof(int4)); 
            __pipeline_memcpy_async(dst_B, src_B, sizeof(int4)); 

            __pipeline_commit();
        }

        // 2. MATH: process the current tile. Recall we have a 2 x 2 grid of 16 x 16 subtiles for each warp.
        #pragma unroll
        for (int i = 0; i < 2; i++) {
            #pragma unroll
            for (int j = 0; j < 2; j++) {
                // Calculate pointer into shared memory for this sub-tile
                int smem_row = warp_row + (i * 16);
                int smem_col = warp_col + (j * 16);

                // Load fragments from shared memory
                half* tile_ptr_A = &sA[stage][smem_row * BLOCK_K];
                half* tile_ptr_B = &sB[stage][smem_col];

                wmma::load_matrix_sync(a_frag, tile_ptr_A, BLOCK_K);
                wmma::load_matrix_sync(b_frag, tile_ptr_B, BLOCK_N);

                // Multiply matrices and accumulate
                wmma::mma_sync(accum_frag[i][j], a_frag, b_frag, accum_frag[i][j]);
            }
        }

        // 3. WAIT for next tile
        if (k + BLOCK_K < K) {
            __pipeline_wait_prior(0);
            __syncthreads();
            stage = 1 - stage;
        }
    }
    // ------------------------------------

    __syncthreads(); // Since the syncthreads above won't execute on the last iteration
   
    // ------- EPILOGUE: Store C ----------
    // Size: 64 * 64 floats = 64 * 64 * 4 bytes = 16 KB. Fits easily in modern L1/Shared
    __shared__ float sC[BLOCK_M * BLOCK_N];

    // Warps dump their fragments to shared memory, one 16x16 subtile at a time.
    #pragma unroll
    for (int i = 0; i < 2; i++) {
        #pragma unroll
        for (int j = 0; j < 2; j++) {
            float* subtile_ptr = sC + (warp_row + i * 16) * BLOCK_N + (warp_col + j * 16);
            wmma::store_matrix_sync(subtile_ptr, accum_frag[i][j], BLOCK_N, wmma::mem_row_major);
        }
    }

    // Wait for all threads to write to sC
    __syncthreads();

    #pragma unroll
    for (int i = tid * 8; i < BLOCK_M * BLOCK_N; i += THREAD_COUNT * 8) {
        int row = i / BLOCK_N;
        int col = i % BLOCK_N;

        int global_row = block_row_start + row;
        int global_col = block_col_start + col;

        half buffer[8];

        // Boundary check
        if (global_row < M && (global_col + 7) < N) {
            #pragma unroll
            for (int j = 0; j < 8; j++) {
                float val = alpha * sC[i + j];

                if (beta != 0.0f) {
                    float old_c = __half2float(C[global_row * N + global_col + j]);
                    val += beta * old_c;
                } 

                buffer[j] = __float2half(val);
            }
            
            // Vectorized store
            *(int4*)&C[global_row * N + global_col] = *(int4*)buffer;

        } else {
            #pragma unroll
            for (int j = 0; j < 8; j++) {
                if (global_row < M && (global_col + j) < N) {
                    int out_idx = global_row * N + global_col + j;

                    float val = alpha * sC[i + j];

                    if (beta != 0.0f) {
                        float old_c = __half2float(C[out_idx]);
                        val += beta * old_c;
                    } 

                    C[out_idx] = __float2half(val);
                }
            }
        }
    }
}


// Same as in Tiled Matrix Multiplication
__global__ void gemm_tiled_kernel(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) { ... }


extern "C" void solve(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {
    
    if (M % 64 == 0 && N % 64 == 0 && K % 16 == 0) {
        
        dim3 blockDim(THREAD_COUNT);
        dim3 gridDim(N / BLOCK_N, M / BLOCK_M);
        gemm_buffer_kernel<<<gridDim, blockDim>>>(A, B, C, M, N, K, alpha, beta);

    } else {

        dim3 block(TILE_WIDTH, TILE_WIDTH);
        dim3 grid(
            (N + (TILE_WIDTH * COARSE_FACTOR) - 1) / (TILE_WIDTH * COARSE_FACTOR), 
            (M + TILE_WIDTH - 1) / TILE_WIDTH
        );

        gemm_tiled_kernel<<<grid, block>>>(A, B, C, M, N, K, alpha, beta);

    }
}
1
Warp Grid: As we have 128 threads per block, we have 4 warps per block, which we arrange in a 2x2 grid. Each block computes a 64 x 64 output tile, so we need to assign each warp a 32 x 32 output tile.
2
Data Loading: We treat the A and B tiles, 64 x 16 and 16 x 64 respectively, as linear arrays of 128 8-element vectors. So each thread is responsible for loading 8 halves to shared memory.
3
Accumulator Grid: Accumulator is a 2 x 2 grid because each warp is assigned a 32 x 32 output tile but can only compute 16 x 16 at a time.
4
__pipeline_memcpy_async: Instructs the Async Copy Engine to copy data from global memory to shared memory. As this is an asynchronous operation, the command returns immediately and allows us to continue with other instructions while the memory loads. We issue a vectorized load for 8 halves worth of data at once (sizeof(int4)).
5
__pipeline_commit: Marks the end of a batch of copy commands. Effectively, memcpy_async adds the copy instruction to our shopping cart, and commit places the order.
6
__pipeline_wait_prior: Since we pass in 0, we are pausing thread execution until all asynchronous loads that were issued are complete (in our case, only a single load). In any case, after this line, we have to issue syncthreads because each thread is collaboratively loading a piece of A and B that every thread will need for compute.
7
Writing to the Double Buffer: We load into the half of the double buffer that we’re not using this loop iterationn.
8
Reading from the Double Buffer: We pull data for the WMMA operation from the half of the double buffer that is ready.
9
Notice the location of this command in the main loop compared to in the prologue. We only had to issue it immediately after placing the copy command in the prologue because we needed to load the very first tile for compute. In the main loop, we don’t need to hold up threads on the copy completion until we have finished all compute for this iteration.
10
Vectorized Store: In this loop, we complete our GEMM operation by taking our accumulated result of A x B, scaling it by alpha, adding it to beta * C, and finally storing it in global memory. We have 64 * 64 = 4096 total elements to process and store, and 128 threads to do this. So we must process 32 elements per thread. If we vectorize this into processing 8 elements per step, we need only 4 steps per thread. However, we have an else block here that covers the tail elements once we have fewer than 8 elements left and can’t do a vectorized store.

Arithmetic Intensity

We will examine a single iteration of the main loop. We load in a 64 x 16 chunk of A and a 16 x 64 chunk of B from global memory, for a total of 2,048 halves, which is 4,096 bytes. Our output tile is 64 x 64, and on each loop iteration, we accumulate a dot product of two 16-element vectors to each pixel of the output tile. This dot product consists of 16 multiplications and 16 additions, so 32 FLOPs per pixel. In total then, we perform 64 * 64 * 32 = 131,072 FLOPs per loop iteration. Dividing this out by our global memory load of 4,096 bytes, we get an arithmetic intensity of 32 FLOPs/B. This is due to our increased tile size, not due to our double buffer which mainly helps with hiding memory latency. So we should theoretically have two different improvements that speed up our runtime: reusing more data due to the larger tile size, and latency hiding due to the double buffer. Thankfully, the runtime confirms this, as we can see considerable speedup on all GPUs.

Benchmarks

GPU Model Memory Bandwidth Peak FP16 Compute Ridge Point (FLOP/Byte) Runtime (ms)
NVIDIA T4 320 GB/s 65 TFLOPS 203 1.04
NVIDIA A100 (80GB) 2,039 GB/s 312 TFLOPS 153 0.12
NVIDIA H100 (SXM) 3,350 GB/s 989 TFLOPS 295 0.05
NVIDIA H200 (SXM) 4,800 GB/s 989 TFLOPS 206 0.05
NVIDIA B200 8,000 GB/s 2,500 TFLOPS 312 0.05

5. Swizzling

So far, we’ve taken pretty good advantage of NVIDIA GPU architecture. Let’s go down the checklist:

I haven’t yet discussed caches in detail. Take a look at the below diagram.

Memory hierarchy of an A100-40GB (“Memory Hierarchy of GPUs” 2025)

There are two types of caches on the GPU: L1 and L2. A separate L1 cache exists on each Streaming Multiprocessor and is physically shared with Shared Memory, but not logically. We can control the split between shared memory and L1 if we so choose, but we can’t control what goes into L1 like we can for shared memory. L1 is a cache, so it’s hardware-managed and caches global memory accesses automatically. It handles some level of spatial and temporal locality automatically for us.

We discussed spatial locality briefly in the tiled GEMM section, but didn’t put a name to it. When we retrieve data from global memory, the GPU memory controller never fetches just a few bytes. It always fetches an aligned chunk of memory called a Cache Line, typically 128 bytes, which goes through and into the L1 cache. Ideally, all of the threads in a warp access contiguous memory addresses (i.e. Thread 0 reads address X, Thread 1 reads X + 4, etc). This is known as memory coalescing, and reduces the number of requests the memory controller needs to make to global memory, since we are using most or all of the full Cache Line retrieved every time, rather than just a fraction. Temporal locality means that the L1 will cache recently used data until its capacity is full and needs to evict old data. That way, in case we access the same data multiple times in a short period of time, we don’t need to retrieve it again from global memory as it is still in the cache.

The L2 cache functions in a similar way but is much larger and global to the whole GPU. As a tradeoff, it is also much slower to access for a thread than its local L1 cache. We already taking advantage of locality in our L1 cache in our previous kernels by ensuring threads in a warp are reading contiguous chunks of data. But we haven’t yet taken advantage of the L2 cache. The particular insight we need is that every block has access to the L2 cache. Ideally, we would figure out a way to establish some inter-block temporal locality: after one block accesses data from global memory, other blocks executing within a short time thereafter will reuse that data before it is evicted from the L2 cache.

Let’s think about what’s happening in our standard grid and tiling logic. Since we defined a 2D grid of blocks, and each block corresponds to a certain output tile in the matrix C, what’s happening is that we end up executing our blocks in a row-major order. Look at the first row of tiles in matrix C below.

Visualization of tiled matrix multiplication (Matthes et al. 2017)

Imagine that our L2 cache can only fit 16 tiles. For each tile in that first row in C, we are repeatedly using the first row of tiles of A. However, we use a different column of B each time. By the time we’re on the fourth tile of C in the first row of tiles, we’re now loading in the fourth column of tiles of B, but we already have 16 tiles in our L2 cache (one row of tiles from A and three columns of tiles from B). So we have to evict some data to make room. We’ve been continuously reusing the first row of A, so that won’t be evicted; instead, we’ll evict the first column of B. But the next output tile we will compute for C after this one is the first tile in the second row, which would have reused the first column of B. Sadly, we just evicted it, so we’ll have to pull it from global memory again.

Instead, what we could do, given our L2 cache size, is split C into “newspaper columns”, each having a width of 2 tiles. We will adjust our block execution order so that we traverse the first newspaper column fully before we proceed to the second one. Now what happens? For the first two tiles of C, it’s the same logic as before. Our L2 cache now has the first row of A and first two columns of B. But now we hit the edge of our newspaper column, so we go down to the first tile in the second row of C. We load in the second row of A to the L2 cache, and now we have actually reached the cache capacity of 16 tiles. However, we are going to reuse the columns of B that are already in the L2 cache for the next two output tiles. Therefore, we loaded 16 tiles a single time from global memory and computed 4 output tiles. As opposed to before, we had to reload the needed column of B every time for the second row of output tiles of C, so we needed to load 20 tiles from global memory to compute 4 output tiles.

One way to think of this is that this is very similar to the rationale for tiled matrix multiplication. We are just adding another layer of tiling to the traversal. This block execution order is called grid swizzling and will allow us to get the most possible out of the L2 cache.

There is another memory bottleneck in our previous kernels that has to do with shared memory. To understand this bottleneck, we have to discuss the physical constraints of shared memory. Shared Memory is not a monolithic block of RAM. It is divided into physical banks. For the A100, Shared Memory in each SM is divided into 32 banks, each 4 bytes wide. These banks are effectively parallel lanes that the GPU can read from. The catch is that if we have multiple data requests to shared memory and these requests live in the same bank, then we have to serialize them. If the requests are each for memory in a different bank, then we can fully parallelize them.

Memory addresses are mapped to shared memory banks sequentially. So for 32 banks, we will have bytes 0-3 in Bank 0, 4-7 in Bank 1, …, 124-127 in Bank 31. And then bytes 128-131 wrap around and are placed in Bank 0 again. What we have been doing is defining a 2D array of shared memory that is exactly the size we need, such as a 64 x 64 array of shared memory to hold a 64 x 64 tile of half-precision float data. Since a half is 2 bytes, one row of this array consumes 128 bytes of shared memory. Therefore, when we access a row from this array, every element in that row will be in a different bank, so the request is highly parallelizable. But when we access a column from this array, it’s disastrous: every element in a column will be in the same bank! The request must be completely serialized.

The solution to this is shared memory swizzling: basically storing data to shared memory in a pattern that minimizes bank conflicts. In the below implementation, I use padding to add some dummy elements at the end of each row. In the above example, if we pad each row with 8 zeroes, then the start of the second row will be Bank 8, the start of the third row will be Bank 16, and so on. So we won’t run into extreme bank conflicts with column access. The disadvantage of this approach is that it does add a slight shared memory footprint, which can be an issue if we’re already using it heavily and near capacity. This isn’t the case for our kernel, but in production libraries the additional memory footprint is undesirable, so an approach called XOR swizzling is used instead. In XOR swizzling, the XOR operator is used (since it is computationally inexpensive) to permute the bank mapping of data based on its row and column. Modern libraries handle this swizzling for us, but since we are not using them as part of the problem constraints, I will stick with padding based swizzling for readability.

There is one more slight optimization included below. We double BLOCK_K to 32 and double our number of fragments for A and B. Doubling BLOCK_K means we load twice the data at the start of our K-loop and then have a new loop wrapping our warp math that executes exactly twice. The benefit is that we’re loading more data at once and have fewer total iterations in our K-loop as it increments by 32 rather than 16, so we have to issue the syncthreads and pipeline wait commands fewer times. Doubling the number of fragments means in our warp math loop, when a warp computes WMMA 4 times for its 2 x 2 grid of subtiles, we can load the necessary data into fragments all at once and then perform the MMA. Previously, we were loading into the same fragments 4 separate times.

Annotated Code

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda_pipeline_primitives.h>

using namespace nvcuda;

// ------------- CONFIGURATION -------------
constexpr int BLOCK_M = 64;
constexpr int BLOCK_N = 64; // One block computes a 64 x 64 tile of the output matrix
constexpr int BLOCK_K = 32; // Accumulation step will be in terms of 16 but we load 32 at once to hide latency
constexpr int WARP_SIZE = 32;
constexpr int THREAD_COUNT = 128;
constexpr int WMMA = 16;

// Pad the row stride to avoid bank conflicts in shared memory.
constexpr int SMEM_PAD = 8;

__global__ void gemm_swizzled_kernel(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {

    // ------------- GRID SWIZZLING (L2 Cache Optimization) -------------
    // Remap the linear block index to a "Swizzled" 2D grid.

    // Usually 2, 4, or 8
    const int swizzle_factor = 4;

    // Calculate linear block ID and grid dimensions
    int idx_linear = blockIdx.y * gridDim.x + blockIdx.x;
    int grid_m_blocks = gridDim.y;
    int grid_n_blocks = gridDim.x;

    // Swizzle logic: Map linear ID to (block_row, block_col) in a localized pattern.
    // This traverses the grid in 'thick columns' of width 'swizzle_factor'
    int panel_number = idx_linear / (swizzle_factor * grid_m_blocks);
    int block_row = (idx_linear / swizzle_factor) % grid_m_blocks;
    int block_col = (idx_linear % swizzle_factor) + panel_number * swizzle_factor;
    
    // Safety check for irregular grids (if grid is not perfectly divisible)
    if (block_row >= grid_m_blocks || block_col >= grid_n_blocks) return;

    // Calculate offsets based on swizzled coordinates
    int block_row_start = block_row * BLOCK_M;
    int block_col_start = block_col * BLOCK_N;
    // --------------------------------------------------------------------

    
    // ------------- INDEX CALCULATIONS -------------
    // Linear view for data loading: which worker out of 128 threads am I?
    int tid = threadIdx.x;

    // As we have 128 threads per block, we have 4 warps per block, which we arrange in a 2x2 grid.
    // As each block computes a 64 x 64 output tile, we need to assign each warp a 32 x 32 output tile.
    int warp_id = tid / WARP_SIZE;
    int warp_row = (warp_id / 2) * 32;
    int warp_col = (warp_id % 2) * 32;
    // ----------------------------------------------


    // ------------- MEMORY INITIALIZATION ----------
    // Double Buffer: Shared Memory. Padded to remove bank conflicts 
    __shared__ half sA[2][BLOCK_M * (BLOCK_K + SMEM_PAD)]; // 64 rows, 40 cols (K + pad)
    __shared__ half sB[2][BLOCK_K * (BLOCK_N + SMEM_PAD)]; // 40 rows (K + pad), 64 cols

    // Declare fragments and initialize accumulator
    wmma::fragment<wmma::matrix_a, WMMA, WMMA, WMMA, half, wmma::row_major> a_frag[2];
    wmma::fragment<wmma::matrix_b, WMMA, WMMA, WMMA, half, wmma::row_major> b_frag[2];
    wmma::fragment<wmma::accumulator, WMMA, WMMA, WMMA, float> accum_frag[2][2];

    #pragma unroll
    for (int i = 0; i < 2; i++) {
        #pragma unroll
        for (int j = 0; j < 2; j++) {
            wmma::fill_fragment(accum_frag[i][j], 0.0f);
        }
    }

    // Pipeline setup
    int stage = 0; // Alternates between 0 and 1
    // ----------------------------------------------


    // ------------- PROLOGUE -------------
    // Load the first tile (k=0). A: 64x32. B: 32x64.
    // We have 128 threads. We need to load 64*32 = 2048 halves per matrix.
    // So each thread must load 16 halves (int4 size) from each matrix.

    const half* src_A_base = A + block_row_start * K;
    const half* src_B_base = B + block_col_start;

    auto load_tile_async = [&](int stage_idx, int k_step) {
        const half* A_ptr = src_A_base + k_step; // Adding row * K is handled in loop
        const half* B_ptr = src_B_base + k_step * N;

        half* sA_ptr = sA[stage_idx];
        half* sB_ptr = sB[stage_idx];

        #pragma unroll
        for (int i = 0; i < 2; i++) {
            // Calculate which vector of 8 halves this thread is moving
            int tid_offset = tid + i * THREAD_COUNT; // 0..127, then 128..255

            // Map linear ID to (row, col) for A (64x32)
            // Width is 32 (4 vectors of 8 halves).
            int vec_row_a = tid_offset / 4;
            int vec_col_a = (tid_offset % 4) * 8;

            if (vec_row_a < BLOCK_M) {
                // Async Copy
                __pipeline_memcpy_async(
                    &sA_ptr[vec_row_a * (BLOCK_K + SMEM_PAD) + vec_col_a], // Swizzled shared ptr
                    &A_ptr[vec_row_a * K + vec_col_a],                     // Global ptr
                    sizeof(int4)
                );
            }

            // Map linear ID to (row, col) for B (32x64)
            // Width is 64 (8 vectors of 8 halves).
            int vec_row_b = tid_offset / 8;
            int vec_col_b = (tid_offset % 8) * 8;

            if (vec_row_b < BLOCK_K) {
                // Async Copy
                __pipeline_memcpy_async(
                    &sB_ptr[vec_row_b * (BLOCK_N + SMEM_PAD) + vec_col_b], // Swizzled shared ptr
                    &B_ptr[vec_row_b * N + vec_col_b],                     // Global ptr
                    sizeof(int4)
                );
            }
        }
    };
    
    load_tile_async(stage, 0);
    __pipeline_commit();
    __pipeline_wait_prior(0);
    __syncthreads();
    // ------------------------------------

    // ------------- MAIN LOOP -------------
    for (int k = 0; k < K; k += BLOCK_K) {

        int k_next = k + BLOCK_K;

        // 1. LOAD the next tile asynchronously
        if (k_next < K) {
            // Turns 1 into 0 or 0 into 1
            int next_stage = 1 - stage;
            load_tile_async(next_stage, k_next);
            __pipeline_commit();
        }

        // 2. MATH: process the current tile. Recall we have a 2 x 2 grid of 16 x 16 subtiles for each warp.
        // BLOCK_K = 32, and WMMA accumulates 16x16x16 at a time, so we need to loop k_step 0..1.
        #pragma unroll
        for (int k_step = 0; k_step < BLOCK_K; k_step += WMMA) {
            
            // --- STEP A: Load Fragments into Registers (Pre-Load) ---
            // A Warp computes a 32 x 32 output tile.
            // This requires 32 rows of A (2 fragments) and 32 cols of B (2 fragments).
            
            // Load the 2 fragments of Matrix A needed for this warp
            #pragma unroll
            for (int i = 0; i < 2; i++) {
                int smem_row = warp_row + (i * 16);
                half* tile_ptr_A = &sA[stage][smem_row * (BLOCK_K + SMEM_PAD) + k_step];
                
                // Load into specific index [i]
                wmma::load_matrix_sync(a_frag[i], tile_ptr_A, BLOCK_K + SMEM_PAD);
            }

            // Load the 2 fragments of Matrix B needed for this warp
            #pragma unroll
            for (int j = 0; j < 2; j++) {
                int smem_col = warp_col + (j * 16);
                half* tile_ptr_B = &sB[stage][k_step * (BLOCK_N + SMEM_PAD) + smem_col];
                
                // Load into specific index [j]
                wmma::load_matrix_sync(b_frag[j], tile_ptr_B, BLOCK_N + SMEM_PAD);
            }

            // --- STEP B: Compute (Reuse Registers) ---
            #pragma unroll
            for (int i = 0; i < 2; i++) {
                #pragma unroll
                for (int j = 0; j < 2; j++) {
                    // Reuse a_frag[i] and b_frag[j] multiple times
                    wmma::mma_sync(accum_frag[i][j], a_frag[i], b_frag[j], accum_frag[i][j]);
                }
            }
        }
       

        // 3. WAIT for next tile
        if (k + BLOCK_K < K) {
            __pipeline_wait_prior(0);
            __syncthreads();
            stage = 1 - stage;
        }
    }
    // ------------------------------------

    __syncthreads(); // Since the syncthreads above won't execute on the last iteration
   
    // ------- EPILOGUE: Store C ----------
    // We need a Shared Memory buffer for the floats from the Accumulators.
    __shared__ float sC[BLOCK_M * BLOCK_N];

    // 1. Store Accumulators (Registers) -> Shared Memory (Float)
    // Each warp holds a 32x32 tile distributed across 2x2 fragments (16x16 each).
    #pragma unroll
    for (int i = 0; i < 2; i++) {
        #pragma unroll
        for (int j = 0; j < 2; j++) {
            // Calculate where this 16x16 fragment belongs in the 64x64 block
            int row_offset = warp_row + (i * 16);
            int col_offset = warp_col + (j * 16);
            
            float* smem_ptr = sC + row_offset * BLOCK_N + col_offset;

            // Store fragment to shared memory (Stride is BLOCK_N)
            wmma::store_matrix_sync(smem_ptr, accum_frag[i][j], BLOCK_N, wmma::mem_row_major);
        }
    }

    // Wait for all warps to finish writing to sC
    __syncthreads();

    // 2. Write Shared Memory (Float) -> Global Memory (Half)
    // Total Elements: 64 * 64 = 4096.
    // Threads: 128.
    // Elements per thread: 32.
    // Vectors per thread: 32 / 8 = 4 vectors (int4).

    #pragma unroll
    for (int v = 0; v < 4; v++) {
        // Calculate the linear index for this vector of 8 elements
        // Stride by THREAD_COUNT to ensure coalescing (Thread 0 takes 0..7, Thread 1 takes 8..15)
        int vec_idx = tid + v * THREAD_COUNT; 
        
        int base_idx = vec_idx * 8; // The starting element index
        int row = base_idx / BLOCK_N;
        int col = base_idx % BLOCK_N;

        int global_row = block_row_start + row;
        int global_col = block_col_start + col;

        // Boundary Check (Safe for arbitrary M/N)
        // We check if the whole vector of 8 fits
        if (global_row < M && global_col + 7 < N) {
            
            half out_buffer[8]; // Register buffer for formatting

            // OPTIONAL: Beta Handling (Load old C)
            // If beta is non-zero, we must load the existing values from Global Memory first
            half old_c[8]; 
            bool use_beta = (beta != 0.0f);

            if (use_beta) {
                 // Vectorized Load of old C
                *(int4*)old_c = *(int4*)&C[global_row * N + global_col];
            }

            // Compute scaling and conversion
            #pragma unroll
            for (int x = 0; x < 8; x++) {
                // Read float from Shared
                float val = sC[base_idx + x]; 
                
                // Apply Alpha
                val *= alpha;

                // Apply Beta
                if (use_beta) {
                    val += beta * __half2float(old_c[x]);
                }

                // Convert to Half
                out_buffer[x] = __float2half(val);
            }

            // Vectorized Store to Global Memory
            *(int4*)&C[global_row * N + global_col] = *(int4*)out_buffer;

        } else if (global_row < M) {
            // Edge Case: Partial vector write (at the edge of the matrix)
            for (int x = 0; x < 8; x++) {
                if (global_col + x < N) {
                    float val = alpha * sC[base_idx + x];
                    if (beta != 0.0f) {
                        val += beta * __half2float(C[global_row * N + global_col + x]);
                    }
                    C[global_row * N + global_col + x] = __float2half(val);
                }
            }
        }
    } 
}


// Same as before
__global__ void gemm_tiled_kernel(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) { ... }

extern "C" void solve(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {
    if (M % 64 == 0 && N % 64 == 0 && K % 32 == 0) {
        dim3 blockDim(THREAD_COUNT);
        dim3 gridDim(N / BLOCK_N, M / BLOCK_M);
        gemm_swizzled_kernel<<<gridDim, blockDim>>>(A, B, C, M, N, K, alpha, beta);
    } else {
        dim3 blockDim(TILE_WIDTH, TILE_WIDTH);
        dim3 gridDim(
            (N + TILE_WIDTH - 1) / TILE_WIDTH,
            (M + TILE_WIDTH - 1) / TILE_WIDTH
        );

        gemm_tiled_kernel<<<gridDim, blockDim>>>(A, B, C, M, N, K, alpha, beta);
    }
}
1
Newspaper Panels: So our “newspaper panels” have a width of 4 blocks. This is a standard balanced choice: if we have too narrow of a panel, we are effectively traversing column-major, and if we have too wide of a panel, we might as well just traverse row-major.
2
Grid Swizzling: Computing the panel index beforehand tells us what column the left side of the current panel starts at. At the end of these few lines, we have our new block row and column index in terms of our output matrix, that traverses our newspaper columns first instead of going row-major.
3
Shared Memory Swizzling: This is where we swizzle the shared memory, by adding padding to our shared memory declaration.
4
Doubled Fragments: We double the number of fragments so we can load all the data at once the warp is using for its math loop into separate fragments, and then do all the math.
5
Async Loading Lambda: We moved the async loading to a lambda for readability.
6
Doubled Data Loading: Since we’re loading twice the data now per K-loop iteration, we need a new k-step loop that performs the warp math twice.
7
Warp Math: We load the data all at once before this into our 4 fragments, and then can just loop 4 times calling mma_sync to perform the math.
8
Dimension Check: Since we doubled BLOCK_K we need to change this dimension check too. It may seem disappointing that we are now handling even fewer matrices with our optimized kernel. Don’t worry, we’ll fix this in the next kernel!

Arithmetic Intensity

The swizzling didn’t impact our actual FLOP count or memory volume. We did double BLOCK_K, but that effectively just doubled the number of FLOPs in our K-loop while also doubling the global memory load. So the arithmetic intensity is unchanged from our prior kernel: we’re still sitting at 32 FLOPs/B. Still, we witness considerable speedup from our optimizations.

Benchmarks

GPU Model Memory Bandwidth Peak FP16 Compute Ridge Point (FLOP/Byte) Runtime (ms)
NVIDIA T4 320 GB/s 65 TFLOPS 203 0.49
NVIDIA A100 (80GB) 2,039 GB/s 312 TFLOPS 153 0.07
NVIDIA H100 (SXM) 3,350 GB/s 989 TFLOPS 295 0.04
NVIDIA H200 (SXM) 4,800 GB/s 989 TFLOPS 206 0.04
NVIDIA B200 8,000 GB/s 2,500 TFLOPS 312 0.03

6. Arbitrary Matrix Dimensions

It is unfortunate that we have made it this far without being able to fully remove our tiled GEMM kernel. This next kernel is a modification of the prior swizzled kernel that allows us to handle arbitrary dimensions in our input matrices. With some smart boundary checks and padding of shared memory with zeroes, we can ensure that we can use WMMA 16x16x16 operations across the entire matrix. We end up having some harmless padded zeroes as part of the operation that don’t impact the final result.

Annotated Code

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda_pipeline_primitives.h>

using namespace nvcuda;

// ------------- CONFIGURATION -------------
constexpr int BLOCK_M = 64;
constexpr int BLOCK_N = 64; // One block computes a 64 x 64 tile of the output matrix
constexpr int BLOCK_K = 32; // Accumulation step will be in terms of 16 but we load 32 at once to hide latency
constexpr int WARP_SIZE = 32;
constexpr int THREAD_COUNT = 128;
constexpr int WMMA = 16;

// Pad to avoid bank conflicts in shared memory.
constexpr int SMEM_PAD = 8;

__global__ void gemm_swizzled_all(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {

    // ------------- GRID SWIZZLING (L2 Cache Optimization) -------------
    // Remap the linear block index to a "Swizzled" 2D grid.

    // Usually 2, 4, or 8
    const int swizzle_factor = 4;

    // Calculate linear block ID and grid dimensions
    int idx_linear = blockIdx.y * gridDim.x + blockIdx.x;
    int grid_m_blocks = gridDim.y;
    int grid_n_blocks = gridDim.x;

    // Swizzle logic: Map linear ID to (block_row, block_col) in a localized pattern.
    // This traverses the grid in 'thick columns' of width 'swizzle_factor'
    int panel_number = idx_linear / (swizzle_factor * grid_m_blocks);
    int block_row = (idx_linear / swizzle_factor) % grid_m_blocks;
    int block_col = (idx_linear % swizzle_factor) + panel_number * swizzle_factor;
    
    // Safety check for irregular grids (if grid is not perfectly divisible)
    if (block_row >= grid_m_blocks || block_col >= grid_n_blocks) return;

    // Calculate offsets based on swizzled coordinates
    int block_row_start = block_row * BLOCK_M;
    int block_col_start = block_col * BLOCK_N;
    // --------------------------------------------------------------------

    
    // ------------- INDEX CALCULATIONS -------------
    // Linear view for data loading: which worker out of 128 threads am I?
    int tid = threadIdx.x;

    // As we have 128 threads per block, we have 4 warps per block, which we arrange in a 2x2 grid.
    // As each block computes a 64 x 64 output tile, we need to assign each warp a 32 x 32 output tile.
    int warp_id = tid / WARP_SIZE;
    int warp_row = (warp_id / 2) * 32;
    int warp_col = (warp_id % 2) * 32;
    // ----------------------------------------------


    // ------------- MEMORY INITIALIZATION ----------
    // Double Buffer: Shared Memory. Padded to remove Bank Conflicts
    __shared__ half sA[2][BLOCK_M * (BLOCK_K + SMEM_PAD)]; // 64 rows, 40 cols (K + pad)
    __shared__ half sB[2][BLOCK_K * (BLOCK_N + SMEM_PAD)]; // 40 rows (K + pad), 64 cols

    // Declare fragments and initialize accumulator
    wmma::fragment<wmma::matrix_a, WMMA, WMMA, WMMA, half, wmma::row_major> a_frag[2]; // x2 for K=32
    wmma::fragment<wmma::matrix_b, WMMA, WMMA, WMMA, half, wmma::row_major> b_frag[2];
    wmma::fragment<wmma::accumulator, WMMA, WMMA, WMMA, float> accum_frag[2][2];

    #pragma unroll
    for (int i = 0; i < 2; i++) {
        #pragma unroll
        for (int j = 0; j < 2; j++) {
            wmma::fill_fragment(accum_frag[i][j], 0.0f);
        }
    }

    // Pipeline setup
    int stage = 0; // Alternates between 0 and 1
    // ----------------------------------------------


    // ------------- PROLOGUE -------------
    // Load the first tile (k=0). A: 64x32. B: 32x64.
    // We have 128 threads. We need to load 64*32 = 2048 halves per matrix.
    // So each thread must load 16 halves (int4 size) from each matrix.

    const half* src_A_base = A + block_row_start * K;
    const half* src_B_base = B + block_col_start;

    auto load_tile_async = [&](int stage_idx, int k_step) {
        const half* A_ptr = src_A_base + k_step; // Base pointer for this tile
        const half* B_ptr = src_B_base + k_step * N; 

        half* sA_ptr = sA[stage_idx];
        half* sB_ptr = sB[stage_idx];

        #pragma unroll
        for (int i = 0; i < 2; i++) {
            int tid_offset = tid + i * THREAD_COUNT;

            // --- LOAD MATRIX A (Row-Major: [M x K]) ---
            int vec_row_a = tid_offset / 4;        // Local Row (0..63)
            int vec_col_a = (tid_offset % 4) * 8;  // Local Col (0, 8, 16, 24)
            
            int global_row_a = block_row_start + vec_row_a;
            int global_col_a = k_step + vec_col_a;

            // Address of the shared memory destination
            half* dst_a = &sA_ptr[vec_row_a * (BLOCK_K + SMEM_PAD) + vec_col_a];

            // 1. Check strict bounds (Is this whole vector inside the matrix?)
            bool a_fully_valid = (global_row_a < M) && (global_col_a + 7 < K);

            if (a_fully_valid) {
                // Fast path: Async Copy
                 __pipeline_memcpy_async(dst_a, &A_ptr[vec_row_a * K + vec_col_a], sizeof(int4));
            } 
            else {
                // Slow / Edge path: Manual loading or Zeroing
                // We must ensure Shared Memory has 0s where the matrix has nothing
                #pragma unroll
                for(int v=0; v<8; v++) {
                    if (global_row_a < M && (global_col_a + v) < K) {
                        dst_a[v] = A_ptr[vec_row_a * K + (vec_col_a + v)];
                    } else {
                        // Pad with zeroes
                        dst_a[v] = __float2half(0.0f);
                    }
                }
            }

            // --- LOAD MATRIX B (Row-Major: [K x N]) ---
            int vec_row_b = tid_offset / 8;        // Local Row (0..31)
            int vec_col_b = (tid_offset % 8) * 8;  // Local Col (0..56)

            int global_row_b = k_step + vec_row_b;
            int global_col_b = block_col_start + vec_col_b;

            half* dst_b = &sB_ptr[vec_row_b * (BLOCK_N + SMEM_PAD) + vec_col_b];

            bool b_fully_valid = (global_row_b < K) && (global_col_b + 7 < N);

            if (b_fully_valid) {
                 __pipeline_memcpy_async(dst_b, &B_ptr[vec_row_b * N + vec_col_b], sizeof(int4));
            } else {
                // Edge path
                #pragma unroll
                for(int v=0; v<8; v++) {
                    if (global_row_b < K && (global_col_b + v) < N) {
                        dst_b[v] = B_ptr[vec_row_b * N + (vec_col_b + v)];
                    } else {
                        dst_b[v] = __float2half(0.0f); // PAD WITH ZERO
                    }
                }
            }
        }
    };
    
    load_tile_async(stage, 0);
    __pipeline_commit();
    __pipeline_wait_prior(0);
    __syncthreads();
    // ------------------------------------

    // ------------- MAIN LOOP -------------
    for (int k = 0; k < K; k += BLOCK_K) {

        int k_next = k + BLOCK_K;

        // 1. LOAD the next tile asynchronously
        if (k_next < K) {
            // Turns 1 into 0 or 0 into 1
            int next_stage = 1 - stage;
            load_tile_async(next_stage, k_next);
            __pipeline_commit();
        }

        // 2. MATH: process the current tile. Recall we have a 2 x 2 grid of 16 x 16 subtiles for each warp.
        // BLOCK_K = 32, and WMMA accumulates 16x16x16 at a time, so we need to loop k_step 0..1.
        #pragma unroll
        for (int k_step = 0; k_step < BLOCK_K; k_step += WMMA) {
            
            // --- STEP A: Load Fragments into Registers (Pre-Load) ---
            // A Warp computes a 32x32 output tile.
            // This requires 32 rows of A (2 fragments) and 32 cols of B (2 fragments).
            
            // Load the 2 fragments of Matrix A needed for this warp
            #pragma unroll
            for (int i = 0; i < 2; i++) {
                int smem_row = warp_row + (i * 16);
                half* tile_ptr_A = &sA[stage][smem_row * (BLOCK_K + SMEM_PAD) + k_step];
                
                // Load into specific index [i]
                wmma::load_matrix_sync(a_frag[i], tile_ptr_A, BLOCK_K + SMEM_PAD);
            }

            // Load the 2 fragments of Matrix B needed for this warp
            #pragma unroll
            for (int j = 0; j < 2; j++) {
                int smem_col = warp_col + (j * 16);
                half* tile_ptr_B = &sB[stage][k_step * (BLOCK_N + SMEM_PAD) + smem_col];
                
                // Load into specific index [j]
                wmma::load_matrix_sync(b_frag[j], tile_ptr_B, BLOCK_N + SMEM_PAD);
            }

            // --- STEP B: Compute (Reuse Registers) ---
            #pragma unroll
            for (int i = 0; i < 2; i++) {
                #pragma unroll
                for (int j = 0; j < 2; j++) {
                    // Reuse a_frag[i] and b_frag[j] multiple times
                    wmma::mma_sync(accum_frag[i][j], a_frag[i], b_frag[j], accum_frag[i][j]);
                }
            }
        }
       

        // 3. WAIT for next tile
        if (k + BLOCK_K < K) {
            __pipeline_wait_prior(0);
            __syncthreads();
            stage = 1 - stage;
        }
    }
    // ------------------------------------

    __syncthreads(); // Since the syncthreads above won't execute on the last iteration
   
    // ------- EPILOGUE: Store C ----------
    // We need a Shared Memory buffer for the floats from the Accumulators.
    __shared__ float sC[BLOCK_M * BLOCK_N];

    // 1. Store Accumulators (Registers) -> Shared Memory (Float)
    // Each warp holds a 32x32 tile distributed across 2x2 fragments (16x16 each).
    #pragma unroll
    for (int i = 0; i < 2; i++) {
        #pragma unroll
        for (int j = 0; j < 2; j++) {
            // Calculate where this 16x16 fragment belongs in the 64x64 block
            int row_offset = warp_row + (i * 16);
            int col_offset = warp_col + (j * 16);
            
            float* smem_ptr = sC + row_offset * BLOCK_N + col_offset;

            // Store fragment to shared memory (Stride is BLOCK_N)
            wmma::store_matrix_sync(smem_ptr, accum_frag[i][j], BLOCK_N, wmma::mem_row_major);
        }
    }

    // Wait for all warps to finish writing to sC
    __syncthreads();

    // 2. Write Shared Memory (Float) -> Global Memory (Half)
    // Total Elements: 64 * 64 = 4096.
    // Threads: 128.
    // Elements per thread: 32.
    // Vectors per thread: 32 / 8 = 4 vectors (int4).

    #pragma unroll
    for (int v = 0; v < 4; v++) {
        // Calculate the linear index for this vector of 8 elements
        // Stride by THREAD_COUNT to ensure coalescing (Thread 0 takes 0..7, Thread 1 takes 8..15)
        int vec_idx = tid + v * THREAD_COUNT; 
        
        int base_idx = vec_idx * 8; // The starting element index
        int row = base_idx / BLOCK_N;
        int col = base_idx % BLOCK_N;

        int global_row = block_row_start + row;
        int global_col = block_col_start + col;

        // Boundary Check (Safe for arbitrary M/N)
        // We check if the whole vector of 8 fits
        if (global_row < M && global_col + 7 < N) {
            
            half out_buffer[8]; // Register buffer for formatting

            // OPTIONAL: Beta Handling (Load old C)
            // If beta is non-zero, we must load the existing values from Global Memory first
            half old_c[8]; 
            bool use_beta = (beta != 0.0f);

            if (use_beta) {
                 // Vectorized Load of old C
                *(int4*)old_c = *(int4*)&C[global_row * N + global_col];
            }

            // Compute scaling and conversion
            #pragma unroll
            for (int x = 0; x < 8; x++) {
                // Read float from Shared
                float val = sC[base_idx + x]; 
                
                // Apply Alpha
                val *= alpha;

                // Apply Beta
                if (use_beta) {
                    val += beta * __half2float(old_c[x]);
                }

                // Convert to Half
                out_buffer[x] = __float2half(val);
            }

            // Vectorized Store to Global Memory
            *(int4*)&C[global_row * N + global_col] = *(int4*)out_buffer;

        } else if (global_row < M) {
            // Edge Case: Partial vector write (at the edge of the matrix)
            for (int x = 0; x < 8; x++) {
                if (global_col + x < N) {
                    float val = alpha * sC[base_idx + x];
                    if (beta != 0.0f) {
                        val += beta * __half2float(C[global_row * N + global_col + x]);
                    }
                    C[global_row * N + global_col + x] = __float2half(val);
                }
            }
        }
    } 
}


extern "C" void solve(const half* A, const half* B, half* C, int M, int N, int K, float alpha, float beta) {
    dim3 blockDim(THREAD_COUNT);
    dim3 gridDim((N + BLOCK_N - 1) / BLOCK_N, (M + BLOCK_M - 1) / BLOCK_M);
    gemm_swizzled_all<<<gridDim, blockDim>>>(A, B, C, M, N, K, alpha, beta);
}
1
Valid flag: We check if we are far enough within bounds to still do a vectorized load, or if we would go beyond the edges of the input matrices. If we’re far enough within bounds, we can issue our pipeline_memcpy_async command as before.
2
Zero padding: If we’re too close to the edge of the matrix, we loop element by element, loading from global memory where we’re still in bounds and padding with zeroes wherever we’re not.
3
Boundary check for writing: We have the same boundary checks as usual to not write out zeroes or junk to global memory in the epilogue.

Arithmetic Intensity

We’re not performing more FLOPs or global memory access than the previous kernel. However, we are avoiding the use of the tiled GEMM kernel entirely, which means that in reality, our overall arithmetic intensity for all test cases will be closer to the optimized kernel’s 32 FLOPs/B. We are not funneling any straggler test cases with awkward dimensions to the tiled kernel which only had 8 FLOPs/B.

Benchmarks

I omit the benchmark table here as the runtimes were the same as the prior kernel on the LeetGPU test suite, plus or minus some run to run variation. This checks out with my understanding that the runtime is given for a particular test case, that was probably already compatible with our WMMA dimension checks in the previous kernel.

Final Performance Analysis

The final graph of kernel versus runtime on each GPU is below.

Figure 1: GPU Runtime by Kernel Optimization Step

LeetGPU has a leaderboard for each GPU for the GEMM problem, as well as a list of public solutions ordered by runtime. The leaderboard considers both private and public solutions (it is a user preference whether your solutions are public or not - I left mine as public as I greatly benefited from reading others solutions to understand their approaches). At the time of writing, on most of the GPUs, I am not in the top 3 on the leaderboard, but on all of them my solution is in the top 5. In particular, for the Blackwell B200, my solution sits at 1st place by a whopping 0.1 microsecond over the next best solution. Not bad!

“ShaderShinobi” is my pseudonym. I debated whether to omit the other leaderboard usernames for anonymity, but those usernames already look pretty pseudonymous. Also, I’m hoping if either of those authors see this post, they’ll contact me to nerd out about GPUs.

Further Optimizations

I can almost certainly ascertain that the author of the next best solution had a generally superior kernel though, as their solutions are public. In particular, they used Warp Group MMA, a capability introduced in the Hopper generation that is much more efficient than standard WMMA. The cleanest way to use Warp Group MMA is with an external library, which is prohibited by the problem constraints so I considered it out of scope for this problem. Admirably, this author went ahead and called it directly with PTX code. While I assumed this would be very messy, their solution was surprisingly still quite nice to read. The architecture of Hopper and Blackwell GPUs is quite different and more optimized than previous generations for GEMM operations. In a future post, I will explore Warp Group MMA, the CuTe library, and various optimizations only available on the current generation of GPUs.

References

2024. NVIDIA Technical Blog. NVIDIA. https://developer.nvidia.com/blog/programming-tensor-cores-cuda-9/.
Kirk, David B, and Wen-mei W Hwu. 2022. Programming Massively Parallel Processors: A Hands-on Approach. 4th ed. Morgan Kaufmann.
“LeetGPU: Competitive GPU Programming.” 2026. https://leetgpu.com.
Matthes, Alexander, Rene Widera, Erik Zenker, Benjamin Worpitz, Axel Huebl, and Michael Bussmann. 2017. “Tuning and Optimization for a Variety of Many-Core Architectures Without Changing a Single Line of Implementation Code Using the Alpaka Library.” In, 496–514. https://doi.org/10.1007/978-3-319-67630-2_36.
“Memory Hierarchy of GPUs.” 2025. Arc Compute. https://www.arccompute.io/arc-blog/gpu-101-memory-hierarchy.