“”“Process the WikiText dataset for training the BERT model. Using Hugging Face
datasets library.
““”
Â
import time
import random
from typing import Iterator
Â
import tokenizers
from datasets import load_dataset, Dataset
Â
# path and name of each dataset
DATASETS = {
    “wikitext-2”: (“wikitext”, “wikitext-2-raw-v1”),
    “wikitext-103”: (“wikitext”, “wikitext-103-raw-v1”),
}
PATH, NAME = DATASETS[“wikitext-103”]
TOKENIZER_PATH = “wikitext-103_wordpiece.json”
Â
Â
def create_docs(path: str, name: str, tokenizer: tokenizers.Tokenizer) -> list[list[list[int]]]:
    “”“Load wikitext dataset and extract text as documents”“”
    dataset = load_dataset(path, name, split=“train”)
    docs: list[list[list[int]]] = []
    for line in dataset[“text”]:
        line = line.strip()
        if not line or line.startswith(“=”):
            docs.append([])  # new document encountered
        else:
            tokens = tokenizer.encode(line).ids
            docs[–1].append(tokens)
    docs = [doc for doc in docs if doc]  # remove empty documents
    return docs
Â
Â
def create_dataset(
    docs: list[list[list[int]]],
    tokenizer: tokenizers.Tokenizer,
    max_seq_length: int = 512,
    doc_repeat: int = 10,
    mask_prob: float = 0.15,
    short_seq_prob: float = 0.1,
    max_predictions_per_seq: int = 20,
) -> Iterator[dict]:
    “”“Generate samples from all documents”“”
    doc_indices = list(range(len(docs))) * doc_repeat
    for doc_idx in doc_indices:
        yield from generate_samples(doc_idx, docs, tokenizer, max_seq_length, mask_prob, short_seq_prob, max_predictions_per_seq)
Â
def generate_samples(
    doc_idx: int,
    all_docs: list[list[list[int]]],
    tokenizer: tokenizers.Tokenizer,
    max_seq_length: int = 512,
    mask_prob: float = 0.15,
    short_seq_prob: float = 0.1,
    max_predictions_per_seq: int = 20,
) -> Iterator[dict]:
    “”“Generate samples from a given document”“”
    # number of tokens to extract from this doc, excluding [CLS], [SEP], [SEP]
    target_length = max_seq_length – 3
    if random.random() < short_seq_prob:
        # shorter sequence is used 10% of the time
        target_length = random.randint(2, target_length)
Â
    # copy the document
    chunks = []
    for chunk in all_docs[doc_idx]:
        chunks.append(chunk)
Â
    # exhaust chunks and create samples
    while chunks:
        # scan until target token length
        running_length = 0
        end = 1
        while end < len(chunks) and running_length < target_length:
            running_length += len(chunks[end–1])
            end += 1
        # randomly separate the chunk into two segments
        sep = random.randint(1, end–1) if end > 1 else 1
        sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]
        sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]
        # sentence B: may be from another document
        if not sentence_b or random.random() < 0.5:
            # find another document (must not be the same as doc_idx)
            b_idx = random.randint(0, len(all_docs)–2)
            if b_idx >= doc_idx:
                b_idx += 1
            # sentence B starts from a random position in the new document
            sentence_b = []
            running_length = len(sentence_a)
            i = random.randint(0, len(all_docs[b_idx])–1)
            while i < len(all_docs[b_idx]) and running_length < target_length:
                sentence_b.extend(all_docs[b_idx][i])
                running_length += len(all_docs[b_idx][i])
                i += 1
            is_random_next = True
            chunks = chunks[sep:]
        else:
            is_random_next = False
            chunks = chunks[end:]
        # create a sample from the pair
        yield create_sample(sentence_a, sentence_b, is_random_next, tokenizer, max_seq_length, mask_prob, max_predictions_per_seq)
Â
def create_sample(
    sentence_a: list[list[int]],
    sentence_b: list[list[int]],
    is_random_next: bool,
    tokenizer: tokenizers.Tokenizer,
    max_seq_length: int = 512,
    mask_prob: float = 0.15,
    max_predictions_per_seq: int = 20,
) -> dict:
    “”“Create a sample from a pair of sentences”“”
    # Collect id of special tokens
    cls_id = tokenizer.token_to_id(“[CLS]”)
    sep_id = tokenizer.token_to_id(“[SEP]”)
    mask_id = tokenizer.token_to_id(“[MASK]”)
    pad_id = tokenizer.padding[“pad_id”]
    # adjust length to fit the max sequence length
    truncate_seq_pair(sentence_a, sentence_b, max_seq_length–3)
    num_pad = max_seq_length – len(sentence_a) – len(sentence_b) – 3
    # create unmodified tokens sequence
    tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)
    seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [–1] * num_pad
    assert len(tokens) == len(seg_id) == max_seq_length
    # create the prediction targets
    cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]
    random.shuffle(cand_indices)
    num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))
    num_predictions = min(max_predictions_per_seq, max(1, num_predictions))
    mlm_positions = sorted(cand_indices[:num_predictions])
    mlm_labels = []
    for i in mlm_positions:
        mlm_labels.append(tokens[i])
        # prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original
        if random.random() < 0.8:
            tokens[i] = mask_id
        elif random.random() < 0.5:
            tokens[i] = random.randint(4, tokenizer.get_vocab_size()–1)
    # randomly mask some tokens
    ret = {
        “tokens”: tokens,
        “segment_ids”: seg_id,
        “is_random_next”: is_random_next,
        “masked_positions”: mlm_positions,
        “masked_labels”: mlm_labels,
    }
    return ret
Â
Â
def truncate_seq_pair(sentence_a: list[int], sentence_b: list[int], max_num_tokens: int) -> None:
    “”“Truncate a pair of sequences until below a maximum sequence length.”“”
    while len(sentence_a) + len(sentence_b) > max_num_tokens:
        # pick the longer sentence to remove tokens from
        candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b
        # remove one token from either end in equal probabilities
        if random.random() < 0.5:
            candidate.pop(0)
        else:
            candidate.pop()
Â
Â
if __name__ == “__main__”:
    print(time.time(), “started”)
    tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)
    print(time.time(), “loaded tokenizer”)
    docs = create_docs(PATH, NAME, tokenizer)
    print(time.time(), “created docs with %d documents” % len(docs))
    dataset = Dataset.from_generator(create_dataset, gen_kwargs={“docs”: docs, “tokenizer”: tokenizer})
    print(time.time(), “created dataset from generator”)
    # Save dataset to parquet file
    dataset.to_parquet(“wikitext-103_train_data.parquet”)
    print(time.time(), “saved dataset to parquet file”)
    # Load dataset from parquet file
    dataset = Dataset.from_parquet(“wikitext-103_train_data.parquet”, streaming=True)
    print(time.time(), “loaded dataset from parquet file”)
    # Print a few samples
    for i, sample in enumerate(dataset):
        print(i)
        print(sample)
        print()
        if i >= 3:
            break
    print(time.time(), “finished”)

