import dataclasses
Â
import datasets
import torch
import torch.nn as nn
import tqdm
Â
Â
@dataclasses.dataclass
class BertConfig:
    “”“Configuration for BERT model.”“”
    vocab_size: int = 30522
    num_layers: int = 12
    hidden_size: int = 768
    num_heads: int = 12
    dropout_prob: float = 0.1
    pad_id: int = 0
    max_seq_len: int = 512
    num_types: int = 2
Â
Â
Â
class BertBlock(nn.Module):
    “”“One transformer block in BERT.”“”
    def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_size, num_heads,
                                              dropout=dropout_prob, batch_first=True)
        self.attn_norm = nn.LayerNorm(hidden_size)
        self.ff_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_prob)
        self.feed_forward = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.GELU(),
            nn.Linear(4 * hidden_size, hidden_size),
        )
Â
    def forward(self, x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:
        # self-attention with padding mask and post-norm
        attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)
        x = self.attn_norm(x + attn_output)
        # feed-forward with GeLU activation and post-norm
        ff_output = self.feed_forward(x)
        x = self.ff_norm(x + self.dropout(ff_output))
        return x
Â
Â
class BertPooler(nn.Module):
    “”“Pooler layer for BERT to process the [CLS] token output.”“”
    def __init__(self, hidden_size: int):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()
Â
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.dense(x)
        x = self.activation(x)
        return x
Â
Â
class BertModel(nn.Module):
    “”“Backbone of BERT model.”“”
    def __init__(self, config: BertConfig):
        super().__init__()
        # embedding layers
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,
                                            padding_idx=config.pad_id)
        self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)
        self.embeddings_norm = nn.LayerNorm(config.hidden_size)
        self.embeddings_dropout = nn.Dropout(config.dropout_prob)
        # transformer blocks
        self.blocks = nn.ModuleList([
            BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)
            for _ in range(config.num_layers)
        ])
        # [CLS] pooler layer
        self.pooler = BertPooler(config.hidden_size)
Â
    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0
                ) -> tuple[torch.Tensor, torch.Tensor]:
        # create attention mask for padding tokens
        pad_mask = input_ids == pad_id
        # convert integer tokens to embedding vectors
        batch_size, seq_len = input_ids.shape
        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        position_embeddings = self.position_embeddings(position_ids)
        type_embeddings = self.type_embeddings(token_type_ids)
        token_embeddings = self.word_embeddings(input_ids)
        x = token_embeddings + type_embeddings + position_embeddings
        x = self.embeddings_norm(x)
        x = self.embeddings_dropout(x)
        # process the sequence with transformer blocks
        for block in self.blocks:
            x = block(x, pad_mask)
        # pool the hidden state of the `[CLS]` token
        pooled_output = self.pooler(x[:, 0, :])
        return x, pooled_output
Â
Â
class BertPretrainingModel(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        self.bert = BertModel(config)
        self.mlm_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.LayerNorm(config.hidden_size),
            nn.Linear(config.hidden_size, config.vocab_size),
        )
        self.nsp_head = nn.Linear(config.hidden_size, 2)
Â
    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0
                ) -> tuple[torch.Tensor, torch.Tensor]:
        # Process the sequence with the BERT model backbone
        x, pooled_output = self.bert(input_ids, token_type_ids, pad_id)
        # Predict the masked tokens for the MLM task and the classification for the NSP task
        mlm_logits = self.mlm_head(x)
        nsp_logits = self.nsp_head(pooled_output)
        return mlm_logits, nsp_logits
Â
Â
# Training parameters
epochs = 10
learning_rate = 1e–4
batch_size = 32
Â
# Load dataset and set up dataloader
dataset = datasets.Dataset.from_parquet(“wikitext-2_train_data.parquet”)
Â
def collate_fn(batch: list[dict]):
    “”“Custom collate function to handle variable-length sequences in dataset.”“”
    # always at max length: tokens, segment_ids; always singleton: is_random_next
    input_ids = torch.tensor([item[“tokens”] for item in batch])
    token_type_ids = torch.tensor([item[“segment_ids”] for item in batch]).abs()
    is_random_next = torch.tensor([item[“is_random_next”] for item in batch]).to(int)
    # variable length: masked_positions, masked_labels
    masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item[“masked_positions”]]
    masked_labels = torch.tensor([label for item in batch for label in item[“masked_labels”]])
    return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels
Â
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
                                        collate_fn=collate_fn, num_workers=8)
Â
# train the model
Â
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
model = BertPretrainingModel(BertConfig()).to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
loss_fn = nn.CrossEntropyLoss()
Â
for epoch in range(epochs):
    pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”)
    for batch in pbar:
        # get batched data
        input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch
        input_ids = input_ids.to(device)
        token_type_ids = token_type_ids.to(device)
        is_random_next = is_random_next.to(device)
        masked_labels = masked_labels.to(device)
        # extract output from model
        mlm_logits, nsp_logits = model(input_ids, token_type_ids)
        # MLM loss: masked_positions is a list of tuples of (B, S), extract the
        # corresponding logits from tensor mlm_logits of shape (B, S, V)
        batch_indices, token_positions = zip(*masked_pos)
        mlm_logits = mlm_logits[batch_indices, token_positions]
        mlm_loss = loss_fn(mlm_logits, masked_labels)
        # Compute the loss for the NSP task
        nsp_loss = loss_fn(nsp_logits, is_random_next)
        # backward with total loss
        total_loss = mlm_loss + nsp_loss
        pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item())
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        pbar.update(1)
    pbar.close()
Â
# Save the model
torch.save(model.state_dict(), “bert_pretraining_model.pth”)
torch.save(model.bert.state_dict(), “bert_model.pth”)

