mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] startegy add prepare method (#2766)
* [chatgpt] startegy add prepare method * [chatgpt] refactor examples * [chatgpt] refactor strategy.prepare * [chatgpt] support save/load checkpoint * [chatgpt] fix unwrap actor * [chatgpt] fix unwrap actorpull/2791/head
parent
a2b43e393d
commit
4ee311c026
|
@ -133,6 +133,9 @@ def main(args):
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||||
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||||
|
|
||||||
trainer = PPOTrainer(strategy,
|
trainer = PPOTrainer(strategy,
|
||||||
actor,
|
actor,
|
||||||
critic,
|
critic,
|
||||||
|
|
|
@ -126,6 +126,9 @@ def main(args):
|
||||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||||
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||||
|
|
||||||
trainer = PPOTrainer(strategy,
|
trainer = PPOTrainer(strategy,
|
||||||
actor,
|
actor,
|
||||||
critic,
|
critic,
|
||||||
|
|
|
@ -60,11 +60,6 @@ class PPOTrainer(Trainer):
|
||||||
dataloader_pin_memory: bool = True,
|
dataloader_pin_memory: bool = True,
|
||||||
callbacks: List[Callback] = [],
|
callbacks: List[Callback] = [],
|
||||||
**generate_kwargs) -> None:
|
**generate_kwargs) -> None:
|
||||||
self._set_default_generate_kwargs(generate_kwargs, actor)
|
|
||||||
actor = Actor(strategy.setup_model(actor.model))
|
|
||||||
critic = strategy.setup_model(critic)
|
|
||||||
reward_model = strategy.setup_model(reward_model)
|
|
||||||
initial_model = Actor(strategy.setup_model(initial_model.model))
|
|
||||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||||
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
||||||
|
@ -75,8 +70,9 @@ class PPOTrainer(Trainer):
|
||||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||||
self.critic_loss_fn = ValueLoss(value_clip)
|
self.critic_loss_fn = ValueLoss(value_clip)
|
||||||
|
|
||||||
self.actor_optim = strategy.setup_optimizer(actor_optim, self.actor.model)
|
self.actor_optim = actor_optim
|
||||||
self.critic_optim = strategy.setup_optimizer(critic_optim, self.critic)
|
self.critic_optim = critic_optim
|
||||||
|
self._set_default_generate_kwargs(generate_kwargs, actor)
|
||||||
|
|
||||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
|
@ -106,9 +102,10 @@ class PPOTrainer(Trainer):
|
||||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||||
|
|
||||||
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
|
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
|
||||||
|
origin_model = self.strategy._unwrap_actor(actor)
|
||||||
# use huggingface models method directly
|
# use huggingface models method directly
|
||||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||||
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation
|
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||||
|
|
||||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||||
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||||
|
|
|
@ -1,12 +1,17 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
from chatgpt.nn import Actor, Critic, RewardModel
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
from chatgpt.replay_buffer import ReplayBuffer
|
||||||
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
ModelOptimPair = Tuple[nn.Module, Optimizer]
|
||||||
|
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
|
||||||
|
|
||||||
|
|
||||||
class Strategy(ABC):
|
class Strategy(ABC):
|
||||||
"""
|
"""
|
||||||
|
@ -18,11 +23,11 @@ class Strategy(ABC):
|
||||||
self.setup_distributed()
|
self.setup_distributed()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -34,7 +39,7 @@ class Strategy(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -43,3 +48,78 @@ class Strategy(ABC):
|
||||||
|
|
||||||
def model_init_context(self):
|
def model_init_context(self):
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
|
||||||
|
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
|
||||||
|
"""Prepare models or model-optimizer-pairs based on each strategy.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> # when fine-tuning actor and critic
|
||||||
|
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||||
|
>>> # or when training reward model
|
||||||
|
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
|
||||||
|
>>> # or just inference
|
||||||
|
>>> actor, critic = strategy.prepare(actor, critic)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prepare_model(model: nn.Module):
|
||||||
|
if isinstance(model, Actor):
|
||||||
|
return Actor(self.setup_model(self._unwrap_model(model)))
|
||||||
|
return self.setup_model(self._unwrap_model(model))
|
||||||
|
|
||||||
|
rets = []
|
||||||
|
for arg in models_or_model_optim_pairs:
|
||||||
|
if isinstance(arg, tuple):
|
||||||
|
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
||||||
|
model, optimizer = arg
|
||||||
|
model = prepare_model(model)
|
||||||
|
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
|
||||||
|
rets.append((model, optimizer))
|
||||||
|
elif isinstance(arg, nn.Module):
|
||||||
|
rets.append(prepare_model(arg))
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
||||||
|
|
||||||
|
if len(rets) == 1:
|
||||||
|
return rets[0]
|
||||||
|
return rets
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
|
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): an actor or a critic
|
||||||
|
"""
|
||||||
|
if isinstance(model, Actor):
|
||||||
|
return model.model
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||||
|
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor (Actor): a wrapped actor
|
||||||
|
"""
|
||||||
|
return Strategy._unwrap_model(actor)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||||
|
pass
|
||||||
|
|
|
@ -1,18 +1,21 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from chatgpt.nn import Actor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
|
from .base import Strategy
|
||||||
from .ddp import DDPStrategy
|
from .ddp import DDPStrategy
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,3 +132,23 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
|
|
||||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||||
|
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
|
||||||
|
if isinstance(model, ZeroDDP):
|
||||||
|
return model.module
|
||||||
|
return model
|
||||||
|
|
||||||
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||||
|
unwrapped_model = self._unwrap_model(model)
|
||||||
|
state_dict = unwrapped_model.state_dict()
|
||||||
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
|
return
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
|
||||||
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||||
|
if only_rank0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
|
||||||
|
torch.save(optimizer.state_dict(), path)
|
||||||
|
|
|
@ -5,10 +5,13 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from chatgpt.nn import Actor
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
from chatgpt.replay_buffer import ReplayBuffer
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
|
||||||
|
from .base import Strategy
|
||||||
from .naive import NaiveStrategy
|
from .naive import NaiveStrategy
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,3 +60,18 @@ class DDPStrategy(NaiveStrategy):
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
collate_fn=replay_buffer.collate_fn)
|
collate_fn=replay_buffer.collate_fn)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||||
|
model: DDP = Strategy._unwrap_actor(actor)
|
||||||
|
return model.module
|
||||||
|
|
||||||
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||||
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
|
return
|
||||||
|
super().save_model(model, path, only_rank0)
|
||||||
|
|
||||||
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||||
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
|
return
|
||||||
|
super().save_optimizer(optimizer, path, only_rank0)
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
from chatgpt.replay_buffer import ReplayBuffer
|
||||||
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .base import Strategy
|
from .base import Strategy
|
||||||
|
@ -34,3 +37,19 @@ class NaiveStrategy(Strategy):
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
collate_fn=replay_buffer.collate_fn)
|
collate_fn=replay_buffer.collate_fn)
|
||||||
|
|
||||||
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||||
|
unwrapped_model = self._unwrap_model(model)
|
||||||
|
torch.save(unwrapped_model.state_dict(), path)
|
||||||
|
|
||||||
|
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||||
|
unwrapped_model = self._unwrap_model(model)
|
||||||
|
state_dict = torch.load(path, map_location=map_location)
|
||||||
|
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||||
|
torch.save(optimizer.state_dict(), path)
|
||||||
|
|
||||||
|
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||||
|
state_dict = torch.load(path, map_location=map_location)
|
||||||
|
optimizer.load_state_dict(state_dict)
|
||||||
|
|
|
@ -68,6 +68,9 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||||
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||||
|
|
||||||
# configure trainer
|
# configure trainer
|
||||||
trainer = PPOTrainer(
|
trainer = PPOTrainer(
|
||||||
strategy,
|
strategy,
|
||||||
|
|
|
@ -68,6 +68,9 @@ def main(args):
|
||||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||||||
return {k: v.cuda() for k, v in batch.items()}
|
return {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||||
|
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||||
|
|
||||||
# configure trainer
|
# configure trainer
|
||||||
trainer = PPOTrainer(
|
trainer = PPOTrainer(
|
||||||
strategy,
|
strategy,
|
||||||
|
|
Loading…
Reference in New Issue