[chatgpt] startegy add prepare method (#2766)

* [chatgpt] startegy add prepare method

* [chatgpt] refactor examples

* [chatgpt] refactor strategy.prepare

* [chatgpt] support save/load checkpoint

* [chatgpt] fix unwrap actor

* [chatgpt] fix unwrap actor
pull/2791/head
ver217 2023-02-17 11:27:27 +08:00 committed by GitHub
parent a2b43e393d
commit 4ee311c026
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 164 additions and 15 deletions

View File

@ -133,6 +133,9 @@ def main(args):
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
trainer = PPOTrainer(strategy, trainer = PPOTrainer(strategy,
actor, actor,
critic, critic,

View File

@ -126,6 +126,9 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
trainer = PPOTrainer(strategy, trainer = PPOTrainer(strategy,
actor, actor,
critic, critic,

View File

@ -60,11 +60,6 @@ class PPOTrainer(Trainer):
dataloader_pin_memory: bool = True, dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [], callbacks: List[Callback] = [],
**generate_kwargs) -> None: **generate_kwargs) -> None:
self._set_default_generate_kwargs(generate_kwargs, actor)
actor = Actor(strategy.setup_model(actor.model))
critic = strategy.setup_model(critic)
reward_model = strategy.setup_model(reward_model)
initial_model = Actor(strategy.setup_model(initial_model.model))
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer, super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
@ -75,8 +70,9 @@ class PPOTrainer(Trainer):
self.actor_loss_fn = PolicyLoss(eps_clip) self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip) self.critic_loss_fn = ValueLoss(value_clip)
self.actor_optim = strategy.setup_optimizer(actor_optim, self.actor.model) self.actor_optim = actor_optim
self.critic_optim = strategy.setup_optimizer(critic_optim, self.critic) self.critic_optim = critic_optim
self._set_default_generate_kwargs(generate_kwargs, actor)
def training_step(self, experience: Experience) -> Dict[str, float]: def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train() self.actor.train()
@ -106,9 +102,10 @@ class PPOTrainer(Trainer):
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None: def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
origin_model = self.strategy._unwrap_actor(actor)
# use huggingface models method directly # use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'): if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs: if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn

View File

@ -1,12 +1,17 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, List, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim from chatgpt.nn import Actor, Critic, RewardModel
from chatgpt.replay_buffer import ReplayBuffer from chatgpt.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
ModelOptimPair = Tuple[nn.Module, Optimizer]
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
class Strategy(ABC): class Strategy(ABC):
""" """
@ -18,11 +23,11 @@ class Strategy(ABC):
self.setup_distributed() self.setup_distributed()
@abstractmethod @abstractmethod
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
pass pass
@abstractmethod @abstractmethod
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
pass pass
@abstractmethod @abstractmethod
@ -34,7 +39,7 @@ class Strategy(ABC):
pass pass
@abstractmethod @abstractmethod
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
pass pass
@abstractmethod @abstractmethod
@ -43,3 +48,78 @@ class Strategy(ABC):
def model_init_context(self): def model_init_context(self):
return nullcontext() return nullcontext()
def prepare(
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
"""Prepare models or model-optimizer-pairs based on each strategy.
Example::
>>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
>>> # or just inference
>>> actor, critic = strategy.prepare(actor, critic)
Returns:
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
"""
def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(self._unwrap_model(model)))
return self.setup_model(self._unwrap_model(model))
rets = []
for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg))
else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
if len(rets) == 1:
return rets[0]
return rets
@staticmethod
def _unwrap_model(model: nn.Module) -> nn.Module:
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
Args:
model (nn.Module): an actor or a critic
"""
if isinstance(model, Actor):
return model.model
return model
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
Args:
actor (Actor): a wrapped actor
"""
return Strategy._unwrap_model(actor)
@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
pass
@abstractmethod
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
pass
@abstractmethod
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
pass
@abstractmethod
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
pass

View File

@ -1,18 +1,21 @@
import warnings import warnings
from typing import Optional from typing import Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from chatgpt.nn import Actor
from torch.optim import Optimizer
import colossalai import colossalai
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from .base import Strategy
from .ddp import DDPStrategy from .ddp import DDPStrategy
@ -129,3 +132,23 @@ class ColossalAIStrategy(DDPStrategy):
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step() optimizer.step()
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
if isinstance(model, ZeroDDP):
return model.module
return model
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model)
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0:
raise RuntimeError(
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
torch.save(optimizer.state_dict(), path)

View File

@ -5,10 +5,13 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import Actor
from chatgpt.replay_buffer import ReplayBuffer from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from .base import Strategy
from .naive import NaiveStrategy from .naive import NaiveStrategy
@ -57,3 +60,18 @@ class DDPStrategy(NaiveStrategy):
sampler=sampler, sampler=sampler,
pin_memory=pin_memory, pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn) collate_fn=replay_buffer.collate_fn)
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
model: DDP = Strategy._unwrap_actor(actor)
return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_model(model, path, only_rank0)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_optimizer(optimizer, path, only_rank0)

View File

@ -1,7 +1,10 @@
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from chatgpt.replay_buffer import ReplayBuffer from chatgpt.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .base import Strategy from .base import Strategy
@ -34,3 +37,19 @@ class NaiveStrategy(Strategy):
drop_last=True, drop_last=True,
pin_memory=pin_memory, pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn) collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model)
torch.save(unwrapped_model.state_dict(), path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self._unwrap_model(model)
state_dict = torch.load(path, map_location=map_location)
unwrapped_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path)
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
state_dict = torch.load(path, map_location=map_location)
optimizer.load_state_dict(state_dict)

View File

@ -68,6 +68,9 @@ def main(args):
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
# configure trainer # configure trainer
trainer = PPOTrainer( trainer = PPOTrainer(
strategy, strategy,

View File

@ -68,6 +68,9 @@ def main(args):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()} return {k: v.cuda() for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
# configure trainer # configure trainer
trainer = PPOTrainer( trainer = PPOTrainer(
strategy, strategy,