Skip to content

Datasets

PyLate is designed to be compatible with Hugging Face datasets, enabling seamless integration for tasks like knowledge distillation and contrastive model training. Below are examples of how to load and prepare datasets for these specific training objectives.

Contrastive Dataset

Contrastive training requires datasets that include a query, a positive document (relevant to the query), and a negative document (irrelevant to the query). This is the standard triplet format used by Sentence Transformers, making PyLate's contrastive training compatible with all existing triplet datasets.

Loading a pre-built contrastive dataset

You can directly download an existing contrastive dataset from Hugging Face's hub, such as the msmarco-bm25 triplet dataset.

from datasets import load_dataset

dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")

train_dataset, test_dataset = dataset.train_test_split(test_size=0.001)

Then we can shuffle the dataset:

train_dataset = train_dataset.shuffle(seed=42)

And select a subset of the dataset if needed:

train_dataset = train_dataset.select(range(10_000))

Creating a contrastive dataset from list

If you want to create a custom contrastive dataset, you can do so by manually specifying the query, positive, and negative samples.

from datasets import Dataset

dataset = [
    {
        "query": "example query 1",
        "positive": "example positive document 1",
        "negative": "example negative document 1",
    },
    {
        "query": "example query 2",
        "positive": "example positive document 2",
        "negative": "example negative document 2",
    },
    {
        "query": "example query 3",
        "positive": "example positive document 3",
        "negative": "example negative document 3",
    },
]

dataset = Dataset.from_list(mapping=dataset)

train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)

Loading a contrastive dataset from a local parquet file

To load a local dataset stored in a Parquet file:

from datasets import load_dataset

dataset = load_dataset(
    path="parquet", 
    data_files="dataset.parquet", 
    split="train"
)

train_dataset, test_dataset = dataset.train_test_split(test_size=0.001)

Knowledge distillation dataset

For fine-tuning a model using knowledge distillation loss, three distinct dataset files are required: train, queries, and documents.

Info

Each file contains unique and complementary information necessary for the distillation process:

  • train: Contains three columns: ['query_id', 'document_ids', 'scores']
    • query_id refers to the query identifier.
    • document_ids is a list of document IDs relevant to the query.
    • scores corresponds to the relevance scores between the query and each document.

Train

Example entry:

{
    "query_id": 54528,
    "document_ids": [
        6862419,
        335116,
        339186,
        7509316,
        7361291,
        7416534,
        5789936,
        5645247,
    ],
    "scores": [
        0.4546215673141326,
        0.6575686537173476,
        0.26825184192900203,
        0.5256195579370395,
        0.879939718687207,
        0.7894968184862693,
        0.6450100468854655,
        0.5823844608171467,
    ],
}
Warning

Ensure that the length of document_ids matches the length of scores.

Queries

  • queries: Contains two columns: ['query_id', 'text']

Example entry:

{"query_id": 749480, "text": "example query 1"}

Documents

  • documents: contains two columns: ['document_ids', 'text']

Example entry:

{
    "document_id": 136062,
    "text": "example document 1",
}

Loading a pre-built knowledge distillation dataset

You can directly download an existing knowledge distillation dataset from Hugging Face's hub, such as the English MS MARCO dataset with BGE M3 scores or the French version. Simply load the different files by giving the respective names to the load_dataset function:

from datasets import load_dataset

train = load_dataset(
    "lightonai/ms-marco-en-bge",
    "train",
    split="train",
)

queries = load_dataset(
    "lightonai/ms-marco-en-bge",
    "queries",
    split="train",
)

documents = load_dataset(
    "lightonai/ms-marco-en-bge",
    "documents",
    split="train",
)

Knowledge distillation dataset from list

You can also create custom datasets from list in Python. This example demonstrates how to build the train, queries, and documents datasets

from datasets import Dataset

dataset = [
    {
        "query_id": 54528,
        "document_ids": [
            6862419,
            335116,
            339186,
            7509316,
            7361291,
            7416534,
            5789936,
            5645247,
        ],
        "scores": [
            0.4546215673141326,
            0.6575686537173476,
            0.26825184192900203,
            0.5256195579370395,
            0.879939718687207,
            0.7894968184862693,
            0.6450100468854655,
            0.5823844608171467,
        ],
    },
    {
        "query_id": 749480,
        "document_ids": [
            6862419,
            335116,
            339186,
            7509316,
            7361291,
            7416534,
            5789936,
            5645247,
        ],
        "scores": [
            0.2546215673141326,
            0.7575686537173476,
            0.96825184192900203,
            0.0256195579370395,
            0.779939718687207,
            0.2894968184862693,
            0.1450100468854655,
            0.7823844608171467,
        ],
    },
]


dataset = Dataset.from_list(mapping=dataset)

documents = [
    {"document_id": 6862419, "text": "example document 1"},
    {"document_id": 335116, "text": "example document 2"},
    {"document_id": 339186, "text": "example document 3"},
    {"document_id": 7509316, "text": "example document 4"},
    {"document_id": 7361291, "text": "example document 5"},
    {"document_id": 7416534, "text": "example document 6"},
    {"document_id": 5789936, "text": "example document 7"},
    {"document_id": 5645247, "text": "example document 8"},
]

queries = [
    {"query_id": 749480, "text": "example query 1"},
    {"query_id": 54528, "text": "example query 2"},
]

documents = Dataset.from_list(mapping=documents)

queries = Dataset.from_list(mapping=queries)