diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index 9d7c1ff99..af59c8db2 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -34,6 +34,7 @@ jobs: - name: Execute Examples run: | + cd applications/ChatGPT ./examples/test_ci.sh env: NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml index 3ac0d2d8c..8dcf21fe2 100644 --- a/.github/workflows/run_chatgpt_unit_tests.yml +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -35,6 +35,7 @@ jobs: - name: Execute Unit Testing run: | + cd applications/ChatGPT pytest tests/ env: NCCL_SHM_DISABLE: 1 diff --git a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py index 8474f3ba7..3e66e4e7a 100644 --- a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py +++ b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py @@ -5,7 +5,6 @@ import torch import torch.distributed as dist import torch.nn as nn from chatgpt.nn import GPTActor, GPTCritic, RewardModel -from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy @@ -151,8 +150,6 @@ def main(args): top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, - prepare_inputs_fn=gpt_prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, callbacks=[performance_evaluator]) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) diff --git a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py index accbc4155..8cee5489e 100644 --- a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py @@ -5,7 +5,6 @@ import torch import torch.distributed as dist import torch.nn as nn from chatgpt.nn import OPTActor, OPTCritic, RewardModel -from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy @@ -144,8 +143,6 @@ def main(args): top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, - prepare_inputs_fn=opt_prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, callbacks=[performance_evaluator]) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) diff --git a/applications/ChatGPT/chatgpt/trainer/ppo.py b/applications/ChatGPT/chatgpt/trainer/ppo.py index 85beb223e..b1d11b224 100644 --- a/applications/ChatGPT/chatgpt/trainer/ppo.py +++ b/applications/ChatGPT/chatgpt/trainer/ppo.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional import torch.nn as nn from chatgpt.experience_maker import Experience, NaiveExperienceMaker from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss +from chatgpt.nn.generation_utils import update_model_kwargs_fn from chatgpt.replay_buffer import NaiveReplayBuffer from torch.optim import Optimizer @@ -59,6 +60,7 @@ class PPOTrainer(Trainer): dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: + self._set_default_generate_kwargs(generate_kwargs, actor) actor = Actor(strategy.setup_model(actor.model)) critic = strategy.setup_model(critic) reward_model = strategy.setup_model(reward_model) @@ -102,3 +104,11 @@ class PPOTrainer(Trainer): self.critic_optim.zero_grad() return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} + + def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None: + # use huggingface models method directly + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'): + generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation + + if 'update_model_kwargs_fn' not in generate_kwargs: + generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py index 313be2c3b..a14117ed5 100644 --- a/applications/ChatGPT/examples/train_dummy.py +++ b/applications/ChatGPT/examples/train_dummy.py @@ -3,12 +3,6 @@ from copy import deepcopy import torch from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel -from chatgpt.nn.generation_utils import ( - bloom_prepare_inputs_fn, - gpt_prepare_inputs_fn, - opt_prepare_inputs_fn, - update_model_kwargs_fn, -) from chatgpt.trainer import PPOTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from torch.optim import Adam @@ -66,36 +60,33 @@ def main(args): if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token - prepare_inputs_fn = gpt_prepare_inputs_fn elif args.model == 'bloom': tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer.pad_token = tokenizer.eos_token - prepare_inputs_fn = bloom_prepare_inputs_fn elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - prepare_inputs_fn = opt_prepare_inputs_fn else: raise ValueError(f'Unsupported model "{args.model}"') # configure trainer - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - max_epochs=args.max_epochs, - train_batch_size=args.train_batch_size, - tokenizer=preprocess_batch, - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - prepare_inputs_fn=prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn) + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + tokenizer=preprocess_batch, + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) trainer.fit(random_prompts, diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py index 994b10fe0..cf351b91a 100644 --- a/applications/ChatGPT/examples/train_prompts.py +++ b/applications/ChatGPT/examples/train_prompts.py @@ -3,7 +3,6 @@ from copy import deepcopy import pandas as pd from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel -from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn from chatgpt.trainer import PPOTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from torch.optim import Adam @@ -70,24 +69,24 @@ def main(args): return {k: v.cuda() for k, v in batch.items()} # configure trainer - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - max_epochs=args.max_epochs, - train_batch_size=args.train_batch_size, - tokenizer=tokenize_fn, - max_length=128, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - prepare_inputs_fn=gpt_prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn) + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + tokenizer=tokenize_fn, + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) trainer.fit(dataset, num_episodes=args.num_episodes,