Fine-Tuning a BERT Model – MachineLearningMastery.com

import collections

import dataclasses

import functools

import torch

import torch.nn as nn

import torch.optim as optim

import tqdm

from datasets import load_dataset

from tokenizers import Tokenizer

from torch import Tensor

# BERT config and model defined previously

@dataclasses.dataclass

class BertConfig:

“”“Configuration for BERT model.”“”

vocab_size: int = 30522

num_layers: int = 12

hidden_size: int = 768

num_heads: int = 12

dropout_prob: float = 0.1

pad_id: int = 0

max_seq_len: int = 512

num_types: int = 2

class BertBlock(nn.Module):

“”“One transformer block in BERT.”“”

def __init__(self, hidden_size: int, num_heads: int, dropout_prob: float):

super().__init__()

self.attention = nn.MultiheadAttention(hidden_size, num_heads,

dropout=dropout_prob, batch_first=True)

self.attn_norm = nn.LayerNorm(hidden_size)

self.ff_norm = nn.LayerNorm(hidden_size)

self.dropout = nn.Dropout(dropout_prob)

self.feed_forward = nn.Sequential(

nn.Linear(hidden_size, 4 * hidden_size),

nn.GELU(),

nn.Linear(4 * hidden_size, hidden_size),

)

def forward(self, x: Tensor, pad_mask: Tensor) -> Tensor:

# self-attention with padding mask and post-norm

attn_output, _ = self.attention(x, x, x, key_padding_mask=pad_mask)

x = self.attn_norm(x + attn_output)

# feed-forward with GeLU activation and post-norm

ff_output = self.feed_forward(x)

x = self.ff_norm(x + self.dropout(ff_output))

return x

class BertPooler(nn.Module):

“”“Pooler layer for BERT to process the [CLS] token output.”“”

def __init__(self, hidden_size: int):

super().__init__()

self.dense = nn.Linear(hidden_size, hidden_size)

self.activation = nn.Tanh()

def forward(self, x: Tensor) -> Tensor:

x = self.dense(x)

x = self.activation(x)

return x

class BertModel(nn.Module):

“”“Backbone of BERT model.”“”

def __init__(self, config: BertConfig):

super().__init__()

# embedding layers

self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size,

padding_idx=config.pad_id)

self.type_embeddings = nn.Embedding(config.num_types, config.hidden_size)

self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)

self.embeddings_norm = nn.LayerNorm(config.hidden_size)

self.embeddings_dropout = nn.Dropout(config.dropout_prob)

# transformer blocks

self.blocks = nn.ModuleList([

BertBlock(config.hidden_size, config.num_heads, config.dropout_prob)

for _ in range(config.num_layers)

])

# [CLS] pooler layer

self.pooler = BertPooler(config.hidden_size)

def forward(self, input_ids: Tensor, token_type_ids: Tensor, pad_id: int = 0,

) -> tuple[Tensor, Tensor]:

# create attention mask for padding tokens

pad_mask = input_ids == pad_id

# convert integer tokens to embedding vectors

batch_size, seq_len = input_ids.shape

position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

position_embeddings = self.position_embeddings(position_ids)

type_embeddings = self.type_embeddings(token_type_ids)

token_embeddings = self.word_embeddings(input_ids)

x = token_embeddings + type_embeddings + position_embeddings

x = self.embeddings_norm(x)

x = self.embeddings_dropout(x)

# process the sequence with transformer blocks

for block in self.blocks:

x = block(x, pad_mask)

# pool the hidden state of the `[CLS]` token

pooled_output = self.pooler(x[:, 0, :])

return x, pooled_output

# Define new BERT model for question answering

class BertForQuestionAnswering(nn.Module):

“”“BERT model for SQuAD question answering.”“”

def __init__(self, config: BertConfig):

super().__init__()

self.bert = BertModel(config)

# Two outputs: start and end position logits

self.qa_outputs = nn.Linear(config.hidden_size, 2)

def forward(self,

input_ids: Tensor,

token_type_ids: Tensor,

pad_id: int = 0,

) -> tuple[Tensor, Tensor]:

# Get sequence output from BERT (batch_size, seq_len, hidden_size)

seq_output, pooled_output = self.bert(input_ids, token_type_ids, pad_id=pad_id)

# Project to start and end logits

logits = self.qa_outputs(seq_output)# (batch_size, seq_len, 2)

start_logits = logits[:, :, 0]# (batch_size, seq_len)

end_logits = logits[:, :, 1]# (batch_size, seq_len)

return start_logits, end_logits

# Load SQuAD dataset for question answering

dataset = load_dataset(“squad”)

# Load the pretrained BERT tokenizer

TOKENIZER_PATH = “wikitext-2_wordpiece.json”

tokenizer = Tokenizer.from_file(TOKENIZER_PATH)

# Setup collate function to tokenize question-context pairs for the model

def collate(batch: list[dict], tokenizer: Tokenizer, max_len: int,

) -> tuple[Tensor, Tensor, Tensor, Tensor]:

“”“Collate question-context pairs for the model.”“”

cls_id = tokenizer.token_to_id(“[CLS]”)

sep_id = tokenizer.token_to_id(“[SEP]”)

pad_id = tokenizer.token_to_id(“[PAD]”)

input_ids_list = []

token_type_ids_list = []

start_positions = []

end_positions = []

for item in batch:

# Tokenize question and context

question, context = item[“question”], item[“context”]

question_ids = tokenizer.encode(question).ids

context_ids = tokenizer.encode(context).ids

# Build input: [CLS] question [SEP] context [SEP]

input_ids = [cls_id, *question_ids, sep_id, *context_ids, sep_id]

token_type_ids = [0] * (len(question_ids)+2) + [1] * (len(context_ids)+1)

# Truncate or pad to max length

if len(input_ids) > max_len:

input_ids = input_ids[:max_len]

token_type_ids = token_type_ids[:max_len]

else:

input_ids.extend([pad_id] * (max_len len(input_ids)))

token_type_ids.extend([1] * (max_len len(token_type_ids)))

# Find answer position in tokens: Answer may not be in the context

start_pos = end_pos = 0

if len(item[“answers”][“text”]) > 0:

answers = tokenizer.encode(item[“answers”][“text”][0]).ids

# find the context offset of the answer in context_ids

for i in range(len(context_ids) len(answers) + 1):

if context_ids[i:i+len(answers)] == answers:

start_pos = i + len(question_ids) + 2

end_pos = start_pos + len(answers) 1

break

if end_pos >= max_len:

start_pos = end_pos = 0# answer is clipped, hence no answer

input_ids_list.append(input_ids)

token_type_ids_list.append(token_type_ids)

start_positions.append(start_pos)

end_positions.append(end_pos)

input_ids_list = torch.tensor(input_ids_list)

token_type_ids_list = torch.tensor(token_type_ids_list)

start_positions = torch.tensor(start_positions)

end_positions = torch.tensor(end_positions)

return (input_ids_list, token_type_ids_list, start_positions, end_positions)

batch_size = 16

max_len = 384# Longer for Q&A to accommodate context

collate_fn = functools.partial(collate, tokenizer=tokenizer, max_len=max_len)

train_loader = torch.utils.data.DataLoader(dataset[“train”], batch_size=batch_size,

shuffle=True, collate_fn=collate_fn)

val_loader = torch.utils.data.DataLoader(dataset[“validation”], batch_size=batch_size,

shuffle=False, collate_fn=collate_fn)

# Create Q&A model with a pretrained foundation BERT model

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

config = BertConfig()

model = BertForQuestionAnswering(config)

model.to(device)

model.bert.load_state_dict(torch.load(“bert_model.pth”, map_location=device))

# Training setup

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr=2e5)

num_epochs = 3

for epoch in range(num_epochs):

model.train()

# Training

with tqdm.tqdm(train_loader, desc=f“Epoch {epoch+1}/{num_epochs}”) as pbar:

for batch in pbar:

# get batched data

input_ids, token_type_ids, start_positions, end_positions = batch

input_ids = input_ids.to(device)

token_type_ids = token_type_ids.to(device)

start_positions = start_positions.to(device)

end_positions = end_positions.to(device)

# forward pass

start_logits, end_logits = model(input_ids, token_type_ids)

# backward pass

optimizer.zero_grad()

start_loss = loss_fn(start_logits, start_positions)

end_loss = loss_fn(end_logits, end_positions)

loss = start_loss + end_loss

loss.backward()

optimizer.step()

# update progress bar

pbar.set_postfix(loss=float(loss))

pbar.update(1)

# Validation: Keep track of the average loss and accuracy

model.eval()

val_loss, num_matches, num_batches, num_samples = 0, 0, 0, 0

with torch.no_grad():

for batch in val_loader:

# get batched data

input_ids, token_type_ids, start_positions, end_positions = batch

input_ids = input_ids.to(device)

token_type_ids = token_type_ids.to(device)

start_positions = start_positions.to(device)

end_positions = end_positions.to(device)

# forward pass on validation data

start_logits, end_logits = model(input_ids, token_type_ids)

# compute loss

start_loss = loss_fn(start_logits, start_positions)

end_loss = loss_fn(end_logits, end_positions)

loss = start_loss + end_loss

val_loss += loss.item()

num_batches += 1

# compute accuracy

pred_start = start_logits.argmax(dim=1)

pred_end = end_logits.argmax(dim=1)

match = (pred_start == start_positions) & (pred_end == end_positions)

num_matches += match.sum().item()

num_samples += len(start_positions)

avg_loss = val_loss / num_batches

acc = num_matches / num_samples

print(f“Validation {epoch+1}/{num_epochs}: acc {acc:.4f}, avg loss {avg_loss:.4f}”)

# Save the fine-tuned model

torch.save(model.state_dict(), f“bert_model_squad.pth”)

Written By

More From Author

You May Also Like