A Hands-on Learning Journey
note: this post should’ve been posted a year ago but I was trying to finish my undergrad thesis 🙂
Background
As a late undergraduate student who is interested in bioinformatics in general, and computer aided drug discovery in particular, I’ve been naturally drawn to explore the application of machine learning (ML) in these fields. Despite the current focus on Large Language Models (LLMs), I believe that domain-specific applications offer a more promising field to explore.
Having no prior background in ML/AI, I initially struggled with the traditional approach of learning all prerequisites before delving into a more advanced subject, especially for such fast-paced and dynamic field of study. This approach have been ineffective and inefficient for me in general, that’s why I’ve adopted a more hands-on learning approach where I could interact with real-world objects and processes to dialectically learn theory and practice, iteratively.
So, I decided to do some projects that aligns with the above goals — but I didn’t want to just rebuild what existed already like training a LSTM for generating SMILES (e.g LSTM_Chem), nor pre-training a SMILES-based model where a much better model already exist like ChemBERTa.
Three weeks ago, I had an idea to train a sentence transformer based on chemical “language” which so far I looked up back then, had not yet existed. While trying to do so, I found this wonderful and human-readable new language called SELFIES — developed by Aspuru-Guzik group. I found this language fascinating and worth to explore, due to its robustness and at least so far proven to be versatile and easier to train a model using it. For more information on SELFIES, you could read this blogpost or check out their github.
My initial attempt focused on training a sentence transformer based on SELFIES, with the goal of enabling rapid molecule similarity search and clustering. This approach potentially offers advantages over traditional fingerprinting algorithms like MACCS, as the embeddings are context-aware. I decided to fine-tune a lightweight NLP-trained miniLM model by Nils Reimers, as I was unsure about training from scratch and didn’t even know about pre-training at that time.
The next challenges how to properly make molecule pairs that is diverse yet informative, and how to label them. After tackling those, I trained the model on a dataset built from natural compounds taken from COCONUTDB. After some initial training, I pushed the model to Hugging Face to get some feedback. Gladly, Tom Aarsen provided valuable suggestions, including training a custom tokenizer, exploring Matryoshka embeddings, and considering training from scratch. The attempt to implement Tom’s suggestions and sharing my experiences doing so, is the main content of this article.
Lastly before going into the details, it’s important to note that this is a hands-on learning project, and as such — beside my insufficient knowledge — it may not meet rigorous scientific standards. Like any learning journey, it’s messy and I myself constrained by financial, computational, and time limitations. I’ve had to make compromises, such as conducting incomplete experiments and chunking datasets. However, I’m eager to receive any feedback, so that I can improve both myself and future models/projects. Feel free to do so in the comment section or contact me personally
Why pretrain a base model?
In my initial experiments, I attempted to fine-tune the miniLM-L6 model using various optimizers and schedulers. The results showed that the default AdamW with a warm-up ratio of 0.1 and Ranger21 with no warm-down performed best in terms of losses and balance between angular, magnitude, and dot product-based metrics.
Now, when I tried to randomly init the nreimers’ miniLM (L6, A12, H384, 4H) its baseline randomized metrics for Matryoshka embedding loss (dims: latex [384, 192, 96]) on sample natural product eval set are:
Then I tried to train this initialized model on that sample train set on both Ranger21 using MADGRAD and default AdamW with warm up of 0.1. I was using batch size of 64, with lr = 4e-5 on Ranger21 with MADGRAD. The sample contains 18,155 examples (0.8 train | 0.2 test).
Based on both optimizers’ metrics, we could see that the model performed best on the angular ones, but worse on magnitude based, and even more on dot metrics — and despite lower in overall metrics, Ranger21 using MADGRAD seems having more balanced metrics based on its standard deviations.
We can also examine the bird-eyes view of the randomized model, and after test training with AdamW and Ranger21 (we will just see the first three layers’ heads) — using BertViz:
input_text: "[C] [N] [C] [C] [C] [C@H1] [Ring1] [Branch1] [C] [=C] [N] [=C] [C] [=C] [Ring1] [=Branch1]" # (nicotine)I observed that the attention patterns that is trained via Ranger21 are more richer which led to choosing this optimizer for training. There are also redundant attention heads, which led me to prune their amount, which will be discussed later here.
Based on these observations, I’ve decided to invest more time in pre-training the model before fine-tuning it for specific tasks. This approach should yield two models:
- A base model with a general understanding of SELFIES token relationships
- A fine-tuned model optimized for embedding and feature extraction which I hope can have a more balanced metrics — so that we can use it to cluster molecules better.
Since, the base model won’t be fine-tuned only for an embedding model, it potentially can be fine-tuned for various tasks analogous to POS tagging (quick labeling of functional groups or molecular scaffolds), QA (like synthesis/reactant-product prediction), classifications (both token-level, sentence, and pairs), and more to be explored.
Dataset and Preprocessing
Data Sources
The dataset combines two sources of molecular data:
- Natural compounds from COCONUTDB (Sorokina et al., 2021)
- Bioactive compounds from ChemBL34 (Zdrazil et al., 2023)
Data Preparation
- Fetching: Canonical SMILES (Simplified Molecular Input Line Entry System) representations were extracted from both databases.
- De-duplication:
– Each dataset was de-duplicated internally.
– The combined dataset (“All”) was further de-duplicated to ensure unique entries. - Validity Check and Conversion: A dual validity check was performed using RDKit and by converting them into SELFIES
Filtering and Chunking
- Filtering by Lipinski’s Rule of Five or its subsets (e.g., Mw < 500 and LogP < 5) was omitted to maintain broader coverage for potential future expansion to organic and inorganic molecules such in PubCHEM and ZINC20.
- The dataset was chunked into 13 parts, each containing 203,458 molecules, to accommodate the 6-hour time limit on Paperspace’s Gradient.
- Any leftover data was randomly distributed across the 13 chunks to ensure even distribution.
Validation Set
- 10% of each chunk was set aside for validation.
- These validation sets were combined into a main test set, totalling 810,108 examples.
The Tokenizer: Wider Coverage or Biological Relevancy?
The tokenizer is a combination of my own pretrained tokenizer on the merged COCONUTDB+ChemBL34 SELFIES dataset with vocabularies from zpn’s word-level tokenizer trained on PubChem. This approach was chosen to ensure comprehensive coverage while maintaining relevance to biological compounds. The tokenizer was modified to ensure 100% coverage of all training data and adapted to suit the BertTokenizer format, using whitespace to split input tokens.
When using or fine-tuning this model, it’s crucial to separate each SELFIES token with a whitespace. For example:
[C] [N] [C] [C] [C] [C@H1] [Ring1] [Branch1] [C] [=C] [N] [=C] [C] [=C] [Ring1] [=Branch1]Initially, I had a dilemma between focusing solely on natural products and bioactives from the specified datasets or providing a baseline vocabulary for future expansion. But finally, I decided to settle with having a wider-coverage tokenizer first, since having a broad vocabulary now is more straightforward than expanding a limited vocabulary later. This approach hopefully allows for potential application to diverse chemical spaces beyond the initial dataset.
To ensure coverage, the tokenizer underwent evaluation to cover all tokens in the training data. Unrecognized tokens were identified and incorporated into the tokenizer. Additionally, my previous pre-training issues, such as improper tokenization of dot symbol prefixes in complex molecules (e.g., “.[Cl]”), were addressed and resolved.
Model architecture choice: balancing performance and efficiency
Determining Max Length
Maximum sequence length: 510
97th percentile of sequence length: 115.00
99th percentile of sequence length: 171.00Distribution details:
50th percentile: 46.00
75th percentile: 58.00
99th percentile: 171.00
Based on these statistics, I opted to use 512 as the maximum sequence length for the model, because:
- Coverage: a max length of 512 covers even the longest sequences in the dataset (maximum of 510).
- Future-proofing: while 99% of sequences are 171 tokens or shorter, the 512 token limit allows for potential future expansion to more complex molecules.
- BERT compatibility: since I was planning to do both MLM and NSP mimicking the original BERT’s approach, I have to ensure that the split molecule pairs are mostly intact.
- Sequence integrity: The 512 token limit ensures that when splitting molecule pairs for NSP tasks, most sequences will remain intact, preserving important structural information.
Architecture
Important note: In all these experiments, a 15% uniform masking ratio was used and the objective is both MLM and NSP. Hereafter, parameters are shortened:
L: Layers
A: Number of attention heads
H: Hidden size
I: Intermediate size (all is 4*H)Initially, I probed performance using the 1st chunk of my training data. I started with a model configuration of L6, A12, H384, 1536:
As noted a bit earlier, due to some of attention heads either idling or seems to be doing the same thing, I decided to reduce them. I tested a smaller configuration (L6, A4, H256, 1024) using Ranger21 with MADGRAD:
It seems performing better than a much larger one. As the losses in previous failed pretraining using the small configuration never dropped below 0.31, I increased the number of layers to 8:
These results led me to test various combinations using a sample dataset (5% of all train set, V = 3095):
In determining the optimal model, I considered the following constraints:
- Sufficient performance
- Computational demands and training time within my hardware capabilities and 5-hour time limit
- Interpretable number of attention heads (avoiding over-diffusion seen in prior failed pretraining using L6, A4, H256, 1024)
Based on these criteria, I chose Model 1 (L8, A4, H320, I1280) for further development.
Pretraining Strategy: MLM-only or with NSP?
At first, I want to try using the original BERT’s approach but modified to molecular context. After trial and errors in providing the masked and NSP split data, I just discovered three days ago that, especially with RoBERTa’s report, and ChemBERTa itself, used the MLM task without NSP. Since, NSP might be too easy of a task and using only MLM got the same if not improved performance. So, for this model, it is only trained on MLM-only task. Despite this, to provide a more concrete comparison, I will train 2nd base model on customized MLM and NSP task after this.
The main parameter in MLM task is the masking rate, typically done by masking 15% of sequence, then 80% of them is masked with “[MASK]” token, 10% random from vocab, and 10% unchanged. But I want to try adjusting the masking rate based on molecule’s complexity and somehow map this complexity value to adjust the masking rate smoothly.
Examples is generated “on-the-fly” rather than writing them to a file first to avoid IO bottleneck and upload/download limitation to a compute server. For this task, I adapted the pre-training methods to SELFIES-specific use.
Dynamic MLM-rate
The key method in this projec is the implementation of a dynamic masking rate based on molecular complexity. I think we can hereustically infer the molecule’s complexity based on the syntax characteristic of SELFIES. Those simpler tokens will only have one characeter token such as “[N]” (l = 1 ; ignoring the brackets), a more complex one would be “.[N+1]” (l = 4), and those relatively rare atoms as compared to CHONS group like “[Na]” (l = 2), and ionized metals like “[Fe+3]” (l = 4). To normalize them and to infer the density of many characters tokens, we can divide the sum of all tokens length by the molecule’s length. I will refer to this simple score as “complexity score” hereafter. Then we can normalize that, and use this to determine a variable masking probability ranging from 15% to 45%. Additionally, we employ three different masking strategies to introduce further variability. This approach aims to create a more challenging and diverse training dataset, potentially leading to a more robust and generalizable model for molecular representation learning. Each SELFIES string’s complexity is calculated based on the logarithm of sum of tokens ratio with the sequence length.
1. Complexity Score Calculation
The raw complexity score is calculated using the formula:
Example outputs:
Sentence A:
Tokens: ['[C]', '[C]', '[=Branch1]', '[C]', '[=O]', '[O]', '[C]']
Sum of token lengths: 29
Number of tokens: 7
Raw complexity score: 1.4214==================================================
Sentence B:
Tokens: ['[C]', '[N+1]', '[Branch1]', '[C]', '[C]', '[Branch1]', '[C]', '[C]', '[C]']
Sum of token lengths: 41
Number of tokens: 9
Raw complexity score: 1.5163
But before that, we want to normalize these scores, so we need to calculate and find the distributions of this score in the dataset.
Minimum complexity score: 1.10
Maximum complexity score: 2.08
Mean complexity score: 1.58
Median complexity score: 1.58
99th percentile of complexity score: 1.69Recommended normalization range:
Minimum (1st percentile): 1.39
Maximum (99th percentile): 1.69
Normalization function example:
def normalize_complexity_score(score):
return max(0, min(1, (score - 1.39) / (1.69 - 1.39)))
2. Normalization
The raw score is then normalized to a range of 0–1 using predefined minimum (1.39) and maximum (1.69) normalization values:
3. Mapping to Masking Probability
To find the best function to adjust the masking probability smoothly, including also less mask on simpler one, and more mask on relatively complex one. We tried to find this function by plotting linear, quadratic, and exponential mapping with various steps as seen below:
I decided to use quadratic mapping with 0.3 steps, ensuring masking range between 15% to 45%.
The normalized score is then mapped to a masking probability using a quadratic function:
This results in a masking probability between 15% and 45%, with more complex molecules having a higher masking probability.
4 Multi-Strategy Masking
Three different masking strategies are employed for each SELFIES string:
Original Strategy:
- 80% chance to mask the token
- 10% chance to keep the original token
- 10% chance to replace with a random token
Alternative Strategy 1:
- 70% chance to mask the token
- 15% chance to keep the original token
- 15% chance to replace with a random token
Alternative Strategy 2:
- 60% chance to mask the token
- 20% chance to keep the original token
- 20% chance to replace with a random token
5. View Augmentation:
- Each SELFIES string is processed three times, once with each masking strategy.
- Note: These three repeats are views of the same training example, not additional data; they act as inexpensive regularisation but do not increase the effective sample size or alter the underlying molecular distribution. Dataset size is therefore reported as the number of unique molecules.
6. Masking Process
- Tokens are randomly selected for masking based on the calculated masking probability.
- Special tokens ([CLS] and [SEP]) are never masked.
- The number of tokens to be masked is determined by the masking probability and the length of the SELFIES string.
This methodology aims to create a diverse and challenging dataset for masked language modeling of molecular structures, adapting the masking intensity to the complexity of each molecule and employing multiple masking strategies to improve model robustness and generalization. Also, beside masking differently based on complexity scores, the on-the-fly data generation might ensure that each run and batches — the data are masked differently. But additional and further confirmation of this is needed.
Training Hyperparameters
Batch size = 128
Num of Epoch = 1
Total steps on all chunks = 56,966
Training time on each chunk = 03h:24m / ~205 mins
I am using Ranger21 optimizer with these settings:
Core optimizer = madgrad
Learning rate of 2e-05num_epochs of training = ** 1 epochs **
using AdaBelief for variance computation
Warm-up: linear warmup, over 964 iterations (0.22)
Lookahead active, merging every 5 steps, with blend factor of 0.5
Norm Loss active, factor = 0.0001
Stable weight decay of 0.01
Gradient Centralization = On
Adaptive Gradient Clipping = True
clipping value of 0.01
steps for clipping = 0.001
I turned off the warm down, since in prior experiments it led to instability of losses in my case
