Skip to content

XTRScores

XTR contrastive scoring with global top-k token retrieval.

For each query token, the top-k matches are selected globally across all Q*N in-batch document tokens (simulating retrieval from an index). Returns the full (Q, Q*N) cross-product score matrix with query-major ordering: scores[i, j*N + k] is query i against query j's k-th document. The positive for query i sits at column i*N.

Parameters

  • k ('int') – defaults to 256

    Number of top token matches to retain per query token across all Q*N documents.

  • document_chunk_size ('int | None') – defaults to None

    If set, the matmul + masked_fill phase is iterated over document_chunk_size docs at a time (out of Q*N total). The resulting chunks are concatenated before the global top-k, so scoring semantics are unchanged. Useful to trim the transient matmul peak at large effective batch sizes. Default None runs the full matmul in one shot.

Examples

>>> import torch

>>> queries_embeddings = torch.tensor([
...     [[1., 0.], [0., 0.]],
...     [[0., 1.], [0., 0.]],
... ])

>>> documents_embeddings = torch.tensor([
...     [[[1., 0.], [0., 1.]], [[0., 1.], [1., 0.]]],
...     [[[0., 1.], [1., 0.]], [[1., 0.], [0., 1.]]],
... ])

>>> scores = XTRScores(k=2)(
...     queries_embeddings=queries_embeddings,
...     documents_embeddings=documents_embeddings,
... )
>>> scores.shape
torch.Size([2, 4])

Methods

call

Call self as a function.

Parameters

  • queries_embeddings ('list | np.ndarray | torch.Tensor')
  • documents_embeddings ('list | np.ndarray | torch.Tensor')
  • queries_mask ('torch.Tensor | None') – defaults to None
  • documents_mask ('torch.Tensor | None') – defaults to None
compile

Notes

Adapted from PrimeQA (Copyright 2026 IBM PrimeQA Authors, licensed under the Apache License, Version 2.0). Changes from the original implementation:

  • Extricated the scoring function from the end-to-end modeling class that also handled contrastive loss computation.
  • Fixed a bug in the original implementation where the alignment mask was not being applied correctly.

References