Skip to content

KDProcessing

Dataset processing class for knowledge distillation training.

Parameters

  • queries (datasets.arrow_dataset.Dataset | datasets.dataset_dict.DatasetDict)

    Queries dataset.

  • documents (datasets.arrow_dataset.Dataset | datasets.dataset_dict.DatasetDict)

    Documents dataset.

  • split (str) – defaults to train

    Split to use for the queries and documents datasets. Used only if the queries and documents are of type datasets.DatasetDict.

  • n_ways (int) – defaults to 32

    Number of scores to keep for the distillation.

Examples

>>> from datasets import load_dataset
>>> from pylate import utils

>>> train = load_dataset(
...    path="lightonai/lighton-ms-marco-mini",
...    name="train",
...    split="train",
... )

>>> queries = load_dataset(
...    path="lightonai/lighton-ms-marco-mini",
...    name="queries",
...    split="train",
... )

>>> documents = load_dataset(
...    path="lightonai/lighton-ms-marco-mini",
...    name="documents",
...    split="train",
... )

>>> train.set_transform(
...    utils.KDProcessing(
...        queries=queries, documents=documents
...    ).transform,
... )

>>> for sample in train:
...     assert "documents" in sample and isinstance(sample["documents"], list)
...     assert "query" in sample and isinstance(sample["query"], str)
...     assert "scores" in sample and isinstance(sample["scores"], list)

Methods

map

Process a single example.

Parameters

  • example (dict)
transform

Update the input dataset with the queries and documents.

Parameters

  • examples (dict)