Skip to content

XTR

XTR retriever.

Performs XTR (eXact Token Retrieval) scoring: documents are scored only from their initially retrieved tokens, and missing token scores are filled in via min imputation. Differs from :class:ColBERT, which does a full MaxSim rerank using cached document embeddings.

Parameters

  • index ('BaseIndex')

    The index to use for retrieval.

Examples

>>> from pylate import indexes, models, retrieve

>>> model = models.ColBERT(
...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
...     device="cpu",
... )

>>> documents_ids = [f"document_id_{i}" for i in range(20)]
>>> documents = [f"This is the content of document {i}." for i in range(20)]

>>> documents_embeddings = model.encode(
...     sentences=documents,
...     batch_size=1,
...     is_query=False,
... )

>>> index = indexes.ScaNN(
...     override=True,
...     index_name="xtr_scann",
...     store_embeddings=False,
... )

>>> index = index.add_documents(
...     documents_ids=documents_ids,
...     documents_embeddings=documents_embeddings,
...     batch_size=1,
... )

>>> retriever = retrieve.XTR(index=index)

>>> queries_embeddings = model.encode(
...     ["fruits are healthy.", "fruits are good for health."],
...     batch_size=1,
...     is_query=True,
... )

>>> results = retriever.retrieve(
...     queries_embeddings=queries_embeddings,
...     k=2,
...     device="cpu",
... )

>>> assert isinstance(results, list)
>>> assert len(results) == 2

>>> queries_embeddings = model.encode(
...     "fruits are healthy.",
...     batch_size=1,
...     is_query=True,
... )

>>> results = retriever.retrieve(
...     queries_embeddings=queries_embeddings,
...     k=2,
...     device="cpu",
... )

>>> assert isinstance(results, list)
>>> assert len(results) == 1

Methods

retrieve

Retrieve documents for a list of queries.

Parameters

  • queries_embeddings ('list[list | np.ndarray | torch.Tensor]')
  • k ('int') – defaults to 10
  • k_token ('int | None') – defaults to None
  • device ('str | None') – defaults to None
  • batch_size ('int | None') – defaults to None
  • subset ('list[list[str]] | list[str] | None') – defaults to None