[chatgpt] startegy add prepare method (#2766)

* [chatgpt] startegy add prepare method

* [chatgpt] refactor examples

* [chatgpt] refactor strategy.prepare

* [chatgpt] support save/load checkpoint

* [chatgpt] fix unwrap actor

* [chatgpt] fix unwrap actor
pull/2791/head
ver217 2023-02-17 11:27:27 +08:00 committed by GitHub
parent a2b43e393d
commit 4ee311c026
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 164 additions and 15 deletions

View File

@ -133,6 +133,9 @@ def main(args):
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
trainer = PPOTrainer(strategy,
actor,
critic,

View File

@ -126,6 +126,9 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
trainer = PPOTrainer(strategy,
actor,
critic,

View File

@ -60,11 +60,6 @@ class PPOTrainer(Trainer):
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
self._set_default_generate_kwargs(generate_kwargs, actor)
actor = Actor(strategy.setup_model(actor.model))
critic = strategy.setup_model(critic)
reward_model = strategy.setup_model(reward_model)
initial_model = Actor(strategy.setup_model(initial_model.model))
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
@ -75,8 +70,9 @@ class PPOTrainer(Trainer):
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
self.actor_optim = strategy.setup_optimizer(actor_optim, self.actor.model)
self.critic_optim = strategy.setup_optimizer(critic_optim, self.critic)
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self._set_default_generate_kwargs(generate_kwargs, actor)
def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
@ -106,9 +102,10 @@ class PPOTrainer(Trainer):
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
origin_model = self.strategy._unwrap_actor(actor)
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn

View File

@ -1,12 +1,17 @@
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Any, List, Tuple, Union
import torch
import torch.nn as nn
import torch.optim as optim
from chatgpt.nn import Actor, Critic, RewardModel
from chatgpt.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
ModelOptimPair = Tuple[nn.Module, Optimizer]
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
class Strategy(ABC):
"""
@ -18,11 +23,11 @@ class Strategy(ABC):
self.setup_distributed()
@abstractmethod
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
pass
@abstractmethod
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
pass
@abstractmethod
@ -34,7 +39,7 @@ class Strategy(ABC):
pass
@abstractmethod
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
pass
@abstractmethod
@ -43,3 +48,78 @@ class Strategy(ABC):
def model_init_context(self):
return nullcontext()
def prepare(
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
"""Prepare models or model-optimizer-pairs based on each strategy.
Example::
>>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
>>> # or just inference
>>> actor, critic = strategy.prepare(actor, critic)
Returns:
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
"""
def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(self._unwrap_model(model)))
return self.setup_model(self._unwrap_model(model))
rets = []
for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg))
else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
if len(rets) == 1:
return rets[0]
return rets
@staticmethod
def _unwrap_model(model: nn.Module) -> nn.Module:
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
Args:
model (nn.Module): an actor or a critic
"""
if isinstance(model, Actor):
return model.model
return model
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
Args:
actor (Actor): a wrapped actor
"""
return Strategy._unwrap_model(actor)
@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
pass
@abstractmethod
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
pass
@abstractmethod
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
pass
@abstractmethod
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
pass

View File

@ -1,18 +1,21 @@
import warnings
from typing import Optional
from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from chatgpt.nn import Actor
from torch.optim import Optimizer
import colossalai
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from .base import Strategy
from .ddp import DDPStrategy
@ -129,3 +132,23 @@ class ColossalAIStrategy(DDPStrategy):
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
if isinstance(model, ZeroDDP):
return model.module
return model
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model)
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0:
raise RuntimeError(
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
torch.save(optimizer.state_dict(), path)

View File

@ -5,10 +5,13 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import Actor
from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from .base import Strategy
from .naive import NaiveStrategy
@ -57,3 +60,18 @@ class DDPStrategy(NaiveStrategy):
sampler=sampler,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
model: DDP = Strategy._unwrap_actor(actor)
return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_model(model, path, only_rank0)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_optimizer(optimizer, path, only_rank0)

View File

@ -1,7 +1,10 @@
from typing import Any
import torch
import torch.nn as nn
import torch.optim as optim
from chatgpt.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .base import Strategy
@ -34,3 +37,19 @@ class NaiveStrategy(Strategy):
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model)
torch.save(unwrapped_model.state_dict(), path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self._unwrap_model(model)
state_dict = torch.load(path, map_location=map_location)
unwrapped_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path)
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
state_dict = torch.load(path, map_location=map_location)
optimizer.load_state_dict(state_dict)

View File

@ -68,6 +68,9 @@ def main(args):
else:
raise ValueError(f'Unsupported model "{args.model}"')
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
# configure trainer
trainer = PPOTrainer(
strategy,

View File

@ -68,6 +68,9 @@ def main(args):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
# configure trainer
trainer = PPOTrainer(
strategy,