diff --git a/applications/Chat/benchmarks/README.md b/applications/Chat/benchmarks/README.md index b4e28ba1d..bc8ad8ba9 100644 --- a/applications/Chat/benchmarks/README.md +++ b/applications/Chat/benchmarks/README.md @@ -1,70 +1,5 @@ # Benchmarks -## Benchmark GPT on dummy prompt data - -We provide various GPT models (string in parentheses is the corresponding model name used in this script): - -- GPT2-S (s) -- GPT2-M (m) -- GPT2-L (l) -- GPT2-XL (xl) -- GPT2-4B (4b) -- GPT2-6B (6b) -- GPT2-8B (8b) -- GPT2-10B (10b) -- GPT2-12B (12b) -- GPT2-15B (15b) -- GPT2-18B (18b) -- GPT2-20B (20b) -- GPT2-24B (24b) -- GPT2-28B (28b) -- GPT2-32B (32b) -- GPT2-36B (36b) -- GPT2-40B (40b) -- GPT3 (175b) - -We also provide various training strategies: - -- ddp: torch DDP -- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3 -- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload -- colossalai_zero2: ColossalAI zero2 -- colossalai_zero2_cpu: ColossalAI zero2-offload -- colossalai_zero1: ColossalAI zero1 -- colossalai_zero1_cpu: ColossalAI zero1-offload - -We only support `torchrun` to launch now. E.g. - -```shell -# run GPT2-S on single-node single-GPU with min batch size -torchrun --standalone --nproc_per_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1 -# run GPT2-XL on single-node 4-GPU -torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2 -# run GPT3 on 8-node 8-GPU -torchrun --nnodes 8 --nproc_per_node 8 \ - --rdzv_id=$JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR \ - benchmark_gpt_dummy.py --model 175b --strategy colossalai_gemini -``` - -> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU. - -In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic. - -We also provide a simple shell script to run a set of benchmarks. But it only supports benchmark on single node. However, it's easy to run on multi-nodes by modifying launch command in this script. - -Usage: - -```shell -# run for GPUS=(1 2 4 8) x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256) -./benchmark_gpt_dummy.sh -# run for GPUS=2 x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256) -./benchmark_gpt_dummy.sh 2 -# run for GPUS=2 x strategy=ddp x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256) -./benchmark_gpt_dummy.sh 2 ddp -# run for GPUS=2 x strategy=ddp x model=l x batch_size=(1 2 4 8 16 32 64 128 256) -./benchmark_gpt_dummy.sh 2 ddp l -``` - ## Benchmark OPT with LoRA on dummy prompt data We provide various OPT models (string in parentheses is the corresponding model name used in this script): @@ -80,15 +15,21 @@ We provide various OPT models (string in parentheses is the corresponding model - OPT-10B (10b) - OPT-13B (13b) +We also provide various training strategies: + +- ddp: torch DDP +- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3 +- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload +- colossalai_zero2: ColossalAI zero2 +- colossalai_zero2_cpu: ColossalAI zero2-offload +- colossalai_zero1: ColossalAI zero1 +- colossalai_zero1_cpu: ColossalAI zero1-offload + We only support `torchrun` to launch now. E.g. ```shell # run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size -torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 -# run OPT-350M with lora_rank=4 on single-node 4-GPU -torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4 +torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --critic_model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 +# run Actor (OPT-1.3B) and Critic (OPT-350M) with lora_rank=4 on single-node 4-GPU +torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 1.3b --critic_model 350m --strategy colossalai_zero2 --lora_rank 4 ``` - -> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU. - -In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic. diff --git a/applications/Chat/benchmarks/benchmark_gpt_dummy.py b/applications/Chat/benchmarks/benchmark_gpt_dummy.py deleted file mode 100644 index e41ef239d..000000000 --- a/applications/Chat/benchmarks/benchmark_gpt_dummy.py +++ /dev/null @@ -1,186 +0,0 @@ -import argparse -from copy import deepcopy - -import torch -import torch.distributed as dist -import torch.nn as nn -from coati.models.base import RewardModel -from coati.models.gpt import GPTActor, GPTCritic -from coati.trainer import PPOTrainer -from coati.trainer.callbacks import PerformanceEvaluator -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy -from torch.optim import Adam -from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -from colossalai.nn.optimizer import HybridAdam - - -def get_model_numel(model: nn.Module, strategy: Strategy) -> int: - numel = sum(p.numel() for p in model.parameters()) - if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: - numel *= dist.get_world_size() - return numel - - -def preprocess_batch(samples) -> dict: - input_ids = torch.stack(samples) - attention_mask = torch.ones_like(input_ids, dtype=torch.long) - return {'input_ids': input_ids, 'attention_mask': attention_mask} - - -def print_rank_0(*args, **kwargs) -> None: - if dist.get_rank() == 0: - print(*args, **kwargs) - - -def print_model_numel(model_dict: dict) -> None: - B = 1024**3 - M = 1024**2 - K = 1024 - outputs = '' - for name, numel in model_dict.items(): - outputs += f'{name}: ' - if numel >= B: - outputs += f'{numel / B:.2f} B\n' - elif numel >= M: - outputs += f'{numel / M:.2f} M\n' - elif numel >= K: - outputs += f'{numel / K:.2f} K\n' - else: - outputs += f'{numel}\n' - print_rank_0(outputs) - - -def get_gpt_config(model_name: str) -> GPT2Config: - model_map = { - 's': GPT2Config(), - 'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16), - 'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20), - 'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25), - '2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16), - '4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16), - '6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16), - '8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16), - '10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16), - '12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16), - '15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16), - '18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16), - '20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16), - '24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16), - '28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16), - '32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16), - '36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16), - '40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16), - '175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96), - } - try: - return model_map[model_name] - except KeyError: - raise ValueError(f'Unknown model "{model_name}"') - - -def main(args): - if args.strategy == 'ddp': - strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) - elif args.strategy == 'colossalai_gemini_cpu': - strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2_cpu': - strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') - elif args.strategy == 'colossalai_zero1': - strategy = ColossalAIStrategy(stage=1, placement_policy='cuda') - elif args.strategy == 'colossalai_zero1_cpu': - strategy = ColossalAIStrategy(stage=1, placement_policy='cpu') - else: - raise ValueError(f'Unsupported strategy "{args.strategy}"') - - model_config = get_gpt_config(args.model) - - with strategy.model_init_context(): - actor = GPTActor(config=model_config).cuda() - critic = GPTCritic(config=model_config).cuda() - - initial_model = deepcopy(actor).cuda() - reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() - - actor_numel = get_model_numel(actor, strategy) - critic_numel = get_model_numel(critic, strategy) - initial_model_numel = get_model_numel(initial_model, strategy) - reward_model_numel = get_model_numel(reward_model, strategy) - print_model_numel({ - 'Actor': actor_numel, - 'Critic': critic_numel, - 'Initial model': initial_model_numel, - 'Reward model': reward_model_numel - }) - performance_evaluator = PerformanceEvaluator(actor_numel, - critic_numel, - initial_model_numel, - reward_model_numel, - enable_grad_checkpoint=False, - ignore_episodes=1) - - if args.strategy.startswith('colossalai'): - actor_optim = HybridAdam(actor.parameters(), lr=5e-6) - critic_optim = HybridAdam(critic.parameters(), lr=5e-6) - else: - actor_optim = Adam(actor.parameters(), lr=5e-6) - critic_optim = Adam(critic.parameters(), lr=5e-6) - - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.eos_token - - (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( - (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) - - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - max_epochs=args.max_epochs, - train_batch_size=args.train_batch_size, - experience_batch_size=args.experience_batch_size, - tokenizer=preprocess_batch, - max_length=512, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - callbacks=[performance_evaluator]) - - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device()) - random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool) - random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)] - trainer.fit(random_prompts, random_pretrain, - num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps) - - print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--model', default='s') - parser.add_argument('--strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', - 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' - ], - default='ddp') - parser.add_argument('--num_episodes', type=int, default=3) - parser.add_argument('--max_timesteps', type=int, default=8) - parser.add_argument('--update_timesteps', type=int, default=8) - parser.add_argument('--max_epochs', type=int, default=3) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - args = parser.parse_args() - main(args) diff --git a/applications/Chat/benchmarks/benchmark_gpt_dummy.sh b/applications/Chat/benchmarks/benchmark_gpt_dummy.sh deleted file mode 100755 index d70f88725..000000000 --- a/applications/Chat/benchmarks/benchmark_gpt_dummy.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env bash -# Usage: $0 -set -xu - -BASE=$(realpath $(dirname $0)) - - -PY_SCRIPT=${BASE}/benchmark_gpt_dummy.py -export OMP_NUM_THREADS=8 - -function tune_batch_size() { - # we found when experience batch size is equal to train batch size - # peak CUDA memory usage of making experience phase is less than or equal to that of training phase - # thus, experience batch size can be larger than or equal to train batch size - for bs in 1 2 4 8 16 32 64 128 256; do - torchrun --standalone --nproc_per_node $1 $PY_SCRIPT --model $2 --strategy $3 --experience_batch_size $bs --train_batch_size $bs || return 1 - done -} - -if [ $# -eq 0 ]; then - num_gpus=(1 2 4 8) -else - num_gpus=($1) -fi - -if [ $# -le 1 ]; then - strategies=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") -else - strategies=($2) -fi - -if [ $# -le 2 ]; then - models=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") -else - models=($3) -fi - - -for num_gpu in ${num_gpus[@]}; do - for strategy in ${strategies[@]}; do - for model in ${models[@]}; do - tune_batch_size $num_gpu $model $strategy || break - done - done -done diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index 7e03b6953..a991e8558 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -140,8 +140,7 @@ def main(args): ptx_coef=0, max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, - experience_batch_size=args.experience_batch_size, - tokenizer=preprocess_batch, + offload_inference_models=args.offload_inference_models, max_length=512, do_sample=True, temperature=1.0, @@ -179,10 +178,11 @@ if __name__ == '__main__': parser.add_argument('--num_episodes', type=int, default=3) parser.add_argument('--max_timesteps', type=int, default=8) parser.add_argument('--update_timesteps', type=int, default=8) - parser.add_argument('--max_epochs', type=int, default=3) + parser.add_argument('--max_epochs', type=int, default=1) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument('--lora_rank', type=int, default=0) parser.add_argument('--cuda_mem_frac', type=float, default=1.0) + parser.add_argument('--offload_inference_models', action='store_true', default=False) args = parser.parse_args() main(args) diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index d67679949..ac3a878be 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -15,7 +15,6 @@ class Trainer(ABC): Args: strategy (Strategy):the strategy to use for training max_epochs (int, defaults to 1): the number of epochs of training process - tokenizer (Callable, optional): the tokenizer to use for tokenizing the input dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating @@ -24,14 +23,12 @@ class Trainer(ABC): def __init__(self, strategy: Strategy, max_epochs: int = 1, - tokenizer: Optional[Callable[[Any], dict]] = None, dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: super().__init__() self.strategy = strategy self.max_epochs = max_epochs - self.tokenizer = tokenizer self.generate_kwargs = generate_kwargs self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index b8a9f879b..f9ab4a556 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from coati.experience_maker import Experience, NaiveExperienceMaker from coati.models.base import Actor, Critic -from coati.models.loss import PolicyLoss, ValueLoss +from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss from coati.replay_buffer import NaiveReplayBuffer from torch import Tensor from torch.optim import Optimizer @@ -12,10 +12,12 @@ from torch.utils.data import DistributedSampler from tqdm import tqdm from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from colossalai.utils import get_current_device + from .base import Trainer from .callbacks import Callback from .strategies import Strategy -from .utils import is_rank_0 +from .utils import is_rank_0, to_device class PPOTrainer(Trainer): @@ -38,11 +40,10 @@ class PPOTrainer(Trainer): vf_coef (float, defaults to 1.0): the coefficient of value loss ptx_coef (float, defaults to 0.9): the coefficient of ptx loss value_clip (float, defaults to 0.4): the clip coefficient of value loss - experience_batch_size (int, defaults to 8): the batch size to use for experience generation max_epochs (int, defaults to 1): the number of epochs of training process - tokenizer (Callable, optional): the tokenizer to use for tokenizing the input sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating """ @@ -63,22 +64,21 @@ class PPOTrainer(Trainer): eps_clip: float = 0.2, vf_coef: float = 1.0, value_clip: float = 0.4, - experience_batch_size: int = 8, max_epochs: int = 1, - tokenizer: Optional[Callable[[Any], dict]] = None, sample_replay_buffer: bool = False, dataloader_pin_memory: bool = True, + offload_inference_models: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) - super().__init__(strategy, max_epochs, tokenizer, dataloader_pin_memory, callbacks, **generate_kwargs) + super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs) self.experience_maker = experience_maker self.replay_buffer = replay_buffer - self.experience_batch_size = experience_batch_size self.sample_replay_buffer = sample_replay_buffer + self.offload_inference_models = offload_inference_models self.actor = actor self.critic = critic @@ -86,11 +86,13 @@ class PPOTrainer(Trainer): self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) self.vf_coef = vf_coef - self.ptx_loss_fn = nn.CrossEntropyLoss(ignore_index=-100) + self.ptx_loss_fn = GPTLMLoss() self.ptx_coef = ptx_coef self.actor_optim = actor_optim self.critic_optim = critic_optim + self.device = get_current_device() + def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: if isinstance(inputs, Tensor): return self.experience_maker.make_experience(inputs, **self.generate_kwargs) @@ -99,20 +101,15 @@ class PPOTrainer(Trainer): else: raise ValueError(f'Unsupported input type "{type(inputs)}"') - def _sample_prompts(self, prompts) -> list: - indices = list(range(len(prompts))) - sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False) - return [prompts[i] for i in sampled_indices] - def _learn(self): # replay buffer may be empty at first, we should rebuild at each training if not self.sample_replay_buffer: dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory) - device = torch.cuda.current_device() if self.sample_replay_buffer: pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) for _ in pbar: experience = self.replay_buffer.sample() + experience.to_device(self.device) metrics = self.training_step(experience) pbar.set_postfix(metrics) else: @@ -123,7 +120,7 @@ class PPOTrainer(Trainer): pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0()) for experience in pbar: self._on_learn_batch_start() - experience.to_device(device) + experience.to_device(self.device) metrics = self.training_step(experience) self._on_learn_batch_end(metrics, experience) pbar.set_postfix(metrics) @@ -147,14 +144,17 @@ class PPOTrainer(Trainer): time += 1 prompts = next(iter(self.prompt_dataloader)) self._on_make_experience_start() - self.experience_maker.initial_model.to(torch.cuda.current_device()) - self.experience_maker.reward_model.to(torch.cuda.current_device()) + if self.offload_inference_models: + # TODO(ver217): this may be controlled by strategy if they are prepared by strategy + self.experience_maker.initial_model.to(self.device) + self.experience_maker.reward_model.to(self.device) experience = self._make_experience(prompts) self._on_make_experience_end(experience) self.replay_buffer.append(experience) if time % update_timesteps == 0: - self.experience_maker.initial_model.to('cpu') - self.experience_maker.reward_model.to('cpu') + if self.offload_inference_models: + self.experience_maker.initial_model.to('cpu') + self.experience_maker.reward_model.to('cpu') self._learn() self.replay_buffer.clear() self._on_episode_end(episode) @@ -174,11 +174,10 @@ class PPOTrainer(Trainer): # ptx loss if self.ptx_coef != 0: batch = next(iter(self.pretrain_dataloader)) - ptx = batch['input_ids'].to(torch.cuda.current_device()) - label = batch['labels'].to(torch.cuda.current_device())[:, 1:] - attention_mask = batch['attention_mask'].to(torch.cuda.current_device()) - ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :] - ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1)) + batch = to_device(batch, self.device) + ptx_log_probs = self.actor.get_base_model()(batch['input_ids'], + attention_mask=batch['attention_mask'])['logits'] + ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) self.strategy.backward(actor_loss, self.actor, self.actor_optim) diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index 350553108..ceb7a0574 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -1,6 +1,6 @@ import math import time -from typing import Optional, List +from typing import List, Optional import loralib as lora import torch @@ -18,8 +18,8 @@ from transformers.trainer import get_scheduler from colossalai.logging import get_dist_logger -from .callbacks import Callback from .base import Trainer +from .callbacks import Callback from .strategies import Strategy from .utils import is_rank_0 @@ -70,9 +70,10 @@ class SFTTrainer(Trainer): num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps) - def fit(self, logger, log_interval=10): - wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) - wandb.watch(self.model) + def fit(self, logger, use_wandb: bool = False): + if use_wandb: + wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) + wandb.watch(self.model) total_loss = 0 # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0()) step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.max_epochs), @@ -111,7 +112,7 @@ class SFTTrainer(Trainer): self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() self.scheduler.step() - if is_rank_0(): + if is_rank_0() and use_wandb: wandb.log({ "loss": total_loss / self.accimulation_steps, "lr": self.scheduler.get_last_lr()[0], diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py index 1b17a0421..9cccb5c92 100644 --- a/applications/Chat/coati/trainer/utils.py +++ b/applications/Chat/coati/trainer/utils.py @@ -1,14 +1,19 @@ -import torch.distributed as dist -from typing import Any, Callable, Dict, List, Optional -from coati.models.bloom import BLOOMActor, BLOOMCritic -from coati.models.gpt import GPTActor, GPTCritic -from coati.models.opt import OPTActor, OPTCritic -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from typing import Any + import torch -import os +import torch.distributed as dist +from torch.utils._pytree import tree_map def is_rank_0() -> bool: return not dist.is_initialized() or dist.get_rank() == 0 +def to_device(x: Any, device: torch.device) -> Any: + + def _to(t: Any): + if isinstance(t, torch.Tensor): + return t.to(device) + return t + + return tree_map(_to, x) diff --git a/applications/Chat/examples/train_dummy.py b/applications/Chat/examples/train_dummy.py deleted file mode 100644 index 5f34c80f0..000000000 --- a/applications/Chat/examples/train_dummy.py +++ /dev/null @@ -1,156 +0,0 @@ -import argparse -from copy import deepcopy - -import torch -from coati.models.base import RewardModel -from coati.models.bloom import BLOOMActor, BLOOMCritic -from coati.models.gpt import GPTActor, GPTCritic -from coati.models.opt import OPTActor, OPTCritic -from coati.models.roberta import RoBERTaActor, RoBERTaCritic -from coati.trainer import PPOTrainer -from coati.trainer.callbacks import SaveCheckpoint -from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy -from torch.optim import Adam -from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -from colossalai.nn.optimizer import HybridAdam - - -def preprocess_batch(samples): - input_ids = torch.stack(samples) - attention_mask = torch.ones_like(input_ids, dtype=torch.long) - return {'input_ids': input_ids, 'attention_mask': attention_mask} - - -def main(args): - # configure strategy - if args.strategy == 'naive': - strategy = NaiveStrategy() - elif args.strategy == 'ddp': - strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') - else: - raise ValueError(f'Unsupported strategy "{args.strategy}"') - - # configure model - with strategy.model_init_context(): - if args.model == 'gpt2': - actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'bloom': - actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'opt': - actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - elif args.model == 'roberta': - actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - critic = RoBERTaCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) - else: - raise ValueError(f'Unsupported model "{args.model}"') - - initial_model = deepcopy(actor).to(torch.cuda.current_device()) - reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device()) - - # configure optimizer - if args.strategy.startswith('colossalai'): - actor_optim = HybridAdam(actor.parameters(), lr=5e-6) - critic_optim = HybridAdam(critic.parameters(), lr=5e-6) - else: - actor_optim = Adam(actor.parameters(), lr=5e-6) - critic_optim = Adam(critic.parameters(), lr=5e-6) - - # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif args.model == 'roberta': - tokenizer = RobertaTokenizer.from_pretrained("roberta-base") - else: - raise ValueError(f'Unsupported model "{args.model}"') - - (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( - (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) - - callbacks = [] - if args.save_ckpt_path: - ckpt_callback = SaveCheckpoint( - args.save_ckpt_path, - args.save_ckpt_interval, - strategy, - actor, - critic, - actor_optim, - critic_optim, - ) - callbacks.append(ckpt_callback) - - # configure trainer - - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - max_epochs=args.max_epochs, - train_batch_size=args.train_batch_size, - tokenizer=preprocess_batch, - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - callbacks=callbacks) - - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 64), device=torch.cuda.current_device()) - random_attention_mask = torch.randint(1, (1000, 1, 64), device=torch.cuda.current_device()).to(torch.bool) - random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)] - trainer.fit(random_prompts, random_pretrain, - num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps) - - # save model checkpoint after fitting - trainer.save_model(args.save_path, only_rank0=True) - # save optimizer checkpoint on all ranks - if args.need_optim_ckpt: - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='naive') - parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=50) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--save_ckpt_path', - type=str, - default=None, - help="path to save checkpoint, None means not to save") - parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint") - args = parser.parse_args() - main(args) diff --git a/applications/Chat/examples/train_dummy.sh b/applications/Chat/examples/train_dummy.sh deleted file mode 100755 index 595da573e..000000000 --- a/applications/Chat/examples/train_dummy.sh +++ /dev/null @@ -1,18 +0,0 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ - | tail -n +2 \ - | nl -v 0 \ - | tee /dev/tty \ - | sort -g -k 2 \ - | awk '{print $1}' \ - | head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 2 - -torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2 diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index 2086ff003..c0455f3a7 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -71,9 +71,8 @@ def main(args): if args.rm_path is not None: reward_model.load_state_dict(state_dict) - if args.strategy != 'colossalai_gemini': - initial_model.to(torch.float16).to(torch.cuda.current_device()) - reward_model.to(torch.float16).to(torch.cuda.current_device()) + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) with strategy.model_init_context(): if args.model == 'gpt2': @@ -148,9 +147,12 @@ def main(args): prompt_dataloader = DataLoader(prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, - batch_size=args.train_batch_size) + batch_size=args.experience_batch_size) - pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384) + pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, + data_path=args.pretrain_dataset, + max_datasets_size=16384, + max_length=args.max_input_len) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) else: @@ -161,12 +163,6 @@ def main(args): batch_size=args.ptx_batch_size, collate_fn=data_collator) - def tokenize_fn(texts): - # MUST padding to max length to ensure inputs of all ranks have the same length - # Different length may lead to hang when using gemini, as different generation steps - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) - return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} - (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) # configure trainer @@ -182,9 +178,8 @@ def main(args): ptx_coef=args.ptx_coef, max_epochs=args.max_epochs, train_batch_size=args.train_batch_size, - experience_batch_size=args.experience_batch_size, - tokenizer=tokenize_fn, - max_length=128, + max_length=args.max_seq_len, + use_cache=True, do_sample=True, temperature=1.0, top_k=50, @@ -232,5 +227,7 @@ if __name__ == '__main__': parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument('--kl_coef', type=float, default=0.1) parser.add_argument('--ptx_coef', type=float, default=0.9) + parser.add_argument('--max_input_len', type=int, default=96) + parser.add_argument('--max_seq_len', type=int, default=128) args = parser.parse_args() main(args) diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index d7502c23b..d08cf7786 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -156,7 +156,7 @@ def train(args): max_epochs=args.max_epochs, accimulation_steps=args.accimulation_steps) - trainer.fit(logger=logger, log_interval=args.log_interval) + trainer.fit(logger=logger, use_wandb=args.use_wandb) # save model checkpoint after fitting on only rank0 trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) @@ -185,5 +185,6 @@ if __name__ == '__main__': parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") parser.add_argument('--lr', type=float, default=5e-6) parser.add_argument('--accimulation_steps', type=int, default=8) + parser.add_argument('--use_wandb', default=False, action='store_true') args = parser.parse_args() train(args)