ColBERT¶
ColBERT retriever.
Parameters¶
- index ('Voyager | PLAID')
Examples¶
>>> from pylate import indexes, models, retrieve
>>> model = models.ColBERT(
... model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
... device="cpu",
... )
>>> documents_ids = [f"document_id_{i}" for i in range(20)]
>>> documents = [f"This is the content of document {i}." for i in range(20)]
>>> documents_embeddings = model.encode(
... sentences=documents,
... batch_size=1,
... is_query=False,
... )
>>> index = indexes.PLAID(
... index_folder="test_indexes",
... index_name="colbert",
... override=True,
... )
>>> index = index.add_documents(
... documents_ids=documents_ids,
... documents_embeddings=documents_embeddings,
... )
Computing centroids of embeddings.
Creating FastPlaid index.
>>> retriever = retrieve.ColBERT(index=index)
>>> queries_embeddings = model.encode(
... ["fruits are healthy.", "fruits are good for health."],
... batch_size=1,
... is_query=True,
... )
>>> results = retriever.retrieve(
... queries_embeddings=queries_embeddings,
... k=2,
... device="cpu",
... )
>>> assert isinstance(results, list)
>>> assert len(results) == 2
>>> queries_embeddings = model.encode(
... "fruits are healthy.",
... batch_size=1,
... is_query=True,
... )
>>> results = retriever.retrieve(
... queries_embeddings=queries_embeddings,
... k=2,
... device="cpu",
... )
>>> assert isinstance(results, list)
>>> assert len(results) == 1
>>> results = retriever.retrieve(
... queries_embeddings=queries_embeddings,
... k=2,
... device="cpu",
... subset=["document_id_10"],
... )
Methods¶
retrieve
Retrieve documents for a list of queries.
Parameters
- queries_embeddings ('list[list | np.ndarray | torch.Tensor]')
- k ('int') – defaults to
10 - k_token ('int') – defaults to
100 - device ('str | None') – defaults to
None - batch_size ('int') – defaults to
50 - subset ('list[list[str]] | list[str] | None') – defaults to
None