ColBERTCollator¶
Collator for ColBERT model.
Parameters¶
-
tokenize_fn ('Callable')
The function to tokenize the input text.
-
valid_label_columns ('list[str] | None') – defaults to
NoneThe name of the columns that contain the labels: scores or labels.
-
router_mapping ('dict[str, str] | dict[str, dict[str, str]] | None') – defaults to
NoneThe mapping of the columns to the router.
-
prompts ('dict[str, str] | dict[str, dict[str, str]] | None') – defaults to
NoneThe prompts to use for the columns.
-
include_prompt_lengths ('bool') – defaults to
FalseWhether to include the prompt lengths in the batch.
-
all_special_ids ('set[int] | None') – defaults to
NoneThe special ids to use for the tokenization.
-
_prompt_length_mapping ('dict[str, int] | None') – defaults to
None -
_warned_columns ('set[tuple[str]] | None') – defaults to
None
Examples¶
>>> from pylate import models, utils
>>> model = models.ColBERT(
... model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu"
... )
>>> collator = utils.ColBERTCollator(
... tokenize_fn=model.tokenize,
... )
>>> features = [
... {
... "query": "fruits are healthy.",
... "positive": "fruits are good for health.",
... "negative": "fruits are bad for health.",
... "label": [0.7, 0.3]
... }
... ]
>>> features = collator(features=features)
>>> fields = [
... "query_input_ids",
... "positive_input_ids",
... "negative_input_ids",
... "query_attention_mask",
... "positive_attention_mask",
... "negative_attention_mask",
... "query_token_type_ids",
... "positive_token_type_ids",
... "negative_token_type_ids",
... ]
>>> for field in fields:
... assert field in features
... assert isinstance(features[field], torch.Tensor)
... assert features[field].ndim == 2
Methods¶
call
Collate a list of features into a batch.
Parameters
- features ('list[dict]')
maybe_warn_about_column_order
Warn the user if the columns are likely not in the expected order.
Parameters
- column_names ('list[str]')