From e3551443751e5ff1ced4215afe76fb8e22ded06b Mon Sep 17 00:00:00 2001 From: csric <59389055+CsRic@users.noreply.github.com> Date: Mon, 17 Apr 2023 14:46:50 +0800 Subject: [PATCH] [chatgpt] Detached PPO Training (#3195) * run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments --------- Co-authored-by: csric --- applications/Chat/.gitignore | 2 + applications/Chat/coati/ray/__init__.py | 2 + applications/Chat/coati/ray/example/1m1t.py | 153 +++++++++++++ applications/Chat/coati/ray/example/1m1t.sh | 23 ++ applications/Chat/coati/ray/example/1m2t.py | 186 ++++++++++++++++ applications/Chat/coati/ray/example/1m2t.sh | 23 ++ applications/Chat/coati/ray/example/2m1t.py | 140 ++++++++++++ applications/Chat/coati/ray/example/2m1t.sh | 23 ++ applications/Chat/coati/ray/example/2m2t.py | 209 ++++++++++++++++++ applications/Chat/coati/ray/example/2m2t.sh | 23 ++ applications/Chat/coati/ray/src/__init__.py | 0 .../coati/ray/src/detached_replay_buffer.py | 88 ++++++++ .../coati/ray/src/detached_trainer_base.py | 121 ++++++++++ .../coati/ray/src/detached_trainer_ppo.py | 192 ++++++++++++++++ .../coati/ray/src/experience_maker_holder.py | 172 ++++++++++++++ .../Chat/coati/ray/src/pipeline_strategy.py | 105 +++++++++ applications/Chat/coati/ray/src/utils.py | 48 ++++ applications/Chat/coati/trainer/utils.py | 9 + applications/Chat/coati/utils/__init__.py | 2 +- applications/Chat/examples/train_prompts.sh | 2 + 20 files changed, 1522 insertions(+), 1 deletion(-) create mode 100644 applications/Chat/coati/ray/__init__.py create mode 100644 applications/Chat/coati/ray/example/1m1t.py create mode 100644 applications/Chat/coati/ray/example/1m1t.sh create mode 100644 applications/Chat/coati/ray/example/1m2t.py create mode 100644 applications/Chat/coati/ray/example/1m2t.sh create mode 100644 applications/Chat/coati/ray/example/2m1t.py create mode 100644 applications/Chat/coati/ray/example/2m1t.sh create mode 100644 applications/Chat/coati/ray/example/2m2t.py create mode 100644 applications/Chat/coati/ray/example/2m2t.sh create mode 100644 applications/Chat/coati/ray/src/__init__.py create mode 100644 applications/Chat/coati/ray/src/detached_replay_buffer.py create mode 100644 applications/Chat/coati/ray/src/detached_trainer_base.py create mode 100644 applications/Chat/coati/ray/src/detached_trainer_ppo.py create mode 100644 applications/Chat/coati/ray/src/experience_maker_holder.py create mode 100644 applications/Chat/coati/ray/src/pipeline_strategy.py create mode 100644 applications/Chat/coati/ray/src/utils.py diff --git a/applications/Chat/.gitignore b/applications/Chat/.gitignore index 1ec5f53a8..2b9b4f345 100644 --- a/applications/Chat/.gitignore +++ b/applications/Chat/.gitignore @@ -144,3 +144,5 @@ docs/.build # wandb log example/wandb/ + +examples/awesome-chatgpt-prompts/ \ No newline at end of file diff --git a/applications/Chat/coati/ray/__init__.py b/applications/Chat/coati/ray/__init__.py new file mode 100644 index 000000000..5802c05bc --- /dev/null +++ b/applications/Chat/coati/ray/__init__.py @@ -0,0 +1,2 @@ +from .src.detached_replay_buffer import DetachedReplayBuffer +from .src.detached_trainer_ppo import DetachedPPOTrainer diff --git a/applications/Chat/coati/ray/example/1m1t.py b/applications/Chat/coati/ray/example/1m1t.py new file mode 100644 index 000000000..a65273705 --- /dev/null +++ b/applications/Chat/coati/ray/example/1m1t.py @@ -0,0 +1,153 @@ +import argparse +from copy import deepcopy + +import pandas as pd +import torch +from coati.trainer import PPOTrainer + + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.experience_maker import NaiveExperienceMaker +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + +import ray +import os +import socket + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(('8.8.8.8', 80)) + return s.getsockname()[0] + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainer = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '1', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '1', + 'master_port' : maker_port, + 'master_addr' : master_addr} + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy=args.trainer_strategy, + model=args.model, + env_info = env_info_trainer, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # configure Experience Maker + experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1"], + strategy=args.maker_strategy, + env_info = env_info_maker, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # trainer send its actor and critic to experience holders. + ray.get(trainer_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance + maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_done_ref, maker_done_ref]) + + # save model checkpoint after fitting + trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"]) + main(args) diff --git a/applications/Chat/coati/ray/example/1m1t.sh b/applications/Chat/coati/ray/example/1m1t.sh new file mode 100644 index 000000000..f7c5054c8 --- /dev/null +++ b/applications/Chat/coati/ray/example/1m1t.sh @@ -0,0 +1,23 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +export RAY_NAMESPACE="admin" + +python 1m1t.py "/path/to/prompts.csv" \ + --trainer_strategy colossalai_zero2 --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \ + --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ + --max_epochs 10 --debug diff --git a/applications/Chat/coati/ray/example/1m2t.py b/applications/Chat/coati/ray/example/1m2t.py new file mode 100644 index 000000000..3883c364a --- /dev/null +++ b/applications/Chat/coati/ray/example/1m2t.py @@ -0,0 +1,186 @@ +import argparse +from copy import deepcopy + +import pandas as pd +import torch +from coati.trainer import PPOTrainer + + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.experience_maker import NaiveExperienceMaker +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + +import ray +import os +import socket + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(('8.8.8.8', 80)) + return s.getsockname()[0] + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainer_1 = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '2', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + env_info_trainer_2 = {'local_rank' : '0', + 'rank' : '1', + 'world_size' : '2', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker_1 = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '2', + 'master_port' : maker_port, + 'master_addr' : master_addr} + print([env_info_trainer_1, + env_info_trainer_2, + env_info_maker_1]) + ray.init(dashboard_port = 1145) + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy=args.trainer_strategy, + model=args.model, + env_info=env_info_trainer_1, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy=args.trainer_strategy, + model=args.model, + env_info=env_info_trainer_2, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug= args.debug, + ) + + # configure Experience Maker + experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1", "trainer2"], + strategy=args.maker_strategy, + env_info=env_info_maker_1, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # trainer send its actor and critic to experience holders. + # TODO: balance duty + ray.get(trainer_1_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs * 2 + 3 # +3 for fault tolerance + maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref]) + # save model checkpoint after fitting + trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + main(args) diff --git a/applications/Chat/coati/ray/example/1m2t.sh b/applications/Chat/coati/ray/example/1m2t.sh new file mode 100644 index 000000000..669f41410 --- /dev/null +++ b/applications/Chat/coati/ray/example/1m2t.sh @@ -0,0 +1,23 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +export RAY_NAMESPACE="admin" + +python 1m2t.py "/path/to/prompts.csv" --model gpt2 \ + --maker_strategy naive --trainer_strategy ddp --lora_rank 2 \ + --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ + --max_epochs 10 #--debug \ No newline at end of file diff --git a/applications/Chat/coati/ray/example/2m1t.py b/applications/Chat/coati/ray/example/2m1t.py new file mode 100644 index 000000000..b655de1ab --- /dev/null +++ b/applications/Chat/coati/ray/example/2m1t.py @@ -0,0 +1,140 @@ +import argparse +from copy import deepcopy + +import pandas as pd +import torch +from coati.trainer import PPOTrainer + + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.experience_maker import NaiveExperienceMaker +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + +import ray +import os +import socket + + +def main(args): + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1", "maker2"], + strategy=args.trainer_strategy, + model=args.model, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # configure Experience Maker + experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1"], + strategy=args.maker_strategy, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1"], + strategy=args.maker_strategy, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # trainer send its actor and critic to experience holders. + ray.get(trainer_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs // 2 + 3 # +3 for fault tolerance + maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_done_ref, maker_1_done_ref, maker_2_done_ref]) + + # save model checkpoint after fitting + trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"]) + main(args) diff --git a/applications/Chat/coati/ray/example/2m1t.sh b/applications/Chat/coati/ray/example/2m1t.sh new file mode 100644 index 000000000..a207d4118 --- /dev/null +++ b/applications/Chat/coati/ray/example/2m1t.sh @@ -0,0 +1,23 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 3 + +export RAY_NAMESPACE="admin" + +python 2m1t.py "/path/to/prompts.csv" \ + --trainer_strategy naive --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \ + --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ + --max_epochs 10 # --debug diff --git a/applications/Chat/coati/ray/example/2m2t.py b/applications/Chat/coati/ray/example/2m2t.py new file mode 100644 index 000000000..435c71915 --- /dev/null +++ b/applications/Chat/coati/ray/example/2m2t.py @@ -0,0 +1,209 @@ +import argparse +from copy import deepcopy + +import pandas as pd +import torch +from coati.trainer import PPOTrainer + + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.experience_maker import NaiveExperienceMaker +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.nn.optimizer import HybridAdam + +import ray +import os +import socket + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(('8.8.8.8', 80)) + return s.getsockname()[0] + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainer_1 = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '2', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + env_info_trainer_2 = {'local_rank' : '0', + 'rank' : '1', + 'world_size' : '2', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker_1 = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '2', + 'master_port' : maker_port, + 'master_addr' : master_addr} + env_info_maker_2 = {'local_rank' : '0', + 'rank' : '1', + 'world_size' : '2', + 'master_port': maker_port, + 'master_addr' : master_addr} + print([env_info_trainer_1, + env_info_trainer_2, + env_info_maker_1, + env_info_maker_2]) + ray.init() + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1", "maker2"], + strategy=args.trainer_strategy, + model=args.model, + env_info=env_info_trainer_1, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1", "maker2"], + strategy=args.trainer_strategy, + model=args.model, + env_info=env_info_trainer_2, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # configure Experience Maker + experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1", "trainer2"], + strategy=args.maker_strategy, + env_info=env_info_maker_1, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1", "trainer2"], + strategy=args.maker_strategy, + env_info=env_info_maker_2, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # trainer send its actor and critic to experience holders. + # TODO: balance duty + ray.get(trainer_1_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance + maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref, maker_2_done_ref]) + # save model checkpoint after fitting + trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + main(args) diff --git a/applications/Chat/coati/ray/example/2m2t.sh b/applications/Chat/coati/ray/example/2m2t.sh new file mode 100644 index 000000000..fb4024766 --- /dev/null +++ b/applications/Chat/coati/ray/example/2m2t.sh @@ -0,0 +1,23 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 2 + +export RAY_NAMESPACE="admin" + +python 2m2t.py "path/to/prompts.csv" \ + --maker_strategy naive --trainer_strategy colossalai_zero2 --lora_rank 2 \ + --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \ + --max_epochs 10 --debug \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/__init__.py b/applications/Chat/coati/ray/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/Chat/coati/ray/src/detached_replay_buffer.py b/applications/Chat/coati/ray/src/detached_replay_buffer.py new file mode 100644 index 000000000..855eee48c --- /dev/null +++ b/applications/Chat/coati/ray/src/detached_replay_buffer.py @@ -0,0 +1,88 @@ +import torch +import random +from typing import List, Any +# from torch.multiprocessing import Queue +from ray.util.queue import Queue +import ray +import asyncio +from coati.experience_maker.base import Experience +from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch +from coati.replay_buffer import ReplayBuffer +from threading import Lock +import copy + +class DetachedReplayBuffer: + ''' + Detached replay buffer. Share Experience across workers on the same node. + Therefore a trainer node is expected to have only one instance. + It is ExperienceMakerHolder's duty to call append(exp) method, remotely. + + Args: + sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch. + tp_world_size: Number of workers in the same tp group + limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0. + cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. + ''' + + def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None: + self.cpu_offload = cpu_offload + self.sample_batch_size = sample_batch_size + self.limit = limit + self.items = Queue(self.limit, actor_options={"num_cpus":1}) + self.batch_collector : List[BufferItem] = [] + + ''' + Workers in the same tp group share this buffer and need same sample for one step. + Therefore a held_sample should be returned tp_world_size times before it could be dropped. + worker_state records wheter a worker got the held_sample + ''' + self.tp_world_size = tp_world_size + self.worker_state = [False] * self.tp_world_size + self.held_sample = None + self._worker_state_lock = Lock() + + @torch.no_grad() + def append(self, experience: Experience) -> None: + ''' + Expected to be called remotely. + ''' + if self.cpu_offload: + experience.to_device(torch.device('cpu')) + items = split_experience_batch(experience) + self.batch_collector.extend(items) + while len(self.batch_collector) >= self.sample_batch_size: + items = self.batch_collector[:self.sample_batch_size] + experience = make_experience_batch(items) + self.items.put(experience, block=True) + self.batch_collector = self.batch_collector[self.sample_batch_size:] + + def clear(self) -> None: + # self.items.close() + self.items.shutdown() + self.items = Queue(self.limit) + self.worker_state = [False] * self.tp_world_size + self.batch_collector = [] + + @torch.no_grad() + def sample(self, worker_rank = 0, to_device = "cpu") -> Experience: + self._worker_state_lock.acquire() + if not any(self.worker_state): + self.held_sample = self._sample_and_erase() + self.worker_state[worker_rank] = True + if all(self.worker_state): + self.worker_state = [False] * self.tp_world_size + ret = self.held_sample + else: + ret = copy.deepcopy(self.held_sample) + self._worker_state_lock.release() + ret.to_device(to_device) + return ret + + @torch.no_grad() + def _sample_and_erase(self) -> Experience: + ret = self.items.get(block=True) + return ret + + def get_length(self) -> int: + ret = self.items.qsize() + return ret \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py new file mode 100644 index 000000000..f1ed1ec71 --- /dev/null +++ b/applications/Chat/coati/ray/src/detached_trainer_base.py @@ -0,0 +1,121 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union +from tqdm import tqdm +from coati.trainer.callbacks import Callback +from coati.experience_maker import Experience +import ray +import os + +from .detached_replay_buffer import DetachedReplayBuffer +from .utils import is_rank_0 + +class DetachedTrainer(ABC): + ''' + Base class for detached rlhf trainers. + 'detach' means that the experience maker is detached compared to a normal Trainer. + Please set name attribute during init: + >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote() + So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name. + Args: + detached_strategy (DetachedStrategy): the strategy to use for training + detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training + experience_batch_size (int, defaults to 8): the batch size to use for experience generation + max_epochs (int, defaults to 1): the number of epochs of training process + data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + ''' + + def __init__(self, + experience_maker_holder_name_list: List[str], + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + experience_batch_size: int = 8, + max_epochs: int = 1, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + super().__init__() + self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload) + self.experience_batch_size = experience_batch_size + self.max_epochs = max_epochs + self.dataloader_pin_memory = dataloader_pin_memory + self.callbacks = callbacks + self.generate_kwargs = generate_kwargs + self.target_holder_name_list = experience_maker_holder_name_list + self.target_holder_list = [] + + def update_target_holder_list(self, experience_maker_holder_name_list): + self.target_holder_name_list = experience_maker_holder_name_list + self.target_holder_list = [] + for name in self.target_holder_name_list: + self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + + @abstractmethod + def _update_remote_makers(self): + pass + + @abstractmethod + def training_step(self, experience: Experience) -> Dict[str, Any]: + pass + + def _learn(self): + pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) + for _ in pbar: + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[trainer] sampling exp") + experience = self._buffer_sample() + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[trainer] training step") + metrics = self.training_step(experience) + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[trainer] step over") + pbar.set_postfix(metrics) + + def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None: + self._on_fit_start() + for episode in range(num_episodes): + self._on_episode_start(episode) + for timestep in tqdm(range(max_timesteps // update_timesteps), + desc=f'Episode [{episode+1}/{num_episodes}]', + disable=not is_rank_0()): + self._learn() + self._update_remote_makers() + self._on_episode_end(episode) + self._on_fit_end() + + @ray.method(concurrency_group="buffer_length") + def buffer_get_length(self): + # called by ExperienceMakerHolder + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[trainer] telling length") + return self.detached_replay_buffer.get_length() + + @ray.method(concurrency_group="buffer_append") + def buffer_append(self, experience: Experience): + # called by ExperienceMakerHolder + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + # print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}") + print(f"[trainer] receiving exp.") + self.detached_replay_buffer.append(experience) + + @ray.method(concurrency_group="buffer_sample") + def _buffer_sample(self): + return self.detached_replay_buffer.sample() + + def _on_fit_start(self) -> None: + for callback in self.callbacks: + callback.on_fit_start() + + def _on_fit_end(self) -> None: + for callback in self.callbacks: + callback.on_fit_end() + + def _on_episode_start(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_start(episode) + + def _on_episode_end(self, episode: int) -> None: + for callback in self.callbacks: + callback.on_episode_end(episode) diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py new file mode 100644 index 000000000..90e5e4377 --- /dev/null +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -0,0 +1,192 @@ +from typing import Any, Callable, Dict, List, Optional +import torch +from torch.optim import Adam + +from coati.experience_maker import Experience, NaiveExperienceMaker +from coati.models.base import Actor, Critic +from coati.models.generation_utils import update_model_kwargs_fn +from coati.models.loss import PolicyLoss, ValueLoss +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy +from coati.trainer.callbacks import Callback + +from colossalai.nn.optimizer import HybridAdam + +import ray + + +from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env +from .detached_trainer_base import DetachedTrainer + + +@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1}) +class DetachedPPOTrainer(DetachedTrainer): + ''' + Detached Trainer for PPO algorithm + Args: + strategy (Strategy): the strategy to use for training + model (str) : for actor / critic init + pretrained (str) : for actor / critic init + lora_rank (int) : for actor / critic init + train_batch_size (int, defaults to 8): the batch size to use for training + train_batch_size (int, defaults to 8): the batch size to use for training + buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer + buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu + eps_clip (float, defaults to 0.2): the clip coefficient of policy loss + value_clip (float, defaults to 0.4): the clip coefficient of value loss + experience_batch_size (int, defaults to 8): the batch size to use for experience generation + max_epochs (int, defaults to 1): the number of epochs of training process + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + ''' + + def __init__(self, + experience_maker_holder_name_list: List[str], + strategy: str, + model: str, + env_info: Dict[str, str] = None, + pretrained: str = None, + lora_rank: int = 0, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + value_clip: float = 0.4, + experience_batch_size: int = 8, + max_epochs: int = 10, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + # set environment variables + if env_info: + set_dist_env(env_info=env_info) + # configure strategy + self.strategy = get_strategy_from_args(strategy) + # configure models, loss and optimizers + with self.strategy.model_init_context(): + self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank) + + if strategy != 'colossalai_gemini': + self.actor.to(torch.float16).to(torch.cuda.current_device()) + self.critic.to(torch.float16).to(torch.cuda.current_device()) + + if strategy.startswith('colossalai'): + self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6) + self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6) + else: + self.actor_optim = Adam(self.actor.parameters(), lr=5e-6) + self.critic_optim = Adam(self.critic.parameters(), lr=5e-6) + + (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \ + self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) + generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor) + + self.actor_loss_fn = PolicyLoss(eps_clip) + self.critic_loss_fn = ValueLoss(value_clip) + + super().__init__(experience_maker_holder_name_list, + train_batch_size=train_batch_size, + buffer_limit=buffer_limit, + buffer_cpu_offload=buffer_cpu_offload, + experience_batch_size=experience_batch_size, + max_epochs=max_epochs, + dataloader_pin_memory=dataloader_pin_memory, + callbacks=callbacks, + **generate_kwargs) + + @ray.method(concurrency_group="model_io") + def _update_remote_makers(self): + # TODO: balance duties + if is_rank_0(): + self.update_target_holder_list(self.target_holder_name_list) + for target_holder in self.target_holder_list: + # TODO: reduce malloc + with torch.no_grad(): + ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) + + @ray.method(concurrency_group="model_io") + def initialize_remote_makers(self): + # TODO: balance duties + if is_rank_0(): + self.update_target_holder_list(self.target_holder_name_list) + for target_holder in self.target_holder_list: + # TODO: reduce malloc + with torch.no_grad(): + ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) + + @ray.method(concurrency_group="compute") + def training_step(self, experience: Experience) -> Dict[str, float]: + self.actor.train() + self.critic.train() + + experience.to_device(torch.cuda.current_device()) + num_actions = experience.action_mask.size(1) + action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) + actor_loss = self.actor_loss_fn(action_log_probs, + experience.action_log_probs, + experience.advantages, + action_mask=experience.action_mask) + self.strategy.backward(actor_loss, self.actor, self.actor_optim) + self.strategy.optimizer_step(self.actor_optim) + self.actor_optim.zero_grad() + + values = self.critic(experience.sequences, + action_mask=experience.action_mask, + attention_mask=experience.attention_mask) + critic_loss = self.critic_loss_fn(values, + experience.values, + experience.reward, + action_mask=experience.action_mask) + + self.strategy.backward(critic_loss, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim) + self.critic_optim.zero_grad() + return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} + + def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_model(self.actor, path, only_rank0) + + def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_model(self.critic, path, only_rank0) + + def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_optimizer(self.actor_optim, path, only_rank0) + + def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None: + self.strategy.save_optimizer(self.critic_optim, path, only_rank0) + + def _get_unwrapped_actor(self): + if False: + pass + elif isinstance(self.strategy, ColossalAIStrategy): + ret = Actor(self.strategy._unwrap_model(self.actor)) + return ret + elif isinstance(self.strategy, DDPStrategy): + return Actor(self.strategy._unwrap_actor(self.actor)) + elif isinstance(self.strategy, NaiveStrategy): + return self.actor + + def _get_unwrapped_critic(self): + if False: + pass + elif isinstance(self.strategy, ColossalAIStrategy): + ret = self.strategy._unwrap_model(self.critic) + return ret + elif isinstance(self.strategy, DDPStrategy): + return self.critic.module + elif isinstance(self.strategy, NaiveStrategy): + return self.critic + + +def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: + origin_model = strategy._unwrap_actor(actor) + new_kwargs = {**generate_kwargs} + # use huggingface models method directly + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): + new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation + + if 'update_model_kwargs_fn' not in generate_kwargs: + new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn + + return new_kwargs + \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py new file mode 100644 index 000000000..696773e84 --- /dev/null +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -0,0 +1,172 @@ +import torch +from typing import Any, Callable, Dict, List, Optional, Union +import ray +from ray.exceptions import GetTimeoutError +from torch import Tensor +import torch.nn as nn +from coati.models.base import Actor, Critic, RewardModel +from coati.trainer.strategies.sampler import DistributedSampler +from coati.trainer.strategies import Strategy +from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker + +from copy import deepcopy +from threading import Lock +import time +import os + + +from .utils import is_rank_0, get_strategy_from_args, set_dist_env + + +@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) +class ExperienceMakerHolder: + ''' + Args: + detached_trainer_name_list: str list to get ray actor handleskkk + strategy: + experience_batch_size: batch size of generated experience + kl_coef: the coefficient of kl divergence loss + ''' + + def __init__(self, + detached_trainer_name_list: List[str], + strategy: str, + env_info: Dict[str, str] = None, + experience_batch_size: int = 8, + kl_coef: float = 0.1, + **generate_kwargs): + # set environment variables + if env_info: + set_dist_env(env_info=env_info) + self.target_trainer_list = [] + for name in detached_trainer_name_list: + self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + self.strategy_str = strategy + self.strategy = get_strategy_from_args(strategy) + self.experience_batch_size = experience_batch_size + self.kl_coef = kl_coef + self.generate_kwargs = generate_kwargs + # Need a trainer to give an actor and a critic via initialize_experience_maker(...) + actor, critic, reward_model, initial_model = None, None, None, None + self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) + self._model_visit_lock = Lock() + self.fully_initialized = False + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print('[maker] Waiting for INIT') + + def _get_ready(self): + while not self.fully_initialized: + time.sleep(1.0) + + def update_target_trainer_list(self, detached_trainer_name_list): + self.target_trainer_list = [] + for name in detached_trainer_name_list: + self.target_trainer_list.append(ray.get_actor(name)) + + # copy from ../trainer/base.py + @ray.method(concurrency_group="compute") + def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: + self._get_ready() + if isinstance(inputs, Tensor): + return self.experience_maker.make_experience(inputs, **self.generate_kwargs) + elif isinstance(inputs, dict): + return self.experience_maker.make_experience(**inputs, **self.generate_kwargs) + else: + raise ValueError(f'Unsupported input type "{type(inputs)}"') + + @ray.method(concurrency_group="experience_io") + def _send_experience(self, experience): + ''' + ignore it + + # choose a trainer that has the least experience batch in its detached_replay_buffer + chosen_trainer = None + min_length = None + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[maker] choosing tartget trainer") + while chosen_trainer is None: + for target_trainer in self.target_trainer_list: + try: + temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1) + if min_length is None: + min_length = temp_length + chosen_trainer = target_trainer + else: + if temp_length < min_length: + min_length = temp_length + chosen_trainer = target_trainer + except GetTimeoutError: + pass + + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print(f"[maker] sending exp to {chosen_trainer}") + chosen_trainer.buffer_append.remote(experience) + ''' + # + if not hasattr(self, "_target_idx"): + self._target_idx = 0 + chosen_trainer = self.target_trainer_list[self._target_idx] + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print(f"[maker] sending exp to {chosen_trainer}") + chosen_trainer.buffer_append.remote(experience) + self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list) + + def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000): + self._get_ready() + sampler = self.strategy.setup_sampler(dataset) + for _ in range(times): + rand_prompts = sampler.sample(self.experience_batch_size) + if tokenizer is not None: + inputs = tokenizer(rand_prompts) + else: + inputs = rand_prompts + self._model_visit_lock.acquire() + experience = self._make_experience(inputs=inputs) + self._model_visit_lock.release() + self._send_experience(experience=experience) + + @ray.method(concurrency_group="model_io") + def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic): + ''' + called by trainer. Only once. + ''' + # TODO: reduce malloc + if self.fully_initialized: + return + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print('[maker] INIT') + with torch.no_grad(): + with self.strategy.model_init_context(): + actor = init_actor + critic = init_critic + initial_model = deepcopy(actor) + reward_model = RewardModel(deepcopy(critic.model), + deepcopy(critic.value_head)).to(torch.cuda.current_device()) + if self.strategy_str != 'colossalai_gemini': + actor.to(torch.float16).to(torch.cuda.current_device()) + critic.to(torch.float16).to(torch.cuda.current_device()) + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) + + self.experience_maker.actor = self.strategy.prepare(actor) + self.experience_maker.critic = self.strategy.prepare(critic) + self.experience_maker.initial_model = self.strategy.prepare(initial_model) + self.experience_maker.reward_model = self.strategy.prepare(reward_model) + self.fully_initialized = True + + @ray.method(concurrency_group="model_io") + def update_experience_maker(self, new_actor: Actor, new_critic: Critic): + ''' + called by trainer + ''' + # TODO: reduce malloc + self._model_visit_lock.acquire() + with torch.no_grad(): + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + print("[maker] UPDATE ") + if self.strategy_str != 'colossalai_gemini': + new_actor.to(torch.float16).to(torch.cuda.current_device()) + new_critic.to(torch.float16).to(torch.cuda.current_device()) + self.experience_maker.actor = self.strategy.prepare(new_actor) + self.experience_maker.critic = self.strategy.prepare(new_critic) + self._model_visit_lock.release() diff --git a/applications/Chat/coati/ray/src/pipeline_strategy.py b/applications/Chat/coati/ray/src/pipeline_strategy.py new file mode 100644 index 000000000..1780839c6 --- /dev/null +++ b/applications/Chat/coati/ray/src/pipeline_strategy.py @@ -0,0 +1,105 @@ +# WIP + + +from coati.trainer.strategies import Strategy +from coati.trainer.strategies import NaiveStrategy +from coati.models.base import Actor, RewardModel, Critic + +import numpy as np +import torch +from torch._C._distributed_rpc import _is_current_rpc_agent_set + +import colossalai +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from colossalai.pipeline.middleware.adaptor import get_fx_topology + + +import os +from functools import partial +import random + +rpc_is_initialized = _is_current_rpc_agent_set + +class PipelineModel(torch.nn.Module): + ''' + Actor has 2 kinds of jobs: forward and generate. + better to just pipelinize the inner model + ''' + def __init__(self, + model: torch.nn.Module, + stage_num: int, + num_microbatches: int, + data_kwargs = None, + ): + super().__init__() + # create partition module + def create_partition_module(pp_rank:int, stage_num: int, model, data_kwargs): + model.eval() + tracer = ColoTracer() + meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + annotated_model = balanced_split_pass(gm, stage_num) + top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + return split_submodules[pp_rank + 1] + + def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): + partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) + return partition + self.inference_engine = OneFOneBPipelineEngine( + partition_fn=partial(partition, model, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device='cuda', + ) + + def forward(self, + **model_inputs): + return self.inference_engine.forward_backward(**model_inputs, forward_only=True) + + + +class PPStrategy(NaiveStrategy): + """ + Strategy for Pipeline inference (inference only!) + + master node only + """ + def __init__( + self, + seed: int = 42 + ): + self.seed = seed + super().__init__() + + + def setup_distributed(self) -> None: + colossalai.launch_from_torch({}, seed=self.seed) + ppg.set_global_info(rank = int(os.environ['RANK']), + world_size=int(os.environ['WORLD_SIZE']), + dp_degree=1, + tp_degree=1, + num_worker_threads=128, + device="cuda") + + def model_init_context(self): + return super().model_init_context() + + def setup_model(self, model: torch.nn.Module) -> torch.nn.Module: + if isinstance(model, Actor) or \ + isinstance(model, RewardModel) or \ + isinstance(model, Critic): + model.model = PipelineModel(model.model) + + def set_seed(self, seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + diff --git a/applications/Chat/coati/ray/src/utils.py b/applications/Chat/coati/ray/src/utils.py new file mode 100644 index 000000000..c750879b6 --- /dev/null +++ b/applications/Chat/coati/ray/src/utils.py @@ -0,0 +1,48 @@ +import torch.distributed as dist +from typing import Any, Callable, Dict, List, Optional +from coati.models.bloom import BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTActor, GPTCritic +from coati.models.opt import OPTActor, OPTCritic +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +import torch +import os + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0): + if model == 'gpt2': + actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + elif model == 'bloom': + actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + elif model == 'opt': + actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + else: + raise ValueError(f'Unsupported model "{model}"') + return actor, critic + + +def get_strategy_from_args(strategy: str): + if strategy == 'naive': + strategy_ = NaiveStrategy() + elif strategy == 'ddp': + strategy_ = DDPStrategy() + elif strategy == 'colossalai_gemini': + strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif strategy == 'colossalai_zero2': + strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + return strategy_ + + +def set_dist_env(env_info: Dict[str, str]): + os.environ["RANK"] = env_info['rank'] + os.environ["LOCAL_RANK"] = env_info['local_rank'] + os.environ["WORLD_SIZE"] = env_info['world_size'] + os.environ['MASTER_PORT'] = env_info['master_port'] + os.environ['MASTER_ADDR'] = env_info['master_addr'] diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py index 6c9f7f085..1b17a0421 100644 --- a/applications/Chat/coati/trainer/utils.py +++ b/applications/Chat/coati/trainer/utils.py @@ -1,5 +1,14 @@ import torch.distributed as dist +from typing import Any, Callable, Dict, List, Optional +from coati.models.bloom import BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTActor, GPTCritic +from coati.models.opt import OPTActor, OPTCritic +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +import torch +import os def is_rank_0() -> bool: return not dist.is_initialized() or dist.get_rank() == 0 + + diff --git a/applications/Chat/coati/utils/__init__.py b/applications/Chat/coati/utils/__init__.py index e75401d38..112b82b97 100644 --- a/applications/Chat/coati/utils/__init__.py +++ b/applications/Chat/coati/utils/__init__.py @@ -1,3 +1,3 @@ from .tokenizer_utils import prepare_llama_tokenizer_and_embedding, smart_tokenizer_and_embedding_resize -__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding'] +__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding'] \ No newline at end of file diff --git a/applications/Chat/examples/train_prompts.sh b/applications/Chat/examples/train_prompts.sh index b750cf358..8e1ce67ec 100755 --- a/applications/Chat/examples/train_prompts.sh +++ b/applications/Chat/examples/train_prompts.sh @@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { set_n_least_used_CUDA_VISIBLE_DEVICES 2 +# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2 + torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_path /path/to/data.json --strategy colossalai_zero2