mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] Detached PPO Training (#3195)
* run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments --------- Co-authored-by: csric <richcsr256@gmail.com>pull/3577/head
parent
d329c294ec
commit
e355144375
|
@ -144,3 +144,5 @@ docs/.build
|
||||||
|
|
||||||
# wandb log
|
# wandb log
|
||||||
example/wandb/
|
example/wandb/
|
||||||
|
|
||||||
|
examples/awesome-chatgpt-prompts/
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .src.detached_replay_buffer import DetachedReplayBuffer
|
||||||
|
from .src.detached_trainer_ppo import DetachedPPOTrainer
|
|
@ -0,0 +1,153 @@
|
||||||
|
import argparse
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from coati.trainer import PPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
|
||||||
|
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
|
||||||
|
|
||||||
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
|
from coati.experience_maker import NaiveExperienceMaker
|
||||||
|
from torch.optim import Adam
|
||||||
|
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
|
||||||
|
def get_free_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(('', 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_ip():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||||
|
s.connect(('8.8.8.8', 80))
|
||||||
|
return s.getsockname()[0]
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
master_addr = str(get_local_ip())
|
||||||
|
# trainer_env_info
|
||||||
|
trainer_port = str(get_free_port())
|
||||||
|
env_info_trainer = {'local_rank' : '0',
|
||||||
|
'rank' : '0',
|
||||||
|
'world_size' : '1',
|
||||||
|
'master_port' : trainer_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
|
||||||
|
# maker_env_info
|
||||||
|
maker_port = str(get_free_port())
|
||||||
|
env_info_maker = {'local_rank' : '0',
|
||||||
|
'rank' : '0',
|
||||||
|
'world_size' : '1',
|
||||||
|
'master_port' : maker_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
|
||||||
|
# configure tokenizer
|
||||||
|
if args.model == 'gpt2':
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'bloom':
|
||||||
|
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'opt':
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
# configure Trainer
|
||||||
|
trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
|
||||||
|
experience_maker_holder_name_list=["maker1"],
|
||||||
|
strategy=args.trainer_strategy,
|
||||||
|
model=args.model,
|
||||||
|
env_info = env_info_trainer,
|
||||||
|
pretrained=args.pretrain,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
train_batch_size=args.train_batch_size,
|
||||||
|
buffer_limit=16,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
# configure Experience Maker
|
||||||
|
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
|
||||||
|
detached_trainer_name_list=["trainer1"],
|
||||||
|
strategy=args.maker_strategy,
|
||||||
|
env_info = env_info_maker,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
kl_coef=0.1,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
# trainer send its actor and critic to experience holders.
|
||||||
|
ray.get(trainer_ref.initialize_remote_makers.remote())
|
||||||
|
|
||||||
|
# configure sampler
|
||||||
|
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||||
|
|
||||||
|
def tokenize_fn(texts):
|
||||||
|
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||||
|
# Different length may lead to hang when using gemini, as different generation steps
|
||||||
|
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||||
|
return {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||||
|
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
|
||||||
|
maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||||
|
|
||||||
|
ray.get([trainer_done_ref, maker_done_ref])
|
||||||
|
|
||||||
|
# save model checkpoint after fitting
|
||||||
|
trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||||
|
# save optimizer checkpoint on all ranks
|
||||||
|
if args.need_optim_ckpt:
|
||||||
|
trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||||
|
only_rank0=False)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('prompt_path')
|
||||||
|
parser.add_argument('--trainer_strategy',
|
||||||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
|
default='naive')
|
||||||
|
parser.add_argument('--maker_strategy',
|
||||||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
|
default='naive')
|
||||||
|
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||||
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
|
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||||
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
|
parser.add_argument('--num_episodes', type=int, default=10)
|
||||||
|
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||||
|
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||||
|
parser.add_argument('--max_epochs', type=int, default=5)
|
||||||
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
|
|
||||||
|
parser.add_argument('--debug', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
ray.init(namespace=os.environ["RAY_NAMESPACE"])
|
||||||
|
main(args)
|
|
@ -0,0 +1,23 @@
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
|
local n=${1:-"9999"}
|
||||||
|
echo "GPU Memory Usage:"
|
||||||
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||||
|
| tail -n +2 \
|
||||||
|
| nl -v 0 \
|
||||||
|
| tee /dev/tty \
|
||||||
|
| sort -g -k 2 \
|
||||||
|
| awk '{print $1}' \
|
||||||
|
| head -n $n)
|
||||||
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||||
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||||
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||||
|
|
||||||
|
export RAY_NAMESPACE="admin"
|
||||||
|
|
||||||
|
python 1m1t.py "/path/to/prompts.csv" \
|
||||||
|
--trainer_strategy colossalai_zero2 --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
|
||||||
|
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
|
||||||
|
--max_epochs 10 --debug
|
|
@ -0,0 +1,186 @@
|
||||||
|
import argparse
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from coati.trainer import PPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
|
||||||
|
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
|
||||||
|
|
||||||
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
|
from coati.experience_maker import NaiveExperienceMaker
|
||||||
|
from torch.optim import Adam
|
||||||
|
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
|
||||||
|
|
||||||
|
def get_free_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(('', 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_ip():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||||
|
s.connect(('8.8.8.8', 80))
|
||||||
|
return s.getsockname()[0]
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
master_addr = str(get_local_ip())
|
||||||
|
# trainer_env_info
|
||||||
|
trainer_port = str(get_free_port())
|
||||||
|
env_info_trainer_1 = {'local_rank' : '0',
|
||||||
|
'rank' : '0',
|
||||||
|
'world_size' : '2',
|
||||||
|
'master_port' : trainer_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
env_info_trainer_2 = {'local_rank' : '0',
|
||||||
|
'rank' : '1',
|
||||||
|
'world_size' : '2',
|
||||||
|
'master_port' : trainer_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
# maker_env_info
|
||||||
|
maker_port = str(get_free_port())
|
||||||
|
env_info_maker_1 = {'local_rank' : '0',
|
||||||
|
'rank' : '0',
|
||||||
|
'world_size' : '2',
|
||||||
|
'master_port' : maker_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
print([env_info_trainer_1,
|
||||||
|
env_info_trainer_2,
|
||||||
|
env_info_maker_1])
|
||||||
|
ray.init(dashboard_port = 1145)
|
||||||
|
# configure tokenizer
|
||||||
|
if args.model == 'gpt2':
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'bloom':
|
||||||
|
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'opt':
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
# configure Trainer
|
||||||
|
trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||||
|
experience_maker_holder_name_list=["maker1"],
|
||||||
|
strategy=args.trainer_strategy,
|
||||||
|
model=args.model,
|
||||||
|
env_info=env_info_trainer_1,
|
||||||
|
pretrained=args.pretrain,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
train_batch_size=args.train_batch_size,
|
||||||
|
buffer_limit=16,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||||
|
experience_maker_holder_name_list=["maker1"],
|
||||||
|
strategy=args.trainer_strategy,
|
||||||
|
model=args.model,
|
||||||
|
env_info=env_info_trainer_2,
|
||||||
|
pretrained=args.pretrain,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
train_batch_size=args.train_batch_size,
|
||||||
|
buffer_limit=16,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug= args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
# configure Experience Maker
|
||||||
|
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||||
|
detached_trainer_name_list=["trainer1", "trainer2"],
|
||||||
|
strategy=args.maker_strategy,
|
||||||
|
env_info=env_info_maker_1,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
kl_coef=0.1,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
# trainer send its actor and critic to experience holders.
|
||||||
|
# TODO: balance duty
|
||||||
|
ray.get(trainer_1_ref.initialize_remote_makers.remote())
|
||||||
|
|
||||||
|
# configure sampler
|
||||||
|
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||||
|
|
||||||
|
def tokenize_fn(texts):
|
||||||
|
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||||
|
# Different length may lead to hang when using gemini, as different generation steps
|
||||||
|
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||||
|
return {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||||
|
trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||||
|
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs * 2 + 3 # +3 for fault tolerance
|
||||||
|
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||||
|
|
||||||
|
ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref])
|
||||||
|
# save model checkpoint after fitting
|
||||||
|
trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||||
|
trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||||
|
# save optimizer checkpoint on all ranks
|
||||||
|
if args.need_optim_ckpt:
|
||||||
|
trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||||
|
only_rank0=False)
|
||||||
|
trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||||
|
only_rank0=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('prompt_path')
|
||||||
|
parser.add_argument('--trainer_strategy',
|
||||||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
|
default='naive')
|
||||||
|
parser.add_argument('--maker_strategy',
|
||||||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
|
default='naive')
|
||||||
|
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||||
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
|
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||||
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
|
parser.add_argument('--num_episodes', type=int, default=10)
|
||||||
|
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||||
|
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||||
|
parser.add_argument('--max_epochs', type=int, default=5)
|
||||||
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
|
|
||||||
|
parser.add_argument('--debug', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
|
@ -0,0 +1,23 @@
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
|
local n=${1:-"9999"}
|
||||||
|
echo "GPU Memory Usage:"
|
||||||
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||||
|
| tail -n +2 \
|
||||||
|
| nl -v 0 \
|
||||||
|
| tee /dev/tty \
|
||||||
|
| sort -g -k 2 \
|
||||||
|
| awk '{print $1}' \
|
||||||
|
| head -n $n)
|
||||||
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||||
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||||
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||||
|
|
||||||
|
export RAY_NAMESPACE="admin"
|
||||||
|
|
||||||
|
python 1m2t.py "/path/to/prompts.csv" --model gpt2 \
|
||||||
|
--maker_strategy naive --trainer_strategy ddp --lora_rank 2 \
|
||||||
|
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
|
||||||
|
--max_epochs 10 #--debug
|
|
@ -0,0 +1,140 @@
|
||||||
|
import argparse
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from coati.trainer import PPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
|
||||||
|
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
|
||||||
|
|
||||||
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
|
from coati.experience_maker import NaiveExperienceMaker
|
||||||
|
from torch.optim import Adam
|
||||||
|
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# configure tokenizer
|
||||||
|
if args.model == 'gpt2':
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'bloom':
|
||||||
|
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'opt':
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
# configure Trainer
|
||||||
|
trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
|
||||||
|
experience_maker_holder_name_list=["maker1", "maker2"],
|
||||||
|
strategy=args.trainer_strategy,
|
||||||
|
model=args.model,
|
||||||
|
pretrained=args.pretrain,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
train_batch_size=args.train_batch_size,
|
||||||
|
buffer_limit=16,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
# configure Experience Maker
|
||||||
|
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
|
||||||
|
detached_trainer_name_list=["trainer1"],
|
||||||
|
strategy=args.maker_strategy,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
kl_coef=0.1,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", num_gpus=1, max_concurrency=2).remote(
|
||||||
|
detached_trainer_name_list=["trainer1"],
|
||||||
|
strategy=args.maker_strategy,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
kl_coef=0.1,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
# trainer send its actor and critic to experience holders.
|
||||||
|
ray.get(trainer_ref.initialize_remote_makers.remote())
|
||||||
|
|
||||||
|
# configure sampler
|
||||||
|
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||||
|
|
||||||
|
def tokenize_fn(texts):
|
||||||
|
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||||
|
# Different length may lead to hang when using gemini, as different generation steps
|
||||||
|
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||||
|
return {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||||
|
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs // 2 + 3 # +3 for fault tolerance
|
||||||
|
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||||
|
maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||||
|
|
||||||
|
ray.get([trainer_done_ref, maker_1_done_ref, maker_2_done_ref])
|
||||||
|
|
||||||
|
# save model checkpoint after fitting
|
||||||
|
trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||||
|
# save optimizer checkpoint on all ranks
|
||||||
|
if args.need_optim_ckpt:
|
||||||
|
trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||||
|
only_rank0=False)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('prompt_path')
|
||||||
|
parser.add_argument('--trainer_strategy',
|
||||||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
|
default='naive')
|
||||||
|
parser.add_argument('--maker_strategy',
|
||||||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
|
default='naive')
|
||||||
|
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||||
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
|
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||||
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
|
parser.add_argument('--num_episodes', type=int, default=10)
|
||||||
|
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||||
|
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||||
|
parser.add_argument('--max_epochs', type=int, default=5)
|
||||||
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
|
|
||||||
|
parser.add_argument('--debug', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
ray.init(namespace=os.environ["RAY_NAMESPACE"])
|
||||||
|
main(args)
|
|
@ -0,0 +1,23 @@
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
|
local n=${1:-"9999"}
|
||||||
|
echo "GPU Memory Usage:"
|
||||||
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||||
|
| tail -n +2 \
|
||||||
|
| nl -v 0 \
|
||||||
|
| tee /dev/tty \
|
||||||
|
| sort -g -k 2 \
|
||||||
|
| awk '{print $1}' \
|
||||||
|
| head -n $n)
|
||||||
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||||
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||||
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES 3
|
||||||
|
|
||||||
|
export RAY_NAMESPACE="admin"
|
||||||
|
|
||||||
|
python 2m1t.py "/path/to/prompts.csv" \
|
||||||
|
--trainer_strategy naive --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
|
||||||
|
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
|
||||||
|
--max_epochs 10 # --debug
|
|
@ -0,0 +1,209 @@
|
||||||
|
import argparse
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from coati.trainer import PPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
|
||||||
|
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
|
||||||
|
|
||||||
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
|
from coati.experience_maker import NaiveExperienceMaker
|
||||||
|
from torch.optim import Adam
|
||||||
|
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
|
||||||
|
|
||||||
|
def get_free_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(('', 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_ip():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||||
|
s.connect(('8.8.8.8', 80))
|
||||||
|
return s.getsockname()[0]
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
master_addr = str(get_local_ip())
|
||||||
|
# trainer_env_info
|
||||||
|
trainer_port = str(get_free_port())
|
||||||
|
env_info_trainer_1 = {'local_rank' : '0',
|
||||||
|
'rank' : '0',
|
||||||
|
'world_size' : '2',
|
||||||
|
'master_port' : trainer_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
env_info_trainer_2 = {'local_rank' : '0',
|
||||||
|
'rank' : '1',
|
||||||
|
'world_size' : '2',
|
||||||
|
'master_port' : trainer_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
# maker_env_info
|
||||||
|
maker_port = str(get_free_port())
|
||||||
|
env_info_maker_1 = {'local_rank' : '0',
|
||||||
|
'rank' : '0',
|
||||||
|
'world_size' : '2',
|
||||||
|
'master_port' : maker_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
env_info_maker_2 = {'local_rank' : '0',
|
||||||
|
'rank' : '1',
|
||||||
|
'world_size' : '2',
|
||||||
|
'master_port': maker_port,
|
||||||
|
'master_addr' : master_addr}
|
||||||
|
print([env_info_trainer_1,
|
||||||
|
env_info_trainer_2,
|
||||||
|
env_info_maker_1,
|
||||||
|
env_info_maker_2])
|
||||||
|
ray.init()
|
||||||
|
# configure tokenizer
|
||||||
|
if args.model == 'gpt2':
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'bloom':
|
||||||
|
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'opt':
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
# configure Trainer
|
||||||
|
trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||||
|
experience_maker_holder_name_list=["maker1", "maker2"],
|
||||||
|
strategy=args.trainer_strategy,
|
||||||
|
model=args.model,
|
||||||
|
env_info=env_info_trainer_1,
|
||||||
|
pretrained=args.pretrain,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
train_batch_size=args.train_batch_size,
|
||||||
|
buffer_limit=16,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||||
|
experience_maker_holder_name_list=["maker1", "maker2"],
|
||||||
|
strategy=args.trainer_strategy,
|
||||||
|
model=args.model,
|
||||||
|
env_info=env_info_trainer_2,
|
||||||
|
pretrained=args.pretrain,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
train_batch_size=args.train_batch_size,
|
||||||
|
buffer_limit=16,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
# configure Experience Maker
|
||||||
|
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||||
|
detached_trainer_name_list=["trainer1", "trainer2"],
|
||||||
|
strategy=args.maker_strategy,
|
||||||
|
env_info=env_info_maker_1,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
kl_coef=0.1,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||||
|
detached_trainer_name_list=["trainer1", "trainer2"],
|
||||||
|
strategy=args.maker_strategy,
|
||||||
|
env_info=env_info_maker_2,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
|
kl_coef=0.1,
|
||||||
|
#kwargs:
|
||||||
|
max_length=128,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
# trainer send its actor and critic to experience holders.
|
||||||
|
# TODO: balance duty
|
||||||
|
ray.get(trainer_1_ref.initialize_remote_makers.remote())
|
||||||
|
|
||||||
|
# configure sampler
|
||||||
|
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||||
|
|
||||||
|
def tokenize_fn(texts):
|
||||||
|
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||||
|
# Different length may lead to hang when using gemini, as different generation steps
|
||||||
|
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||||
|
return {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||||
|
trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||||
|
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
|
||||||
|
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||||
|
maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||||
|
|
||||||
|
ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref, maker_2_done_ref])
|
||||||
|
# save model checkpoint after fitting
|
||||||
|
trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||||
|
trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||||
|
# save optimizer checkpoint on all ranks
|
||||||
|
if args.need_optim_ckpt:
|
||||||
|
trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||||
|
only_rank0=False)
|
||||||
|
trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||||
|
only_rank0=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('prompt_path')
|
||||||
|
parser.add_argument('--trainer_strategy',
|
||||||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
|
default='naive')
|
||||||
|
parser.add_argument('--maker_strategy',
|
||||||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
|
default='naive')
|
||||||
|
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||||
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
|
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||||
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
|
parser.add_argument('--num_episodes', type=int, default=10)
|
||||||
|
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||||
|
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||||
|
parser.add_argument('--max_epochs', type=int, default=5)
|
||||||
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
|
|
||||||
|
parser.add_argument('--debug', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
|
@ -0,0 +1,23 @@
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
|
local n=${1:-"9999"}
|
||||||
|
echo "GPU Memory Usage:"
|
||||||
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||||
|
| tail -n +2 \
|
||||||
|
| nl -v 0 \
|
||||||
|
| tee /dev/tty \
|
||||||
|
| sort -g -k 2 \
|
||||||
|
| awk '{print $1}' \
|
||||||
|
| head -n $n)
|
||||||
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||||
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||||
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||||
|
|
||||||
|
export RAY_NAMESPACE="admin"
|
||||||
|
|
||||||
|
python 2m2t.py "path/to/prompts.csv" \
|
||||||
|
--maker_strategy naive --trainer_strategy colossalai_zero2 --lora_rank 2 \
|
||||||
|
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
|
||||||
|
--max_epochs 10 --debug
|
|
@ -0,0 +1,88 @@
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
from typing import List, Any
|
||||||
|
# from torch.multiprocessing import Queue
|
||||||
|
from ray.util.queue import Queue
|
||||||
|
import ray
|
||||||
|
import asyncio
|
||||||
|
from coati.experience_maker.base import Experience
|
||||||
|
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||||
|
from coati.replay_buffer import ReplayBuffer
|
||||||
|
from threading import Lock
|
||||||
|
import copy
|
||||||
|
|
||||||
|
class DetachedReplayBuffer:
|
||||||
|
'''
|
||||||
|
Detached replay buffer. Share Experience across workers on the same node.
|
||||||
|
Therefore a trainer node is expected to have only one instance.
|
||||||
|
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch.
|
||||||
|
tp_world_size: Number of workers in the same tp group
|
||||||
|
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
|
||||||
|
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None:
|
||||||
|
self.cpu_offload = cpu_offload
|
||||||
|
self.sample_batch_size = sample_batch_size
|
||||||
|
self.limit = limit
|
||||||
|
self.items = Queue(self.limit, actor_options={"num_cpus":1})
|
||||||
|
self.batch_collector : List[BufferItem] = []
|
||||||
|
|
||||||
|
'''
|
||||||
|
Workers in the same tp group share this buffer and need same sample for one step.
|
||||||
|
Therefore a held_sample should be returned tp_world_size times before it could be dropped.
|
||||||
|
worker_state records wheter a worker got the held_sample
|
||||||
|
'''
|
||||||
|
self.tp_world_size = tp_world_size
|
||||||
|
self.worker_state = [False] * self.tp_world_size
|
||||||
|
self.held_sample = None
|
||||||
|
self._worker_state_lock = Lock()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def append(self, experience: Experience) -> None:
|
||||||
|
'''
|
||||||
|
Expected to be called remotely.
|
||||||
|
'''
|
||||||
|
if self.cpu_offload:
|
||||||
|
experience.to_device(torch.device('cpu'))
|
||||||
|
items = split_experience_batch(experience)
|
||||||
|
self.batch_collector.extend(items)
|
||||||
|
while len(self.batch_collector) >= self.sample_batch_size:
|
||||||
|
items = self.batch_collector[:self.sample_batch_size]
|
||||||
|
experience = make_experience_batch(items)
|
||||||
|
self.items.put(experience, block=True)
|
||||||
|
self.batch_collector = self.batch_collector[self.sample_batch_size:]
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
# self.items.close()
|
||||||
|
self.items.shutdown()
|
||||||
|
self.items = Queue(self.limit)
|
||||||
|
self.worker_state = [False] * self.tp_world_size
|
||||||
|
self.batch_collector = []
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, worker_rank = 0, to_device = "cpu") -> Experience:
|
||||||
|
self._worker_state_lock.acquire()
|
||||||
|
if not any(self.worker_state):
|
||||||
|
self.held_sample = self._sample_and_erase()
|
||||||
|
self.worker_state[worker_rank] = True
|
||||||
|
if all(self.worker_state):
|
||||||
|
self.worker_state = [False] * self.tp_world_size
|
||||||
|
ret = self.held_sample
|
||||||
|
else:
|
||||||
|
ret = copy.deepcopy(self.held_sample)
|
||||||
|
self._worker_state_lock.release()
|
||||||
|
ret.to_device(to_device)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _sample_and_erase(self) -> Experience:
|
||||||
|
ret = self.items.get(block=True)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_length(self) -> int:
|
||||||
|
ret = self.items.qsize()
|
||||||
|
return ret
|
|
@ -0,0 +1,121 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
from tqdm import tqdm
|
||||||
|
from coati.trainer.callbacks import Callback
|
||||||
|
from coati.experience_maker import Experience
|
||||||
|
import ray
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .detached_replay_buffer import DetachedReplayBuffer
|
||||||
|
from .utils import is_rank_0
|
||||||
|
|
||||||
|
class DetachedTrainer(ABC):
|
||||||
|
'''
|
||||||
|
Base class for detached rlhf trainers.
|
||||||
|
'detach' means that the experience maker is detached compared to a normal Trainer.
|
||||||
|
Please set name attribute during init:
|
||||||
|
>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
|
||||||
|
So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
|
||||||
|
Args:
|
||||||
|
detached_strategy (DetachedStrategy): the strategy to use for training
|
||||||
|
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
|
||||||
|
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||||
|
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||||
|
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||||
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||||
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
experience_maker_holder_name_list: List[str],
|
||||||
|
train_batch_size: int = 8,
|
||||||
|
buffer_limit: int = 0,
|
||||||
|
buffer_cpu_offload: bool = True,
|
||||||
|
experience_batch_size: int = 8,
|
||||||
|
max_epochs: int = 1,
|
||||||
|
dataloader_pin_memory: bool = True,
|
||||||
|
callbacks: List[Callback] = [],
|
||||||
|
**generate_kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload)
|
||||||
|
self.experience_batch_size = experience_batch_size
|
||||||
|
self.max_epochs = max_epochs
|
||||||
|
self.dataloader_pin_memory = dataloader_pin_memory
|
||||||
|
self.callbacks = callbacks
|
||||||
|
self.generate_kwargs = generate_kwargs
|
||||||
|
self.target_holder_name_list = experience_maker_holder_name_list
|
||||||
|
self.target_holder_list = []
|
||||||
|
|
||||||
|
def update_target_holder_list(self, experience_maker_holder_name_list):
|
||||||
|
self.target_holder_name_list = experience_maker_holder_name_list
|
||||||
|
self.target_holder_list = []
|
||||||
|
for name in self.target_holder_name_list:
|
||||||
|
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _update_remote_makers(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _learn(self):
|
||||||
|
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||||||
|
for _ in pbar:
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print("[trainer] sampling exp")
|
||||||
|
experience = self._buffer_sample()
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print("[trainer] training step")
|
||||||
|
metrics = self.training_step(experience)
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print("[trainer] step over")
|
||||||
|
pbar.set_postfix(metrics)
|
||||||
|
|
||||||
|
def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
|
||||||
|
self._on_fit_start()
|
||||||
|
for episode in range(num_episodes):
|
||||||
|
self._on_episode_start(episode)
|
||||||
|
for timestep in tqdm(range(max_timesteps // update_timesteps),
|
||||||
|
desc=f'Episode [{episode+1}/{num_episodes}]',
|
||||||
|
disable=not is_rank_0()):
|
||||||
|
self._learn()
|
||||||
|
self._update_remote_makers()
|
||||||
|
self._on_episode_end(episode)
|
||||||
|
self._on_fit_end()
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="buffer_length")
|
||||||
|
def buffer_get_length(self):
|
||||||
|
# called by ExperienceMakerHolder
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print("[trainer] telling length")
|
||||||
|
return self.detached_replay_buffer.get_length()
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="buffer_append")
|
||||||
|
def buffer_append(self, experience: Experience):
|
||||||
|
# called by ExperienceMakerHolder
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
# print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}")
|
||||||
|
print(f"[trainer] receiving exp.")
|
||||||
|
self.detached_replay_buffer.append(experience)
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="buffer_sample")
|
||||||
|
def _buffer_sample(self):
|
||||||
|
return self.detached_replay_buffer.sample()
|
||||||
|
|
||||||
|
def _on_fit_start(self) -> None:
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_fit_start()
|
||||||
|
|
||||||
|
def _on_fit_end(self) -> None:
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_fit_end()
|
||||||
|
|
||||||
|
def _on_episode_start(self, episode: int) -> None:
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_episode_start(episode)
|
||||||
|
|
||||||
|
def _on_episode_end(self, episode: int) -> None:
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_episode_end(episode)
|
|
@ -0,0 +1,192 @@
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
import torch
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||||
|
from coati.models.base import Actor, Critic
|
||||||
|
from coati.models.generation_utils import update_model_kwargs_fn
|
||||||
|
from coati.models.loss import PolicyLoss, ValueLoss
|
||||||
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
|
||||||
|
from coati.trainer.callbacks import Callback
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
|
||||||
|
from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env
|
||||||
|
from .detached_trainer_base import DetachedTrainer
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1})
|
||||||
|
class DetachedPPOTrainer(DetachedTrainer):
|
||||||
|
'''
|
||||||
|
Detached Trainer for PPO algorithm
|
||||||
|
Args:
|
||||||
|
strategy (Strategy): the strategy to use for training
|
||||||
|
model (str) : for actor / critic init
|
||||||
|
pretrained (str) : for actor / critic init
|
||||||
|
lora_rank (int) : for actor / critic init
|
||||||
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||||
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||||
|
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
|
||||||
|
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
||||||
|
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
||||||
|
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
||||||
|
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||||
|
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||||
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||||
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||||
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
experience_maker_holder_name_list: List[str],
|
||||||
|
strategy: str,
|
||||||
|
model: str,
|
||||||
|
env_info: Dict[str, str] = None,
|
||||||
|
pretrained: str = None,
|
||||||
|
lora_rank: int = 0,
|
||||||
|
train_batch_size: int = 8,
|
||||||
|
buffer_limit: int = 0,
|
||||||
|
buffer_cpu_offload: bool = True,
|
||||||
|
eps_clip: float = 0.2,
|
||||||
|
value_clip: float = 0.4,
|
||||||
|
experience_batch_size: int = 8,
|
||||||
|
max_epochs: int = 10,
|
||||||
|
dataloader_pin_memory: bool = True,
|
||||||
|
callbacks: List[Callback] = [],
|
||||||
|
**generate_kwargs) -> None:
|
||||||
|
# set environment variables
|
||||||
|
if env_info:
|
||||||
|
set_dist_env(env_info=env_info)
|
||||||
|
# configure strategy
|
||||||
|
self.strategy = get_strategy_from_args(strategy)
|
||||||
|
# configure models, loss and optimizers
|
||||||
|
with self.strategy.model_init_context():
|
||||||
|
self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank)
|
||||||
|
|
||||||
|
if strategy != 'colossalai_gemini':
|
||||||
|
self.actor.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
self.critic.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
|
||||||
|
if strategy.startswith('colossalai'):
|
||||||
|
self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6)
|
||||||
|
self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6)
|
||||||
|
else:
|
||||||
|
self.actor_optim = Adam(self.actor.parameters(), lr=5e-6)
|
||||||
|
self.critic_optim = Adam(self.critic.parameters(), lr=5e-6)
|
||||||
|
|
||||||
|
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
|
||||||
|
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
|
||||||
|
generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)
|
||||||
|
|
||||||
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||||
|
self.critic_loss_fn = ValueLoss(value_clip)
|
||||||
|
|
||||||
|
super().__init__(experience_maker_holder_name_list,
|
||||||
|
train_batch_size=train_batch_size,
|
||||||
|
buffer_limit=buffer_limit,
|
||||||
|
buffer_cpu_offload=buffer_cpu_offload,
|
||||||
|
experience_batch_size=experience_batch_size,
|
||||||
|
max_epochs=max_epochs,
|
||||||
|
dataloader_pin_memory=dataloader_pin_memory,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**generate_kwargs)
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="model_io")
|
||||||
|
def _update_remote_makers(self):
|
||||||
|
# TODO: balance duties
|
||||||
|
if is_rank_0():
|
||||||
|
self.update_target_holder_list(self.target_holder_name_list)
|
||||||
|
for target_holder in self.target_holder_list:
|
||||||
|
# TODO: reduce malloc
|
||||||
|
with torch.no_grad():
|
||||||
|
ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="model_io")
|
||||||
|
def initialize_remote_makers(self):
|
||||||
|
# TODO: balance duties
|
||||||
|
if is_rank_0():
|
||||||
|
self.update_target_holder_list(self.target_holder_name_list)
|
||||||
|
for target_holder in self.target_holder_list:
|
||||||
|
# TODO: reduce malloc
|
||||||
|
with torch.no_grad():
|
||||||
|
ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="compute")
|
||||||
|
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||||
|
self.actor.train()
|
||||||
|
self.critic.train()
|
||||||
|
|
||||||
|
experience.to_device(torch.cuda.current_device())
|
||||||
|
num_actions = experience.action_mask.size(1)
|
||||||
|
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||||
|
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||||
|
experience.action_log_probs,
|
||||||
|
experience.advantages,
|
||||||
|
action_mask=experience.action_mask)
|
||||||
|
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||||
|
self.strategy.optimizer_step(self.actor_optim)
|
||||||
|
self.actor_optim.zero_grad()
|
||||||
|
|
||||||
|
values = self.critic(experience.sequences,
|
||||||
|
action_mask=experience.action_mask,
|
||||||
|
attention_mask=experience.attention_mask)
|
||||||
|
critic_loss = self.critic_loss_fn(values,
|
||||||
|
experience.values,
|
||||||
|
experience.reward,
|
||||||
|
action_mask=experience.action_mask)
|
||||||
|
|
||||||
|
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||||
|
self.strategy.optimizer_step(self.critic_optim)
|
||||||
|
self.critic_optim.zero_grad()
|
||||||
|
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||||
|
|
||||||
|
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
|
||||||
|
self.strategy.save_model(self.actor, path, only_rank0)
|
||||||
|
|
||||||
|
def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None:
|
||||||
|
self.strategy.save_model(self.critic, path, only_rank0)
|
||||||
|
|
||||||
|
def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None:
|
||||||
|
self.strategy.save_optimizer(self.actor_optim, path, only_rank0)
|
||||||
|
|
||||||
|
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
|
||||||
|
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
|
||||||
|
|
||||||
|
def _get_unwrapped_actor(self):
|
||||||
|
if False:
|
||||||
|
pass
|
||||||
|
elif isinstance(self.strategy, ColossalAIStrategy):
|
||||||
|
ret = Actor(self.strategy._unwrap_model(self.actor))
|
||||||
|
return ret
|
||||||
|
elif isinstance(self.strategy, DDPStrategy):
|
||||||
|
return Actor(self.strategy._unwrap_actor(self.actor))
|
||||||
|
elif isinstance(self.strategy, NaiveStrategy):
|
||||||
|
return self.actor
|
||||||
|
|
||||||
|
def _get_unwrapped_critic(self):
|
||||||
|
if False:
|
||||||
|
pass
|
||||||
|
elif isinstance(self.strategy, ColossalAIStrategy):
|
||||||
|
ret = self.strategy._unwrap_model(self.critic)
|
||||||
|
return ret
|
||||||
|
elif isinstance(self.strategy, DDPStrategy):
|
||||||
|
return self.critic.module
|
||||||
|
elif isinstance(self.strategy, NaiveStrategy):
|
||||||
|
return self.critic
|
||||||
|
|
||||||
|
|
||||||
|
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
||||||
|
origin_model = strategy._unwrap_actor(actor)
|
||||||
|
new_kwargs = {**generate_kwargs}
|
||||||
|
# use huggingface models method directly
|
||||||
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||||
|
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||||
|
|
||||||
|
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||||
|
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||||
|
|
||||||
|
return new_kwargs
|
||||||
|
|
|
@ -0,0 +1,172 @@
|
||||||
|
import torch
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
import ray
|
||||||
|
from ray.exceptions import GetTimeoutError
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn as nn
|
||||||
|
from coati.models.base import Actor, Critic, RewardModel
|
||||||
|
from coati.trainer.strategies.sampler import DistributedSampler
|
||||||
|
from coati.trainer.strategies import Strategy
|
||||||
|
from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from threading import Lock
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
from .utils import is_rank_0, get_strategy_from_args, set_dist_env
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||||
|
class ExperienceMakerHolder:
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
detached_trainer_name_list: str list to get ray actor handleskkk
|
||||||
|
strategy:
|
||||||
|
experience_batch_size: batch size of generated experience
|
||||||
|
kl_coef: the coefficient of kl divergence loss
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
detached_trainer_name_list: List[str],
|
||||||
|
strategy: str,
|
||||||
|
env_info: Dict[str, str] = None,
|
||||||
|
experience_batch_size: int = 8,
|
||||||
|
kl_coef: float = 0.1,
|
||||||
|
**generate_kwargs):
|
||||||
|
# set environment variables
|
||||||
|
if env_info:
|
||||||
|
set_dist_env(env_info=env_info)
|
||||||
|
self.target_trainer_list = []
|
||||||
|
for name in detached_trainer_name_list:
|
||||||
|
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||||
|
self.strategy_str = strategy
|
||||||
|
self.strategy = get_strategy_from_args(strategy)
|
||||||
|
self.experience_batch_size = experience_batch_size
|
||||||
|
self.kl_coef = kl_coef
|
||||||
|
self.generate_kwargs = generate_kwargs
|
||||||
|
# Need a trainer to give an actor and a critic via initialize_experience_maker(...)
|
||||||
|
actor, critic, reward_model, initial_model = None, None, None, None
|
||||||
|
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
|
||||||
|
self._model_visit_lock = Lock()
|
||||||
|
self.fully_initialized = False
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print('[maker] Waiting for INIT')
|
||||||
|
|
||||||
|
def _get_ready(self):
|
||||||
|
while not self.fully_initialized:
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
|
def update_target_trainer_list(self, detached_trainer_name_list):
|
||||||
|
self.target_trainer_list = []
|
||||||
|
for name in detached_trainer_name_list:
|
||||||
|
self.target_trainer_list.append(ray.get_actor(name))
|
||||||
|
|
||||||
|
# copy from ../trainer/base.py
|
||||||
|
@ray.method(concurrency_group="compute")
|
||||||
|
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||||||
|
self._get_ready()
|
||||||
|
if isinstance(inputs, Tensor):
|
||||||
|
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||||||
|
elif isinstance(inputs, dict):
|
||||||
|
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="experience_io")
|
||||||
|
def _send_experience(self, experience):
|
||||||
|
'''
|
||||||
|
ignore it
|
||||||
|
|
||||||
|
# choose a trainer that has the least experience batch in its detached_replay_buffer
|
||||||
|
chosen_trainer = None
|
||||||
|
min_length = None
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print("[maker] choosing tartget trainer")
|
||||||
|
while chosen_trainer is None:
|
||||||
|
for target_trainer in self.target_trainer_list:
|
||||||
|
try:
|
||||||
|
temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1)
|
||||||
|
if min_length is None:
|
||||||
|
min_length = temp_length
|
||||||
|
chosen_trainer = target_trainer
|
||||||
|
else:
|
||||||
|
if temp_length < min_length:
|
||||||
|
min_length = temp_length
|
||||||
|
chosen_trainer = target_trainer
|
||||||
|
except GetTimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print(f"[maker] sending exp to {chosen_trainer}")
|
||||||
|
chosen_trainer.buffer_append.remote(experience)
|
||||||
|
'''
|
||||||
|
#
|
||||||
|
if not hasattr(self, "_target_idx"):
|
||||||
|
self._target_idx = 0
|
||||||
|
chosen_trainer = self.target_trainer_list[self._target_idx]
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print(f"[maker] sending exp to {chosen_trainer}")
|
||||||
|
chosen_trainer.buffer_append.remote(experience)
|
||||||
|
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
|
||||||
|
|
||||||
|
def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000):
|
||||||
|
self._get_ready()
|
||||||
|
sampler = self.strategy.setup_sampler(dataset)
|
||||||
|
for _ in range(times):
|
||||||
|
rand_prompts = sampler.sample(self.experience_batch_size)
|
||||||
|
if tokenizer is not None:
|
||||||
|
inputs = tokenizer(rand_prompts)
|
||||||
|
else:
|
||||||
|
inputs = rand_prompts
|
||||||
|
self._model_visit_lock.acquire()
|
||||||
|
experience = self._make_experience(inputs=inputs)
|
||||||
|
self._model_visit_lock.release()
|
||||||
|
self._send_experience(experience=experience)
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="model_io")
|
||||||
|
def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic):
|
||||||
|
'''
|
||||||
|
called by trainer. Only once.
|
||||||
|
'''
|
||||||
|
# TODO: reduce malloc
|
||||||
|
if self.fully_initialized:
|
||||||
|
return
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print('[maker] INIT')
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.strategy.model_init_context():
|
||||||
|
actor = init_actor
|
||||||
|
critic = init_critic
|
||||||
|
initial_model = deepcopy(actor)
|
||||||
|
reward_model = RewardModel(deepcopy(critic.model),
|
||||||
|
deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||||
|
if self.strategy_str != 'colossalai_gemini':
|
||||||
|
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
|
||||||
|
self.experience_maker.actor = self.strategy.prepare(actor)
|
||||||
|
self.experience_maker.critic = self.strategy.prepare(critic)
|
||||||
|
self.experience_maker.initial_model = self.strategy.prepare(initial_model)
|
||||||
|
self.experience_maker.reward_model = self.strategy.prepare(reward_model)
|
||||||
|
self.fully_initialized = True
|
||||||
|
|
||||||
|
@ray.method(concurrency_group="model_io")
|
||||||
|
def update_experience_maker(self, new_actor: Actor, new_critic: Critic):
|
||||||
|
'''
|
||||||
|
called by trainer
|
||||||
|
'''
|
||||||
|
# TODO: reduce malloc
|
||||||
|
self._model_visit_lock.acquire()
|
||||||
|
with torch.no_grad():
|
||||||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||||
|
print("[maker] UPDATE ")
|
||||||
|
if self.strategy_str != 'colossalai_gemini':
|
||||||
|
new_actor.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
new_critic.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
self.experience_maker.actor = self.strategy.prepare(new_actor)
|
||||||
|
self.experience_maker.critic = self.strategy.prepare(new_critic)
|
||||||
|
self._model_visit_lock.release()
|
|
@ -0,0 +1,105 @@
|
||||||
|
# WIP
|
||||||
|
|
||||||
|
|
||||||
|
from coati.trainer.strategies import Strategy
|
||||||
|
from coati.trainer.strategies import NaiveStrategy
|
||||||
|
from coati.models.base import Actor, RewardModel, Critic
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||||
|
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
import random
|
||||||
|
|
||||||
|
rpc_is_initialized = _is_current_rpc_agent_set
|
||||||
|
|
||||||
|
class PipelineModel(torch.nn.Module):
|
||||||
|
'''
|
||||||
|
Actor has 2 kinds of jobs: forward and generate.
|
||||||
|
better to just pipelinize the inner model
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
stage_num: int,
|
||||||
|
num_microbatches: int,
|
||||||
|
data_kwargs = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# create partition module
|
||||||
|
def create_partition_module(pp_rank:int, stage_num: int, model, data_kwargs):
|
||||||
|
model.eval()
|
||||||
|
tracer = ColoTracer()
|
||||||
|
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
annotated_model = balanced_split_pass(gm, stage_num)
|
||||||
|
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||||
|
topo = get_fx_topology(top_module)
|
||||||
|
for submodule in split_submodules:
|
||||||
|
if isinstance(submodule, torch.fx.GraphModule):
|
||||||
|
setattr(submodule, '_topo', topo)
|
||||||
|
return split_submodules[pp_rank + 1]
|
||||||
|
|
||||||
|
def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
|
||||||
|
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||||
|
return partition
|
||||||
|
self.inference_engine = OneFOneBPipelineEngine(
|
||||||
|
partition_fn=partial(partition, model, data_kwargs),
|
||||||
|
stage_num=stage_num,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
device='cuda',
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
**model_inputs):
|
||||||
|
return self.inference_engine.forward_backward(**model_inputs, forward_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PPStrategy(NaiveStrategy):
|
||||||
|
"""
|
||||||
|
Strategy for Pipeline inference (inference only!)
|
||||||
|
|
||||||
|
master node only
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seed: int = 42
|
||||||
|
):
|
||||||
|
self.seed = seed
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_distributed(self) -> None:
|
||||||
|
colossalai.launch_from_torch({}, seed=self.seed)
|
||||||
|
ppg.set_global_info(rank = int(os.environ['RANK']),
|
||||||
|
world_size=int(os.environ['WORLD_SIZE']),
|
||||||
|
dp_degree=1,
|
||||||
|
tp_degree=1,
|
||||||
|
num_worker_threads=128,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
def model_init_context(self):
|
||||||
|
return super().model_init_context()
|
||||||
|
|
||||||
|
def setup_model(self, model: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
if isinstance(model, Actor) or \
|
||||||
|
isinstance(model, RewardModel) or \
|
||||||
|
isinstance(model, Critic):
|
||||||
|
model.model = PipelineModel(model.model)
|
||||||
|
|
||||||
|
def set_seed(self, seed: int) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
import torch.distributed as dist
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
from coati.models.bloom import BLOOMActor, BLOOMCritic
|
||||||
|
from coati.models.gpt import GPTActor, GPTCritic
|
||||||
|
from coati.models.opt import OPTActor, OPTCritic
|
||||||
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
def is_rank_0() -> bool:
|
||||||
|
return not dist.is_initialized() or dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0):
|
||||||
|
if model == 'gpt2':
|
||||||
|
actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||||
|
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||||
|
elif model == 'bloom':
|
||||||
|
actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||||
|
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||||
|
elif model == 'opt':
|
||||||
|
actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||||
|
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported model "{model}"')
|
||||||
|
return actor, critic
|
||||||
|
|
||||||
|
|
||||||
|
def get_strategy_from_args(strategy: str):
|
||||||
|
if strategy == 'naive':
|
||||||
|
strategy_ = NaiveStrategy()
|
||||||
|
elif strategy == 'ddp':
|
||||||
|
strategy_ = DDPStrategy()
|
||||||
|
elif strategy == 'colossalai_gemini':
|
||||||
|
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||||
|
elif strategy == 'colossalai_zero2':
|
||||||
|
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||||
|
return strategy_
|
||||||
|
|
||||||
|
|
||||||
|
def set_dist_env(env_info: Dict[str, str]):
|
||||||
|
os.environ["RANK"] = env_info['rank']
|
||||||
|
os.environ["LOCAL_RANK"] = env_info['local_rank']
|
||||||
|
os.environ["WORLD_SIZE"] = env_info['world_size']
|
||||||
|
os.environ['MASTER_PORT'] = env_info['master_port']
|
||||||
|
os.environ['MASTER_ADDR'] = env_info['master_addr']
|
|
@ -1,5 +1,14 @@
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
from coati.models.bloom import BLOOMActor, BLOOMCritic
|
||||||
|
from coati.models.gpt import GPTActor, GPTCritic
|
||||||
|
from coati.models.opt import OPTActor, OPTCritic
|
||||||
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def is_rank_0() -> bool:
|
def is_rank_0() -> bool:
|
||||||
return not dist.is_initialized() or dist.get_rank() == 0
|
return not dist.is_initialized() or dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
|
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||||
|
|
||||||
|
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_path /path/to/data.json --strategy colossalai_zero2
|
torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_path /path/to/data.json --strategy colossalai_zero2
|
||||||
|
|
Loading…
Reference in New Issue