Skip to content

Inference Pipeline API

The inference module provides high-level inference pipeline and sampling utilities.

InferencePipeline

High-level interface for running model inference.

pub struct InferencePipeline<M: Model> {
    model: M,
    tokenizer: Tokenizer,
    device: Device,
}

impl<M: Model> InferencePipeline<M> {
    /// Create new pipeline
    pub fn new(model: M, tokenizer: Tokenizer, device: Device) -> Self;

    /// Generate text from prompt
    pub fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<String>;

    /// Generate with streaming callback
    pub fn generate_streaming<F>(&self, prompt: &str, config: &GenerationConfig, callback: F) -> Result<String>
    where
        F: FnMut(&str);

    /// Run forward pass with raw inputs
    pub fn forward(&self, inputs: &ModelInputs) -> Result<ModelOutputs>;

    /// Get next token logits
    pub fn get_logits(&self, input_ids: &Tensor) -> Result<Tensor>;
}

Basic Usage

use unillm::inference::InferencePipeline;
use unillm::models_v2::llama::{LlamaModelV2, LlamaConfig};

// Create model and tokenizer
let model = LlamaModelV2::from_weights(config, weights)?;
let tokenizer = Tokenizer::from_gguf("model.gguf")?;

// Create pipeline
let pipeline = InferencePipeline::new(model, tokenizer, Device::auto());

// Generate text
let response = pipeline.generate("Hello, world!", &GenerationConfig::default())?;

Streaming Generation

let config = GenerationConfig {
    max_new_tokens: 100,
    ..Default::default()
};

pipeline.generate_streaming("Once upon a time", &config, |token| {
    print!("{}", token);
    std::io::stdout().flush().unwrap();
})?;

Sampler

Token sampling strategies for text generation.

pub struct Sampler {
    temperature: f32,
    top_p: f32,
    top_k: Option<usize>,
    repetition_penalty: f32,
}

impl Sampler {
    /// Create new sampler with config
    pub fn new(config: &GenerationConfig) -> Self;

    /// Sample next token from logits
    pub fn sample(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32>;

    /// Apply temperature scaling
    pub fn apply_temperature(&self, logits: &Tensor) -> Result<Tensor>;

    /// Apply top-k filtering
    pub fn apply_top_k(&self, logits: &Tensor) -> Result<Tensor>;

    /// Apply top-p (nucleus) filtering
    pub fn apply_top_p(&self, logits: &Tensor) -> Result<Tensor>;

    /// Apply repetition penalty
    pub fn apply_repetition_penalty(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<Tensor>;
}

Sampling Strategies

use unillm::sampler::Sampler;

// Greedy sampling (temperature = 0)
let greedy = Sampler::new(&GenerationConfig {
    do_sample: false,
    ..Default::default()
});

// Temperature sampling
let temp_sampler = Sampler::new(&GenerationConfig {
    do_sample: true,
    temperature: 0.7,
    ..Default::default()
});

// Top-p (nucleus) sampling
let nucleus_sampler = Sampler::new(&GenerationConfig {
    do_sample: true,
    temperature: 0.8,
    top_p: 0.9,
    ..Default::default()
});

// Combined sampling
let combined = Sampler::new(&GenerationConfig {
    do_sample: true,
    temperature: 0.7,
    top_p: 0.9,
    top_k: Some(50),
    repetition_penalty: 1.1,
    ..Default::default()
});

Manual Sampling

let logits = model.get_logits(&input_ids)?;
let sampler = Sampler::new(&gen_config);

// Apply all transformations
let scaled = sampler.apply_temperature(&logits)?;
let filtered = sampler.apply_top_k(&scaled)?;
let nucleus = sampler.apply_top_p(&filtered)?;
let penalized = sampler.apply_repetition_penalty(&nucleus, &past_tokens)?;

// Sample token
let next_token = sampler.sample(&penalized, &past_tokens)?;

Tokenizer

Text tokenization and detokenization.

pub struct Tokenizer {
    // Internal implementation
}

impl Tokenizer {
    /// Load tokenizer from GGUF file
    pub fn from_gguf(path: &str) -> Result<Self>;

    /// Load tokenizer from HuggingFace tokenizers JSON
    pub fn from_hf_tokenizers(path: &str) -> Result<Self>;

    /// Encode text to token IDs
    pub fn encode(&self, text: &str) -> Result<Vec<u32>>;

    /// Decode token IDs to text
    pub fn decode(&self, ids: &[u32]) -> Result<String>;

    /// Decode single token
    pub fn decode_token(&self, id: u32) -> Result<String>;

    /// Get vocabulary size
    pub fn vocab_size(&self) -> usize;

    /// Get special tokens
    pub fn special_tokens(&self) -> &SpecialTokens;
}

Tokenization

use unillm::tokenizer::Tokenizer;

let tokenizer = Tokenizer::from_gguf("model.gguf")?;

// Encode text
let ids = tokenizer.encode("Hello, world!")?;
println!("Token IDs: {:?}", ids);

// Decode back to text
let text = tokenizer.decode(&ids)?;
println!("Decoded: {}", text);

// Single token decode
for id in &ids {
    let token = tokenizer.decode_token(*id)?;
    println!("{}: {:?}", id, token);
}

Special Tokens

pub struct SpecialTokens {
    pub bos_token_id: u32,
    pub eos_token_id: u32,
    pub pad_token_id: u32,
    pub unk_token_id: u32,
}

let special = tokenizer.special_tokens();
println!("BOS: {}", special.bos_token_id);
println!("EOS: {}", special.eos_token_id);

GenerationConfig

Configuration for text generation.

#[derive(Debug, Clone)]
pub struct GenerationConfig {
    /// Maximum new tokens to generate
    pub max_new_tokens: usize,

    /// Sampling temperature (0.0 = greedy)
    pub temperature: f32,

    /// Nucleus sampling threshold
    pub top_p: f32,

    /// Top-k sampling (None = disabled)
    pub top_k: Option<usize>,

    /// Enable sampling vs greedy
    pub do_sample: bool,

    /// Repetition penalty
    pub repetition_penalty: f32,

    /// Stop sequences
    pub stop_sequences: Vec<String>,

    /// End of sequence token ID
    pub eos_token_id: u32,

    /// Padding token ID
    pub pad_token_id: u32,
}

Preset Configurations

impl GenerationConfig {
    /// Greedy decoding (deterministic)
    pub fn greedy() -> Self {
        Self {
            do_sample: false,
            temperature: 0.0,
            ..Default::default()
        }
    }

    /// Creative writing
    pub fn creative() -> Self {
        Self {
            do_sample: true,
            temperature: 1.0,
            top_p: 0.95,
            repetition_penalty: 1.2,
            ..Default::default()
        }
    }

    /// Balanced (general purpose)
    pub fn balanced() -> Self {
        Self {
            do_sample: true,
            temperature: 0.7,
            top_p: 0.9,
            top_k: Some(50),
            repetition_penalty: 1.1,
            ..Default::default()
        }
    }

    /// Code generation
    pub fn code() -> Self {
        Self {
            do_sample: true,
            temperature: 0.2,
            top_p: 0.95,
            max_new_tokens: 512,
            ..Default::default()
        }
    }
}

KV Cache

Key-value cache for efficient autoregressive generation.

pub struct KVCache {
    // Internal implementation
}

impl KVCache {
    /// Create new cache
    pub fn new(num_layers: usize, max_seq_len: usize, num_heads: usize, head_dim: usize) -> Self;

    /// Get cached key/value for a layer
    pub fn get(&self, layer: usize) -> Option<(&Tensor, &Tensor)>;

    /// Update cache with new key/value
    pub fn update(&mut self, layer: usize, key: Tensor, value: Tensor);

    /// Get current sequence length
    pub fn seq_len(&self) -> usize;

    /// Clear cache
    pub fn clear(&mut self);
}

Using KV Cache

let mut cache = KVCache::new(32, 2048, 32, 128);

// First token - no cache
let outputs = model.forward_with_cache(&inputs, &mut cache)?;

// Subsequent tokens - use cache
for _ in 0..max_tokens {
    let outputs = model.forward_with_cache(&next_input, &mut cache)?;
    // Cache automatically updated
}

Batch Inference

Running inference on multiple inputs.

// Create batched inputs
let batch_size = 4;
let prompts = vec![
    "Hello",
    "How are you",
    "What is AI",
    "Tell me a story",
];

// Tokenize all prompts
let batch_ids: Vec<Vec<u32>> = prompts.iter()
    .map(|p| tokenizer.encode(p))
    .collect::<Result<_>>()?;

// Pad to same length
let max_len = batch_ids.iter().map(|ids| ids.len()).max().unwrap();
let padded = pad_sequences(&batch_ids, max_len, tokenizer.special_tokens().pad_token_id);

// Create batched tensor
let input_tensor = Tensor::from_slice(&padded, &[batch_size, max_len])?;

// Run batched forward pass
let outputs = model.forward(&ModelInputs::text(input_tensor))?;

Error Handling

use anyhow::Result;

fn generate_with_fallback(pipeline: &InferencePipeline<impl Model>, prompt: &str) -> Result<String> {
    // Try with preferred settings
    let config = GenerationConfig::balanced();

    match pipeline.generate(prompt, &config) {
        Ok(response) => Ok(response),
        Err(e) => {
            eprintln!("Generation failed: {}", e);
            // Fallback to greedy
            pipeline.generate(prompt, &GenerationConfig::greedy())
        }
    }
}

Examples

Complete Generation Pipeline

use unillm::prelude::*;

fn main() -> Result<()> {
    // Load model
    let weights = WeightLoader::from_gguf("model.gguf")?;
    let config = LlamaConfig::from_gguf_metadata(weights.metadata())?;
    let model = LlamaModelV2::from_weights(config, weights)?;

    // Load tokenizer
    let tokenizer = Tokenizer::from_gguf("model.gguf")?;

    // Create pipeline
    let pipeline = InferencePipeline::new(model, tokenizer, Device::auto());

    // Generate
    let response = pipeline.generate(
        "Explain quantum computing in simple terms:",
        &GenerationConfig::balanced(),
    )?;

    println!("{}", response);
    Ok(())
}

Chat Interface

fn chat_loop(pipeline: &InferencePipeline<impl Model>) -> Result<()> {
    let mut history = String::new();

    loop {
        print!("User: ");
        let input = read_line()?;

        if input == "quit" {
            break;
        }

        history.push_str(&format!("User: {}\nAssistant: ", input));

        let response = pipeline.generate(
            &history,
            &GenerationConfig {
                max_new_tokens: 256,
                stop_sequences: vec!["User:".to_string()],
                ..GenerationConfig::balanced()
            },
        )?;

        println!("Assistant: {}", response);
        history.push_str(&format!("{}\n", response));
    }

    Ok(())
}