ColossalAI/applications/Chat/coati/trainer/strategies/base.py

152 lines
5.7 KiB
Python
Raw Normal View History

2023-03-28 12:25:36 +00:00
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Tuple, Union
2023-03-28 12:25:36 +00:00
import torch
import torch.nn as nn
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster.plugin import Plugin
2023-03-28 12:25:36 +00:00
from .sampler import DistributedSampler
_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
2023-03-28 12:25:36 +00:00
class Strategy(ABC):
"""
Base class for training strategies.
"""
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
2023-03-28 12:25:36 +00:00
super().__init__()
# NOTE: dist must be initialized before Booster
2023-03-28 12:25:36 +00:00
self.setup_distributed()
self.plugin = plugin_initializer()
self.booster = Booster(plugin=self.plugin)
self._post_init()
2023-03-28 12:25:36 +00:00
@abstractmethod
def _post_init(self) -> None:
2023-03-28 12:25:36 +00:00
pass
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
self.booster.backward(loss, optimizer)
2023-03-28 12:25:36 +00:00
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
optimizer.step()
2023-03-28 12:25:36 +00:00
@abstractmethod
def setup_distributed(self) -> None:
pass
@abstractmethod
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
pass
def model_init_context(self):
return nullcontext()
def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
"""Prepare [model | (model, optimizer) | Dict] based on each strategy.
NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
2023-03-28 12:25:36 +00:00
Example::
>>> # e.g., include lr_scheduler
>>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
2023-03-28 12:25:36 +00:00
>>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
>>> # or just inference
>>> actor, critic = strategy.prepare(actor, critic)
Returns:
Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
2023-03-28 12:25:36 +00:00
"""
rets = []
for arg in boost_args:
if isinstance(arg, nn.Module):
model, *_ = self.booster.boost(arg)
rets.append(model)
elif isinstance(arg, tuple):
try:
model, optimizer = arg
except ValueError:
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
model, optimizer, *_ = self.booster.boost(model=model,
optimizer=optimizer)
2023-03-28 12:25:36 +00:00
rets.append((model, optimizer))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
boost_result = dict(model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
# remove None values
boost_result = {
key: value
for key, value in boost_result.items() if value is not None
}
rets.append(boost_result)
2023-03-28 12:25:36 +00:00
else:
raise RuntimeError(f'Type {type(arg)} is not supported')
2023-03-28 12:25:36 +00:00
return rets[0] if len(rets) == 1 else rets
2023-03-28 12:25:36 +00:00
@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
2023-03-28 12:25:36 +00:00
Args:
model (nn.Module): the model to unwrap
2023-03-28 12:25:36 +00:00
Returns:
nn.Module: the original model
2023-03-28 12:25:36 +00:00
"""
return model
2023-03-28 12:25:36 +00:00
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
**kwargs
) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
2023-03-28 12:25:36 +00:00
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)
2023-03-28 12:25:36 +00:00
def save_optimizer(self,
optimizer: Optimizer,
path: str,
only_rank0: bool = False,
**kwargs
) -> None:
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
2023-03-28 12:25:36 +00:00
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
self.booster.load_optimizer(optimizer, path)
2023-03-28 12:25:36 +00:00
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
2023-03-28 12:25:36 +00:00
return DistributedSampler(dataset, 1, 0)
@abstractmethod
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
pass
[chat] add distributed PPO trainer (#3740) * Detached ppo (#9) * run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments * working on pipeline strategy. in progress. * remove pipeline code. clean this branch * update remote parameters by state_dict. no test * nothing * state_dict sharding transfer * merge debug branch * gemini _unwrap_model fix * simplify code * simplify code & fix LoRALinear AttributeError * critic unwrapped state_dict --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add perfomance evaluator and fix bugs (#10) * [chat] add performance evaluator for ray * [chat] refactor debug arg * [chat] support hf config * [chat] fix generation * [chat] add 1mmt dummy example * [chat] fix gemini ckpt * split experience to send (#11) Co-authored-by: csric <richcsr256@gmail.com> * [chat] refactor trainer and maker (#12) * [chat] refactor experience maker holder * [chat] refactor model init * [chat] refactor trainer args * [chat] refactor model init * [chat] refactor trainer * [chat] refactor experience sending logic and training loop args (#13) * [chat] refactor experience send logic * [chat] refactor trainer * [chat] refactor trainer * [chat] refactor experience maker * [chat] refactor pbar * [chat] refactor example folder (#14) * [chat] support quant (#15) * [chat] add quant * [chat] add quant example * prompt example (#16) * prompt example * prompt load csv data * remove legacy try --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add mmmt dummy example and refactor experience sending (#17) * [chat] add mmmt dummy example * [chat] refactor naive strategy * [chat] fix struck problem * [chat] fix naive strategy * [chat] optimize experience maker sending logic * [chat] refactor sending assignment * [chat] refactor performance evaluator (#18) * Prompt Example & requires_grad state_dict & sharding state_dict (#19) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design --------- Co-authored-by: csric <richcsr256@gmail.com> * state_dict sending adapts to new unwrap function (#20) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design * opt benchmark * better script * nothing * [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test * working on lora reconstruction * state_dict sending adapts to new unwrap function * remove comments --------- Co-authored-by: csric <richcsr256@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * [chat-ray] add readme (#21) * add readme * transparent graph * add note background --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] get images from url (#22) * Refactor/chat ray (#23) * [chat] lora add todo * [chat] remove unused pipeline strategy * [chat] refactor example structure * [chat] setup ci for ray * [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24) * lora support prototype * lora support * 1mmt lora & remove useless code --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] fix test ci for ray * [chat] fix test ci requirements for ray * [chat] fix ray runtime env * [chat] fix ray runtime env * [chat] fix example ci docker args * [chat] add debug info in trainer * [chat] add nccl debug info * [chat] skip ray test * [doc] fix typo --------- Co-authored-by: csric <59389055+CsRic@users.noreply.github.com> Co-authored-by: csric <richcsr256@gmail.com>
2023-06-07 02:41:16 +00:00
@abstractmethod
def get_model_state_dict_shard(self, model: nn.Module, **config):
pass