rerank¶
Rerank the documents based on the queries embeddings.
Parameters¶
-
documents_ids (list[list[int | str]])
The documents ids.
-
queries_embeddings (list[list[float | int] | numpy.ndarray | torch.Tensor])
The queries embeddings which is a dictionary of queries and their embeddings.
-
documents_embeddings (list[list[float | int] | numpy.ndarray | torch.Tensor])
The documents embeddings which is a dictionary of documents ids and their embeddings.
-
device (str) – defaults to
None
Examples¶
>>> from pylate import models, rank
>>> model = models.ColBERT(
... model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu"
... )
>>> queries = [
... "query A",
... "query B",
... ]
>>> documents = [
... ["document A", "document B"],
... ["document 1", "document C", "document B"],
... ]
>>> documents_ids = [
... [1, 2],
... [1, 3, 2],
... ]
>>> queries_embeddings = model.encode(
... queries,
... is_query=True,
... batch_size=1,
... )
>>> documents_embeddings = model.encode(
... documents,
... is_query=False,
... batch_size=1,
... )
>>> reranked_documents = rank.rerank(
... documents_ids=documents_ids,
... queries_embeddings=queries_embeddings,
... documents_embeddings=documents_embeddings,
... )
>>> assert isinstance(reranked_documents, list)
>>> assert len(reranked_documents) == 2
>>> assert len(reranked_documents[0]) == 2
>>> assert len(reranked_documents[1]) == 3
>>> assert isinstance(reranked_documents[0], list)
>>> assert isinstance(reranked_documents[0][0], dict)
>>> assert "id" in reranked_documents[0][0]
>>> assert "score" in reranked_documents[0][0]