mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ZijianYY
fe0f7970a2
|
2 years ago | |
---|---|---|
.. | ||
data | 2 years ago | |
palm_pytorch | 2 years ago | |
README.md | 2 years ago | |
requirements.txt | 2 years ago | |
run.sh | 2 years ago | |
train.py | 2 years ago |
README.md
PaLM - Pytorch
Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways, in less than 200 lines of code.
This model is pretty much SOTA on everything language.
It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is.
Install
$ pip install PaLM-pytorch
Usage
import torch
from palm_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
heads = 8,
dim_head = 64,
)
tokens = torch.randint(0, 20000, (1, 2048))
logits = palm(tokens) # (1, 2048, 20000)
The PaLM 540B in the paper would be
palm = PaLM(
num_tokens = 256000,
dim = 18432,
depth = 118,
heads = 48,
dim_head = 256
)
Test on Enwik8
$ python train.py
Todo
- offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer
Citations
@article{chowdhery2022PaLM,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Chowdhery, Aakanksha et al},
year = {2022}
}