XTRScores¶
XTR contrastive scoring with global top-k token retrieval.
For each query token, the top-k matches are selected globally across all Q*N in-batch document tokens (simulating retrieval from an index). Returns the full (Q, Q*N) cross-product score matrix with query-major ordering: scores[i, j*N + k] is query i against query j's k-th document. The positive for query i sits at column i*N.
Parameters¶
-
k ('int') – defaults to
256Number of top token matches to retain per query token across all Q*N documents.
-
document_chunk_size ('int | None') – defaults to
NoneIf set, the matmul +
masked_fillphase is iterated overdocument_chunk_sizedocs at a time (out ofQ*Ntotal). The resulting chunks are concatenated before the global top-k, so scoring semantics are unchanged. Useful to trim the transient matmul peak at large effective batch sizes. DefaultNoneruns the full matmul in one shot.
Examples¶
>>> import torch
>>> queries_embeddings = torch.tensor([
... [[1., 0.], [0., 0.]],
... [[0., 1.], [0., 0.]],
... ])
>>> documents_embeddings = torch.tensor([
... [[[1., 0.], [0., 1.]], [[0., 1.], [1., 0.]]],
... [[[0., 1.], [1., 0.]], [[1., 0.], [0., 1.]]],
... ])
>>> scores = XTRScores(k=2)(
... queries_embeddings=queries_embeddings,
... documents_embeddings=documents_embeddings,
... )
>>> scores.shape
torch.Size([2, 4])
Methods¶
call
Call self as a function.
Parameters
- queries_embeddings ('list | np.ndarray | torch.Tensor')
- documents_embeddings ('list | np.ndarray | torch.Tensor')
- queries_mask ('torch.Tensor | None') – defaults to
None - documents_mask ('torch.Tensor | None') – defaults to
None
compile
Notes¶
Adapted from PrimeQA (Copyright 2026 IBM PrimeQA Authors, licensed under the Apache License, Version 2.0). Changes from the original implementation:
- Extricated the scoring function from the end-to-end modeling class that also handled contrastive loss computation.
- Fixed a bug in the original implementation where the alignment mask was not being applied correctly.