Wav2Vec2¶
Wav2Vec2 is Meta's self-supervised speech representation model, learning from unlabeled audio for speech recognition and other tasks.
Overview¶
| Property | Value |
|---|---|
| Architecture | CNN Feature Extractor + Transformer |
| Parameters | 95M - 1B |
| Audio Input | 16kHz raw waveform |
| Training | Self-supervised (contrastive) |
| Tasks | ASR, Speaker ID, Emotion |
| Languages | 53+ (XLS-R) |
Quick Start¶
use unillm::models_v2::wav2vec2::{Wav2Vec2ModelV2, Wav2Vec2Config};
use unillm::weight_loader_core::WeightLoader;
use unillm::{Model, ModelInputs};
// Load model
let weights = WeightLoader::from_safetensors("wav2vec2-base.safetensors")?;
let config = Wav2Vec2Config::default();
let model = Wav2Vec2ModelV2::from_weights(config, weights)?;
// Extract features
let audio = load_audio("speech.wav")?;
let inputs = ModelInputs::Audio {
input_features: audio.to_tensor()?,
attention_mask: Some(create_attention_mask(&audio)?),
};
let outputs = model.forward(&inputs)?;
// outputs.embeddings: [batch, seq_len, hidden_size]
Configuration¶
model_config!(Wav2Vec2Config {
// Feature extractor (CNN)
conv_dim: Vec<usize> = vec![512, 512, 512, 512, 512, 512, 512],
conv_stride: Vec<usize> = vec![5, 2, 2, 2, 2, 2, 2],
conv_kernel: Vec<usize> = vec![10, 3, 3, 3, 3, 2, 2],
// Transformer
hidden_size: usize = 768,
intermediate_size: usize = 3072,
num_hidden_layers: usize = 12,
num_attention_heads: usize = 12,
// Training
num_codevector_groups: usize = 2,
num_codevectors_per_group: usize = 320,
// General
vocab_size: usize = 32, // CTC vocabulary
hidden_dropout: f32 = 0.1,
attention_dropout: f32 = 0.1,
feat_extract_norm: String = "group".to_string(),
});
Model Sizes¶
| Variant | Layers | Hidden | Params |
|---|---|---|---|
| Base | 12 | 768 | 95M |
| Large | 24 | 1024 | 317M |
| XLS-R 300M | 24 | 1024 | 317M |
| XLS-R 1B | 48 | 1280 | 965M |
| XLS-R 2B | 48 | 1920 | 2B |
Architecture¶
Feature Extractor¶
7-layer CNN that converts raw audio to latent representations:
Raw Audio (16kHz)
│
▼
┌─────────────────┐
│ Conv1D (k=10) │ 512 → 512, stride 5
├─────────────────┤
│ Conv1D (k=3) │ 512 → 512, stride 2
├─────────────────┤
│ Conv1D (k=3) │ 512 → 512, stride 2
├─────────────────┤
│ ... 4 more │
└─────────────────┘
│
▼
Latent Representations
[batch, seq_len/320, 512]
Transformer Encoder¶
struct Wav2Vec2Encoder {
feature_projection: Linear, // 512 → 768
layers: Vec<TransformerLayer>,
layer_norm: LayerNorm,
}
fn forward(&self, features: &Tensor) -> Result<Tensor> {
// Project features
let hidden = ops_fn::linear(features, &self.feature_projection, None)?;
// Add positional encoding
let hidden = self.add_positional_encoding(&hidden)?;
// Transformer layers
for layer in &self.layers {
hidden = self.forward_layer(&hidden, layer)?;
}
ops_fn::layer_norm(&hidden, &self.layer_norm, None, 1e-5)
}
Self-Supervised Learning¶
Contrastive Task¶
Wav2Vec2 is trained by: 1. Masking spans of latent representations 2. Predicting the correct quantized target
┌─────────────────────────────────────────┐
│ Raw Audio │
└─────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────┐
│ CNN Feature Extractor │
└─────────────────────────────────────────┘
│
┌──────┴──────┐
│ │
▼ ▼
┌───────┐ ┌───────────┐
│ Mask │ │ Quantizer │
└───────┘ └───────────┘
│ │
▼ │
┌───────────┐ │
│Transformer│ │
└───────────┘ │
│ │
└──────┬──────┘
│
▼
Contrastive Loss
Speech Recognition¶
With CTC Head¶
// Fine-tuned model with CTC head
let model = Wav2Vec2ForCTC::from_weights(config, weights)?;
let audio = load_audio("speech.wav")?;
let inputs = ModelInputs::Audio {
input_features: audio.to_tensor()?,
attention_mask: None,
};
// Get logits
let outputs = model.forward(&inputs)?;
let logits = outputs.logits; // [batch, seq_len, vocab_size]
// Decode with CTC
let transcription = ctc_decode(&logits, &tokenizer)?;
CTC Decoding¶
fn ctc_decode(logits: &Tensor, tokenizer: &Tokenizer) -> Result<String> {
// Get argmax predictions
let predictions = ops_fn::argmax(logits, -1)?;
// Collapse repeated tokens and remove blanks
let mut decoded = Vec::new();
let mut prev_token = None;
for token_id in predictions.to_vec::<i64>()? {
if token_id != BLANK_TOKEN && Some(token_id) != prev_token {
decoded.push(token_id);
}
prev_token = Some(token_id);
}
tokenizer.decode(&decoded)
}
Feature Extraction¶
For Downstream Tasks¶
// Extract representations for classification, etc.
let outputs = model.forward(&inputs)?;
match outputs {
ModelOutputs::Embeddings { embeddings, pooled } => {
// embeddings: [batch, seq_len, hidden_size]
// Use for sequence tasks
// Mean pooling for classification
let pooled = embeddings.mean(1)?;
// Classification head
let logits = classifier.forward(&pooled)?;
}
_ => unreachable!(),
}
Speaker Identification¶
// Fine-tuned for speaker ID
let speaker_model = Wav2Vec2ForSpeakerID::from_weights(config, weights)?;
let embeddings = speaker_model.extract_speaker_embedding(&audio)?;
// embeddings: [batch, embedding_dim]
// Compare with enrolled speakers
let similarities = cosine_similarity(&embeddings, &enrolled_embeddings)?;
Memory Requirements¶
| Variant | F32 | F16 |
|---|---|---|
| Base | 380 MB | 190 MB |
| Large | 1.3 GB | 650 MB |
| XLS-R 1B | 3.9 GB | 2.0 GB |
Performance¶
Word Error Rate (ASR)¶
| Model | LibriSpeech clean | LibriSpeech other |
|---|---|---|
| Base | 3.4% | 8.0% |
| Large | 2.1% | 4.8% |
| XLS-R 1B | 1.9% | 4.0% |
Use Cases¶
Ideal For¶
- Speech recognition - High-quality ASR
- Speaker identification - Voice biometrics
- Emotion recognition - Sentiment from speech
- Language identification - Detect spoken language
- Low-resource languages - Transfer learning
Comparison¶
| Task | Wav2Vec2 | Whisper |
|---|---|---|
| Real-time ASR | Better | Good |
| Multi-lingual | Good (XLS-R) | Excellent |
| Zero-shot | Needs fine-tune | Yes |
| Feature extraction | Excellent | Limited |
Related Models¶
HuBERT¶
Similar architecture but with different pre-training:
use unillm::models_v2::hubert::{HuBertModelV2, HuBertConfig};
let model = HuBertModelV2::from_weights(config, weights)?;
// Same interface as Wav2Vec2
XLS-R¶
Multilingual Wav2Vec2:
// XLS-R for 128 languages
let config = Wav2Vec2Config {
hidden_size: 1280,
num_hidden_layers: 48,
..Default::default()
};
Best Practices¶
- Normalize audio - Mean/variance normalization
- Use attention mask - For variable-length audio
- Fine-tune for task - Pre-trained features need adaptation
- Consider XLS-R - For multilingual applications