mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] startegy add prepare method (#2766)
* [chatgpt] startegy add prepare method * [chatgpt] refactor examples * [chatgpt] refactor strategy.prepare * [chatgpt] support save/load checkpoint * [chatgpt] fix unwrap actor * [chatgpt] fix unwrap actorpull/2791/head
parent
a2b43e393d
commit
4ee311c026
|
@ -133,6 +133,9 @@ def main(args):
|
|||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
|
|
|
@ -126,6 +126,9 @@ def main(args):
|
|||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
|
|
|
@ -60,11 +60,6 @@ class PPOTrainer(Trainer):
|
|||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
self._set_default_generate_kwargs(generate_kwargs, actor)
|
||||
actor = Actor(strategy.setup_model(actor.model))
|
||||
critic = strategy.setup_model(critic)
|
||||
reward_model = strategy.setup_model(reward_model)
|
||||
initial_model = Actor(strategy.setup_model(initial_model.model))
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
||||
|
@ -75,8 +70,9 @@ class PPOTrainer(Trainer):
|
|||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
|
||||
self.actor_optim = strategy.setup_optimizer(actor_optim, self.actor.model)
|
||||
self.critic_optim = strategy.setup_optimizer(critic_optim, self.critic)
|
||||
self.actor_optim = actor_optim
|
||||
self.critic_optim = critic_optim
|
||||
self._set_default_generate_kwargs(generate_kwargs, actor)
|
||||
|
||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
self.actor.train()
|
||||
|
@ -106,9 +102,10 @@ class PPOTrainer(Trainer):
|
|||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||
|
||||
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = self.strategy._unwrap_actor(actor)
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
|
||||
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from chatgpt.nn import Actor, Critic, RewardModel
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
ModelOptimPair = Tuple[nn.Module, Optimizer]
|
||||
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
|
||||
|
||||
|
||||
class Strategy(ABC):
|
||||
"""
|
||||
|
@ -18,11 +23,11 @@ class Strategy(ABC):
|
|||
self.setup_distributed()
|
||||
|
||||
@abstractmethod
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -34,7 +39,7 @@ class Strategy(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -43,3 +48,78 @@ class Strategy(ABC):
|
|||
|
||||
def model_init_context(self):
|
||||
return nullcontext()
|
||||
|
||||
def prepare(
|
||||
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
|
||||
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
|
||||
"""Prepare models or model-optimizer-pairs based on each strategy.
|
||||
|
||||
Example::
|
||||
>>> # when fine-tuning actor and critic
|
||||
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
>>> # or when training reward model
|
||||
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
|
||||
>>> # or just inference
|
||||
>>> actor, critic = strategy.prepare(actor, critic)
|
||||
|
||||
Returns:
|
||||
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
||||
"""
|
||||
|
||||
def prepare_model(model: nn.Module):
|
||||
if isinstance(model, Actor):
|
||||
return Actor(self.setup_model(self._unwrap_model(model)))
|
||||
return self.setup_model(self._unwrap_model(model))
|
||||
|
||||
rets = []
|
||||
for arg in models_or_model_optim_pairs:
|
||||
if isinstance(arg, tuple):
|
||||
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
||||
model, optimizer = arg
|
||||
model = prepare_model(model)
|
||||
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
|
||||
rets.append((model, optimizer))
|
||||
elif isinstance(arg, nn.Module):
|
||||
rets.append(prepare_model(arg))
|
||||
else:
|
||||
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
||||
|
||||
if len(rets) == 1:
|
||||
return rets[0]
|
||||
return rets
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
|
||||
|
||||
Args:
|
||||
model (nn.Module): an actor or a critic
|
||||
"""
|
||||
if isinstance(model, Actor):
|
||||
return model.model
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
|
||||
|
||||
Args:
|
||||
actor (Actor): a wrapped actor
|
||||
"""
|
||||
return Strategy._unwrap_model(actor)
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||
pass
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from chatgpt.nn import Actor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from .base import Strategy
|
||||
from .ddp import DDPStrategy
|
||||
|
||||
|
||||
|
@ -129,3 +132,23 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.step()
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
|
||||
if isinstance(model, ZeroDDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0:
|
||||
raise RuntimeError(
|
||||
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
|
|
@ -5,10 +5,13 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import Actor
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from .base import Strategy
|
||||
from .naive import NaiveStrategy
|
||||
|
||||
|
||||
|
@ -57,3 +60,18 @@ class DDPStrategy(NaiveStrategy):
|
|||
sampler=sampler,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
model: DDP = Strategy._unwrap_actor(actor)
|
||||
return model.module
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
super().save_model(model, path, only_rank0)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
super().save_optimizer(optimizer, path, only_rank0)
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .base import Strategy
|
||||
|
@ -34,3 +37,19 @@ class NaiveStrategy(Strategy):
|
|||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
torch.save(unwrapped_model.state_dict(), path)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
optimizer.load_state_dict(state_dict)
|
||||
|
|
|
@ -68,6 +68,9 @@ def main(args):
|
|||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
|
|
|
@ -68,6 +68,9 @@ def main(args):
|
|||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
|
|
Loading…
Reference in New Issue