Merge branch 'hpcaitech:main' into coati/support-pp

pull/5994/head
Tong Li 2024-08-13 11:59:53 +08:00 committed by GitHub
commit 8806efd047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 51 additions and 52 deletions

View File

@ -169,7 +169,7 @@ def tokenize_prompt(
template.messages = template.messages[:-1] template.messages = template.messages[:-1]
# Prepare data # Prepare data
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True) prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0] tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
if tokenizer.bos_token_id is not None: if tokenizer.bos_token_id is not None:

View File

@ -138,6 +138,7 @@ def disable_dropout(model: torch.nn.Module):
Returns: Returns:
None None
""" """
for module in model.modules(): if model is not None:
if isinstance(module, torch.nn.Dropout): for module in model.modules():
module.p = 0.0 if isinstance(module, torch.nn.Dropout):
module.p = 0.0

View File

@ -462,26 +462,24 @@ Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of
#### Step 1: Data Collection #### Step 1: Data Collection
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format. The first step in Stage 1 is to collect a dataset of human demonstrations of the following JSONL format.
```json ```json
[ {"messages":
{"messages": [
[ {
{ "from": "user",
"from": "user", "content": "what are some pranks with a pen i can do?"
"content": "what are some pranks with a pen i can do?" },
}, {
{ "from": "assistant",
"from": "assistant", "content": "Are you looking for practical joke ideas?"
"content": "Are you looking for practical joke ideas?"
},
...
]
}, },
... ...
] ]
},
...
``` ```

View File

@ -151,7 +151,6 @@ def main(args):
chat_io.prompt_for_output("assistant") chat_io.prompt_for_output("assistant")
prompt = conv.get_prompt(add_generation_prompt=True) prompt = conv.get_prompt(add_generation_prompt=True)
print(prompt + "<end_of_prompt>")
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to( input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
torch.cuda.current_device() torch.cuda.current_device()
) )

View File

@ -502,7 +502,7 @@ if __name__ == "__main__":
parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_seq_len", type=int, default=256) parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--log_dir", default="logs", type=str) parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true")

View File

@ -2,7 +2,7 @@ transformers==4.39.3
tqdm tqdm
datasets==2.14.7 datasets==2.14.7
loralib loralib
colossalai==0.4.0 colossalai>=0.4.0
torch>=2.1.0 torch>=2.1.0
langchain langchain
tokenizers tokenizers

View File

@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
} }
set_n_least_used_CUDA_VISIBLE_DEVICES 4 set_n_least_used_CUDA_VISIBLE_DEVICES 2
set -xu set -xu
@ -119,11 +119,11 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config="" lora_config=""
fi fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
if [[ $plugin == "tp_zero2" ]]; then if [[ $plugin == "tp_zero2" ]]; then
tp='4' tp='2'
bs='8' bs='8'
zero_stage='2' zero_stage='2'
plugin='3d' plugin='3d'
@ -136,13 +136,13 @@ for lora_rank in ${LORA_RANK[@]}; do
fi fi
if [[ $plugin == "pp" ]]; then if [[ $plugin == "pp" ]]; then
bs='8' bs='8'
pp='4' pp='2'
plugin='3d' plugin='3d'
fi fi
if [[ $plugin == "sp_split_gather" ]]; then if [[ $plugin == "sp_split_gather" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism' enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='split_gather' sp_mode='split_gather'
tp='4' tp='2'
sp='1' sp='1'
bs='8' bs='8'
plugin='3d' plugin='3d'
@ -150,7 +150,7 @@ for lora_rank in ${LORA_RANK[@]}; do
if [[ $plugin == "sp_ring" ]]; then if [[ $plugin == "sp_ring" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism' enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='ring' sp_mode='ring'
tp='4' tp='2'
sp='1' sp='1'
bs='8' bs='8'
plugin='3d' plugin='3d'
@ -159,7 +159,7 @@ for lora_rank in ${LORA_RANK[@]}; do
enable_sequence_parallelism='--enable_sequence_parallelism' enable_sequence_parallelism='--enable_sequence_parallelism'
sp_mode='all_to_all' sp_mode='all_to_all'
tp='1' tp='1'
sp='4' sp='2'
bs='8' bs='8'
plugin='3d' plugin='3d'
fi fi
@ -175,7 +175,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
@ -242,7 +242,7 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config="" lora_config=""
fi fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
grad_accu='2' grad_accu='2'
@ -256,7 +256,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
@ -325,7 +325,7 @@ for lora_rank in ${LORA_RANK[@]}; do
lora_config="" lora_config=""
fi fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='16' bs='16'
ebs='32' ebs='32'
fi fi
@ -350,7 +350,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \ --pretrain $pretrain \
--rm_pretrain $pretrain \ --rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
@ -417,7 +417,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1' tp='1'
bs='2' bs='2'
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then if [[ $plugin == "zero2" ]]; then
@ -442,7 +442,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
@ -500,7 +500,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1' tp='1'
bs='2' bs='2'
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then if [[ $plugin == "zero2" ]]; then
@ -525,7 +525,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
@ -583,7 +583,7 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='1' tp='1'
bs='2' bs='2'
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='2'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then if [[ $plugin == "zero2" ]]; then
@ -608,7 +608,7 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
done done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \ colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \

View File

@ -42,7 +42,7 @@ try:
return output return output
except ImportError: except ImportError:
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel") warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,
@ -270,12 +270,6 @@ class FusedRMSNorm(BaseLayerNorm):
Returns: Returns:
nn.Module: FusedRMSNorm module. nn.Module: FusedRMSNorm module.
""" """
try:
pass
except ImportError:
raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
)
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
@ -284,11 +278,18 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True) elementwise_affine = getattr(module, "elementwise_affine", True)
rmsnorm = FusedRMSNormWithHook( try:
normalized_shape=normalized_shape, rmsnorm = FusedRMSNormWithHook(
eps=eps, normalized_shape=normalized_shape,
elementwise_affine=elementwise_affine, eps=eps,
) elementwise_affine=elementwise_affine,
)
except ImportError:
warnings.warn(
"Module replacement failed.\
Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
)
return module
rmsnorm.weight = module.weight rmsnorm.weight = module.weight