[chat] refactor strategy class with booster api (#3987)

* refactor: adapt boost API in base and naive strategies

* fix: initialize plugin after setup_distributed

* fix: fix save_pretrained fn

* refactor: adapt boost API in DDPStrategy

* to: add _post_init check

* to: fix ddp backward, modify ddp dataloader and unwrap

* feat: adapt boost API in ColossalAIStrategy

* fix: call setup_distributed before use get_current_device

* fix: fix save_model and save_optimizer

* test: remove save_sharded_optimizer test

* style: apply formatter

* fix: fix stage check and add comments

* feat: allow dict type arg in strategy.prepare

* to: temporarily remove lr_scheduler for testing

* style: simplify init of ColossalAIStrategy

* fix: fix lr_scheduler in sft and rm

* style: modify comments

* test: add train_prompts tests

* fix: fix inference only case and use in train_prompts

* test: skip failed tests in ci

* style: fix CodeFactor check

* fix: do not use model.to('cpu') with GeminiPlugin

* test: enable colossalai_gemini tests

* test: set CUDA_VISIBLE_DEVICES in ci

* docs: add note
pull/4046/merge
Wenhao Chen 2023-06-25 17:36:21 +08:00 committed by GitHub
parent b463651f3e
commit 153b957a1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 350 additions and 290 deletions

View File

@ -19,8 +19,10 @@ from colossalai.nn.optimizer import HybridAdam
def get_model_numel(model: nn.Module, strategy: Strategy) -> int: def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
numel = sum(p.numel() for p in model.parameters()) numel = sum(p.numel() for p in model.parameters())
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init: if isinstance(strategy, ColossalAIStrategy):
numel *= dist.get_world_size() from colossalai.booster.plugin import GeminiPlugin
if isinstance(strategy.plugin, GeminiPlugin) and strategy.shard_init:
numel *= dist.get_world_size()
return numel return numel

View File

@ -17,7 +17,7 @@ from colossalai.utils import get_current_device
from .base import Trainer from .base import Trainer
from .callbacks import Callback from .callbacks import Callback
from .strategies import Strategy from .strategies import ColossalAIStrategy, Strategy
from .utils import is_rank_0, to_device from .utils import is_rank_0, to_device
@ -71,6 +71,11 @@ class PPOTrainer(Trainer):
offload_inference_models: bool = True, offload_inference_models: bool = True,
callbacks: List[Callback] = [], callbacks: List[Callback] = [],
**generate_kwargs) -> None: **generate_kwargs) -> None:
if isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
"GeminiPlugin is not compatible with manual model.to('cpu')"
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
@ -105,6 +110,8 @@ class PPOTrainer(Trainer):
def _learn(self): def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training # replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer: if not self.sample_replay_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory) dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
if self.sample_replay_buffer: if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())

View File

@ -1,13 +1,12 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import Callable, List
import pandas as pd import pandas as pd
import torch import torch
import torch.distributed as dist from torch.optim import Optimizer
from torch.optim import Optimizer, lr_scheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset, DistributedSampler from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Trainer from .base import Trainer
from .callbacks import Callback from .callbacks import Callback
@ -22,7 +21,8 @@ class RewardModelTrainer(Trainer):
Args: Args:
model (torch.nn.Module): the model to train model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training optim (Optimizer): the optimizer to use for training
lr_scheduler (_LRScheduler): the lr scheduler to use for training
loss_fn (callable): the loss function to use for training loss_fn (callable): the loss function to use for training
train_dataloader (DataLoader): the dataloader to use for training train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation valid_dataloader (DataLoader): the dataloader to use for validation
@ -37,7 +37,8 @@ class RewardModelTrainer(Trainer):
model, model,
strategy: Strategy, strategy: Strategy,
optim: Optimizer, optim: Optimizer,
loss_fn, lr_scheduler: _LRScheduler,
loss_fn: Callable,
train_dataloader: DataLoader, train_dataloader: DataLoader,
valid_dataloader: DataLoader, valid_dataloader: DataLoader,
eval_dataloader: DataLoader, eval_dataloader: DataLoader,
@ -53,7 +54,7 @@ class RewardModelTrainer(Trainer):
self.model = model self.model = model
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optim self.optimizer = optim
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100) self.scheduler = lr_scheduler
def eval_acc(self, dataloader): def eval_acc(self, dataloader):
dist = 0 dist = 0
@ -116,7 +117,8 @@ class RewardModelTrainer(Trainer):
# eval # eval
dist, acc = self.eval_acc(self.eval_dataloader) dist, acc = self.eval_acc(self.eval_dataloader)
if is_rank_0(): if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log.csv', mode='a', header=False, index=False) log.to_csv('log.csv', mode='a', header=False, index=False)
epoch_bar.update() epoch_bar.update()
step_bar.set_postfix({'dist': dist, 'acc': acc}) step_bar.set_postfix({'dist': dist, 'acc': acc})

View File

@ -1,15 +1,13 @@
import math
import time import time
from typing import List, Optional from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import wandb import wandb
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler
from .base import Trainer from .base import Trainer
from .callbacks import Callback from .callbacks import Callback
@ -38,14 +36,17 @@ class SFTTrainer(Trainer):
model, model,
strategy: Strategy, strategy: Strategy,
optim: Optimizer, optim: Optimizer,
lr_scheduler: _LRScheduler,
train_dataloader: DataLoader, train_dataloader: DataLoader,
eval_dataloader: DataLoader = None, eval_dataloader: DataLoader = None,
max_epochs: int = 2, max_epochs: int = 2,
accumulation_steps: int = 8, accumulation_steps: int = 8,
callbacks: List[Callback] = [], callbacks: List[Callback] = [],
) -> None: ) -> None:
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3: if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI") from colossalai.booster.plugin import GeminiPlugin
assert not isinstance(strategy.plugin, GeminiPlugin), \
"Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, callbacks=callbacks) super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader self.eval_dataloader = eval_dataloader
@ -53,13 +54,8 @@ class SFTTrainer(Trainer):
self.optimizer = optim self.optimizer = optim
self.accumulation_steps = accumulation_steps self.accumulation_steps = accumulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
self.scheduler = get_scheduler("cosine", self.scheduler = lr_scheduler
self.optimizer,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
def fit(self, logger, use_wandb: bool = False): def fit(self, logger, use_wandb: bool = False):
if use_wandb: if use_wandb:

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -9,10 +9,12 @@ from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster.plugin import Plugin
from .sampler import DistributedSampler from .sampler import DistributedSampler
ModelOptimPair = Tuple[nn.Module, Optimizer] _BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
class Strategy(ABC): class Strategy(ABC):
@ -20,30 +22,28 @@ class Strategy(ABC):
Base class for training strategies. Base class for training strategies.
""" """
def __init__(self) -> None: def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
super().__init__() super().__init__()
# NOTE: dist must be initialized before Booster
self.setup_distributed() self.setup_distributed()
self.plugin = plugin_initializer()
self.booster = Booster(plugin=self.plugin)
self._post_init()
@abstractmethod @abstractmethod
def _post_init(self) -> None:
pass
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None: def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
pass self.booster.backward(loss, optimizer)
@abstractmethod
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None: def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
pass optimizer.step()
@abstractmethod @abstractmethod
def setup_distributed(self) -> None: def setup_distributed(self) -> None:
pass pass
@abstractmethod
def setup_model(self, model: nn.Module) -> nn.Module:
pass
@abstractmethod
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
pass
@abstractmethod @abstractmethod
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
pass pass
@ -51,12 +51,13 @@ class Strategy(ABC):
def model_init_context(self): def model_init_context(self):
return nullcontext() return nullcontext()
def prepare( def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
self, *models_or_model_optim_pairs: ModelOrModelOptimPair """Prepare [model | (model, optimizer) | Dict] based on each strategy.
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
"""Prepare models or model-optimizer-pairs based on each strategy.
Example:: Example::
>>> # e.g., include lr_scheduler
>>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
>>> # when fine-tuning actor and critic >>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model) >>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model >>> # or when training reward model
@ -65,25 +66,39 @@ class Strategy(ABC):
>>> actor, critic = strategy.prepare(actor, critic) >>> actor, critic = strategy.prepare(actor, critic)
Returns: Returns:
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order. Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
""" """
rets = [] rets = []
for arg in models_or_model_optim_pairs: for arg in boost_args:
if isinstance(arg, tuple): if isinstance(arg, nn.Module):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"' model, *_ = self.booster.boost(arg)
model, optimizer = arg rets.append(model)
model = self.setup_model(model) elif isinstance(arg, tuple):
optimizer = self.setup_optimizer(optimizer, model) try:
model, optimizer = arg
except ValueError:
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
model, optimizer, *_ = self.booster.boost(model=model,
optimizer=optimizer)
rets.append((model, optimizer)) rets.append((model, optimizer))
elif isinstance(arg, nn.Module): elif isinstance(arg, Dict):
rets.append(self.setup_model(model)) model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
boost_result = dict(model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
# remove None values
boost_result = {
key: value
for key, value in boost_result.items() if value is not None
}
rets.append(boost_result)
else: else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}') raise RuntimeError(f'Type {type(arg)} is not supported')
if len(rets) == 1: return rets[0] if len(rets) == 1 else rets
return rets[0]
return rets
@staticmethod @staticmethod
def unwrap_model(model: nn.Module) -> nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:
@ -97,23 +112,30 @@ class Strategy(ABC):
""" """
return model return model
@abstractmethod def save_model(self,
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: model: nn.Module,
pass path: str,
only_rank0: bool = True,
**kwargs
) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
@abstractmethod def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: self.booster.load_model(model, path, strict)
pass
@abstractmethod def save_optimizer(self,
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: optimizer: Optimizer,
pass path: str,
only_rank0: bool = False,
**kwargs
) -> None:
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
@abstractmethod def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: self.booster.load_optimizer(optimizer, path)
pass
def setup_sampler(self, dataset) -> DistributedSampler: def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, 1, 0) return DistributedSampler(dataset, 1, 0)
@abstractmethod @abstractmethod

View File

@ -1,24 +1,23 @@
import functools
import warnings import warnings
from typing import Optional, Union from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
from torch.optim import Optimizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai import colossalai
from colossalai.logging import get_dist_logger from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.booster.plugin.gemini_plugin import GeminiModel
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy from .ddp import DDPStrategy
logger = get_dist_logger(__name__)
class ColossalAIStrategy(DDPStrategy): class ColossalAIStrategy(DDPStrategy):
""" """
@ -62,7 +61,6 @@ class ColossalAIStrategy(DDPStrategy):
placement_policy: str = 'cuda', placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3 pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3 force_outputs_fp32: bool = False, # only for stage 3
scatter_after_inference: bool = False, # only for stage 3
search_range_mb: int = 32, # only for stage 3 search_range_mb: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3 hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_mb: float = 32, # only for stage 3 min_chunk_size_mb: float = 32, # only for stage 3
@ -78,50 +76,76 @@ class ColossalAIStrategy(DDPStrategy):
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0) -> None: norm_type: float = 2.0) -> None:
super().__init__(seed)
assert stage in (1, 2, 3), f'Unsupported stage "{stage}"'
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
self.stage = stage
# TODO(ver217): support shard_init when using from_pretrained() # TODO(ver217): support shard_init when using from_pretrained()
if shard_init: if shard_init:
warnings.warn( warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()' f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()'
) )
if stage == 3 and precision == 'fp32': if stage == 3 and precision == 'fp32':
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
precision = 'fp16' precision = 'fp16'
self.precision = precision self.precision = precision
self.shard_init = shard_init self.shard_init = shard_init
self.gemini_config = dict(device=get_current_device(),
placement_policy=placement_policy, optim_kwargs = dict(
pin_memory=pin_memory, initial_scale=initial_scale,
force_outputs_fp32=force_outputs_fp32, growth_factor=growth_factor,
strict_ddp_mode=shard_init, backoff_factor=backoff_factor,
search_range_mb=search_range_mb, growth_interval=growth_interval,
hidden_dim=hidden_dim, hysteresis=hysteresis,
min_chunk_size_mb=min_chunk_size_mb, min_scale=min_scale,
scatter_after_inference=scatter_after_inference) max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
# NOTE: dist should be initialized before calling get_current_device()
if stage == 3: if stage == 3:
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio) plugin_initializer = lambda: GeminiPlugin(
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision=precision,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_mb=search_range_mb,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
**optim_kwargs
)
else: else:
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size, plugin_initializer = lambda: LowLevelZeroPlugin(
overlap_communication=overlap_communication, # zero_config
cpu_offload=(placement_policy == 'cpu')) stage=stage,
self.optim_kwargs = dict(initial_scale=initial_scale, precision=precision,
growth_factor=growth_factor, # zero_optim_config
backoff_factor=backoff_factor, reduce_bucket_size_in_m=reduce_bucket_size,
growth_interval=growth_interval, overlap_communication=overlap_communication,
hysteresis=hysteresis, cpu_offload=(placement_policy == 'cpu'),
min_scale=min_scale, # optim_config
max_scale=max_scale, **optim_kwargs
max_norm=max_norm, )
norm_type=norm_type)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, (LowLevelZeroPlugin, GeminiPlugin)), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
def setup_distributed(self) -> None: def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed) colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self): def model_init_context(self):
if self.stage == 3: if isinstance(self.plugin, GeminiPlugin):
world_size = dist.get_world_size() world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
@ -131,61 +155,29 @@ class ColossalAIStrategy(DDPStrategy):
default_dist_spec=default_dist_spec) default_dist_spec=default_dist_spec)
return super().model_init_context() return super().model_init_context()
def setup_model(self, model: nn.Module) -> nn.Module:
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
if self.stage != 3 and self.precision == 'fp16':
model = model.half().cuda()
return model
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.backward(loss)
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
return
if self.stage == 3:
assert isinstance(model, ZeroDDP)
# for stage 3, state_dict() method should be called on every rank
state_dict = model.state_dict(only_rank_0=only_rank0)
else:
# only_rank0 is false or rank == 0
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0:
raise RuntimeError(
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
torch.save(optimizer.state_dict(), path)
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module) -> nn.Module:
if self.stage == 3: if isinstance(self.plugin, GeminiPlugin):
assert isinstance(model, ZeroDDP) assert isinstance(model, GeminiModel)
ddp_model = model.unwrap()
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module
elif isinstance(self.plugin, LowLevelZeroPlugin):
assert isinstance(model, LowLevelZeroModel)
return model.module return model.module
return model else:
raise RuntimeError(f'Unsupported plugin {type(self.plugin)}')
def save_pretrained(self, def save_pretrained(self,
model: nn.Module, model: nn.Module,
path: str, path: str,
only_rank0: bool = True, only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if self.stage == 3: if isinstance(self.plugin, GeminiPlugin):
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now') raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
super().save_pretrained(model, path, only_rank0, tokenizer) super().save_pretrained(model, path, only_rank0, tokenizer)
def get_model_state_dict_shard(self, model: nn.Module, **config): def get_model_state_dict_shard(self, model: nn.Module, **config):
if self.stage != 3: if not isinstance(self.plugin, GeminiPlugin):
yield from super().get_model_state_dict_shard(model, **config) yield from super().get_model_state_dict_shard(model, **config)
else: else:
# unwrapped_model = self._unwrap_model(model) # unwrapped_model = self._unwrap_model(model)
@ -193,5 +185,5 @@ class ColossalAIStrategy(DDPStrategy):
# if isinstance(module, LoraLinear): # if isinstance(module, LoraLinear):
# module.merge_weights = True # module.merge_weights = True
# module.eval() # module.eval()
assert isinstance(model, ZeroDDP) assert isinstance(model, LowLevelZeroModel)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False) yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)

View File

@ -1,17 +1,18 @@
import os
import random import random
from typing import Optional from typing import Callable, Optional
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel
from .naive import NaiveStrategy from .naive import NaiveStrategy
from .sampler import DistributedSampler from .sampler import DistributedSampler
@ -21,9 +22,16 @@ class DDPStrategy(NaiveStrategy):
Strategy for distributed training using torch.distributed. Strategy for distributed training using torch.distributed.
""" """
def __init__(self, seed: int = 42) -> None: def __init__(self,
seed: int = 42,
plugin_initializer: Callable = TorchDDPPlugin
) -> None:
self.seed = seed self.seed = seed
super().__init__() super().__init__(plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, TorchDDPPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
def setup_distributed(self) -> None: def setup_distributed(self) -> None:
self._try_init_dist(force=True) self._try_init_dist(force=True)
@ -34,43 +42,24 @@ class DDPStrategy(NaiveStrategy):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
def setup_model(self, model: nn.Module) -> nn.Module: def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
device = torch.cuda.current_device() self.booster.backward(loss, optimizer)
return DDP(model, device_ids=[device])
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
# DDP only mode, replay buffers on each rank are different. return self.plugin.prepare_dataloader(replay_buffer,
# sampler = DistributedSampler(replay_buffer, batch_size=replay_buffer.sample_batch_size,
# num_replicas=dist.get_world_size(), shuffle=True,
# rank=dist.get_rank(), drop_last=True,
# shuffle=True, pin_memory=pin_memory,
# seed=self.seed, collate_fn=replay_buffer.collate_fn)
# drop_last=True)
return DataLoader(
replay_buffer,
batch_size=replay_buffer.sample_batch_size,
# sampler=sampler,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_model(model, path, only_rank0)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_optimizer(optimizer, path, only_rank0)
def setup_sampler(self, dataset) -> DistributedSampler: def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, DDP) assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
return model.module return model.unwrap()
def save_pretrained(self, def save_pretrained(self,
model: nn.Module, model: nn.Module,

View File

@ -1,16 +1,10 @@
import os import os
import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Optional from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
from coati.models.base import get_base_model
from coati.replay_buffer import ReplayBuffer
from coati.models.base import RewardModel
from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -34,20 +28,18 @@ class NaiveStrategy(Strategy):
Strategy for single GPU. No parallelism is used. Strategy for single GPU. No parallelism is used.
""" """
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: def _post_init(self) -> None:
loss.backward() assert self.plugin is None, \
f'{type(self).__name__}\'s plugin is not initialized properly.'
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()
def setup_distributed(self) -> None: def setup_distributed(self) -> None:
self._try_init_dist(force=False) self._try_init_dist(force=False)
def setup_model(self, model: nn.Module) -> nn.Module: def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
return model # HACK: self.booster.backward(loss, optimizer) can't work if plugin is None,
# it would run `optimizer.backward(loss)`, which is not compatible with torch.optim.Optimizer
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: assert self.plugin is None, "DO NOT call this method if plugin is not None"
return optimizer loss.backward()
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader: def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
return DataLoader(replay_buffer, return DataLoader(replay_buffer,
@ -57,22 +49,6 @@ class NaiveStrategy(Strategy):
pin_memory=pin_memory, pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn) collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
state_dict = model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self.unwrap_model(model)
state_dict = torch.load(path, map_location=map_location)
unwrapped_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path)
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
state_dict = torch.load(path, map_location=map_location)
optimizer.load_state_dict(state_dict)
def save_pretrained(self, def save_pretrained(self,
model: nn.Module, model: nn.Module,
path: str, path: str,

View File

@ -1,5 +1,22 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
set -xue set -xue
if [ -z "$SFT_DATASET" ]; then if [ -z "$SFT_DATASET" ]; then
@ -26,109 +43,137 @@ pip install -r ${BASE}/requirements.txt
wandb init -m offline wandb init -m offline
# FIXME: This is a hack to skip tests that are not working (tested at commit b3ab7fbabf)
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: Repository Not Found for url: https://huggingface.co/{...}/resolve/main/tokenizer.model.
# - roberta-*: RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
SKIPPED_TESTS=(
"gpt2-ddp"
"llama-naive" "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2"
"roberta-naive" "roberta-ddp" "roberta-colossalai_gemini" "roberta-colossalai_zero2"
)
# These tests are quick and do not have any dependencies
for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do
for strategy in 'naive' 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " ${model}-${strategy} " ]]; then
echo "[Test]: Skipped $model-$strategy"
continue
fi
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy $strategy --model $model \
--num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2
done
done
# train sft # train sft
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4\ --model 'bloom' --strategy colossalai_zero2 --lora_rank 4 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output --save_path ${BASE}/output
rm -rf ${BASE}/output rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
--model 'gpt2' --strategy colossalai_zero2 \ --model 'gpt2' --strategy colossalai_zero2 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output --save_path ${BASE}/output
rm -rf ${BASE}/output rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
--model 'opt' --strategy colossalai_zero2 --lora_rank 4\ --model 'opt' --strategy colossalai_zero2 --lora_rank 4 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output --save_path ${BASE}/output
rm -rf ${BASE}/output rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
--model 'gpt2' --strategy ddp --lora_rank 4\ --model 'gpt2' --strategy ddp --lora_rank 4 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output --save_path ${BASE}/output
#torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \ # torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
# --model 'opt' --strategy naive \ # --model 'opt' --strategy naive \
# --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \ # --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
# --save_path ${BASE}/output # --save_path ${BASE}/output
rm -rf ${BASE}/output rm -rf ${BASE}/output
# train rm # train rm
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'facebook/opt-350m' --model 'opt' \ --pretrain 'facebook/opt-350m' --model 'opt' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\ --strategy colossalai_zero2 --loss_fn 'log_sig' \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 0 \ --test True --lora_rank 0 \
--save_path ${BASE}/rm_ckpt_opt.pt --save_path ${BASE}/rm_ckpt_opt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'gpt2' --model 'gpt2' \ --pretrain 'gpt2' --model 'gpt2' \
--strategy colossalai_zero2 --loss_fn 'log_exp' \ --strategy colossalai_zero2 --loss_fn 'log_exp' \
--dataset 'Dahoas/rm-static' \ --dataset 'Dahoas/rm-static' \
--test True --lora_rank 0 \ --test True --lora_rank 0 \
--save_path ${BASE}/rm_ckpt_gpt.pt --save_path ${BASE}/rm_ckpt_gpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'gpt2' --model 'gpt2' \ --pretrain 'gpt2' --model 'gpt2' \
--strategy ddp --loss_fn 'log_exp' \ --strategy ddp --loss_fn 'log_exp' \
--dataset 'Dahoas/rm-static' \ --dataset 'Dahoas/rm-static' \
--test True --lora_rank 4 \ --test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt --save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'bigscience/bloom-560m' --model 'bloom' \ --pretrain 'bigscience/bloom-560m' --model 'bloom' \
--strategy colossalai_zero2 --loss_fn 'log_sig' \ --strategy colossalai_zero2 --loss_fn 'log_sig' \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \ --test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt --save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \ --pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
--strategy colossalai_zero2 --loss_fn 'log_sig' \ --strategy colossalai_zero2 --loss_fn 'log_sig' \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \ --test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt --save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'roberta-base' --model 'roberta' \ --pretrain 'roberta-base' --model 'roberta' \
--strategy colossalai_zero2 --loss_fn 'log_exp'\ --strategy colossalai_zero2 --loss_fn 'log_exp' \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \ --test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt --save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
--pretrain 'facebook/opt-350m' --model opt \ --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--rm_pretrain 'facebook/opt-350m' \ --pretrain 'facebook/opt-350m' --model opt \
--rm_path ${BASE}/rm_ckpt_opt.pt \ --rm_pretrain 'facebook/opt-350m' \
--save_path ${BASE}/actor_checkpoint_prompts.pt --rm_path ${BASE}/rm_ckpt_opt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/rm_ckpt_opt.pt rm -rf ${BASE}/rm_ckpt_opt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
--pretrain 'gpt2' --model gpt2 \ --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--rm_pretrain 'gpt2' \ --pretrain 'gpt2' --model gpt2 \
--rm_path ${BASE}/rm_ckpt_gpt.pt \ --rm_pretrain 'gpt2' \
--save_path ${BASE}/actor_checkpoint_prompts.pt --rm_path ${BASE}/rm_ckpt_gpt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \ --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ --strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
--pretrain 'gpt2' --model gpt2 \ --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--rm_pretrain 'gpt2' \ --pretrain 'gpt2' --model gpt2 \
--rm_path ${BASE}/rm_ckpt_gpt.pt \ --rm_pretrain 'gpt2' \
--save_path ${BASE}/actor_checkpoint_prompts.pt --rm_path ${BASE}/rm_ckpt_gpt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/rm_ckpt_gpt.pt rm -rf ${BASE}/rm_ckpt_gpt.pt
rm -rf ${BASE}/actor_checkpoint_prompts.pt rm -rf ${BASE}/actor_checkpoint_prompts.pt

View File

@ -1,6 +1,5 @@
import argparse import argparse
import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
@ -51,7 +50,7 @@ def main(args):
else: else:
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
if args.rm_model == None: if args.rm_model is None:
rm_model_name = args.model rm_model_name = args.model
else: else:
rm_model_name = args.rm_model rm_model_name = args.rm_model
@ -163,7 +162,9 @@ def main(args):
batch_size=args.ptx_batch_size, batch_size=args.ptx_batch_size,
collate_fn=data_collator) collate_fn=data_collator)
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
# configure trainer # configure trainer
trainer = PPOTrainer( trainer = PPOTrainer(
@ -185,6 +186,7 @@ def main(args):
top_k=50, top_k=50,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
offload_inference_models=args.strategy != 'colossalai_gemini'
) )
trainer.fit(prompt_dataloader=prompt_dataloader, trainer.fit(prompt_dataloader=prompt_dataloader,

View File

@ -18,6 +18,7 @@ from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrat
from coati.utils import prepare_llama_tokenizer_and_embedding from coati.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
@ -165,10 +166,17 @@ def train(args):
batch_size=args.batch_size, batch_size=args.batch_size,
pin_memory=True) pin_memory=True)
(model, optim) = strategy.prepare((model, optim)) lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
strategy_dict = strategy.prepare(
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
)
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
trainer = RewardModelTrainer(model=model, trainer = RewardModelTrainer(model=model,
strategy=strategy, strategy=strategy,
optim=optim, optim=optim,
lr_scheduler=lr_scheduler,
loss_fn=loss_fn, loss_fn=loss_fn,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader, valid_dataloader=valid_dataloader,

View File

@ -1,4 +1,5 @@
import argparse import argparse
import math
import os import os
import loralib as lora import loralib as lora
@ -19,6 +20,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.models.opt.configuration_opt import OPTConfig from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -152,10 +154,22 @@ def train(args):
else: else:
eval_dataloader = None eval_dataloader = None
(model, optim) = strategy.prepare((model, optim)) num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
lr_scheduler = get_scheduler("cosine",
optim,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
strategy_dict = strategy.prepare(
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
)
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
trainer = SFTTrainer(model=model, trainer = SFTTrainer(model=model,
strategy=strategy, strategy=strategy,
optim=optim, optim=optim,
lr_scheduler=lr_scheduler,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader, eval_dataloader=eval_dataloader,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,

View File

@ -60,10 +60,15 @@ def run_test_checkpoint(strategy):
rank0_dirname = rank0_dirname[0] rank0_dirname = rank0_dirname[0]
model_path = os.path.join(rank0_dirname, 'model.pt') model_path = os.path.join(rank0_dirname, 'model.pt')
optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
strategy.save_model(actor, model_path, only_rank0=True) strategy.save_model(actor, model_path, only_rank0=True)
strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
optim_path = os.path.join(rank0_dirname, f'optim.pt')
strategy.save_optimizer(actor_optim, optim_path, only_rank0=True)
# FIXME(cwher): Sharded optimizer checkpoint is not supported yet.
# at "ColossalAI/colossalai/checkpoint_io/general_checkpoint_io.py", line 62
# optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
# strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
dist.barrier() dist.barrier()