Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.
See the techniques/ directory for explanations of various techniques implemented in this repository.
Latest: NorMuon optimizer - Predicting the oder of upcoming tokens
BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.
Results below are the numbers after an epoch of training with fineweb10B, mostly using the default parameters. My goal is to see how things work without fiddling with the model and hyperparameters a lot.
Multi-token prediction - link
Weight tying - link
Grouped query attention - link
Capping logits - link
QKV bias - link
Zero-init projection layer - link
Post and pre-RMSNorm - link
Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324
Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527
KV-Shifting attention - link - seems to improve performance - loss: 3.3310
-> 3.3138
- peak memory consumption: 42858 MiB
Dilated Attention (LongNet) - link
Multi-Head Latent Attention - link - loss: 3.3479
- peak memory consumption: 42192 MiB
Per token output bias - link - loss: 3.3257
- peak memory consumption: 42120 MiB
DyT Norm - link - didn't really work. Loss stuck too high
Forgetting Transformer (Vanilla and Pro vers) - link - vanilla loss: 3.3243
, pro loss: OOM
Multi-Token Attention - link - loss: 3.3357
- peak memory: 42136 MiB
Differential Attention - link - best_model_loss: 3.2411
-> loss: 3.2460
- peak memory: 41521 MiB
Softpick - link - loss: 3.3446
- peak memory: 59417 MiB
Canon Layer - link - loss: 3.3217
- peak memory: 43199 MiB
Parallel Transformer Block - link - loss: 3.3473
- peak memory: 40302 MiB
Per Layer Token Embedding - link - loss: 3.2411
- peak memory: 40916 MiB
PolyNorm - link - best_model_loss: 3.2411
-> loss: 3.3017
- peak memory: 40895 MiB
PolyReLU - link - best_model_loss: 3.2411
-> loss: 3.2642
- peak memory: 40890 MiB
TOP loss - link - best_model_loss: 3.2411
-> loss: 3.2636
- peak memory: 47816 MiB
MegaByte - link - loss: 3.810
FTP (heavily modified) - link - loss: 3.901
Rene - link - loss: 3.340
Rwkv7 - link - loss: 4.450
Zamba2 - link - Zamba2 > Rene > Rwkv7
Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710
Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505
Tokenformer (in BlaGPT model) - link - loss: 3.390
LLaDa (dLLM) - link - val-loss: 8.6930
, xentropy-loss: 4.2891
(comparable to other models and estimated by llada_validation_cross_entropy.py
),
Avey - link - loss: 3.323
, peak memory: 51962 MiB
(batch size 8), step_time: 2871ms
(very slow to train and uses >3x more memory than other models)
LFM2 - link - TBD
Hourglass Transformer (modified) - link - val_loss:1.0048 train_time:2671049ms step_avg:524.76ms
AUNet - link - val_loss:1.1502 train_time:7246104ms step_avg:1423.60ms
SpaceByte - link - val_loss:1.6755 train_time:2154923ms step_avg:423.36ms peak memory consumption: 27781 MiB
HNet - link - val_loss:1.4554 train_time:2207809ms step_avg:433.75ms peak memory consumption: 23948 MiB
PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results
Ademamix - link - Unstable even after trying different learning rates.
Adopt - link - straight up Nan
CAdamW - link - loss: 3.3517
AdamW with independent weight decay - link - loss: 3.320
Adam - loss: 3.3224
AdamW - loss: 3.3310
, peak VRAM: 42053 MiB
, step_time: 533ms
DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676
, peak VRAM: 41534 MiB
, step_time: 820ms
Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324
, peak VRAM: 41534 MiB
, step_time: 610ms
MARS - link - loss: 3.3459
, peak VRAM: 40953 MiB, step_time: 628ms
Muon - link - loss: 3.2923
, peak VRAM: 40332MB
, step_time: 620.24ms
AdaMuon - link - Adaptive Muon with second-moment estimation (default optimizer) - See detailed explanation
BiClip - link - (not working well) loss: 7.2292
, peak VRAM: 39751 MiB
, step_time: 510ms
NorMuon - link - best_model_loss: 3.2411
-> loss: 3.4630
, peak VRAM: 44154 MiB
, step_time: 387.46 ms
- See detailed explanation
- Implement the model
- Return the loss in the forward function
- Add model to
model_registry.py
- And start training
See one of the implementations for details.
-
Get the data by running
data/fineweb10B_cached.py
-
Start training with:
torchrun --standalone --nproc_per_node=8 train.py --run_name pre_post_norm --model_name blagpt
- (Optional) Run the learning rate finder before the training
torchrun --standalone --nproc_per_node=8 find_lr.py --model_name blagpt
# Output
Results:
Steepest gradient learning rate: 3.31e-06
Elbow point learning rate: 1.20e-01
Plot saved to: logs/lr_finder_blagpt/lr_finder_plot.png
Results saved to: logs/lr_finder_blagpt/lr_finder_results.pt
-
Check
best_model_config.py
for the best model configuration so far. -
You can run the training with the best model config by running:
torchrun --standalone --nproc_per_node=8 train.py --run_name best_model --model_name best
The initial code is based on
Nano GPT - link
Modded NanoGPT - link
Thanks to @xumingyu2021 for memory friendly implementation of the Differential Attention