Mistral¶
Overview¶
Mistral 7B, released by Mistral AI in September 20231, introduced two key efficiency innovations to the LLaMA-style architecture: Sliding Window Attention (SWA) and Grouped-Query Attention (GQA). Despite having only 7.3 billion parameters, Mistral 7B matched or exceeded LLaMA 2 13B across most benchmarks, demonstrating that architectural efficiency can substitute for raw parameter count.
ZigLLM implements the Mistral architecture in src/models/mistral.zig, including the sliding window mask, GQA head repetition, and the Mixtral 8x7B MoE variant configuration.
Key Innovations¶
Sliding Window Attention¶
Standard causal attention allows each token to attend to all preceding tokens, resulting in \( O(n^2) \) complexity for sequence length \( n \). Sliding Window Attention restricts each token to attend only within a local window of size \( W \).
where the mask \( M_\text{SWA}[i,j] \) is:
graph LR
subgraph "Standard Causal Attention"
A1["Token 1"] --> A1
A2["Token 2"] --> A1
A2 --> A2
A3["Token 3"] --> A1
A3 --> A2
A3 --> A3
A4["Token 4"] --> A1
A4 --> A2
A4 --> A3
A4 --> A4
end graph LR
subgraph "Sliding Window (W=2)"
B2["Token 2"] --> B1["Token 1"]
B2 --> B2
B3["Token 3"] --> B2
B3 --> B3
B4["Token 4"] --> B3
B4 --> B4
end Complexity Reduction
- Standard causal: \( O(n^2 \cdot d) \) time, \( O(n^2) \) memory for attention scores
- Sliding window: \( O(n \cdot W \cdot d) \) time, \( O(n \cdot W) \) memory
For Mistral 7B with \( W = 4096 \) and typical inference sequences of \( n = 8192 \)--\( 32768 \), this provides a 2--8x reduction in attention memory.
Information Flow Beyond the Window
Although each layer only attends within window \( W \), information can propagate across the full sequence through multiple layers. After \( L \) layers, a token can theoretically access information from \( L \times W \) positions back. For Mistral 7B: \( 32 \times 4096 = 131{,}072 \) positions.
Grouped-Query Attention (GQA)¶
Mistral 7B uses 32 query heads but only 8 key-value heads, giving a 4:1 sharing ratio. Each group of 4 query heads shares a single set of K and V projections.
| Attention Type | Query Heads | KV Heads | KV Memory | Used By |
|---|---|---|---|---|
| MHA | 32 | 32 | 1.0x | LLaMA, GPT-2 |
| GQA | 32 | 8 | 0.25x | Mistral, LLaMA 2 70B |
| MQA | 32 | 1 | 0.03x | Falcon-7B |
The KV cache memory saving is proportional to the ratio \( n_\text{kv\_heads} / n_\text{heads} \).
Configuration¶
MistralConfig Struct¶
pub const MistralConfig = struct {
d_model: usize, // 4096
n_heads: usize, // 32 (query heads)
n_kv_heads: usize, // 8 (key-value heads for GQA)
n_layers: usize, // 32
vocab_size: usize, // 32000
max_seq_len: usize, // 32768
intermediate_size: usize, // 14336 (SwiGLU FFN)
rope_theta: f32, // 10000.0
sliding_window: ?usize, // 4096 (null = full attention)
num_experts: ?usize, // null for base, 8 for Mixtral
num_experts_per_tok: ?usize, // null for base, 2 for Mixtral
};
Variant Configurations¶
| Parameter | Mistral 7B | Mixtral 8x7B |
|---|---|---|
d_model | 4096 | 4096 |
n_heads | 32 | 32 |
n_kv_heads | 8 | 8 |
n_layers | 32 | 32 |
vocab_size | 32000 | 32000 |
max_seq_len | 32768 | 32768 |
intermediate_size | 14336 | 14336 |
rope_theta | 10000.0 | 1000000.0 |
sliding_window | 4096 | null (full) |
num_experts | null | 8 |
num_experts_per_tok | null | 2 |
| Total params | 7.3B | 46.7B (12.9B active) |
Mixtral Active Parameters
Mixtral 8x7B has 46.7B total parameters but only activates 2 of 8 experts per token, resulting in ~12.9B active parameters per forward pass -- comparable compute to a 13B dense model with the capacity of a 47B model.
Architecture Components¶
GroupedQueryAttention¶
The core attention module handles the asymmetric Q/KV head counts.
pub const GroupedQueryAttention = struct {
d_model: usize,
n_heads: usize, // 32 query heads
n_kv_heads: usize, // 8 KV heads
head_dim: usize, // d_model / n_heads = 128
q_proj: Tensor(f32), // [d_model, d_model]
k_proj: Tensor(f32), // [d_model, n_kv_heads * head_dim]
v_proj: Tensor(f32), // [d_model, n_kv_heads * head_dim]
o_proj: Tensor(f32), // [d_model, d_model]
};
The key implementation detail is the KV head repetition step, where each KV head is broadcast to serve multiple query heads:
fn repeatKVHeads(self: *Self, kv_tensor: Tensor(f32)) !Tensor(f32) {
const repeat_factor = self.n_heads / self.n_kv_heads; // 4
// For each sequence position:
// For each KV head:
// Copy to repeat_factor query head positions
for (0..seq_len) |s| {
for (0..self.n_kv_heads) |kv_head| {
for (0..repeat_factor) |rep| {
const q_head = kv_head * repeat_factor + rep;
// Copy kv_head data to q_head position
@memcpy(result[q_head_offset..], kv_data[kv_head_offset..]);
}
}
}
}
Sliding Window Mask¶
fn createSlidingWindowMask(self: *Self, seq_len: usize,
window_size: usize) !Tensor(f32) {
for (0..seq_len) |i| {
for (0..seq_len) |j| {
if (j <= i and (i - j) <= window_size) {
mask[i][j] = 0.0; // Attend
} else {
mask[i][j] = -inf; // Mask out
}
}
}
}
SwiGLU MLP¶
Mistral uses the same SwiGLU FFN as LLaMA:
pub const SwiGLUMLP = struct {
gate_proj: Tensor(f32), // [d_model, intermediate_size]
up_proj: Tensor(f32), // [d_model, intermediate_size]
down_proj: Tensor(f32), // [intermediate_size, d_model]
pub fn forward(self: *Self, input: Tensor(f32)) !Tensor(f32) {
const gate = try input.matmul(self.gate_proj, self.allocator);
const gate_activated = try silu(gate); // SiLU = Swish
const up = try input.matmul(self.up_proj, self.allocator);
const combined = try elementwiseMul(gate_activated, up);
return try combined.matmul(self.down_proj, self.allocator);
}
};
Forward Pass¶
The MistralBlock implements the standard pre-norm residual pattern with an optional sliding window:
pub fn forward(self: *Self, input: Tensor(f32),
rope_cache: ?Tensor(f32)) !Tensor(f32) {
// 1. Pre-attention RMSNorm
const normed_input = try self.input_layernorm.forward(input);
// 2. GQA with sliding window
const attn_output = try self.attention.forward(
normed_input, rope_cache, self.config.sliding_window);
// 3. Residual connection
const after_attn = try self.addResidual(input, attn_output);
// 4. Pre-MLP RMSNorm
const normed_attn = try self.post_attention_layernorm.forward(after_attn);
// 5. SwiGLU MLP
const mlp_output = try self.mlp.forward(normed_attn);
// 6. Residual connection
return try self.addResidual(after_attn, mlp_output);
}
The full model stacks \( N \) blocks:
pub fn forward(self: *Self, input_ids: []const u32) !Tensor(f32) {
var hidden_states = try self.getEmbeddings(input_ids);
const rope_cache = try self.createRoPECache(input_ids.len);
for (self.blocks) |*block| {
const new_states = try block.forward(hidden_states, rope_cache);
hidden_states.deinit(self.allocator);
hidden_states = new_states;
}
const normed = try self.norm.forward(hidden_states);
return try normed.matmul(self.lm_head, self.allocator);
}
Mistral vs LLaMA Comparison¶
| Aspect | LLaMA 7B | Mistral 7B |
|---|---|---|
| Parameters | 6.7B | 7.3B |
| Context length | 2048 (LLaMA 1) / 4096 (LLaMA 2) | 32768 |
| Attention | MHA (32/32) | GQA (32 Q / 8 KV) |
| Window | Full causal | Sliding (W=4096) |
| KV cache per layer | \( 2 \times 32 \times 128 \times S \) | \( 2 \times 8 \times 128 \times S \) |
| KV cache ratio | 1.0x | 0.25x |
| FFN intermediate | 11008 | 14336 |
| Activation | SwiGLU | SwiGLU |
| Normalization | RMSNorm | RMSNorm |
When to Choose Mistral
Choose Mistral over LLaMA when you need long-context inference (> 4096 tokens) or when KV cache memory is a bottleneck (e.g., high-throughput serving with many concurrent sequences). The GQA head reduction provides significant memory savings during generation.
Mixtral (MoE Variant)¶
Mixtral 8x7B replaces each dense FFN with a Mixture-of-Experts layer containing 8 expert FFNs, of which 2 are activated per token via a learned router. This is covered in detail in Mixture of Experts. The key differences from base Mistral are:
rope_thetaincreased to 1,000,000 for better long-range position encoding- Sliding window disabled (full causal attention)
- 8 expert FFNs per layer, 2 active per token
- Total parameters: 46.7B, active parameters: ~12.9B
References¶
-
Jiang, A. Q. et al. "Mistral 7B." arXiv:2310.06825, 2023. ↩
-
Jiang, A. Q. et al. "Mixtral of Experts." arXiv:2401.04088, 2024. ↩
-
Ainslie, J. et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP, 2023. ↩
-
Beltagy, I. et al. "Longformer: The Long-Document Transformer." arXiv:2004.05150, 2020. ↩