BLAS Integration¶
Matrix multiplication dominates transformer inference runtime. For a 7-billion-parameter LLaMA model generating a single token, the forward pass executes hundreds of GEMM (General Matrix Multiply) calls on matrices with thousands of rows and columns. Hand-written triple loops are orders of magnitude slower than vendor-optimized BLAS libraries. ZigLlama provides a uniform BlasInterface that dispatches to the best available library at initialization time and falls back to a pure-Zig SIMD implementation when no external library is present.
1. What is BLAS¶
BLAS (Basic Linear Algebra Subprograms) is a specification -- not a single library -- that defines a standard set of low-level routines for vector, matrix-vector, and matrix-matrix operations1. First published in 1979, BLAS remains the de facto interface through which high-performance numerical software accesses optimized linear algebra kernels.
BLAS
A collection of routines organized into three levels, each operating on progressively higher-dimensional objects, with the property that every conforming implementation produces the same result (within floating-point rounding) for the same input.
2. Level Hierarchy¶
BLAS operations are classified by the rank of the operands and the resulting computational complexity:
| Level | Operands | Canonical Operation | Complexity | Example Routine |
|---|---|---|---|---|
| 1 | vector -- vector | \( \mathbf{y} \leftarrow \alpha \mathbf{x} + \mathbf{y} \) | \( O(n) \) | saxpy |
| 2 | matrix -- vector | \( \mathbf{y} \leftarrow \alpha \mathbf{A}\mathbf{x} + \beta \mathbf{y} \) | \( O(n^2) \) | sgemv |
| 3 | matrix -- matrix | \( \mathbf{C} \leftarrow \alpha \mathbf{A}\mathbf{B} + \beta \mathbf{C} \) | \( O(n^3) \) | sgemm |
Arithmetic Intensity
| Level | Flops | Memory Accesses | Arithmetic Intensity |
|---|---|---|---|
| 1 | \( O(n) \) | \( O(n) \) | \( O(1) \) -- memory-bound |
| 2 | \( O(n^2) \) | \( O(n^2) \) | \( O(1) \) -- memory-bound |
| 3 | \( O(n^3) \) | \( O(n^2) \) | \( O(n) \) -- compute-bound |
Level 3 operations are compute-bound, meaning the CPU's floating-point units are the bottleneck rather than memory bandwidth. This is why GEMM implementations invest heavily in register tiling, cache blocking, and SIMD vectorization.
3. GEMM -- General Matrix Multiply¶
The most performance-critical routine in transformer inference. Its full BLAS signature is:
where \( \text{op}(\mathbf{X}) \) is either \( \mathbf{X} \), \( \mathbf{X}^{\!\top} \), or \( \mathbf{X}^H \).
3.1 Parameter Reference¶
| Parameter | Type | Description |
|---|---|---|
layout | enum | row_major or column_major |
transa | enum | Transpose of \( \mathbf{A} \): N, T, or C |
transb | enum | Transpose of \( \mathbf{B} \): N, T, or C |
M | u32 | Rows of \( \text{op}(\mathbf{A}) \) and \( \mathbf{C} \) |
N | u32 | Columns of \( \text{op}(\mathbf{B}) \) and \( \mathbf{C} \) |
K | u32 | Columns of \( \text{op}(\mathbf{A}) \), rows of \( \text{op}(\mathbf{B}) \) |
alpha | f32 | Scalar multiplier for \( \mathbf{A}\mathbf{B} \) |
A | []const f32 | Matrix data |
lda | u32 | Leading dimension of A |
B | []const f32 | Matrix data |
ldb | u32 | Leading dimension of B |
beta | f32 | Scalar multiplier for \( \mathbf{C} \) (0.0 = overwrite) |
C | []f32 | Output matrix data |
ldc | u32 | Leading dimension of C |
Leading Dimension
The leading dimension is the stride between consecutive columns (column-major) or consecutive rows (row-major) in memory. For a row-major \( M \times N \) matrix stored without padding, lda = N.
3.2 Transformer GEMM Usage¶
| Transformer Operation | M | N | K | Notes |
|---|---|---|---|---|
| Q/K/V projection | \( s \) | \( d_h \) | \( d \) | Per attention head |
| Attention scores | \( s \) | \( s \) | \( d_h \) | \( \mathbf{Q}\mathbf{K}^{\!\top} \) |
| Attention output | \( s \) | \( d_h \) | \( s \) | Weighted sum of values |
| FF up-projection | \( s \) | \( 4d \) | \( d \) | |
| FF down-projection | \( s \) | \( d \) | \( 4d \) |
4. Supported Libraries¶
ZigLlama detects and integrates with the following BLAS implementations:
| Library | Platform | Expected Speedup | Detection |
|---|---|---|---|
| GenericBlas | All | 1.0x (baseline) | Always available |
| OpenBLAS | Linux, Windows, macOS | ~4.0x | Dynamic library probe |
| Intel MKL | Linux, Windows | ~6.0x | Dynamic library probe |
| Apple Accelerate | macOS, iOS | ~5.5x | Compile-time OS check |
| ATLAS | Linux | ~3.5x | Dynamic library probe |
fn detectAvailableLibrary() BlasLibrary {
const target_os = @import("builtin").target.os.tag;
return switch (target_os) {
.macos, .ios => .accelerate,
.linux, .windows => blk: {
if (detectOpenBlas()) break :blk .openblas;
if (detectMkl()) break :blk .mkl;
break :blk .generic;
},
else => .generic,
};
}
5. BlasInterface Vtable¶
The BlasInterface is a runtime-polymorphic wrapper that allows ZigLlama to swap BLAS backends without changing any calling code:
pub const BlasInterface = struct {
vtable: *const VTable,
context: *anyopaque,
pub const VTable = struct {
// Level 1 BLAS
dot: *const fn (*anyopaque, u32, []const f32, []const f32) f32,
axpy: *const fn (*anyopaque, u32, f32, []const f32, []f32) void,
scal: *const fn (*anyopaque, u32, f32, []f32) void,
nrm2: *const fn (*anyopaque, u32, []const f32) f32,
// Level 2 BLAS
gemv: *const fn (*anyopaque, MatrixLayout, BlasOperation,
u32, u32, f32, []const f32, u32,
[]const f32, f32, []f32) void,
// Level 3 BLAS
gemm: *const fn (*anyopaque, MatrixLayout, BlasOperation, BlasOperation,
u32, u32, u32, f32, []const f32, u32,
[]const f32, u32, f32, []f32, u32) void,
deinit: *const fn (*anyopaque) void,
};
};
5.1 Operation Summary¶
| Method | BLAS Name | Formula |
|---|---|---|
dot(n, x, y) | sdot | \( \sum_{i=1}^{n} x_i y_i \) |
axpy(n, alpha, x, y) | saxpy | \( \mathbf{y} \leftarrow \alpha\mathbf{x} + \mathbf{y} \) |
scal(n, alpha, x) | sscal | \( \mathbf{x} \leftarrow \alpha\mathbf{x} \) |
nrm2(n, x) | snrm2 | \( \lVert \mathbf{x} \rVert_2 \) |
gemv(...) | sgemv | \( \mathbf{y} \leftarrow \alpha \mathbf{A}\mathbf{x} + \beta\mathbf{y} \) |
gemm(...) | sgemm | \( \mathbf{C} \leftarrow \alpha \mathbf{A}\mathbf{B} + \beta\mathbf{C} \) |
6. GenericBlas -- Pure Zig SIMD Fallback¶
When no external library is available, GenericBlas provides a reasonably optimized pure-Zig implementation using SIMD intrinsics.
6.1 Dot Product with SIMD¶
fn genericDot(context: *anyopaque, n: u32, x: []const f32, y: []const f32) f32 {
_ = context;
var result: f32 = 0.0;
const simd_width = 8; // AVX2: 8 x f32
const simd_end = (n / simd_width) * simd_width;
var i: u32 = 0;
if (comptime std.simd.suggestVectorLength(f32)) |vec_len| {
if (vec_len >= simd_width) {
const Vec = @Vector(simd_width, f32);
var sum_vec: Vec = @splat(0.0);
while (i < simd_end) {
const x_vec: Vec = x[i..][0..simd_width].*;
const y_vec: Vec = y[i..][0..simd_width].*;
sum_vec += x_vec * y_vec;
i += simd_width;
}
// Horizontal reduction
for (0..simd_width) |j| result += sum_vec[j];
}
}
// Scalar tail
while (i < n) : (i += 1) result += x[i] * y[i];
return result;
}
Zig SIMD Model
Zig's @Vector(N, T) type maps directly to the target's SIMD registers. On x86-64 with AVX2, @Vector(8, f32) compiles to a single __m256 register. The compiler handles register allocation and spilling.
6.2 Generic GEMM¶
The generic GEMM follows the standard triple loop with support for both row-major and column-major layouts, transpose combinations, and alpha/beta scaling:
fn genericGemm(context: *anyopaque, layout: MatrixLayout,
transa: BlasOperation, transb: BlasOperation,
m: u32, n: u32, k: u32,
alpha: f32, a: []const f32, lda: u32,
b: []const f32, ldb: u32,
beta: f32, c: []f32, ldc: u32) void {
// 1. Scale C by beta
for (0..m) |i| {
for (0..n) |j| {
const idx = if (layout == .column_major) j * ldc + i
else i * ldc + j;
c[idx] *= beta;
}
}
// 2. Accumulate alpha * A * B
for (0..m) |i| {
for (0..n) |j| {
var sum: f32 = 0.0;
for (0..k) |l| {
const a_idx = indexFor(layout, transa, i, l, lda);
const b_idx = indexFor(layout, transb, l, j, ldb);
sum += a[a_idx] * b[b_idx];
}
const c_idx = if (layout == .column_major) j * ldc + i
else i * ldc + j;
c[c_idx] += alpha * sum;
}
}
}
Performance Note
The generic GEMM is \( O(MNK) \) with no cache blocking or register tiling. For production inference on matrices larger than 256 x 256, using an external BLAS library is strongly recommended.
7. Detection and Selection¶
7.1 Compile-Time Platform Detection¶
pub fn detect() BlasConfig {
const detected_library = detectAvailableLibrary();
const cpu_count = @as(u32, @intCast(std.Thread.getCpuCount() catch 4));
return BlasConfig{
.library = detected_library,
.num_threads = @max(1, cpu_count - 1),
.use_threading = cpu_count > 1,
.memory_alignment = 64,
.prefer_column_major = true,
};
}
7.2 BlasManager Initialization¶
The BlasManager wraps library selection, high-level tensor operations, and performance statistics:
pub fn init(allocator: std.mem.Allocator, config: BlasConfig) !BlasManager {
const blas = switch (config.library) {
.generic => blk: {
var g = try GenericBlas.init(allocator, config);
break :blk g.interface();
},
.openblas => blk: {
var o = try OpenBlas.init(allocator, config);
break :blk o.interface();
},
.mkl, .accelerate, .atlas => blk: {
std.log.warn("{} not yet implemented, falling back to generic",
.{config.library});
var g = try GenericBlas.init(allocator, config);
break :blk g.interface();
},
};
return BlasManager{ .blas = blas, .config = config,
.allocator = allocator, .stats = BlasStats.init() };
}
7.3 Selection Flowchart¶
flowchart TD
START["BlasConfig.detect()"] --> OS{"Target OS?"}
OS -->|macOS/iOS| ACC["Use Accelerate"]
OS -->|Linux/Windows| LIB{"OpenBLAS found?"}
LIB -->|Yes| OB["Use OpenBLAS"]
LIB -->|No| MKL{"Intel MKL found?"}
MKL -->|Yes| MK["Use MKL"]
MKL -->|No| GEN["Use GenericBlas (Pure Zig)"]
OS -->|Other| GEN 8. Performance Statistics¶
BlasStats tracks per-operation counts, cumulative time, and FLOP counts to help identify bottlenecks:
pub const BlasStats = struct {
operation_counts: std.EnumMap(BlasOpType, u64),
total_time_ns: std.EnumMap(BlasOpType, u64),
total_flops: std.EnumMap(BlasOpType, u64),
pub fn getGflops(self: *const BlasStats, op: BlasOpType) f64 {
const flops = self.total_flops.get(op) orelse 0;
const time_ns = self.total_time_ns.get(op) orelse 1;
return @as(f64, @floatFromInt(flops)) / @as(f64, @floatFromInt(time_ns));
}
};
Typical output:
=== BLAS Performance Statistics ===
GEMM: 960 ops, 12.34 GFLOPS, 245.6 us avg
GEMV: 64 ops, 8.21 GFLOPS, 18.3 us avg
DOT: 128 ops, 6.50 GFLOPS, 2.1 us avg
===================================
References¶
-
Lawson, C. et al. "Basic Linear Algebra Subprograms for Fortran Usage." ACM TOMS, 5(3), 1979. ↩
-
Goto, K. and van de Geijn, R. "Anatomy of High-Performance Matrix Multiplication." ACM TOMS, 34(3), 2008. ↩
-
Wang, Q. et al. "BLIS: A Framework for Rapidly Instantiating BLAS Functionality." ACM TOMS, 41(3), 2015. ↩
-
Intel. "oneMKL Developer Reference." Intel Corporation, 2024. ↩