diff --git a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py index 3e66e4e7a..b5730c7c7 100644 --- a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py +++ b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py @@ -133,6 +133,9 @@ def main(args): tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + trainer = PPOTrainer(strategy, actor, critic, diff --git a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py index 8cee5489e..6777cb770 100644 --- a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py @@ -126,6 +126,9 @@ def main(args): tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') tokenizer.pad_token = tokenizer.eos_token + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + trainer = PPOTrainer(strategy, actor, critic, diff --git a/applications/ChatGPT/chatgpt/trainer/ppo.py b/applications/ChatGPT/chatgpt/trainer/ppo.py index b1d11b224..2c1fd2fb6 100644 --- a/applications/ChatGPT/chatgpt/trainer/ppo.py +++ b/applications/ChatGPT/chatgpt/trainer/ppo.py @@ -60,11 +60,6 @@ class PPOTrainer(Trainer): dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], **generate_kwargs) -> None: - self._set_default_generate_kwargs(generate_kwargs, actor) - actor = Actor(strategy.setup_model(actor.model)) - critic = strategy.setup_model(critic) - reward_model = strategy.setup_model(reward_model) - initial_model = Actor(strategy.setup_model(initial_model.model)) experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer, @@ -75,8 +70,9 @@ class PPOTrainer(Trainer): self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) - self.actor_optim = strategy.setup_optimizer(actor_optim, self.actor.model) - self.critic_optim = strategy.setup_optimizer(critic_optim, self.critic) + self.actor_optim = actor_optim + self.critic_optim = critic_optim + self._set_default_generate_kwargs(generate_kwargs, actor) def training_step(self, experience: Experience) -> Dict[str, float]: self.actor.train() @@ -106,9 +102,10 @@ class PPOTrainer(Trainer): return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None: + origin_model = self.strategy._unwrap_actor(actor) # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'): - generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): + generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation if 'update_model_kwargs_fn' not in generate_kwargs: generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/base.py b/applications/ChatGPT/chatgpt/trainer/strategies/base.py index 3a2923b8c..2c6aefcd9 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/base.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/base.py @@ -1,12 +1,17 @@ from abc import ABC, abstractmethod from contextlib import nullcontext +from typing import Any, List, Tuple, Union import torch import torch.nn as nn -import torch.optim as optim +from chatgpt.nn import Actor, Critic, RewardModel from chatgpt.replay_buffer import ReplayBuffer +from torch.optim import Optimizer from torch.utils.data import DataLoader +ModelOptimPair = Tuple[nn.Module, Optimizer] +ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair] + class Strategy(ABC): """ @@ -18,11 +23,11 @@ class Strategy(ABC): self.setup_distributed() @abstractmethod - def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: + def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None: pass @abstractmethod - def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: + def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None: pass @abstractmethod @@ -34,7 +39,7 @@ class Strategy(ABC): pass @abstractmethod - def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: + def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer: pass @abstractmethod @@ -43,3 +48,78 @@ class Strategy(ABC): def model_init_context(self): return nullcontext() + + def prepare( + self, *models_or_model_optim_pairs: ModelOrModelOptimPair + ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: + """Prepare models or model-optimizer-pairs based on each strategy. + + Example:: + >>> # when fine-tuning actor and critic + >>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + >>> # or when training reward model + >>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim)) + >>> # or just inference + >>> actor, critic = strategy.prepare(actor, critic) + + Returns: + Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order. + """ + + def prepare_model(model: nn.Module): + if isinstance(model, Actor): + return Actor(self.setup_model(self._unwrap_model(model))) + return self.setup_model(self._unwrap_model(model)) + + rets = [] + for arg in models_or_model_optim_pairs: + if isinstance(arg, tuple): + assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"' + model, optimizer = arg + model = prepare_model(model) + optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model)) + rets.append((model, optimizer)) + elif isinstance(arg, nn.Module): + rets.append(prepare_model(arg)) + else: + raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}') + + if len(rets) == 1: + return rets[0] + return rets + + @staticmethod + def _unwrap_model(model: nn.Module) -> nn.Module: + """Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving. + + Args: + model (nn.Module): an actor or a critic + """ + if isinstance(model, Actor): + return model.model + return model + + @staticmethod + def _unwrap_actor(actor: Actor) -> nn.Module: + """Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model. + + Args: + actor (Actor): a wrapped actor + """ + return Strategy._unwrap_model(actor) + + @abstractmethod + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: + pass + + @abstractmethod + def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: + pass + + @abstractmethod + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + pass + + @abstractmethod + def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: + pass diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py index 578844bdb..bf4ecdfdf 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py @@ -1,18 +1,21 @@ import warnings -from typing import Optional +from typing import Optional, Union import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim +from chatgpt.nn import Actor +from torch.optim import Optimizer import colossalai from colossalai.nn.optimizer import CPUAdam, HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext +from .base import Strategy from .ddp import DDPStrategy @@ -129,3 +132,23 @@ class ColossalAIStrategy(DDPStrategy): def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: optimizer.step() + + @staticmethod + def _unwrap_actor(actor: Actor) -> nn.Module: + model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor) + if isinstance(model, ZeroDDP): + return model.module + return model + + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: + unwrapped_model = self._unwrap_model(model) + state_dict = unwrapped_model.state_dict() + if only_rank0 and dist.get_rank() != 0: + return + torch.save(state_dict, path) + + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + if only_rank0: + raise RuntimeError( + f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.') + torch.save(optimizer.state_dict(), path) diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py index b636515b4..7ceb3a3ca 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py @@ -5,10 +5,13 @@ import numpy as np import torch import torch.distributed as dist import torch.nn as nn +from chatgpt.nn import Actor from chatgpt.replay_buffer import ReplayBuffer from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler +from .base import Strategy from .naive import NaiveStrategy @@ -57,3 +60,18 @@ class DDPStrategy(NaiveStrategy): sampler=sampler, pin_memory=pin_memory, collate_fn=replay_buffer.collate_fn) + + @staticmethod + def _unwrap_actor(actor: Actor) -> nn.Module: + model: DDP = Strategy._unwrap_actor(actor) + return model.module + + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: + if only_rank0 and dist.get_rank() != 0: + return + super().save_model(model, path, only_rank0) + + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + if only_rank0 and dist.get_rank() != 0: + return + super().save_optimizer(optimizer, path, only_rank0) diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/naive.py b/applications/ChatGPT/chatgpt/trainer/strategies/naive.py index 1bb472ae6..99b8d6635 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/naive.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/naive.py @@ -1,7 +1,10 @@ +from typing import Any + import torch import torch.nn as nn import torch.optim as optim from chatgpt.replay_buffer import ReplayBuffer +from torch.optim import Optimizer from torch.utils.data import DataLoader from .base import Strategy @@ -34,3 +37,19 @@ class NaiveStrategy(Strategy): drop_last=True, pin_memory=pin_memory, collate_fn=replay_buffer.collate_fn) + + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: + unwrapped_model = self._unwrap_model(model) + torch.save(unwrapped_model.state_dict(), path) + + def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: + unwrapped_model = self._unwrap_model(model) + state_dict = torch.load(path, map_location=map_location) + unwrapped_model.load_state_dict(state_dict, strict=strict) + + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: + torch.save(optimizer.state_dict(), path) + + def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: + state_dict = torch.load(path, map_location=map_location) + optimizer.load_state_dict(state_dict) diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py index a14117ed5..f98b4792d 100644 --- a/applications/ChatGPT/examples/train_dummy.py +++ b/applications/ChatGPT/examples/train_dummy.py @@ -68,6 +68,9 @@ def main(args): else: raise ValueError(f'Unsupported model "{args.model}"') + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + # configure trainer trainer = PPOTrainer( strategy, diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py index cf351b91a..e79b2acf1 100644 --- a/applications/ChatGPT/examples/train_prompts.py +++ b/applications/ChatGPT/examples/train_prompts.py @@ -68,6 +68,9 @@ def main(args): batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) return {k: v.cuda() for k, v in batch.items()} + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + # configure trainer trainer = PPOTrainer( strategy,