mirror of https://github.com/hpcaitech/ColossalAI
[chat] add distributed PPO trainer (#3740)
* Detached ppo (#9) * run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments * working on pipeline strategy. in progress. * remove pipeline code. clean this branch * update remote parameters by state_dict. no test * nothing * state_dict sharding transfer * merge debug branch * gemini _unwrap_model fix * simplify code * simplify code & fix LoRALinear AttributeError * critic unwrapped state_dict --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add perfomance evaluator and fix bugs (#10) * [chat] add performance evaluator for ray * [chat] refactor debug arg * [chat] support hf config * [chat] fix generation * [chat] add 1mmt dummy example * [chat] fix gemini ckpt * split experience to send (#11) Co-authored-by: csric <richcsr256@gmail.com> * [chat] refactor trainer and maker (#12) * [chat] refactor experience maker holder * [chat] refactor model init * [chat] refactor trainer args * [chat] refactor model init * [chat] refactor trainer * [chat] refactor experience sending logic and training loop args (#13) * [chat] refactor experience send logic * [chat] refactor trainer * [chat] refactor trainer * [chat] refactor experience maker * [chat] refactor pbar * [chat] refactor example folder (#14) * [chat] support quant (#15) * [chat] add quant * [chat] add quant example * prompt example (#16) * prompt example * prompt load csv data * remove legacy try --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add mmmt dummy example and refactor experience sending (#17) * [chat] add mmmt dummy example * [chat] refactor naive strategy * [chat] fix struck problem * [chat] fix naive strategy * [chat] optimize experience maker sending logic * [chat] refactor sending assignment * [chat] refactor performance evaluator (#18) * Prompt Example & requires_grad state_dict & sharding state_dict (#19) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design --------- Co-authored-by: csric <richcsr256@gmail.com> * state_dict sending adapts to new unwrap function (#20) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design * opt benchmark * better script * nothing * [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test * working on lora reconstruction * state_dict sending adapts to new unwrap function * remove comments --------- Co-authored-by: csric <richcsr256@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * [chat-ray] add readme (#21) * add readme * transparent graph * add note background --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] get images from url (#22) * Refactor/chat ray (#23) * [chat] lora add todo * [chat] remove unused pipeline strategy * [chat] refactor example structure * [chat] setup ci for ray * [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24) * lora support prototype * lora support * 1mmt lora & remove useless code --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] fix test ci for ray * [chat] fix test ci requirements for ray * [chat] fix ray runtime env * [chat] fix ray runtime env * [chat] fix example ci docker args * [chat] add debug info in trainer * [chat] add nccl debug info * [chat] skip ray test * [doc] fix typo --------- Co-authored-by: csric <59389055+CsRic@users.noreply.github.com> Co-authored-by: csric <richcsr256@gmail.com>pull/3911/head^2
parent
41fb7236aa
commit
b5f0566363
|
@ -20,7 +20,7 @@ jobs:
|
|||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat
|
||||
options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
import argparse
|
||||
import os
|
||||
import socket
|
||||
from functools import partial
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
|
||||
from coati.ray.experience_maker_holder import ExperienceMakerHolder
|
||||
from coati.ray.utils import (
|
||||
get_actor_from_args,
|
||||
get_critic_from_args,
|
||||
get_receivers_per_sender,
|
||||
get_reward_model_from_args,
|
||||
get_strategy_from_args,
|
||||
)
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_trainers)]
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_maker = {
|
||||
'local_rank': '0',
|
||||
'rank': '0',
|
||||
'world_size': '1',
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
}
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def model_fn():
|
||||
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
|
||||
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
|
||||
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
reward_model = get_reward_model_from_args(args.critic_model,
|
||||
config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
||||
# quantize initial model
|
||||
with low_resource_init(), no_init_weights():
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
||||
args.quant_group_size).cuda().requires_grad_(False)
|
||||
else:
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||
return actor, critic, reward_model, initial_model
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
|
||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||
model_fn=model_fn,
|
||||
env_info=env_info_maker,
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
eval_performance=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
def trainer_model_fn():
|
||||
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
|
||||
critic = get_critic_from_args(args.critic_model,
|
||||
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
|
||||
return actor, critic
|
||||
|
||||
# configure Trainer
|
||||
trainer_refs = [
|
||||
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=[
|
||||
f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
|
||||
],
|
||||
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
||||
model_fn=trainer_model_fn,
|
||||
env_info=env_info_trainer,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
eval_performance=True,
|
||||
debug=args.debug,
|
||||
) for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
]
|
||||
|
||||
dataset_size = args.experience_batch_size * 4
|
||||
|
||||
def data_gen_fn():
|
||||
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
|
||||
attn_mask = torch.ones_like(input_ids)
|
||||
return {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
|
||||
def build_dataloader(size):
|
||||
dataset = [data_gen_fn() for _ in range(size)]
|
||||
dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
|
||||
return dataloader
|
||||
|
||||
# uncomment this function if sync_models_from_trainers is True
|
||||
# ray.get([
|
||||
# trainer_ref.sync_models_to_remote_makers.remote()
|
||||
# for trainer_ref in trainer_refs
|
||||
# ])
|
||||
|
||||
wait_tasks = []
|
||||
|
||||
wait_tasks.append(
|
||||
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
|
||||
num_steps=args.experience_steps))
|
||||
|
||||
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
|
||||
for trainer_ref in trainer_refs:
|
||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
||||
|
||||
ray.get(wait_tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
||||
parser.add_argument('--experience_steps', type=int, default=4)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--train_epochs', type=int, default=1)
|
||||
parser.add_argument('--update_steps', type=int, default=2)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
||||
parser.add_argument('--quant_bits', type=int, default=4)
|
||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||
main(args)
|
|
@ -0,0 +1,189 @@
|
|||
import argparse
|
||||
import os
|
||||
import socket
|
||||
from functools import partial
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
|
||||
from coati.ray.experience_maker_holder import ExperienceMakerHolder
|
||||
from coati.ray.utils import (
|
||||
get_actor_from_args,
|
||||
get_critic_from_args,
|
||||
get_receivers_per_sender,
|
||||
get_reward_model_from_args,
|
||||
get_strategy_from_args,
|
||||
)
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_trainers)]
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_makers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_makers),
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_makers)]
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def model_fn():
|
||||
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
|
||||
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
|
||||
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
reward_model = get_reward_model_from_args(args.critic_model,
|
||||
config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
||||
# quantize initial model
|
||||
with low_resource_init(), no_init_weights():
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
||||
args.quant_group_size).cuda().requires_grad_(False)
|
||||
else:
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||
return actor, critic, reward_model, initial_model
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_refs = [
|
||||
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=[
|
||||
f'trainer{x}'
|
||||
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
|
||||
],
|
||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||
model_fn=model_fn,
|
||||
env_info=env_info_maker,
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
eval_performance=True,
|
||||
use_cache=True,
|
||||
)
|
||||
for i, env_info_maker in enumerate(env_info_makers)
|
||||
]
|
||||
|
||||
def trainer_model_fn():
|
||||
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
|
||||
critic = get_critic_from_args(args.critic_model,
|
||||
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
|
||||
return actor, critic
|
||||
|
||||
# configure Trainer
|
||||
trainer_refs = [
|
||||
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=[
|
||||
f"maker{x}"
|
||||
for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
|
||||
],
|
||||
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
||||
model_fn=trainer_model_fn,
|
||||
env_info=env_info_trainer,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
eval_performance=True,
|
||||
debug=args.debug,
|
||||
)
|
||||
for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
]
|
||||
|
||||
dataset_size = args.experience_batch_size * 4
|
||||
|
||||
def data_gen_fn():
|
||||
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
|
||||
attn_mask = torch.ones_like(input_ids)
|
||||
return {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
|
||||
def build_dataloader(size):
|
||||
dataset = [data_gen_fn() for _ in range(size)]
|
||||
dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
|
||||
return dataloader
|
||||
|
||||
# uncomment this function if sync_models_from_trainers is True
|
||||
# ray.get([
|
||||
# trainer_ref.sync_models_to_remote_makers.remote()
|
||||
# for trainer_ref in trainer_refs
|
||||
# ])
|
||||
|
||||
wait_tasks = []
|
||||
|
||||
for experience_holder_ref in experience_holder_refs:
|
||||
wait_tasks.append(
|
||||
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
|
||||
num_steps=args.experience_steps))
|
||||
|
||||
total_steps = args.experience_batch_size * args.experience_steps * \
|
||||
args.num_makers // (args.num_trainers * args.train_batch_size)
|
||||
for trainer_ref in trainer_refs:
|
||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
||||
|
||||
ray.get(wait_tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--num_makers', type=int, default=1)
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
||||
parser.add_argument('--experience_steps', type=int, default=4)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--train_epochs', type=int, default=1)
|
||||
parser.add_argument('--update_steps', type=int, default=2)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
||||
parser.add_argument('--quant_bits', type=int, default=4)
|
||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||
main(args)
|
|
@ -61,7 +61,13 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
|
||||
# FIXME(csric): temporary fix
|
||||
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
|
||||
self.reset_parameters()
|
||||
else:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .llama_gptq import load_quant as llama_load_quant
|
||||
from .utils import low_resource_init
|
||||
|
||||
__all__ = [
|
||||
'llama_load_quant',
|
||||
'low_resource_init',
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
from .loader import load_quant
|
||||
|
||||
__all__ = [
|
||||
'load_quant',
|
||||
]
|
|
@ -0,0 +1,26 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .model_utils import find_layers
|
||||
from .quant import make_quant
|
||||
|
||||
|
||||
def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
|
||||
model = model.eval()
|
||||
layers = find_layers(model)
|
||||
|
||||
# ignore lm head
|
||||
layers = find_layers(model)
|
||||
for name in ['lm_head']:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
|
||||
make_quant(model, layers, wbits, groupsize)
|
||||
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
model.load_state_dict(safe_load(checkpoint))
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
|
||||
return model
|
|
@ -0,0 +1,13 @@
|
|||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||
if type(module) in layers:
|
||||
return {name: module}
|
||||
res = {}
|
||||
for name1, child in module.named_children():
|
||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
||||
return res
|
|
@ -0,0 +1,283 @@
|
|||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def quantize(x, scale, zero, maxq):
|
||||
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
||||
return scale * (q - zero)
|
||||
|
||||
|
||||
class Quantizer(nn.Module):
|
||||
|
||||
def __init__(self, shape=1):
|
||||
super(Quantizer, self).__init__()
|
||||
self.register_buffer('maxq', torch.tensor(0))
|
||||
self.register_buffer('scale', torch.zeros(shape))
|
||||
self.register_buffer('zero', torch.zeros(shape))
|
||||
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
self.mse = mse
|
||||
self.norm = norm
|
||||
self.grid = grid
|
||||
self.maxshrink = maxshrink
|
||||
|
||||
def find_params(self, x, weight=False):
|
||||
dev = x.device
|
||||
self.maxq = self.maxq.to(dev)
|
||||
|
||||
shape = x.shape
|
||||
if self.perchannel:
|
||||
if weight:
|
||||
x = x.flatten(1)
|
||||
else:
|
||||
if len(shape) == 4:
|
||||
x = x.permute([1, 0, 2, 3])
|
||||
x = x.flatten(1)
|
||||
if len(shape) == 3:
|
||||
x = x.reshape((-1, shape[-1])).t()
|
||||
if len(shape) == 2:
|
||||
x = x.t()
|
||||
else:
|
||||
x = x.flatten().unsqueeze(0)
|
||||
|
||||
tmp = torch.zeros(x.shape[0], device=dev)
|
||||
xmin = torch.minimum(x.min(1)[0], tmp)
|
||||
xmax = torch.maximum(x.max(1)[0], tmp)
|
||||
|
||||
if self.sym:
|
||||
xmax = torch.maximum(torch.abs(xmin), xmax)
|
||||
tmp = xmin < 0
|
||||
if torch.any(tmp):
|
||||
xmin[tmp] = -xmax[tmp]
|
||||
tmp = (xmin == 0) & (xmax == 0)
|
||||
xmin[tmp] = -1
|
||||
xmax[tmp] = +1
|
||||
|
||||
self.scale = (xmax - xmin) / self.maxq
|
||||
if self.sym:
|
||||
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
||||
else:
|
||||
self.zero = torch.round(-xmin / self.scale)
|
||||
|
||||
if self.mse:
|
||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
||||
for i in range(int(self.maxshrink * self.grid)):
|
||||
p = 1 - i / self.grid
|
||||
xmin1 = p * xmin
|
||||
xmax1 = p * xmax
|
||||
scale1 = (xmax1 - xmin1) / self.maxq
|
||||
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
||||
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
||||
q -= x
|
||||
q.abs_()
|
||||
q.pow_(self.norm)
|
||||
err = torch.sum(q, 1)
|
||||
tmp = err < best
|
||||
if torch.any(tmp):
|
||||
best[tmp] = err[tmp]
|
||||
self.scale[tmp] = scale1[tmp]
|
||||
self.zero[tmp] = zero1[tmp]
|
||||
if not self.perchannel:
|
||||
if weight:
|
||||
tmp = shape[0]
|
||||
else:
|
||||
tmp = shape[1] if len(shape) != 3 else shape[2]
|
||||
self.scale = self.scale.repeat(tmp)
|
||||
self.zero = self.zero.repeat(tmp)
|
||||
|
||||
if weight:
|
||||
shape = [-1] + [1] * (len(shape) - 1)
|
||||
self.scale = self.scale.reshape(shape)
|
||||
self.zero = self.zero.reshape(shape)
|
||||
return
|
||||
if len(shape) == 4:
|
||||
self.scale = self.scale.reshape((1, -1, 1, 1))
|
||||
self.zero = self.zero.reshape((1, -1, 1, 1))
|
||||
if len(shape) == 3:
|
||||
self.scale = self.scale.reshape((1, 1, -1))
|
||||
self.zero = self.zero.reshape((1, 1, -1))
|
||||
if len(shape) == 2:
|
||||
self.scale = self.scale.unsqueeze(0)
|
||||
self.zero = self.zero.unsqueeze(0)
|
||||
|
||||
def quantize(self, x):
|
||||
if self.ready():
|
||||
return quantize(x, self.scale, self.zero, self.maxq)
|
||||
return x
|
||||
|
||||
def enabled(self):
|
||||
return self.maxq > 0
|
||||
|
||||
def ready(self):
|
||||
return torch.all(self.scale != 0)
|
||||
|
||||
|
||||
try:
|
||||
import quant_cuda
|
||||
except:
|
||||
print('CUDA extension not installed.')
|
||||
|
||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures):
|
||||
super().__init__()
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))):
|
||||
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
|
||||
groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.groupsize = groupsize
|
||||
self.register_buffer(
|
||||
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
|
||||
dtype=torch.int))
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
||||
self.register_buffer('bias', torch.zeros(outfeatures))
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
||||
self._initialized_quant_state = False
|
||||
|
||||
def pack(self, linear, scales, zeros):
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
g_idx = idx // self.groupsize
|
||||
intweight.append(
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
|
||||
None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
elif self.bits == 3:
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i))
|
||||
i += 10
|
||||
qweight[row] |= intweight[i] << 30
|
||||
row += 1
|
||||
qweight[row] |= (intweight[i] >> 2) & 1
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
|
||||
i += 10
|
||||
qweight[row] |= intweight[i] << 31
|
||||
row += 1
|
||||
qweight[row] |= (intweight[i] >> 1) & 0x3
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
|
||||
i += 10
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
elif self.bits == 3:
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
|
||||
i += 10
|
||||
qzeros[:, col] |= zeros[:, i] << 30
|
||||
col += 1
|
||||
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
|
||||
i += 10
|
||||
qzeros[:, col] |= zeros[:, i] << 31
|
||||
col += 1
|
||||
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
|
||||
i += 10
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
intermediate_dtype = torch.float32
|
||||
|
||||
if not self._initialized_quant_state:
|
||||
# Do we even have a bias? Check for at least one non-zero element.
|
||||
if self.bias is not None and bool(torch.any(self.bias != 0)):
|
||||
# Then make sure it's the right type.
|
||||
self.bias.data = self.bias.data.to(intermediate_dtype)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
outshape = list(x.shape)
|
||||
outshape[-1] = self.outfeatures
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if self.bias is None:
|
||||
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
|
||||
else:
|
||||
y = self.bias.clone().repeat(x.shape[0], 1)
|
||||
|
||||
output_dtype = x.dtype
|
||||
x = x.to(intermediate_dtype)
|
||||
if self.bits == 2:
|
||||
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 3:
|
||||
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 4:
|
||||
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 8:
|
||||
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
y = y.to(output_dtype)
|
||||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def make_quant(module, names, bits, groupsize, name=''):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
for attr in dir(module):
|
||||
tmp = getattr(module, attr)
|
||||
name1 = name + '.' + attr if name != '' else attr
|
||||
if name1 in names:
|
||||
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
|
||||
for name1, child in module.named_children():
|
||||
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
|
@ -0,0 +1,28 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def low_resource_init():
|
||||
"""This context manager disables weight initialization and sets the default float dtype to half.
|
||||
"""
|
||||
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
|
||||
old_uniform_ = torch.nn.init.uniform_
|
||||
old_normal_ = torch.nn.init.normal_
|
||||
dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.nn.init.kaiming_uniform_ = _noop
|
||||
torch.nn.init.uniform_ = _noop
|
||||
torch.nn.init.normal_ = _noop
|
||||
torch.set_default_dtype(torch.half)
|
||||
yield
|
||||
finally:
|
||||
torch.nn.init.kaiming_uniform_ = old_kaiming_uniform_
|
||||
torch.nn.init.uniform_ = old_uniform_
|
||||
torch.nn.init.normal_ = old_normal_
|
||||
torch.set_default_dtype(dtype)
|
|
@ -0,0 +1,160 @@
|
|||
# Distributed PPO Training on Stage 3
|
||||
|
||||
## Detach Experience Makers and Trainers
|
||||
|
||||
We can completely separate the trainers and makers.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/basic_structure.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
- The experience maker performs inference, produces experience, and remotely delivers it to the trainer (1).
|
||||
- The trainer consumes experience to train models, and periodically transmits new model parameters to the maker (2.1, 2.2).
|
||||
- Using an experience buffer to overlap transmission and computing.
|
||||
|
||||
In this manner, each node will work continuously without model idle time, and different optimization strategies can be applied for inference and training to meet the needs of speed or storage. It is also helpful for scalability.
|
||||
|
||||
`DetachedPPOTrainer` and `ExperienceMakerHolder` are Ray Actors (distinguished from Actor Model), representing Trainer and Experience Maker on the graph above, respectively.
|
||||
|
||||
[More about Ray Core](https://docs.ray.io/en/latest/ray-core/walkthrough.html)
|
||||
|
||||
## Usage
|
||||
|
||||
See examples at `ColossalAI/application/Chat/examples/ray`
|
||||
|
||||
### Setup Makers
|
||||
|
||||
- define makers' environment variables :
|
||||
|
||||
```python
|
||||
env_info_makers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(num_makers),
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(num_makers)]
|
||||
|
||||
```
|
||||
- define maker models :
|
||||
```python
|
||||
def model_fn():
|
||||
actor = get_actor_from_args(...)
|
||||
critic = get_critic_from_args(...)
|
||||
reward_model = get_reward_model_from_args(...)
|
||||
initial_model = get_actor_from_args(...)
|
||||
return actor, critic, reward_model, initial_model
|
||||
|
||||
```
|
||||
- set experience_holder_refs :
|
||||
|
||||
```python
|
||||
experience_holder_refs = [
|
||||
ExperienceMakerHolder.options(
|
||||
name=f"maker_{i}",
|
||||
num_gpus=1,
|
||||
max_concurrency=2
|
||||
).remote(
|
||||
detached_trainer_name_list=[f"trainer_{x}" for x in target_trainers(...)],
|
||||
model_fn=model_fn,
|
||||
...)
|
||||
for i, env_info_maker in enumerate(env_info_makers)
|
||||
]
|
||||
```
|
||||
The names in the `detached_trainer_name_list` refer to the target trainers that the maker should send experience to.
|
||||
We set a trainer's name the same as a maker, by `.options(name="str")`. See below.
|
||||
|
||||
### Setup Trainers
|
||||
|
||||
- define trainers' environment variables :
|
||||
```python
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(num_trainers)]
|
||||
```
|
||||
- define trainer models :
|
||||
|
||||
```python
|
||||
def trainer_model_fn():
|
||||
actor = get_actor_from_args(...)
|
||||
critic = get_critic_from_args(...)
|
||||
return actor, critic
|
||||
```
|
||||
- set trainer_refs :
|
||||
```python
|
||||
trainer_refs = [
|
||||
DetachedPPOTrainer.options(
|
||||
name=f"trainer{i}",
|
||||
num_gpus=1,
|
||||
max_concurrency=2
|
||||
).remote(
|
||||
experience_maker_holder_name_list=[f"maker{x}" for x in target_makers(...)],
|
||||
model_fn = trainer_model_fn(),
|
||||
...)
|
||||
for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
]
|
||||
```
|
||||
The names in `experience_maker_holder_name_list` refer to the target makers that the trainer should send updated models to.
|
||||
By setting `detached_trainer_name_list` and `experience_maker_holder_name_list`, we can customize the transmission graph.
|
||||
|
||||
### Launch Jobs
|
||||
- define data_loader :
|
||||
```python
|
||||
def data_loader_fn():
|
||||
return = torch.utils.data.DataLoader(dataset=dataset)
|
||||
|
||||
```
|
||||
- launch makers :
|
||||
```python
|
||||
wait_tasks = []
|
||||
for experience_holder_ref in experience_holder_refs:
|
||||
wait_tasks.append(
|
||||
experience_holder_ref.workingloop.remote(data_loader_fn(),
|
||||
num_steps=experience_steps))
|
||||
|
||||
```
|
||||
|
||||
- launch trainers :
|
||||
```python
|
||||
for trainer_ref in trainer_refs:
|
||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, update_steps, train_epochs))
|
||||
```
|
||||
|
||||
- wait for done :
|
||||
```python
|
||||
ray.get(wait_tasks)
|
||||
```
|
||||
|
||||
## Flexible Structure
|
||||
|
||||
We can deploy different strategies to makers and trainers. Here are some notions.
|
||||
|
||||
### 2 Makers 1 Trainer
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m1t.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
### 2 Makers 2 Trainer
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m2t.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
### Maker Inference Quantization
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m2t_quantize.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
### Tensor Parallel
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/tp_ddp_hybrid.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Support LoRA
|
||||
- [ ] Support TP & PP
|
|
@ -1,2 +0,0 @@
|
|||
from .src.detached_replay_buffer import DetachedReplayBuffer
|
||||
from .src.detached_trainer_ppo import DetachedPPOTrainer
|
|
@ -0,0 +1,9 @@
|
|||
from .base import MakerCallback, TrainerCallback
|
||||
from .performance_evaluator import ExperienceMakerPerformanceEvaluator, TrainerPerformanceEvaluator
|
||||
|
||||
__all__ = [
|
||||
"TrainerCallback",
|
||||
"MakerCallback",
|
||||
"ExperienceMakerPerformanceEvaluator",
|
||||
"TrainerPerformanceEvaluator",
|
||||
]
|
|
@ -0,0 +1,66 @@
|
|||
from abc import ABC
|
||||
|
||||
from coati.experience_maker import Experience
|
||||
|
||||
|
||||
class TrainerCallback(ABC):
|
||||
"""
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
"""
|
||||
|
||||
def on_fit_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_start(self, episode: int) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_end(self, episode: int) -> None:
|
||||
pass
|
||||
|
||||
def on_epoch_start(self, epoch: int) -> None:
|
||||
pass
|
||||
|
||||
def on_epoch_end(self, epoch: int) -> None:
|
||||
pass
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
pass
|
||||
|
||||
def on_update_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_update_end(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MakerCallback(ABC):
|
||||
|
||||
def on_loop_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_loop_end(self) -> None:
|
||||
pass
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_make_experience_end(self, experience: Experience) -> None:
|
||||
pass
|
||||
|
||||
def on_send_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_send_end(self) -> None:
|
||||
pass
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_batch_end(self) -> None:
|
||||
pass
|
|
@ -0,0 +1,212 @@
|
|||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.experience_maker import Experience
|
||||
|
||||
from .base import MakerCallback, TrainerCallback
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
return 1
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
|
||||
def end(self) -> None:
|
||||
self.duration += time() - self.start_time
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.
|
||||
|
||||
|
||||
class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
|
||||
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
|
||||
reward_model_num_params: int) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
self.critic_num_params = critic_num_params
|
||||
self.initial_model_num_params = initial_model_num_params
|
||||
self.reward_model_num_params = reward_model_num_params
|
||||
|
||||
self.batch_timer = Timer()
|
||||
self.send_timer = Timer()
|
||||
self.make_experience_timer = Timer()
|
||||
self.total_samples: int = 0
|
||||
self.make_experience_flop: int = 0
|
||||
|
||||
print_rank_0(
|
||||
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
||||
)
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
self.make_experience_timer.start()
|
||||
|
||||
def on_make_experience_end(self, experience: Experience) -> None:
|
||||
self.make_experience_timer.end()
|
||||
|
||||
batch_size, seq_len = experience.sequences.shape
|
||||
|
||||
self.total_samples += batch_size
|
||||
|
||||
# actor generate
|
||||
num_actions = experience.action_mask.size(1)
|
||||
input_len = seq_len - num_actions
|
||||
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
|
||||
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
|
||||
# actor forward
|
||||
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
|
||||
# critic forward
|
||||
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
|
||||
# initial model forward
|
||||
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
|
||||
# reward model forward
|
||||
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
|
||||
|
||||
def on_send_start(self) -> None:
|
||||
self.send_timer.start()
|
||||
|
||||
def on_send_end(self) -> None:
|
||||
self.send_timer.end()
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
self.batch_timer.start()
|
||||
|
||||
def on_batch_end(self) -> None:
|
||||
self.batch_timer.end()
|
||||
|
||||
def on_loop_end(self) -> None:
|
||||
avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
|
||||
avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
||||
avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size)
|
||||
|
||||
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
|
||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
|
||||
(self.total_samples * self.world_size)
|
||||
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
||||
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
|
||||
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
||||
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
)
|
||||
|
||||
|
||||
class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
|
||||
def __init__(self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_first_episodes: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
self.critic_num_params = critic_num_params
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_first_episodes = ignore_first_episodes
|
||||
self.ignore_this_episode = False
|
||||
|
||||
self.episode_timer = Timer()
|
||||
self.batch_timer = Timer()
|
||||
self.update_timer = Timer()
|
||||
self.total_samples: int = 0
|
||||
self.learn_flop: int = 0
|
||||
|
||||
print_rank_0(
|
||||
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
||||
)
|
||||
|
||||
def on_episode_start(self, episodes: int) -> None:
|
||||
self.ignore_this_episode = episodes < self.ignore_first_episodes
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.episode_timer.start()
|
||||
|
||||
def on_episode_end(self, episodes: int) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.episode_timer.end()
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.batch_timer.start()
|
||||
|
||||
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.batch_timer.end()
|
||||
|
||||
batch_size, seq_len = experience.sequences.shape
|
||||
|
||||
self.total_samples += batch_size
|
||||
|
||||
# actor forward-backward, 3 means forward(1) + backward(2)
|
||||
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
# critic forward-backward
|
||||
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_update_start(self) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.update_timer.start()
|
||||
|
||||
def on_update_end(self) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.update_timer.end()
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
if self.total_samples == 0:
|
||||
print_rank_0('No samples are collected, skip trainer performance evaluation')
|
||||
return
|
||||
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
||||
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
|
||||
avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size)
|
||||
|
||||
avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12)
|
||||
avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12)
|
||||
avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
||||
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
||||
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
)
|
|
@ -1,22 +1,24 @@
|
|||
import torch
|
||||
import asyncio
|
||||
import copy
|
||||
import random
|
||||
from typing import List, Any
|
||||
from threading import Lock
|
||||
from typing import Any, List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_maker.base import Experience
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
# from torch.multiprocessing import Queue
|
||||
from ray.util.queue import Queue
|
||||
import ray
|
||||
import asyncio
|
||||
from coati.experience_maker.base import Experience
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from threading import Lock
|
||||
import copy
|
||||
|
||||
|
||||
class DetachedReplayBuffer:
|
||||
'''
|
||||
Detached replay buffer. Share Experience across workers on the same node.
|
||||
Therefore a trainer node is expected to have only one instance.
|
||||
Detached replay buffer. Share Experience across workers on the same node.
|
||||
Therefore a trainer node is expected to have only one instance.
|
||||
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
|
||||
|
||||
|
||||
Args:
|
||||
sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch.
|
||||
tp_world_size: Number of workers in the same tp group
|
||||
|
@ -24,31 +26,25 @@ class DetachedReplayBuffer:
|
|||
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
|
||||
'''
|
||||
|
||||
def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None:
|
||||
self.cpu_offload = cpu_offload
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.limit = limit
|
||||
self.items = Queue(self.limit, actor_options={"num_cpus":1})
|
||||
self.batch_collector : List[BufferItem] = []
|
||||
|
||||
'''
|
||||
Workers in the same tp group share this buffer and need same sample for one step.
|
||||
Therefore a held_sample should be returned tp_world_size times before it could be dropped.
|
||||
worker_state records whether a worker got the held_sample
|
||||
'''
|
||||
self.tp_world_size = tp_world_size
|
||||
self.worker_state = [False] * self.tp_world_size
|
||||
self.held_sample = None
|
||||
self._worker_state_lock = Lock()
|
||||
self.items = Queue(self.limit, actor_options={"num_cpus": 1})
|
||||
self.batch_collector: List[BufferItem] = []
|
||||
|
||||
@torch.no_grad()
|
||||
def append(self, experience: Experience) -> None:
|
||||
'''
|
||||
Expected to be called remotely.
|
||||
'''
|
||||
if self.cpu_offload:
|
||||
experience.to_device(torch.device('cpu'))
|
||||
items = split_experience_batch(experience)
|
||||
self.extend(items)
|
||||
|
||||
@torch.no_grad()
|
||||
def extend(self, items: List[BufferItem]) -> None:
|
||||
'''
|
||||
Expected to be called remotely.
|
||||
'''
|
||||
self.batch_collector.extend(items)
|
||||
while len(self.batch_collector) >= self.sample_batch_size:
|
||||
items = self.batch_collector[:self.sample_batch_size]
|
||||
|
@ -62,19 +58,10 @@ class DetachedReplayBuffer:
|
|||
self.items = Queue(self.limit)
|
||||
self.worker_state = [False] * self.tp_world_size
|
||||
self.batch_collector = []
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, worker_rank = 0, to_device = "cpu") -> Experience:
|
||||
self._worker_state_lock.acquire()
|
||||
if not any(self.worker_state):
|
||||
self.held_sample = self._sample_and_erase()
|
||||
self.worker_state[worker_rank] = True
|
||||
if all(self.worker_state):
|
||||
self.worker_state = [False] * self.tp_world_size
|
||||
ret = self.held_sample
|
||||
else:
|
||||
ret = copy.deepcopy(self.held_sample)
|
||||
self._worker_state_lock.release()
|
||||
def sample(self, worker_rank=0, to_device="cpu") -> Experience:
|
||||
ret = self._sample_and_erase()
|
||||
ret.to_device(to_device)
|
||||
return ret
|
||||
|
||||
|
@ -85,4 +72,4 @@ class DetachedReplayBuffer:
|
|||
|
||||
def get_length(self) -> int:
|
||||
ret = self.items.qsize()
|
||||
return ret
|
||||
return ret
|
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_maker import Experience
|
||||
from coati.replay_buffer.utils import BufferItem
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import TrainerCallback
|
||||
from .detached_replay_buffer import DetachedReplayBuffer
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class DetachedTrainer(ABC):
|
||||
'''
|
||||
Base class for detached rlhf trainers.
|
||||
'detach' means that the experience maker is detached compared to a normal Trainer.
|
||||
Please set name attribute during init:
|
||||
>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
|
||||
So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
|
||||
Args:
|
||||
detached_strategy (DetachedStrategy): the strategy to use for training
|
||||
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
|
||||
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
debug: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
self.target_holder_name_list = experience_maker_holder_name_list
|
||||
self.target_holder_list = []
|
||||
self._is_target_holder_initialized = False
|
||||
self._debug = debug
|
||||
|
||||
def update_target_holder_list(self):
|
||||
# as the length of target_holder_list may be zero, we need to check it by a bool flag
|
||||
if not self._is_target_holder_initialized:
|
||||
for name in self.target_holder_name_list:
|
||||
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||
self._is_target_holder_initialized = True
|
||||
|
||||
@abstractmethod
|
||||
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
|
||||
pass
|
||||
|
||||
def sync_models_to_remote_makers(self, **kwargs):
|
||||
self._update_remote_makers(fully_update=True, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def _learn(self, update_steps: int, train_epochs: int) -> None:
|
||||
data = []
|
||||
# warmup
|
||||
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
|
||||
self._on_epoch_start(0)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(0)
|
||||
# item is already a batch
|
||||
dataloader = DataLoader(data,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
pin_memory=self.dataloader_pin_memory,
|
||||
collate_fn=lambda x: x[0])
|
||||
for epoch in range(1, train_epochs):
|
||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
|
||||
self._on_epoch_start(epoch)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(epoch)
|
||||
|
||||
def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
|
||||
is_warmup = len(data) == 0
|
||||
for x in pbar:
|
||||
if self._debug:
|
||||
print("[trainer] training step")
|
||||
# sample a batch and then train to avoid waiting
|
||||
experience = x if not is_warmup else self._buffer_sample()
|
||||
experience.to_device(torch.cuda.current_device())
|
||||
self._on_batch_start()
|
||||
metrics = self.training_step(experience)
|
||||
self._on_batch_end(metrics, experience)
|
||||
|
||||
if self._debug:
|
||||
print("[trainer] step over")
|
||||
experience.to_device("cpu")
|
||||
if is_warmup:
|
||||
data.append(experience)
|
||||
pbar.set_postfix(metrics)
|
||||
|
||||
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
|
||||
self._on_fit_start()
|
||||
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
|
||||
self._on_episode_start(i)
|
||||
self._learn(update_steps, train_epochs)
|
||||
self._on_update_start()
|
||||
self._update_remote_makers()
|
||||
self._on_update_end()
|
||||
self._on_episode_end(i)
|
||||
self._on_fit_end()
|
||||
|
||||
@ray.method(concurrency_group="buffer_length")
|
||||
def buffer_get_length(self):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print("[trainer] telling length")
|
||||
return self.detached_replay_buffer.get_length()
|
||||
|
||||
@ray.method(concurrency_group="buffer_append")
|
||||
def buffer_append(self, experience: Experience):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print(f"[trainer] receiving exp.")
|
||||
self.detached_replay_buffer.append(experience)
|
||||
|
||||
@ray.method(concurrency_group="buffer_append")
|
||||
def buffer_extend(self, items: List[BufferItem]):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print(f"[trainer] receiving exp.")
|
||||
self.detached_replay_buffer.extend(items)
|
||||
|
||||
@ray.method(concurrency_group="buffer_sample")
|
||||
def _buffer_sample(self):
|
||||
return self.detached_replay_buffer.sample()
|
||||
|
||||
def _on_fit_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_start()
|
||||
|
||||
def _on_fit_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_end()
|
||||
|
||||
def _on_episode_start(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_start(episode)
|
||||
|
||||
def _on_episode_end(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_end(episode)
|
||||
|
||||
def _on_epoch_start(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_start(epoch)
|
||||
|
||||
def _on_epoch_end(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_end(epoch)
|
||||
|
||||
def _on_batch_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_start()
|
||||
|
||||
def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_end(metrics, experience)
|
||||
|
||||
def _on_update_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_update_start()
|
||||
|
||||
def _on_update_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_update_end()
|
|
@ -1,24 +1,38 @@
|
|||
from typing import Any, Callable, Dict, List, Optional
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic
|
||||
from coati.models.generation_utils import update_model_kwargs_fn
|
||||
from coati.models.loss import PolicyLoss, ValueLoss
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env
|
||||
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
|
||||
from .detached_trainer_base import DetachedTrainer
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import (
|
||||
get_actor_from_args,
|
||||
get_critic_from_args,
|
||||
get_model_numel,
|
||||
get_rank,
|
||||
get_strategy_from_args,
|
||||
is_rank_0,
|
||||
set_dist_env,
|
||||
state_dict_to,
|
||||
)
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1})
|
||||
@ray.remote(concurrency_groups={
|
||||
"buffer_length": 1,
|
||||
"buffer_append": 1,
|
||||
"buffer_sample": 1,
|
||||
"model_io": 1,
|
||||
"compute": 1
|
||||
})
|
||||
class DetachedPPOTrainer(DetachedTrainer):
|
||||
'''
|
||||
Detached Trainer for PPO algorithm
|
||||
|
@ -40,86 +54,102 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
strategy: str,
|
||||
model: str,
|
||||
env_info: Dict[str, str] = None,
|
||||
pretrained: str = None,
|
||||
lora_rank: int = 0,
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
buffer_cpu_offload: bool = True,
|
||||
eps_clip: float = 0.2,
|
||||
value_clip: float = 0.4,
|
||||
experience_batch_size: int = 8,
|
||||
max_epochs: int = 10,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
model_fn: Callable[[], Tuple[Actor, Critic]],
|
||||
env_info: Dict[str, str] = None,
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
eps_clip: float = 0.2,
|
||||
value_clip: float = 0.4,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
eval_performance: bool = False,
|
||||
debug: bool = False,
|
||||
update_lora_weights: bool = False,
|
||||
) -> None:
|
||||
# set environment variables
|
||||
if env_info:
|
||||
set_dist_env(env_info=env_info)
|
||||
# configure strategy
|
||||
self.strategy = get_strategy_from_args(strategy)
|
||||
self.strategy = strategy_fn()
|
||||
# configure models, loss and optimizers
|
||||
with self.strategy.model_init_context():
|
||||
self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank)
|
||||
self.actor, self.critic = model_fn()
|
||||
|
||||
if strategy != 'colossalai_gemini':
|
||||
self.actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
self.critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
if eval_performance:
|
||||
actor_numel = get_model_numel(self.actor)
|
||||
critic_numel = get_model_numel(self.critic)
|
||||
evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
|
||||
callbacks = callbacks + [evaluator]
|
||||
|
||||
if strategy.startswith('colossalai'):
|
||||
self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6)
|
||||
self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6)
|
||||
if isinstance(self.strategy, ColossalAIStrategy):
|
||||
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
|
||||
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
self.actor_optim = Adam(self.actor.parameters(), lr=5e-6)
|
||||
self.critic_optim = Adam(self.critic.parameters(), lr=5e-6)
|
||||
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
|
||||
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
|
||||
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
|
||||
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
|
||||
generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)
|
||||
|
||||
# configure trainer
|
||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
|
||||
super().__init__(experience_maker_holder_name_list,
|
||||
train_batch_size=train_batch_size,
|
||||
buffer_limit=buffer_limit,
|
||||
buffer_cpu_offload=buffer_cpu_offload,
|
||||
experience_batch_size=experience_batch_size,
|
||||
max_epochs=max_epochs,
|
||||
dataloader_pin_memory=dataloader_pin_memory,
|
||||
callbacks=callbacks,
|
||||
**generate_kwargs)
|
||||
debug=debug)
|
||||
if self._debug:
|
||||
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
|
||||
|
||||
self._update_lora_weights = update_lora_weights
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
def _update_remote_makers(self):
|
||||
@torch.no_grad()
|
||||
def _update_remote_makers(self, fully_update: bool = False, **config):
|
||||
# TODO: balance duties
|
||||
if is_rank_0():
|
||||
self.update_target_holder_list(self.target_holder_name_list)
|
||||
if not fully_update:
|
||||
config['requires_grad_only'] = True
|
||||
self.update_target_holder_list()
|
||||
# mark start, ensure order
|
||||
tasks = []
|
||||
for target_holder in self.target_holder_list:
|
||||
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
|
||||
ray.get(tasks)
|
||||
# sending loop
|
||||
tasks = []
|
||||
|
||||
for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config):
|
||||
for target_holder in self.target_holder_list:
|
||||
# TODO: reduce malloc
|
||||
with torch.no_grad():
|
||||
ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
def initialize_remote_makers(self):
|
||||
# TODO: balance duties
|
||||
if is_rank_0():
|
||||
self.update_target_holder_list(self.target_holder_name_list)
|
||||
tasks.append(
|
||||
target_holder.update_experience_maker.remote(
|
||||
new_actor_state_dict=state_dict_shard,
|
||||
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
|
||||
fully_update=fully_update))
|
||||
# sending loop
|
||||
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
|
||||
for target_holder in self.target_holder_list:
|
||||
# TODO: reduce malloc
|
||||
with torch.no_grad():
|
||||
ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
|
||||
tasks.append(
|
||||
target_holder.update_experience_maker.remote(
|
||||
new_critic_state_dict=state_dict_shard,
|
||||
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
|
||||
fully_update=fully_update))
|
||||
ray.get(tasks)
|
||||
# mark end
|
||||
for target_holder in self.target_holder_list:
|
||||
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
|
||||
|
||||
@ray.method(concurrency_group="compute")
|
||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
experience.to_device(torch.cuda.current_device())
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||
|
@ -155,38 +185,16 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
|
||||
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
|
||||
|
||||
def _get_unwrapped_actor(self):
|
||||
if False:
|
||||
pass
|
||||
elif isinstance(self.strategy, ColossalAIStrategy):
|
||||
ret = Actor(self.strategy._unwrap_model(self.actor))
|
||||
return ret
|
||||
elif isinstance(self.strategy, DDPStrategy):
|
||||
return Actor(self.strategy._unwrap_actor(self.actor))
|
||||
elif isinstance(self.strategy, NaiveStrategy):
|
||||
return self.actor
|
||||
def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config):
|
||||
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
|
||||
if not self._update_lora_weights or fully_update:
|
||||
yield state_dict_to(state_dict)
|
||||
else:
|
||||
state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
|
||||
yield state_dict_to(state_dict_lora)
|
||||
|
||||
def _get_unwrapped_critic(self):
|
||||
if False:
|
||||
pass
|
||||
elif isinstance(self.strategy, ColossalAIStrategy):
|
||||
ret = self.strategy._unwrap_model(self.critic)
|
||||
return ret
|
||||
elif isinstance(self.strategy, DDPStrategy):
|
||||
return self.critic.module
|
||||
elif isinstance(self.strategy, NaiveStrategy):
|
||||
return self.critic
|
||||
|
||||
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = strategy._unwrap_actor(actor)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||
|
||||
return new_kwargs
|
||||
|
||||
def _get_model_lora_config_dict(self, model: torch.nn.Module):
|
||||
if not self._update_lora_weights:
|
||||
return None
|
||||
unwrapped_model = self.strategy.unwrap_model(model)
|
||||
return LoRAConstructor.extract_lora_config(unwrapped_model)
|
|
@ -1,153 +0,0 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from coati.trainer import PPOTrainer
|
||||
|
||||
|
||||
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
|
||||
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
|
||||
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.experience_maker import NaiveExperienceMaker
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
import ray
|
||||
import os
|
||||
import socket
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainer = {'local_rank' : '0',
|
||||
'rank' : '0',
|
||||
'world_size' : '1',
|
||||
'master_port' : trainer_port,
|
||||
'master_addr' : master_addr}
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_maker = {'local_rank' : '0',
|
||||
'rank' : '0',
|
||||
'world_size' : '1',
|
||||
'master_port' : maker_port,
|
||||
'master_addr' : master_addr}
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure Trainer
|
||||
trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=["maker1"],
|
||||
strategy=args.trainer_strategy,
|
||||
model=args.model,
|
||||
env_info = env_info_trainer,
|
||||
pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=["trainer1"],
|
||||
strategy=args.maker_strategy,
|
||||
env_info = env_info_maker,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
kl_coef=0.1,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
# trainer send its actor and critic to experience holders.
|
||||
ray.get(trainer_ref.initialize_remote_makers.remote())
|
||||
|
||||
# configure sampler
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
|
||||
maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||
|
||||
ray.get([trainer_done_ref, maker_done_ref])
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('prompt_path')
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--maker_strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"])
|
||||
main(args)
|
|
@ -1,23 +0,0 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
export RAY_NAMESPACE="admin"
|
||||
|
||||
python 1m1t.py "/path/to/prompts.csv" \
|
||||
--trainer_strategy colossalai_zero2 --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
|
||||
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
|
||||
--max_epochs 10 --debug
|
|
@ -1,186 +0,0 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from coati.trainer import PPOTrainer
|
||||
|
||||
|
||||
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
|
||||
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
|
||||
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.experience_maker import NaiveExperienceMaker
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
import ray
|
||||
import os
|
||||
import socket
|
||||
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainer_1 = {'local_rank' : '0',
|
||||
'rank' : '0',
|
||||
'world_size' : '2',
|
||||
'master_port' : trainer_port,
|
||||
'master_addr' : master_addr}
|
||||
env_info_trainer_2 = {'local_rank' : '0',
|
||||
'rank' : '1',
|
||||
'world_size' : '2',
|
||||
'master_port' : trainer_port,
|
||||
'master_addr' : master_addr}
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_maker_1 = {'local_rank' : '0',
|
||||
'rank' : '0',
|
||||
'world_size' : '2',
|
||||
'master_port' : maker_port,
|
||||
'master_addr' : master_addr}
|
||||
print([env_info_trainer_1,
|
||||
env_info_trainer_2,
|
||||
env_info_maker_1])
|
||||
ray.init(dashboard_port = 1145)
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure Trainer
|
||||
trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=["maker1"],
|
||||
strategy=args.trainer_strategy,
|
||||
model=args.model,
|
||||
env_info=env_info_trainer_1,
|
||||
pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=["maker1"],
|
||||
strategy=args.trainer_strategy,
|
||||
model=args.model,
|
||||
env_info=env_info_trainer_2,
|
||||
pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug= args.debug,
|
||||
)
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=["trainer1", "trainer2"],
|
||||
strategy=args.maker_strategy,
|
||||
env_info=env_info_maker_1,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
kl_coef=0.1,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
# trainer send its actor and critic to experience holders.
|
||||
# TODO: balance duty
|
||||
ray.get(trainer_1_ref.initialize_remote_makers.remote())
|
||||
|
||||
# configure sampler
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||
trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs * 2 + 3 # +3 for fault tolerance
|
||||
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||
|
||||
ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref])
|
||||
# save model checkpoint after fitting
|
||||
trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||
trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('prompt_path')
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--maker_strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -1,23 +0,0 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
export RAY_NAMESPACE="admin"
|
||||
|
||||
python 1m2t.py "/path/to/prompts.csv" --model gpt2 \
|
||||
--maker_strategy naive --trainer_strategy ddp --lora_rank 2 \
|
||||
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
|
||||
--max_epochs 10 #--debug
|
|
@ -1,140 +0,0 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from coati.trainer import PPOTrainer
|
||||
|
||||
|
||||
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
|
||||
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
|
||||
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.experience_maker import NaiveExperienceMaker
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
import ray
|
||||
import os
|
||||
import socket
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure Trainer
|
||||
trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=["maker1", "maker2"],
|
||||
strategy=args.trainer_strategy,
|
||||
model=args.model,
|
||||
pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=["trainer1"],
|
||||
strategy=args.maker_strategy,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
kl_coef=0.1,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=["trainer1"],
|
||||
strategy=args.maker_strategy,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
kl_coef=0.1,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
# trainer send its actor and critic to experience holders.
|
||||
ray.get(trainer_ref.initialize_remote_makers.remote())
|
||||
|
||||
# configure sampler
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs // 2 + 3 # +3 for fault tolerance
|
||||
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||
maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||
|
||||
ray.get([trainer_done_ref, maker_1_done_ref, maker_2_done_ref])
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('prompt_path')
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--maker_strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"])
|
||||
main(args)
|
|
@ -1,23 +0,0 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 3
|
||||
|
||||
export RAY_NAMESPACE="admin"
|
||||
|
||||
python 2m1t.py "/path/to/prompts.csv" \
|
||||
--trainer_strategy naive --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
|
||||
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
|
||||
--max_epochs 10 # --debug
|
|
@ -1,209 +0,0 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from coati.trainer import PPOTrainer
|
||||
|
||||
|
||||
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
|
||||
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
|
||||
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.experience_maker import NaiveExperienceMaker
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
import ray
|
||||
import os
|
||||
import socket
|
||||
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainer_1 = {'local_rank' : '0',
|
||||
'rank' : '0',
|
||||
'world_size' : '2',
|
||||
'master_port' : trainer_port,
|
||||
'master_addr' : master_addr}
|
||||
env_info_trainer_2 = {'local_rank' : '0',
|
||||
'rank' : '1',
|
||||
'world_size' : '2',
|
||||
'master_port' : trainer_port,
|
||||
'master_addr' : master_addr}
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_maker_1 = {'local_rank' : '0',
|
||||
'rank' : '0',
|
||||
'world_size' : '2',
|
||||
'master_port' : maker_port,
|
||||
'master_addr' : master_addr}
|
||||
env_info_maker_2 = {'local_rank' : '0',
|
||||
'rank' : '1',
|
||||
'world_size' : '2',
|
||||
'master_port': maker_port,
|
||||
'master_addr' : master_addr}
|
||||
print([env_info_trainer_1,
|
||||
env_info_trainer_2,
|
||||
env_info_maker_1,
|
||||
env_info_maker_2])
|
||||
ray.init()
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure Trainer
|
||||
trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=["maker1", "maker2"],
|
||||
strategy=args.trainer_strategy,
|
||||
model=args.model,
|
||||
env_info=env_info_trainer_1,
|
||||
pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=["maker1", "maker2"],
|
||||
strategy=args.trainer_strategy,
|
||||
model=args.model,
|
||||
env_info=env_info_trainer_2,
|
||||
pretrained=args.pretrain,
|
||||
lora_rank=args.lora_rank,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=["trainer1", "trainer2"],
|
||||
strategy=args.maker_strategy,
|
||||
env_info=env_info_maker_1,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
kl_coef=0.1,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=["trainer1", "trainer2"],
|
||||
strategy=args.maker_strategy,
|
||||
env_info=env_info_maker_2,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
kl_coef=0.1,
|
||||
#kwargs:
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
# trainer send its actor and critic to experience holders.
|
||||
# TODO: balance duty
|
||||
ray.get(trainer_1_ref.initialize_remote_makers.remote())
|
||||
|
||||
# configure sampler
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||
trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
|
||||
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
|
||||
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||
maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
|
||||
|
||||
ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref, maker_2_done_ref])
|
||||
# save model checkpoint after fitting
|
||||
trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||
trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('prompt_path')
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--maker_strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -1,23 +0,0 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
export RAY_NAMESPACE="admin"
|
||||
|
||||
python 2m2t.py "path/to/prompts.csv" \
|
||||
--maker_strategy naive --trainer_strategy colossalai_zero2 --lora_rank 2 \
|
||||
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
|
||||
--max_epochs 10 --debug
|
|
@ -0,0 +1,271 @@
|
|||
import os
|
||||
import time
|
||||
import tracemalloc
|
||||
from copy import deepcopy
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import Strategy
|
||||
from coati.trainer.strategies.sampler import DistributedSampler
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
||||
from .utils import (get_model_numel,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
is_rank_0,
|
||||
set_dist_env,
|
||||
state_dict_to)
|
||||
from .lora_constructor import LoRAConstructor
|
||||
|
||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||
class ExperienceMakerHolder:
|
||||
'''
|
||||
Args:
|
||||
detached_trainer_name_list: str list to get ray actor handles
|
||||
strategy:
|
||||
kl_coef: the coefficient of kl divergence loss
|
||||
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
buffer_cpu_offload: bool = True,
|
||||
kl_coef: float = 0.1,
|
||||
callbacks: List[MakerCallback] = [],
|
||||
eval_performance: bool = False,
|
||||
debug: bool = False,
|
||||
update_lora_weights: bool = False,
|
||||
**generate_kwargs):
|
||||
# set environment variables
|
||||
if env_info:
|
||||
set_dist_env(env_info=env_info)
|
||||
self.target_trainer_list = []
|
||||
assert len(detached_trainer_name_list) > 0
|
||||
self._detached_trainer_name_list = detached_trainer_name_list
|
||||
self.strategy = strategy_fn()
|
||||
self.buffer_cpu_offload = buffer_cpu_offload
|
||||
self.kl_coef = kl_coef
|
||||
# init models
|
||||
with self.strategy.model_init_context():
|
||||
actor, critic, reward_model, initial_model = model_fn()
|
||||
self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor)
|
||||
if eval_performance:
|
||||
actor_numel = get_model_numel(actor)
|
||||
critic_numel = get_model_numel(critic)
|
||||
initial_model_numel = get_model_numel(initial_model)
|
||||
reward_model_numel = get_model_numel(reward_model)
|
||||
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
|
||||
reward_model_numel)
|
||||
callbacks = callbacks + [evaluator]
|
||||
|
||||
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
|
||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
|
||||
self.callbacks = callbacks
|
||||
|
||||
self._model_visit_lock = Lock()
|
||||
|
||||
self._is_fully_initialized = not sync_models_from_trainers
|
||||
|
||||
self._debug = debug
|
||||
self._update_lora_weights = update_lora_weights
|
||||
if self._update_lora_weights:
|
||||
self.actor_lora_constructor = LoRAConstructor()
|
||||
self.critic_lora_constructor = LoRAConstructor()
|
||||
|
||||
self.target_auto_balance = False
|
||||
|
||||
self._target_idx = 0
|
||||
|
||||
if self._debug:
|
||||
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
|
||||
if not self._is_fully_initialized:
|
||||
print(f'[maker{get_rank()}] Waiting for INIT')
|
||||
|
||||
def _get_ready(self):
|
||||
while not self._fully_initialized():
|
||||
time.sleep(1.0)
|
||||
|
||||
def _fully_initialized(self):
|
||||
return self._is_fully_initialized
|
||||
|
||||
def _init_target_trainer_list(self):
|
||||
if len(self.target_trainer_list) > 0:
|
||||
return
|
||||
for name in self._detached_trainer_name_list:
|
||||
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||
|
||||
# copy from ../trainer/base.py
|
||||
@ray.method(concurrency_group="compute")
|
||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||||
if isinstance(inputs, Tensor):
|
||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||||
elif isinstance(inputs, dict):
|
||||
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||||
|
||||
@ray.method(concurrency_group="experience_io")
|
||||
def _send_items(self, experience: Experience) -> None:
|
||||
self._init_target_trainer_list()
|
||||
items = split_experience_batch(experience)
|
||||
items_per_trainer = [[] for _ in range(len(self.target_trainer_list))]
|
||||
for item in items:
|
||||
items_per_trainer[self._target_idx].append(item)
|
||||
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
|
||||
for i, target_trainer in enumerate(self.target_trainer_list):
|
||||
if len(items_per_trainer[i]) > 0:
|
||||
target_trainer.buffer_extend.remote(items_per_trainer[i])
|
||||
|
||||
def _inference_step(self, batch) -> None:
|
||||
self._on_batch_start()
|
||||
with self._model_visit_lock:
|
||||
self._on_make_experience_start()
|
||||
experience = self._make_experience(batch)
|
||||
self._on_make_experience_end(experience)
|
||||
self._on_send_start()
|
||||
if self.buffer_cpu_offload:
|
||||
experience.to_device('cpu')
|
||||
self._send_items(experience)
|
||||
self._on_send_end()
|
||||
self._on_batch_end()
|
||||
|
||||
def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0):
|
||||
"""Working loop of the experience maker.
|
||||
|
||||
Args:
|
||||
dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader.
|
||||
num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1.
|
||||
num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0.
|
||||
"""
|
||||
self._get_ready()
|
||||
self._on_loop_start()
|
||||
dataloader = dataloader_fn()
|
||||
if num_steps > 0:
|
||||
# ignore num epochs
|
||||
it = iter(dataloader)
|
||||
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
|
||||
try:
|
||||
batch = next(it)
|
||||
except StopIteration:
|
||||
it = iter(dataloader)
|
||||
batch = next(it)
|
||||
self._inference_step(batch)
|
||||
else:
|
||||
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
|
||||
for _ in range(num_epochs):
|
||||
for batch in dataloader:
|
||||
self._inference_step(batch)
|
||||
pbar.update()
|
||||
self._on_loop_end()
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
def update_experience_maker(self,
|
||||
new_actor_state_dict: Dict[str, Any] = None,
|
||||
new_actor_lora_config_dict: Dict[str, Any] = None,
|
||||
new_critic_state_dict: Dict[str, Any] = None,
|
||||
new_critic_lora_config_dict: Dict[str, Any] = None,
|
||||
fully_update: bool = False,
|
||||
chunk_start: bool = None,
|
||||
chunk_end: bool = None):
|
||||
'''
|
||||
called by trainer
|
||||
chunk_start: Set True at the first call. Before sending state_dict calls
|
||||
chunk_end: Set True at the last call. After sending state_dict calls.
|
||||
fully_update: Set True if you want to sync models when initializing
|
||||
|
||||
TODO: load_state_dict integrate with model-sharding strategy
|
||||
'''
|
||||
_watch_memory = self._debug
|
||||
if chunk_start:
|
||||
if self._debug:
|
||||
print("[maker] UPDATE ")
|
||||
if _watch_memory:
|
||||
tracemalloc.start()
|
||||
self._model_visit_lock.acquire()
|
||||
|
||||
with torch.no_grad():
|
||||
if new_actor_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
|
||||
else:
|
||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increasae = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increasae)
|
||||
if new_critic_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||
else:
|
||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increasae = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increasae)
|
||||
|
||||
# the lock must be released after both actor and critic being updated
|
||||
if chunk_end:
|
||||
self._model_visit_lock.release()
|
||||
if _watch_memory:
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
|
||||
tracemalloc.stop()
|
||||
if fully_update:
|
||||
self._is_fully_initialized = True
|
||||
|
||||
def _on_make_experience_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_make_experience_start()
|
||||
|
||||
def _on_make_experience_end(self, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_make_experience_end(experience)
|
||||
|
||||
def _on_loop_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_loop_start()
|
||||
|
||||
def _on_loop_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_loop_end()
|
||||
|
||||
def _on_send_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_send_start()
|
||||
|
||||
def _on_send_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_send_end()
|
||||
|
||||
def _on_batch_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_start()
|
||||
|
||||
def _on_batch_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_end()
|
||||
|
||||
|
||||
def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = actor.model
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
|
||||
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
|
||||
|
||||
return new_kwargs
|
|
@ -0,0 +1,122 @@
|
|||
from typing import Any, Callable, Dict, List, Optional
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loralib.layers import LoRALayer
|
||||
from coati.models.lora import LoraLinear
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
r: int = 0
|
||||
lora_alpha: int = 1
|
||||
lora_dropout: float = 0
|
||||
fan_in_fan_out: bool = False
|
||||
|
||||
|
||||
class LoRAConstructor:
|
||||
'''
|
||||
Tools for reconstructing a model from a remote LoRA model.
|
||||
(Transfering only LoRA data costs much less!)
|
||||
Usage:
|
||||
Step 1 (Sender):
|
||||
filter_state_dict_lora()
|
||||
|
||||
Step 2 (Sender, Optional):
|
||||
extract_lora_config()
|
||||
|
||||
Step 3 (Sender):
|
||||
send state_dict_lora and lora_config_dict
|
||||
|
||||
Step 4 (Receiver):
|
||||
reconstruct_increase()
|
||||
|
||||
Step 5 (Receiver):
|
||||
load_state_dict_increase()
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
self.lora_config_dict = None
|
||||
|
||||
def register_lora_config(self, lora_config_dict: Dict[str, Any]):
|
||||
self.lora_config_dict = lora_config_dict
|
||||
|
||||
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
|
||||
'''
|
||||
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
||||
Warning: the xxx.weight here is the increment actually.
|
||||
'''
|
||||
if lora_config_dict is not None:
|
||||
self.register_lora_config(lora_config_dict)
|
||||
|
||||
state_dict_increasae = OrderedDict()
|
||||
config_iter = iter(self.lora_config_dict.items())
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
for k, v in state_dict_lora.items():
|
||||
if k.rpartition('.')[-1] == 'lora_A':
|
||||
lora_A = v
|
||||
layer_prefix = k.rpartition('.')[0]
|
||||
elif k.rpartition('.')[-1] == 'lora_B':
|
||||
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
|
||||
layer_prefix_2, config = next(config_iter)
|
||||
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
||||
lora_B = v
|
||||
weight_data_increase = self._compute(lora_A, lora_B, config)
|
||||
state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
else:
|
||||
raise ValueError('unexpected key')
|
||||
return state_dict_increasae
|
||||
|
||||
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
||||
def T(w):
|
||||
return w.T if config.fan_in_fan_out else w
|
||||
if config.r > 0:
|
||||
scaling = config.lora_alpha / config.r
|
||||
weight_data_increase = T(lora_B @ lora_A) * scaling
|
||||
return weight_data_increase
|
||||
return 0
|
||||
|
||||
def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]):
|
||||
'''
|
||||
The final reconstruction step
|
||||
'''
|
||||
# naive approach
|
||||
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False)
|
||||
|
||||
@staticmethod
|
||||
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
||||
'''
|
||||
if keep_non_lora, also return non_lora state_dict
|
||||
'''
|
||||
state_dict_lora = OrderedDict()
|
||||
state_dict_non_lora = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if 'lora_A' in k or 'lora_B' in k:
|
||||
state_dict_lora[k] = v
|
||||
elif keep_non_lora:
|
||||
state_dict_non_lora[k] = v
|
||||
if keep_non_lora:
|
||||
return state_dict_lora, state_dict_non_lora
|
||||
else:
|
||||
return state_dict_lora, None
|
||||
|
||||
@staticmethod
|
||||
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
|
||||
'''
|
||||
extract LoraLinear model.
|
||||
return OrderedDict(): name -> LoRAConfig
|
||||
'''
|
||||
lora_config_dict = OrderedDict()
|
||||
|
||||
for name, child in model.named_modules():
|
||||
if isinstance(child, LoraLinear):
|
||||
lora_config_dict[name] = LoRAConfig(r=child.r,
|
||||
lora_alpha=child.lora_alpha,
|
||||
lora_dropout=child.lora_dropout,
|
||||
fan_in_fan_out=child.fan_in_fan_out)
|
||||
|
||||
return lora_config_dict
|
|
@ -1,121 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from tqdm import tqdm
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.experience_maker import Experience
|
||||
import ray
|
||||
import os
|
||||
|
||||
from .detached_replay_buffer import DetachedReplayBuffer
|
||||
from .utils import is_rank_0
|
||||
|
||||
class DetachedTrainer(ABC):
|
||||
'''
|
||||
Base class for detached rlhf trainers.
|
||||
'detach' means that the experience maker is detached compared to a normal Trainer.
|
||||
Please set name attribute during init:
|
||||
>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
|
||||
So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
|
||||
Args:
|
||||
detached_strategy (DetachedStrategy): the strategy to use for training
|
||||
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
|
||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
buffer_cpu_offload: bool = True,
|
||||
experience_batch_size: int = 8,
|
||||
max_epochs: int = 1,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
super().__init__()
|
||||
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload)
|
||||
self.experience_batch_size = experience_batch_size
|
||||
self.max_epochs = max_epochs
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
self.generate_kwargs = generate_kwargs
|
||||
self.target_holder_name_list = experience_maker_holder_name_list
|
||||
self.target_holder_list = []
|
||||
|
||||
def update_target_holder_list(self, experience_maker_holder_name_list):
|
||||
self.target_holder_name_list = experience_maker_holder_name_list
|
||||
self.target_holder_list = []
|
||||
for name in self.target_holder_name_list:
|
||||
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||
|
||||
@abstractmethod
|
||||
def _update_remote_makers(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def _learn(self):
|
||||
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||||
for _ in pbar:
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print("[trainer] sampling exp")
|
||||
experience = self._buffer_sample()
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print("[trainer] training step")
|
||||
metrics = self.training_step(experience)
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print("[trainer] step over")
|
||||
pbar.set_postfix(metrics)
|
||||
|
||||
def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
|
||||
self._on_fit_start()
|
||||
for episode in range(num_episodes):
|
||||
self._on_episode_start(episode)
|
||||
for timestep in tqdm(range(max_timesteps // update_timesteps),
|
||||
desc=f'Episode [{episode+1}/{num_episodes}]',
|
||||
disable=not is_rank_0()):
|
||||
self._learn()
|
||||
self._update_remote_makers()
|
||||
self._on_episode_end(episode)
|
||||
self._on_fit_end()
|
||||
|
||||
@ray.method(concurrency_group="buffer_length")
|
||||
def buffer_get_length(self):
|
||||
# called by ExperienceMakerHolder
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print("[trainer] telling length")
|
||||
return self.detached_replay_buffer.get_length()
|
||||
|
||||
@ray.method(concurrency_group="buffer_append")
|
||||
def buffer_append(self, experience: Experience):
|
||||
# called by ExperienceMakerHolder
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
# print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}")
|
||||
print(f"[trainer] receiving exp.")
|
||||
self.detached_replay_buffer.append(experience)
|
||||
|
||||
@ray.method(concurrency_group="buffer_sample")
|
||||
def _buffer_sample(self):
|
||||
return self.detached_replay_buffer.sample()
|
||||
|
||||
def _on_fit_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_start()
|
||||
|
||||
def _on_fit_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_end()
|
||||
|
||||
def _on_episode_start(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_start(episode)
|
||||
|
||||
def _on_episode_end(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_end(episode)
|
|
@ -1,172 +0,0 @@
|
|||
import torch
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import ray
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.trainer.strategies.sampler import DistributedSampler
|
||||
from coati.trainer.strategies import Strategy
|
||||
from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker
|
||||
|
||||
from copy import deepcopy
|
||||
from threading import Lock
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
from .utils import is_rank_0, get_strategy_from_args, set_dist_env
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||
class ExperienceMakerHolder:
|
||||
'''
|
||||
Args:
|
||||
detached_trainer_name_list: str list to get ray actor handles
|
||||
strategy:
|
||||
experience_batch_size: batch size of generated experience
|
||||
kl_coef: the coefficient of kl divergence loss
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy: str,
|
||||
env_info: Dict[str, str] = None,
|
||||
experience_batch_size: int = 8,
|
||||
kl_coef: float = 0.1,
|
||||
**generate_kwargs):
|
||||
# set environment variables
|
||||
if env_info:
|
||||
set_dist_env(env_info=env_info)
|
||||
self.target_trainer_list = []
|
||||
for name in detached_trainer_name_list:
|
||||
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||
self.strategy_str = strategy
|
||||
self.strategy = get_strategy_from_args(strategy)
|
||||
self.experience_batch_size = experience_batch_size
|
||||
self.kl_coef = kl_coef
|
||||
self.generate_kwargs = generate_kwargs
|
||||
# Need a trainer to give an actor and a critic via initialize_experience_maker(...)
|
||||
actor, critic, reward_model, initial_model = None, None, None, None
|
||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
|
||||
self._model_visit_lock = Lock()
|
||||
self.fully_initialized = False
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print('[maker] Waiting for INIT')
|
||||
|
||||
def _get_ready(self):
|
||||
while not self.fully_initialized:
|
||||
time.sleep(1.0)
|
||||
|
||||
def update_target_trainer_list(self, detached_trainer_name_list):
|
||||
self.target_trainer_list = []
|
||||
for name in detached_trainer_name_list:
|
||||
self.target_trainer_list.append(ray.get_actor(name))
|
||||
|
||||
# copy from ../trainer/base.py
|
||||
@ray.method(concurrency_group="compute")
|
||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||||
self._get_ready()
|
||||
if isinstance(inputs, Tensor):
|
||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||||
elif isinstance(inputs, dict):
|
||||
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||||
|
||||
@ray.method(concurrency_group="experience_io")
|
||||
def _send_experience(self, experience):
|
||||
'''
|
||||
ignore it
|
||||
|
||||
# choose a trainer that has the least experience batch in its detached_replay_buffer
|
||||
chosen_trainer = None
|
||||
min_length = None
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print("[maker] choosing target trainer")
|
||||
while chosen_trainer is None:
|
||||
for target_trainer in self.target_trainer_list:
|
||||
try:
|
||||
temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1)
|
||||
if min_length is None:
|
||||
min_length = temp_length
|
||||
chosen_trainer = target_trainer
|
||||
else:
|
||||
if temp_length < min_length:
|
||||
min_length = temp_length
|
||||
chosen_trainer = target_trainer
|
||||
except GetTimeoutError:
|
||||
pass
|
||||
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print(f"[maker] sending exp to {chosen_trainer}")
|
||||
chosen_trainer.buffer_append.remote(experience)
|
||||
'''
|
||||
#
|
||||
if not hasattr(self, "_target_idx"):
|
||||
self._target_idx = 0
|
||||
chosen_trainer = self.target_trainer_list[self._target_idx]
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print(f"[maker] sending exp to {chosen_trainer}")
|
||||
chosen_trainer.buffer_append.remote(experience)
|
||||
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
|
||||
|
||||
def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000):
|
||||
self._get_ready()
|
||||
sampler = self.strategy.setup_sampler(dataset)
|
||||
for _ in range(times):
|
||||
rand_prompts = sampler.sample(self.experience_batch_size)
|
||||
if tokenizer is not None:
|
||||
inputs = tokenizer(rand_prompts)
|
||||
else:
|
||||
inputs = rand_prompts
|
||||
self._model_visit_lock.acquire()
|
||||
experience = self._make_experience(inputs=inputs)
|
||||
self._model_visit_lock.release()
|
||||
self._send_experience(experience=experience)
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic):
|
||||
'''
|
||||
called by trainer. Only once.
|
||||
'''
|
||||
# TODO: reduce malloc
|
||||
if self.fully_initialized:
|
||||
return
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print('[maker] INIT')
|
||||
with torch.no_grad():
|
||||
with self.strategy.model_init_context():
|
||||
actor = init_actor
|
||||
critic = init_critic
|
||||
initial_model = deepcopy(actor)
|
||||
reward_model = RewardModel(deepcopy(critic.model),
|
||||
deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||
if self.strategy_str != 'colossalai_gemini':
|
||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
self.experience_maker.actor = self.strategy.prepare(actor)
|
||||
self.experience_maker.critic = self.strategy.prepare(critic)
|
||||
self.experience_maker.initial_model = self.strategy.prepare(initial_model)
|
||||
self.experience_maker.reward_model = self.strategy.prepare(reward_model)
|
||||
self.fully_initialized = True
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
def update_experience_maker(self, new_actor: Actor, new_critic: Critic):
|
||||
'''
|
||||
called by trainer
|
||||
'''
|
||||
# TODO: reduce malloc
|
||||
self._model_visit_lock.acquire()
|
||||
with torch.no_grad():
|
||||
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||||
print("[maker] UPDATE ")
|
||||
if self.strategy_str != 'colossalai_gemini':
|
||||
new_actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
new_critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
self.experience_maker.actor = self.strategy.prepare(new_actor)
|
||||
self.experience_maker.critic = self.strategy.prepare(new_critic)
|
||||
self._model_visit_lock.release()
|
|
@ -1,105 +0,0 @@
|
|||
# WIP
|
||||
|
||||
|
||||
from coati.trainer.strategies import Strategy
|
||||
from coati.trainer.strategies import NaiveStrategy
|
||||
from coati.models.base import Actor, RewardModel, Critic
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
|
||||
import colossalai
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
import random
|
||||
|
||||
rpc_is_initialized = _is_current_rpc_agent_set
|
||||
|
||||
class PipelineModel(torch.nn.Module):
|
||||
'''
|
||||
Actor has 2 kinds of jobs: forward and generate.
|
||||
better to just pipeline the inner model
|
||||
'''
|
||||
def __init__(self,
|
||||
model: torch.nn.Module,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
data_kwargs = None,
|
||||
):
|
||||
super().__init__()
|
||||
# create partition module
|
||||
def create_partition_module(pp_rank:int, stage_num: int, model, data_kwargs):
|
||||
model.eval()
|
||||
tracer = ColoTracer()
|
||||
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
annotated_model = balanced_split_pass(gm, stage_num)
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
return split_submodules[pp_rank + 1]
|
||||
|
||||
def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
|
||||
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||
return partition
|
||||
self.inference_engine = OneFOneBPipelineEngine(
|
||||
partition_fn=partial(partition, model, data_kwargs),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
**model_inputs):
|
||||
return self.inference_engine.forward_backward(**model_inputs, forward_only=True)
|
||||
|
||||
|
||||
|
||||
class PPStrategy(NaiveStrategy):
|
||||
"""
|
||||
Strategy for Pipeline inference (inference only!)
|
||||
|
||||
master node only
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
seed: int = 42
|
||||
):
|
||||
self.seed = seed
|
||||
super().__init__()
|
||||
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
ppg.set_global_info(rank = int(os.environ['RANK']),
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
dp_degree=1,
|
||||
tp_degree=1,
|
||||
num_worker_threads=128,
|
||||
device="cuda")
|
||||
|
||||
def model_init_context(self):
|
||||
return super().model_init_context()
|
||||
|
||||
def setup_model(self, model: torch.nn.Module) -> torch.nn.Module:
|
||||
if isinstance(model, Actor) or \
|
||||
isinstance(model, RewardModel) or \
|
||||
isinstance(model, Critic):
|
||||
model.model = PipelineModel(model.model)
|
||||
|
||||
def set_seed(self, seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
import torch.distributed as dist
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from coati.models.bloom import BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
import torch
|
||||
import os
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0):
|
||||
if model == 'gpt2':
|
||||
actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||
elif model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||
elif model == 'opt':
|
||||
actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{model}"')
|
||||
return actor, critic
|
||||
|
||||
|
||||
def get_strategy_from_args(strategy: str):
|
||||
if strategy == 'naive':
|
||||
strategy_ = NaiveStrategy()
|
||||
elif strategy == 'ddp':
|
||||
strategy_ = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
return strategy_
|
||||
|
||||
|
||||
def set_dist_env(env_info: Dict[str, str]):
|
||||
os.environ["RANK"] = env_info['rank']
|
||||
os.environ["LOCAL_RANK"] = env_info['local_rank']
|
||||
os.environ["WORLD_SIZE"] = env_info['world_size']
|
||||
os.environ['MASTER_PORT'] = env_info['master_port']
|
||||
os.environ['MASTER_ADDR'] = env_info['master_addr']
|
|
@ -0,0 +1,152 @@
|
|||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
return dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
return dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
|
||||
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == 'gpt2':
|
||||
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'opt':
|
||||
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'llama':
|
||||
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'roberta':
|
||||
actor = RoBERTaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{model}"')
|
||||
return actor
|
||||
|
||||
|
||||
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == 'gpt2':
|
||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'bloom':
|
||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'opt':
|
||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'llama':
|
||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'roberta':
|
||||
critic = RoBERTaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
return critic
|
||||
|
||||
|
||||
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||||
if model == 'gpt2':
|
||||
reward_model = GPTRM(pretrained=pretrained, config=config)
|
||||
elif model == 'bloom':
|
||||
reward_model = BLOOMRM(pretrained=pretrained, config=config)
|
||||
elif model == 'opt':
|
||||
reward_model = OPTRM(pretrained=pretrained, config=config)
|
||||
elif model == 'llama':
|
||||
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
||||
elif model == 'roberta':
|
||||
reward_model = RoBERTaRM(pretrained=pretrained, config=config)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
return reward_model
|
||||
|
||||
|
||||
def get_strategy_from_args(strategy: str):
|
||||
if strategy == 'naive':
|
||||
strategy_ = NaiveStrategy()
|
||||
elif strategy == 'ddp':
|
||||
strategy_ = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
elif strategy == 'colossalai_gemini_cpu':
|
||||
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2_cpu':
|
||||
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
return strategy_
|
||||
|
||||
|
||||
def get_tokenizer_from_args(model: str, **kwargs):
|
||||
if model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
elif model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
elif model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif model == 'llama':
|
||||
pretrain_path = kwargs["pretrain"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
||||
elif model == 'roberta':
|
||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{model}"')
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def set_dist_env(env_info: Dict[str, str]):
|
||||
os.environ["RANK"] = env_info['rank']
|
||||
os.environ["LOCAL_RANK"] = env_info['local_rank']
|
||||
os.environ["WORLD_SIZE"] = env_info['world_size']
|
||||
os.environ['MASTER_PORT'] = env_info['master_port']
|
||||
os.environ['MASTER_ADDR'] = env_info['master_addr']
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
return numel
|
||||
|
||||
|
||||
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
|
||||
target_receivers = []
|
||||
if num_senders <= num_receivers or allow_idle_sender:
|
||||
# a sender will send data to one or more than one receivers
|
||||
# a receiver only has one sender
|
||||
for i in range(num_receivers):
|
||||
if i % num_senders == sender_idx:
|
||||
target_receivers.append(i)
|
||||
else:
|
||||
# a sender will send data to one receiver
|
||||
# a receiver may have more than one sender
|
||||
target_receivers.append(sender_idx % num_receivers)
|
||||
return target_receivers
|
||||
|
||||
|
||||
def state_dict_to(state_dict: Dict[str, Any],
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
'''
|
||||
keep state_dict intact
|
||||
'''
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = v.to(dtype=dtype, device=device)
|
||||
return new_state_dict
|
|
@ -130,3 +130,7 @@ class Strategy(ABC):
|
|||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
pass
|
|
@ -186,3 +186,15 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
if self.stage == 3:
|
||||
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
|
||||
super().save_pretrained(model, path, only_rank0, tokenizer)
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
if self.stage != 3:
|
||||
yield from super().get_model_state_dict_shard(model, **config)
|
||||
else:
|
||||
# unwrapped_model = self._unwrap_model(model)
|
||||
# for module in unwrapped_model.modules():
|
||||
# if isinstance(module, LoraLinear):
|
||||
# module.merge_weights = True
|
||||
# module.eval()
|
||||
base_model: ZeroDDP = get_base_model(model)
|
||||
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||
|
|
|
@ -26,19 +26,8 @@ class DDPStrategy(NaiveStrategy):
|
|||
super().__init__()
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||
)
|
||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
||||
self._try_init_dist(force=True)
|
||||
self.set_seed(self.seed)
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
def set_seed(self, seed: int) -> None:
|
||||
random.seed(seed)
|
||||
|
|
|
@ -1,10 +1,17 @@
|
|||
from typing import Any, Optional
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import get_base_model
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
@ -13,6 +20,15 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|||
from .base import Strategy
|
||||
|
||||
|
||||
# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
|
||||
def get_grad_required_state_dict(model: nn.Module):
|
||||
state_dict = OrderedDict()
|
||||
for name, parameter in model.named_parameters():
|
||||
if parameter.requires_grad:
|
||||
state_dict[name] = parameter.detach()
|
||||
return state_dict
|
||||
|
||||
|
||||
class NaiveStrategy(Strategy):
|
||||
"""
|
||||
Strategy for single GPU. No parallelism is used.
|
||||
|
@ -25,7 +41,7 @@ class NaiveStrategy(Strategy):
|
|||
optimizer.step()
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
pass
|
||||
self._try_init_dist(force=False)
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
@ -68,3 +84,45 @@ class NaiveStrategy(Strategy):
|
|||
unwrapped_model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
model = self.unwrap_model(model)
|
||||
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
|
||||
state_dict = get_grad_required_state_dict(model)
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if 'shard_size' in config:
|
||||
shard_size = config['shard_size']
|
||||
accumulate_size = 0
|
||||
state_dict_shard = OrderedDict()
|
||||
for name, param in state_dict.items():
|
||||
state_dict_shard[name] = param
|
||||
accumulate_size += param.numel() * param.element_size()
|
||||
if accumulate_size >= shard_size:
|
||||
accumulate_size = 0
|
||||
yield state_dict_shard
|
||||
state_dict_shard = OrderedDict()
|
||||
if accumulate_size > 0:
|
||||
yield state_dict_shard
|
||||
else:
|
||||
yield state_dict
|
||||
|
||||
def _try_init_dist(self, force: bool = False) -> None:
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
except KeyError as e:
|
||||
if force:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||
)
|
||||
except Exception as e:
|
||||
if force:
|
||||
raise e
|
||||
|
|
|
@ -27,6 +27,7 @@ class DistributedSampler:
|
|||
assert len(indices) == self.num_samples
|
||||
self.indices = indices
|
||||
|
||||
|
||||
def sample(self, batch_size: int) -> list:
|
||||
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
|
||||
return [self.dataset[idx] for idx in sampled_indices]
|
||||
|
|
|
@ -0,0 +1,175 @@
|
|||
import argparse
|
||||
import os
|
||||
import socket
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
import ray
|
||||
import torch
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
|
||||
from coati.ray.experience_maker_holder import ExperienceMakerHolder
|
||||
from coati.ray.utils import (
|
||||
get_actor_from_args,
|
||||
get_critic_from_args,
|
||||
get_reward_model_from_args,
|
||||
get_strategy_from_args,
|
||||
get_tokenizer_from_args,
|
||||
)
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_trainers)]
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_maker = {
|
||||
'local_rank': '0',
|
||||
'rank': '0',
|
||||
'world_size': '1',
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
}
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = get_tokenizer_from_args(args.model)
|
||||
|
||||
def trainer_model_fn():
|
||||
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
|
||||
critic = get_critic_from_args(args.model, args.critic_pretrain).half().cuda()
|
||||
return actor, critic
|
||||
|
||||
# configure Trainer
|
||||
trainer_refs = [
|
||||
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=["maker1"],
|
||||
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
||||
model_fn=trainer_model_fn,
|
||||
env_info=env_info_trainer,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
eval_performance=True,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
) for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
]
|
||||
|
||||
def model_fn():
|
||||
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
||||
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
||||
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
||||
# quantize initial model
|
||||
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
|
||||
with low_resource_init(), no_init_weights():
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
||||
args.quant_group_size).cuda().requires_grad_(False)
|
||||
else:
|
||||
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
||||
return actor, critic, reward_model, initial_model
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
|
||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||
model_fn=model_fn,
|
||||
env_info=env_info_maker,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
eval_performance=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# uncomment this function if sync_models_from_trainers is True
|
||||
# ray.get([
|
||||
# trainer_ref.sync_models_to_remote_makers.remote()
|
||||
# for trainer_ref in trainer_refs
|
||||
# ])
|
||||
|
||||
wait_tasks = []
|
||||
|
||||
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
|
||||
for trainer_ref in trainer_refs:
|
||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
||||
|
||||
dataset_size = args.experience_batch_size * 4
|
||||
|
||||
def build_dataloader():
|
||||
|
||||
def tokenize_fn(texts):
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
|
||||
return dataloader
|
||||
|
||||
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
|
||||
|
||||
ray.get(wait_tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_path', type=str, default=None)
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
||||
parser.add_argument('--experience_steps', type=int, default=4)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--train_epochs', type=int, default=1)
|
||||
parser.add_argument('--update_steps', type=int, default=2)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
||||
parser.add_argument('--quant_bits', type=int, default=4)
|
||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||
main(args)
|
|
@ -0,0 +1,189 @@
|
|||
import argparse
|
||||
import os
|
||||
import socket
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
import ray
|
||||
import torch
|
||||
from coati.quant import llama_load_quant, low_resource_init
|
||||
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
|
||||
from coati.ray.experience_maker_holder import ExperienceMakerHolder
|
||||
from coati.ray.utils import (
|
||||
get_actor_from_args,
|
||||
get_critic_from_args,
|
||||
get_receivers_per_sender,
|
||||
get_reward_model_from_args,
|
||||
get_strategy_from_args,
|
||||
)
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_trainers)]
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_makers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_makers),
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_makers)]
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def model_fn():
|
||||
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
||||
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
||||
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
||||
# quantize initial model
|
||||
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
|
||||
with low_resource_init(), no_init_weights():
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
||||
args.quant_group_size).cuda().requires_grad_(False)
|
||||
else:
|
||||
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
||||
return actor, critic, reward_model, initial_model
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_refs = [
|
||||
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=[
|
||||
f'trainer{x}'
|
||||
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
|
||||
],
|
||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||
model_fn=model_fn,
|
||||
env_info=env_info_maker,
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
eval_performance=True,
|
||||
use_cache=True,
|
||||
)
|
||||
for i, env_info_maker in enumerate(env_info_makers)
|
||||
]
|
||||
|
||||
def trainer_model_fn():
|
||||
actor = get_actor_from_args(args.model, args.pretrain, lora_rank=args.lora_rank).half().cuda()
|
||||
critic = get_critic_from_args(args.model, args.critic_pretrain, lora_rank=args.lora_rank).half().cuda()
|
||||
return actor, critic
|
||||
|
||||
# configure Trainer
|
||||
trainer_refs = [
|
||||
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=[
|
||||
f"maker{x}"
|
||||
for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
|
||||
],
|
||||
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
||||
model_fn=trainer_model_fn,
|
||||
env_info=env_info_trainer,
|
||||
train_batch_size=args.train_batch_size,
|
||||
buffer_limit=16,
|
||||
eval_performance=True,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
)
|
||||
for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
]
|
||||
|
||||
dataset_size = args.experience_batch_size * 4
|
||||
|
||||
def build_dataloader():
|
||||
|
||||
def tokenize_fn(texts):
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
|
||||
return dataloader
|
||||
|
||||
# uncomment this function if sync_models_from_trainers is True
|
||||
# ray.get([
|
||||
# trainer_ref.sync_models_to_remote_makers.remote()
|
||||
# for trainer_ref in trainer_refs
|
||||
# ])
|
||||
|
||||
wait_tasks = []
|
||||
|
||||
for experience_holder_ref in experience_holder_refs:
|
||||
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
|
||||
|
||||
total_steps = args.experience_batch_size * args.experience_steps * \
|
||||
args.num_makers // (args.num_trainers * args.train_batch_size)
|
||||
for trainer_ref in trainer_refs:
|
||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
||||
|
||||
ray.get(wait_tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_path', type=str, default=None)
|
||||
parser.add_argument('--num_makers', type=int, default=1)
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
||||
parser.add_argument('--experience_steps', type=int, default=4)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--train_epochs', type=int, default=1)
|
||||
parser.add_argument('--update_steps', type=int, default=2)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
||||
parser.add_argument('--quant_bits', type=int, default=4)
|
||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||
main(args)
|
|
@ -0,0 +1 @@
|
|||
ray
|
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
BASE=$(realpath $(dirname $0))
|
||||
|
||||
export RAY_NAMESPACE=admin
|
||||
export DATA=/data/scratch/chatgpt/prompts.csv
|
||||
|
||||
# install requirements
|
||||
pip install -r ${BASE}/requirements.txt
|
||||
|
||||
python ${BASE}/mmmt_prompt.py --prompt_path $DATA --num_makers 2 --num_trainers 2 --trainer_strategy colossalai_gemini --model opt --critic_model opt --pretrain facebook/opt-350m --critic_pretrain facebook/opt-125m --experience_batch_size 4 --train_batch_size 2
|
|
@ -124,3 +124,6 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
|
|||
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
||||
# 3080 doesn't support P2P, skip this test
|
||||
# cd ${BASE}/ray && bash test_ci.sh && cd ${BASE}
|
||||
|
|
Loading…
Reference in New Issue