2023-03-28 12:25:36 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from contextlib import nullcontext
|
2023-06-25 09:36:21 +00:00
|
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from coati.replay_buffer import ReplayBuffer
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
from colossalai.booster import Booster
|
|
|
|
from colossalai.booster.plugin import Plugin
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
from .sampler import DistributedSampler
|
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Strategy(ABC):
|
|
|
|
"""
|
|
|
|
Base class for training strategies.
|
|
|
|
"""
|
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
|
2023-03-28 12:25:36 +00:00
|
|
|
super().__init__()
|
2023-06-25 09:36:21 +00:00
|
|
|
# NOTE: dist must be initialized before Booster
|
2023-03-28 12:25:36 +00:00
|
|
|
self.setup_distributed()
|
2023-06-25 09:36:21 +00:00
|
|
|
self.plugin = plugin_initializer()
|
|
|
|
self.booster = Booster(plugin=self.plugin)
|
|
|
|
self._post_init()
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
2023-06-25 09:36:21 +00:00
|
|
|
def _post_init(self) -> None:
|
2023-03-28 12:25:36 +00:00
|
|
|
pass
|
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
|
|
|
self.booster.backward(loss, optimizer)
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
|
2023-06-25 09:36:21 +00:00
|
|
|
optimizer.step()
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def setup_distributed(self) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def model_init_context(self):
|
|
|
|
return nullcontext()
|
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
|
|
|
|
"""Prepare [model | (model, optimizer) | Dict] based on each strategy.
|
|
|
|
NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
Example::
|
2023-06-25 09:36:21 +00:00
|
|
|
>>> # e.g., include lr_scheduler
|
|
|
|
>>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
|
2023-03-28 12:25:36 +00:00
|
|
|
>>> # when fine-tuning actor and critic
|
|
|
|
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
|
|
|
>>> # or when training reward model
|
|
|
|
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
|
|
|
|
>>> # or just inference
|
|
|
|
>>> actor, critic = strategy.prepare(actor, critic)
|
|
|
|
|
|
|
|
Returns:
|
2023-06-25 09:36:21 +00:00
|
|
|
Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
|
2023-03-28 12:25:36 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
rets = []
|
2023-06-25 09:36:21 +00:00
|
|
|
for arg in boost_args:
|
|
|
|
if isinstance(arg, nn.Module):
|
|
|
|
model, *_ = self.booster.boost(arg)
|
|
|
|
rets.append(model)
|
|
|
|
elif isinstance(arg, tuple):
|
|
|
|
try:
|
|
|
|
model, optimizer = arg
|
|
|
|
except ValueError:
|
|
|
|
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
|
|
|
|
model, optimizer, *_ = self.booster.boost(model=model,
|
|
|
|
optimizer=optimizer)
|
2023-03-28 12:25:36 +00:00
|
|
|
rets.append((model, optimizer))
|
2023-06-25 09:36:21 +00:00
|
|
|
elif isinstance(arg, Dict):
|
|
|
|
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
|
|
|
|
boost_result = dict(model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
criterion=criterion,
|
|
|
|
dataloader=dataloader,
|
|
|
|
lr_scheduler=lr_scheduler)
|
|
|
|
# remove None values
|
|
|
|
boost_result = {
|
|
|
|
key: value
|
|
|
|
for key, value in boost_result.items() if value is not None
|
|
|
|
}
|
|
|
|
rets.append(boost_result)
|
2023-03-28 12:25:36 +00:00
|
|
|
else:
|
2023-06-25 09:36:21 +00:00
|
|
|
raise RuntimeError(f'Type {type(arg)} is not supported')
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
return rets[0] if len(rets) == 1 else rets
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2023-04-27 10:41:49 +00:00
|
|
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
2023-06-13 05:31:56 +00:00
|
|
|
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
Args:
|
2023-04-27 10:41:49 +00:00
|
|
|
model (nn.Module): the model to unwrap
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-04-27 10:41:49 +00:00
|
|
|
Returns:
|
2023-06-13 05:31:56 +00:00
|
|
|
nn.Module: the original model
|
2023-03-28 12:25:36 +00:00
|
|
|
"""
|
2023-06-13 05:31:56 +00:00
|
|
|
return model
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
def save_model(self,
|
|
|
|
model: nn.Module,
|
|
|
|
path: str,
|
|
|
|
only_rank0: bool = True,
|
|
|
|
**kwargs
|
|
|
|
) -> None:
|
|
|
|
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
|
|
|
|
self.booster.load_model(model, path, strict)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
def save_optimizer(self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
path: str,
|
|
|
|
only_rank0: bool = False,
|
|
|
|
**kwargs
|
|
|
|
) -> None:
|
|
|
|
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-06-25 09:36:21 +00:00
|
|
|
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
|
|
|
|
self.booster.load_optimizer(optimizer, path)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def setup_sampler(self, dataset) -> DistributedSampler:
|
2023-06-25 09:36:21 +00:00
|
|
|
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
|
2023-03-28 12:25:36 +00:00
|
|
|
return DistributedSampler(dataset, 1, 0)
|
2023-04-27 10:41:49 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def save_pretrained(self,
|
|
|
|
model: nn.Module,
|
|
|
|
path: str,
|
|
|
|
only_rank0: bool = True,
|
|
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
|
|
|
pass
|
2023-06-07 02:41:16 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
2023-06-13 05:31:56 +00:00
|
|
|
pass
|