Pretrain a BERT Model from Scratch – MachineLearningMastery.com

import dataclasses

import datasets

import torch

import torch.nn as nn

import tqdm

@dataclasses.dataclass

class BertConfig:

“”“Configuration for BERT model.”“”

vocab_size: int = 30522

num_layers: int = 12

hidden_size: int = 768

num_heads: int = 12

dropout_prob: float = 0.1

pad_id: int = 0

max_seq_len: int = 512

num_types: int = 2

class BertBlock(nn.Module):

“”“One transformer block in BERT.”“”

def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):

super().__init__()

self.attention = nn.MultiheadAttention(hidden_size, num_heads,

dropout=dropout_prob, batch_first=True)

self.attn_norm = nn.LayerNorm(hidden_size)

self.ff_norm = nn.LayerNorm(hidden_size)

self.dropout = nn.Dropout(dropout_prob)

self.feed_forward = nn.Sequential(

nn.Linear(hidden_size, 4 * hidden_size),

nn.GELU(),

nn.Linear(4 * hidden_size, hidden_size),

)

def forward(self, x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:

# self-attention with padding mask and post-norm

attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

x = self.attn_norm(x + attn_output)

# feed-forward with GeLU activation and post-norm

ff_output = self.feed_forward(x)

x = self.ff_norm(x + self.dropout(ff_output))

return x

class BertPooler(nn.Module):

“”“Pooler layer for BERT to process the [CLS] token output.”“”

def __init__(self, hidden_size: int):

super().__init__()

self.dense = nn.Linear(hidden_size, hidden_size)

self.activation = nn.Tanh()

def forward(self, x: torch.Tensor) -> torch.Tensor:

x = self.dense(x)

x = self.activation(x)

return x

class BertModel(nn.Module):

“”“Backbone of BERT model.”“”

def __init__(self, config: BertConfig):

super().__init__()

# embedding layers

self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,

padding_idx=config.pad_id)

self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

self.embeddings_norm = nn.LayerNorm(config.hidden_size)

self.embeddings_dropout = nn.Dropout(config.dropout_prob)

# transformer blocks

self.blocks = nn.ModuleList([

BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

for _ in range(config.num_layers)

])

# [CLS] pooler layer

self.pooler = BertPooler(config.hidden_size)

def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

) -> tuple[torch.Tensor, torch.Tensor]:

# create attention mask for padding tokens

pad_mask = input_ids == pad_id

# convert integer tokens to embedding vectors

batch_size, seq_len = input_ids.shape

position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

position_embeddings = self.position_embeddings(position_ids)

type_embeddings = self.type_embeddings(token_type_ids)

token_embeddings = self.word_embeddings(input_ids)

x = token_embeddings + type_embeddings + position_embeddings

x = self.embeddings_norm(x)

x = self.embeddings_dropout(x)

# process the sequence with transformer blocks

for block in self.blocks:

x = block(x, pad_mask)

# pool the hidden state of the `[CLS]` token

pooled_output = self.pooler(x[:, 0, :])

return x, pooled_output

class BertPretrainingModel(nn.Module):

def __init__(self, config: BertConfig):

super().__init__()

self.bert = BertModel(config)

self.mlm_head = nn.Sequential(

nn.Linear(config.hidden_size, config.hidden_size),

nn.GELU(),

nn.LayerNorm(config.hidden_size),

nn.Linear(config.hidden_size, config.vocab_size),

)

self.nsp_head = nn.Linear(config.hidden_size, 2)

def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, pad_id: int = 0

) -> tuple[torch.Tensor, torch.Tensor]:

# Process the sequence with the BERT model backbone

x, pooled_output = self.bert(input_ids, token_type_ids, pad_id)

# Predict the masked tokens for the MLM task and the classification for the NSP task

mlm_logits = self.mlm_head(x)

nsp_logits = self.nsp_head(pooled_output)

return mlm_logits, nsp_logits

# Training parameters

epochs = 10

learning_rate = 1e4

batch_size = 32

# Load dataset and set up dataloader

dataset = datasets.Dataset.from_parquet(“wikitext-2_train_data.parquet”)

def collate_fn(batch: list[dict]):

“”“Custom collate function to handle variable-length sequences in dataset.”“”

# always at max length: tokens, segment_ids; always singleton: is_random_next

input_ids = torch.tensor([item[“tokens”] for item in batch])

token_type_ids = torch.tensor([item[“segment_ids”] for item in batch]).abs()

is_random_next = torch.tensor([item[“is_random_next”] for item in batch]).to(int)

# variable length: masked_positions, masked_labels

masked_pos = [(idx, pos) for idx, item in enumerate(batch) for pos in item[“masked_positions”]]

masked_labels = torch.tensor([label for item in batch for label in item[“masked_labels”]])

return input_ids, token_type_ids, is_random_next, masked_pos, masked_labels

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,

collate_fn=collate_fn, num_workers=8)

# train the model

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

model = BertPretrainingModel(BertConfig()).to(device)

model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

loss_fn = nn.CrossEntropyLoss()

for epoch in range(epochs):

pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”)

for batch in pbar:

# get batched data

input_ids, token_type_ids, is_random_next, masked_pos, masked_labels = batch

input_ids = input_ids.to(device)

token_type_ids = token_type_ids.to(device)

is_random_next = is_random_next.to(device)

masked_labels = masked_labels.to(device)

# extract output from model

mlm_logits, nsp_logits = model(input_ids, token_type_ids)

# MLM loss: masked_positions is a list of tuples of (B, S), extract the

# corresponding logits from tensor mlm_logits of shape (B, S, V)

batch_indices, token_positions = zip(*masked_pos)

mlm_logits = mlm_logits[batch_indices, token_positions]

mlm_loss = loss_fn(mlm_logits, masked_labels)

# Compute the loss for the NSP task

nsp_loss = loss_fn(nsp_logits, is_random_next)

# backward with total loss

total_loss = mlm_loss + nsp_loss

pbar.set_postfix(MLM=mlm_loss.item(), NSP=nsp_loss.item(), Total=total_loss.item())

optimizer.zero_grad()

total_loss.backward()

optimizer.step()

scheduler.step()

pbar.update(1)

pbar.close()

# Save the model

torch.save(model.state_dict(), “bert_pretraining_model.pth”)

torch.save(model.bert.state_dict(), “bert_model.pth”)