3d8d5d0d58
* feat: remove on_learn_epoch fn as not used * revert: add _on_learn_epoch fn * feat: remove NaiveStrategy * test: update train_prompts tests * fix: remove prepare_llama_tokenizer_and_embedding * test: add lora arg * feat: remove roberta support in train_prompts due to runtime errs * feat: remove deberta & roberta in rm as not used * test: remove deberta and roberta tests * feat: remove deberta and roberta models as not used * fix: remove calls to roberta * fix: remove prepare_llama_tokenizer_and_embedding * chore: update transformers version * docs: update transformers version * fix: fix actor inference * fix: fix ci * feat: change llama pad token to unk * revert: revert ddp setup_distributed * fix: change llama pad token to unk * revert: undo unnecessary changes * fix: use pip to install transformers |
||
---|---|---|
.. | ||
README.md | ||
easy_dataset.py | ||
easy_models.py | ||
train_peft_prompts.py | ||
train_peft_sft.py |
README.md
Add Peft support for SFT and Prompts model training
The original implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed.
Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model.
Preliminary installation
Since the current pypi peft package(0.2) has some bugs, please install the peft package using source.
git clone https://github.com/huggingface/peft
cd peft
pip install .
Usage
For SFT training, just call train_peft_sft.py
Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
For stage-3 rlhf training, call train_peft_prompts.py. Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
Dataformat
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.