[Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5995/head
YeAnbang 2024-08-12 14:55:17 +08:00 committed by GitHub
parent b4d2377d4c
commit ed97d3a5d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 38 additions and 40 deletions

View File

@ -169,7 +169,7 @@ def tokenize_prompt(
template.messages = template.messages[:-1]
# Prepare data
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
if tokenizer.bos_token_id is not None:

View File

@ -138,6 +138,7 @@ def disable_dropout(model: torch.nn.Module):
Returns:
None
"""
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0.0
if model is not None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0.0

View File

@ -462,26 +462,24 @@ Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of
#### Step 1: Data Collection
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
The first step in Stage 1 is to collect a dataset of human demonstrations of the following JSONL format.
```json
[
{"messages":
[
{
"from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
...
]
{"messages":
[
{
"from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
...
]
]
},
...
```

View File

@ -151,7 +151,6 @@ def main(args):
chat_io.prompt_for_output("assistant")
prompt = conv.get_prompt(add_generation_prompt=True)
print(prompt + "<end_of_prompt>")
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
torch.cuda.current_device()
)

View File

@ -502,7 +502,7 @@ if __name__ == "__main__":
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")

View File

@ -2,7 +2,7 @@ transformers==4.39.3
tqdm
datasets==2.14.7
loralib
colossalai==0.4.0
colossalai>=0.4.0
torch>=2.1.0
langchain
tokenizers

View File

@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
set_n_least_used_CUDA_VISIBLE_DEVICES 2
set -xu
@ -119,11 +119,11 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='4'
tp='2'
bs='8'
fi
if [[ $plugin == "tp_zero2" ]]; then
tp='4'
tp='2'
bs='8'
zero_stage='2'
plugin='3d'
@ -136,13 +136,13 @@ for lora_rank in ${LORA_RANK[@]}; do
fi
if [[ $plugin == "pp" ]]; then
bs='8'
pp='4'
pp='2'
plugin='3d'
fi
if [[ $plugin == "sp_split_gather" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='split_gather'
tp='4'
tp='2'
sp='1'
bs='8'
plugin='3d'
@ -150,7 +150,7 @@ for lora_rank in ${LORA_RANK[@]}; do
if [[ $plugin == "sp_ring" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='ring'
tp='4'
tp='2'
sp='1'
bs='8'
plugin='3d'
@ -159,7 +159,7 @@ for lora_rank in ${LORA_RANK[@]}; do
enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='all_to_all'
tp='1'
sp='4'
sp='2'
bs='8'
plugin='3d'
fi
@ -175,7 +175,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
@ -242,7 +242,7 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='4'
tp='2'
bs='8'
fi
grad_accu='2'
@ -256,7 +256,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
@ -325,7 +325,7 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='4'
tp='2'
bs='16'
ebs='32'
fi
@ -350,7 +350,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \
--rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
@ -417,7 +417,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
tp='2'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
@ -442,7 +442,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
@ -500,7 +500,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
tp='2'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
@ -525,7 +525,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
@ -583,7 +583,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
tp='2'
bs='8'
fi
if [[ $plugin == "zero2" ]]; then
@ -608,7 +608,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \