Skip to content

ColBERTCollator

Collator for ColBERT model.

Parameters

  • tokenize_fn ('Callable')

    The function to tokenize the input text.

  • valid_label_columns ('list[str] | None') – defaults to None

    The name of the columns that contain the labels: scores or labels.

  • router_mapping ('dict[str, str] | dict[str, dict[str, str]] | None') – defaults to None

    The mapping of the columns to the router.

  • prompts ('dict[str, str] | dict[str, dict[str, str]] | None') – defaults to None

    The prompts to use for the columns.

  • include_prompt_lengths ('bool') – defaults to False

    Whether to include the prompt lengths in the batch.

  • all_special_ids ('set[int] | None') – defaults to None

    The special ids to use for the tokenization.

  • _prompt_length_mapping ('dict[str, int] | None') – defaults to None

  • _warned_columns ('set[tuple[str]] | None') – defaults to None

Examples

>>> from pylate import models, utils

>>> model = models.ColBERT(
...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu"
... )

>>> collator = utils.ColBERTCollator(
...     tokenize_fn=model.tokenize,
... )

>>> features = [
...     {
...         "query": "fruits are healthy.",
...         "positive": "fruits are good for health.",
...         "negative": "fruits are bad for health.",
...         "label": [0.7, 0.3]
...     }
... ]

>>> features = collator(features=features)

>>> fields = [
...     "query_input_ids",
...     "positive_input_ids",
...     "negative_input_ids",
...     "query_attention_mask",
...     "positive_attention_mask",
...     "negative_attention_mask",
...     "query_token_type_ids",
...     "positive_token_type_ids",
...     "negative_token_type_ids",
... ]

>>> for field in fields:
...     assert field in features
...     assert isinstance(features[field], torch.Tensor)
...     assert features[field].ndim == 2

Methods

call

Collate a list of features into a batch.

Parameters

  • features ('list[dict]')
maybe_warn_about_column_order

Warn the user if the columns are likely not in the expected order.

Parameters

  • column_names ('list[str]')