XTR¶
XTR retriever.
Performs XTR (eXact Token Retrieval) scoring: documents are scored only from their initially retrieved tokens, and missing token scores are filled in via min imputation. Differs from :class:ColBERT, which does a full MaxSim rerank using cached document embeddings.
Parameters¶
-
index ('BaseIndex')
The index to use for retrieval.
Examples¶
>>> from pylate import indexes, models, retrieve
>>> model = models.ColBERT(
... model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
... device="cpu",
... )
>>> documents_ids = [f"document_id_{i}" for i in range(20)]
>>> documents = [f"This is the content of document {i}." for i in range(20)]
>>> documents_embeddings = model.encode(
... sentences=documents,
... batch_size=1,
... is_query=False,
... )
>>> index = indexes.ScaNN(
... override=True,
... index_name="xtr_scann",
... store_embeddings=False,
... )
>>> index = index.add_documents(
... documents_ids=documents_ids,
... documents_embeddings=documents_embeddings,
... batch_size=1,
... )
>>> retriever = retrieve.XTR(index=index)
>>> queries_embeddings = model.encode(
... ["fruits are healthy.", "fruits are good for health."],
... batch_size=1,
... is_query=True,
... )
>>> results = retriever.retrieve(
... queries_embeddings=queries_embeddings,
... k=2,
... device="cpu",
... )
>>> assert isinstance(results, list)
>>> assert len(results) == 2
>>> queries_embeddings = model.encode(
... "fruits are healthy.",
... batch_size=1,
... is_query=True,
... )
>>> results = retriever.retrieve(
... queries_embeddings=queries_embeddings,
... k=2,
... device="cpu",
... )
>>> assert isinstance(results, list)
>>> assert len(results) == 1
Methods¶
retrieve
Retrieve documents for a list of queries.
Parameters
- queries_embeddings ('list[list | np.ndarray | torch.Tensor]')
- k ('int') – defaults to
10 - k_token ('int | None') – defaults to
None - device ('str | None') – defaults to
None - batch_size ('int | None') – defaults to
None - subset ('list[list[str]] | list[str] | None') – defaults to
None