NextPlaid-ONNX

ColBERT multi-vector encoding with ONNX Runtime.
A Python CLI for exporting HuggingFace models to ONNX, and a Rust crate for fast inference.

Rust Crate · Python Export · Models


Rust Crate

High-performance ColBERT inference with ONNX Runtime. Supports CPU, CUDA, TensorRT, CoreML, and DirectML.

Quick Start

Add to Cargo.toml:

[dependencies]
next-plaid-onnx = "0.2"

Encode documents and queries:

use next_plaid_onnx::Colbert;

let model = Colbert::new("models/GTE-ModernColBERT-v1")?;

// Documents → Vec<Array2<f32>> with shape [num_tokens, embedding_dim]
let doc_embeddings = model.encode_documents(
    &["Paris is the capital of France."],
    None,  // No pooling
)?;

// Queries → Vec<Array2<f32>> with shape [query_length, embedding_dim]
let query_embeddings = model.encode_queries(
    &["What is the capital of France?"],
)?;

Builder Pattern

use next_plaid_onnx::{Colbert, ExecutionProvider};

let model = Colbert::builder("models/GTE-ModernColBERT-v1")
    .with_quantized(true)                              // Use model_int8.onnx (~2x speedup)
    .with_parallel(25)                                 // 25 ONNX sessions
    .with_batch_size(2)                                // Documents per inference call
    .with_execution_provider(ExecutionProvider::Cuda)   // Force specific backend
    .with_query_length(32)                             // Override config max query tokens
    .with_document_length(512)                         // Override config max document tokens
    .build()?;

Hardware Acceleration

# NVIDIA CUDA
next-plaid-onnx = { version = "0.2", features = ["cuda"] }

# NVIDIA TensorRT (optimized CUDA)
next-plaid-onnx = { version = "0.2", features = ["tensorrt"] }

# Apple Silicon / CoreML
next-plaid-onnx = { version = "0.2", features = ["coreml"] }

# Windows DirectML (DirectX 12)
next-plaid-onnx = { version = "0.2", features = ["directml"] }

ExecutionProvider::Auto tries providers in order: CUDA → TensorRT → CoreML → DirectML → CPU. Set NEXT_PLAID_FORCE_CPU=1 to bypass all GPU providers.

Token Pooling

Reduce token count with hierarchical clustering (Ward's method):

// Keep ~50% of tokens
let pooled = model.encode_documents(&docs, Some(2))?;

// Keep ~33% of tokens
let pooled = model.encode_documents(&docs, Some(3))?;

The CLS token is always preserved. Clusters are replaced by their centroid embedding.

Performance Tuning

ModeConfigBest for
Single-session CPU.with_threads(8).with_batch_size(32)Simple workloads
Parallel CPU.with_quantized(true).with_parallel(25).with_batch_size(2)High throughput
GPU.with_execution_provider(ExecutionProvider::Cuda).with_batch_size(64)Large batches

API Reference

Colbert

impl Colbert {
    pub fn new<P: AsRef<Path>>(model_dir: P) -> Result<Self>;
    pub fn builder<P: AsRef<Path>>(model_dir: P) -> ColbertBuilder;

    pub fn encode_documents(
        &self, documents: &[&str], pool_factor: Option<usize>,
    ) -> Result<Vec<Array2<f32>>>;

    pub fn encode_queries(&self, queries: &[&str]) -> Result<Vec<Array2<f32>>>;

    pub fn config(&self) -> &ColbertConfig;
    pub fn embedding_dim(&self) -> usize;
    pub fn batch_size(&self) -> usize;
    pub fn num_sessions(&self) -> usize;
}

ColbertBuilder

impl ColbertBuilder {
    pub fn with_parallel(self, num_sessions: usize) -> Self;
    pub fn with_threads(self, num_threads: usize) -> Self;
    pub fn with_batch_size(self, batch_size: usize) -> Self;
    pub fn with_execution_provider(self, provider: ExecutionProvider) -> Self;
    pub fn with_quantized(self, quantized: bool) -> Self;
    pub fn with_query_length(self, query_length: usize) -> Self;
    pub fn with_document_length(self, document_length: usize) -> Self;
    pub fn build(self) -> Result<Colbert>;
}

ExecutionProvider

pub enum ExecutionProvider {
    Auto,      // Best available (default)
    Cpu,       // CPU only
    Cuda,      // NVIDIA CUDA (requires `cuda` feature)
    TensorRT,  // NVIDIA TensorRT (requires `tensorrt` feature)
    CoreML,    // Apple Silicon (requires `coreml` feature)
    DirectML,  // Windows DirectX 12 (requires `directml` feature)
}

ColbertConfig

Loaded automatically from onnx_config.json in the model directory.

pub struct ColbertConfig {
    pub query_prefix: String,           // "[Q] "
    pub document_prefix: String,        // "[D] "
    pub query_length: usize,            // 32
    pub document_length: usize,         // 180
    pub do_query_expansion: bool,       // true
    pub embedding_dim: usize,           // 128
    pub uses_token_type_ids: bool,      // false for ModernBERT
    pub mask_token_id: u32,
    pub pad_token_id: u32,
    pub skiplist_words: Vec<String>,    // Punctuation to filter from docs
    pub do_lower_case: bool,
}

Hierarchical Clustering Module

Scipy-compatible hierarchical clustering, used internally for token pooling:

use next_plaid_onnx::hierarchy::{pdist_cosine, linkage, fcluster, LinkageMethod, FclusterCriterion};

let distances = pdist_cosine(&embeddings_flat, n_samples, n_features);
let linkage_matrix = linkage(&distances, n_samples, LinkageMethod::Ward);
let labels = fcluster(&linkage_matrix, n_samples, FclusterCriterion::MaxClust, k as f64);

Supported methods: Ward, Single, Complete, Average, Weighted.


How Encoding Works

flowchart LR
    subgraph Document
        D1["[D] Paris is the capital of France."]
        D1 --> D2["Tokenize + truncate"]
        D2 --> D3["ONNX inference"]
        D3 --> D4["Filter skiplist + padding"]
        D4 --> D5["[num_tokens, 128]"]
    end

    subgraph Query
        Q1["[Q] What is the capital?"]
        Q1 --> Q2["Tokenize + MASK expansion"]
        Q2 --> Q3["ONNX inference"]
        Q3 --> Q5["[query_length, 128]"]
    end

    style D1 fill:#4a90d9,stroke:#357abd,color:#fff
    style D2 fill:#50b86c,stroke:#3d9956,color:#fff
    style D3 fill:#e8913a,stroke:#d07a2e,color:#fff
    style D4 fill:#50b86c,stroke:#3d9956,color:#fff
    style D5 fill:#9b59b6,stroke:#8445a0,color:#fff
    style Q1 fill:#4a90d9,stroke:#357abd,color:#fff
    style Q2 fill:#50b86c,stroke:#3d9956,color:#fff
    style Q3 fill:#e8913a,stroke:#d07a2e,color:#fff
    style Q5 fill:#9b59b6,stroke:#8445a0,color:#fff

Documents: Text is prefixed with [D], tokenized, run through the ONNX model, then skiplist tokens (punctuation) and padding are removed. Output shape varies per document.

Queries: Text is prefixed with [Q], tokenized, then padded with [MASK] tokens to query_length. MASK tokens enable query augmentation &mdash; the model learns useful expansions during training. Output shape is fixed.


Python Export

A CLI tool for exporting HuggingFace ColBERT models to ONNX format. Creates both FP32 and INT8 quantized versions by default.

Install

pip install pylate-onnx-export

Requires Python 3.10-3.12.

Export a model

# Export (creates model.onnx + model_int8.onnx)
pylate-onnx-export lightonai/GTE-ModernColBERT-v1

# Export to specific directory
pylate-onnx-export lightonai/GTE-ModernColBERT-v1 -o ./my-models

# Export + push to HuggingFace Hub
pylate-onnx-export lightonai/GTE-ModernColBERT-v1 -o ./my-models --push-to-hub myorg/my-onnx-model

# FP32 only
pylate-onnx-export lightonai/GTE-ModernColBERT-v1 --no-quantize

# Quantize existing model
colbert-quantize ./models/GTE-ModernColBERT-v1

Python API

from colbert_export import export_model, quantize_model, push_to_hub

# Export
output_dir = export_model("lightonai/GTE-ModernColBERT-v1", quantize=True)

# Quantize existing
quantized_path = quantize_model("./models/GTE-ModernColBERT-v1")

# Push to Hub
push_to_hub("./models/GTE-ModernColBERT-v1", "myorg/my-onnx-model")

Output structure

models/<model-name>/
├── model.onnx           # FP32 ONNX model
├── model_int8.onnx      # INT8 quantized (~4x smaller, ~2x faster)
├── tokenizer.json       # HuggingFace fast tokenizer
└── onnx_config.json     # Model config for inference

Export pipeline

HuggingFace model
       ↓
PyLate ColBERT (adds [Q]/[D] tokens, extends embeddings)
       ↓
ColBERTForONNX (transformer + linear projection + L2 norm)
       ↓
torch.onnx.export (opset 14, dynamic batch/sequence axes)
       ↓
INT8 dynamic quantization (default)

CLI reference

pylate-onnx-export [OPTIONS] MODEL_NAME

Arguments:
  MODEL_NAME              HuggingFace model name or local path

Options:
  -o, --output-dir DIR    Output directory (default: ./models/<model-name>)
  --no-quantize           Skip INT8 quantization
  -f, --force             Force re-export even if exists
  --push-to-hub REPO_ID   Push to HuggingFace Hub
  --private               Make Hub repository private
  --quiet                 Suppress progress messages

Models

Any PyLate-compatible ColBERT model works when exported to ONNX. These are ready to use:

ModelUse case
lightonai/LateOn-Code-edgeCode search, lightweight
lightonai/LateOn-CodeCode search, accurate
lightonai/answerai-colbert-small-v1-onnxText retrieval, lightweight
lightonai/GTE-ModernColBERT-v1Text retrieval, accurate

License

Apache-2.0