mirror of https://github.com/hpcaitech/ColossalAI
[chat] refactor strategy class with booster api (#3987)
* refactor: adapt boost API in base and naive strategies * fix: initialize plugin after setup_distributed * fix: fix save_pretrained fn * refactor: adapt boost API in DDPStrategy * to: add _post_init check * to: fix ddp backward, modify ddp dataloader and unwrap * feat: adapt boost API in ColossalAIStrategy * fix: call setup_distributed before use get_current_device * fix: fix save_model and save_optimizer * test: remove save_sharded_optimizer test * style: apply formatter * fix: fix stage check and add comments * feat: allow dict type arg in strategy.prepare * to: temporarily remove lr_scheduler for testing * style: simplify init of ColossalAIStrategy * fix: fix lr_scheduler in sft and rm * style: modify comments * test: add train_prompts tests * fix: fix inference only case and use in train_prompts * test: skip failed tests in ci * style: fix CodeFactor check * fix: do not use model.to('cpu') with GeminiPlugin * test: enable colossalai_gemini tests * test: set CUDA_VISIBLE_DEVICES in ci * docs: add notepull/4046/merge
parent
b463651f3e
commit
153b957a1b
|
@ -19,8 +19,10 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
if isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
if isinstance(strategy.plugin, GeminiPlugin) and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from colossalai.utils import get_current_device
|
|||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .strategies import ColossalAIStrategy, Strategy
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
|
@ -71,6 +71,11 @@ class PPOTrainer(Trainer):
|
|||
offload_inference_models: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
if isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
|
||||
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
|
@ -105,6 +110,8 @@ class PPOTrainer(Trainer):
|
|||
def _learn(self):
|
||||
# replay buffer may be empty at first, we should rebuild at each training
|
||||
if not self.sample_replay_buffer:
|
||||
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
|
||||
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
|
||||
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
|
||||
if self.sample_replay_buffer:
|
||||
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import Callable, List
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer, lr_scheduler
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
|
@ -22,7 +21,8 @@ class RewardModelTrainer(Trainer):
|
|||
Args:
|
||||
model (torch.nn.Module): the model to train
|
||||
strategy (Strategy): the strategy to use for training
|
||||
optim(Optimizer): the optimizer to use for training
|
||||
optim (Optimizer): the optimizer to use for training
|
||||
lr_scheduler (_LRScheduler): the lr scheduler to use for training
|
||||
loss_fn (callable): the loss function to use for training
|
||||
train_dataloader (DataLoader): the dataloader to use for training
|
||||
valid_dataloader (DataLoader): the dataloader to use for validation
|
||||
|
@ -37,7 +37,8 @@ class RewardModelTrainer(Trainer):
|
|||
model,
|
||||
strategy: Strategy,
|
||||
optim: Optimizer,
|
||||
loss_fn,
|
||||
lr_scheduler: _LRScheduler,
|
||||
loss_fn: Callable,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader,
|
||||
|
@ -53,7 +54,7 @@ class RewardModelTrainer(Trainer):
|
|||
self.model = model
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optim
|
||||
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)
|
||||
self.scheduler = lr_scheduler
|
||||
|
||||
def eval_acc(self, dataloader):
|
||||
dist = 0
|
||||
|
@ -116,7 +117,8 @@ class RewardModelTrainer(Trainer):
|
|||
# eval
|
||||
dist, acc = self.eval_acc(self.eval_dataloader)
|
||||
if is_rank_0():
|
||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
|
||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
|
||||
columns=['step', 'loss', 'dist', 'acc'])
|
||||
log.to_csv('log.csv', mode='a', header=False, index=False)
|
||||
epoch_bar.update()
|
||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
import math
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
|
@ -38,14 +36,17 @@ class SFTTrainer(Trainer):
|
|||
model,
|
||||
strategy: Strategy,
|
||||
optim: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader = None,
|
||||
max_epochs: int = 2,
|
||||
accumulation_steps: int = 8,
|
||||
callbacks: List[Callback] = [],
|
||||
) -> None:
|
||||
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
|
||||
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
|
||||
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
assert not isinstance(strategy.plugin, GeminiPlugin), \
|
||||
"Accumulation steps are not supported in stage 3 of ColossalAI"
|
||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
||||
self.train_dataloader = train_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
@ -53,13 +54,8 @@ class SFTTrainer(Trainer):
|
|||
self.optimizer = optim
|
||||
|
||||
self.accumulation_steps = accumulation_steps
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
|
||||
max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
self.scheduler = get_scheduler("cosine",
|
||||
self.optimizer,
|
||||
num_warmup_steps=math.ceil(max_steps * 0.03),
|
||||
num_training_steps=max_steps)
|
||||
self.scheduler = lr_scheduler
|
||||
|
||||
def fit(self, logger, use_wandb: bool = False):
|
||||
if use_wandb:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -9,10 +9,12 @@ from torch.optim import Optimizer
|
|||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import Plugin
|
||||
|
||||
from .sampler import DistributedSampler
|
||||
|
||||
ModelOptimPair = Tuple[nn.Module, Optimizer]
|
||||
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
|
||||
_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
|
||||
|
||||
|
||||
class Strategy(ABC):
|
||||
|
@ -20,30 +22,28 @@ class Strategy(ABC):
|
|||
Base class for training strategies.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
|
||||
super().__init__()
|
||||
# NOTE: dist must be initialized before Booster
|
||||
self.setup_distributed()
|
||||
self.plugin = plugin_initializer()
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self._post_init()
|
||||
|
||||
@abstractmethod
|
||||
def _post_init(self) -> None:
|
||||
pass
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
||||
pass
|
||||
self.booster.backward(loss, optimizer)
|
||||
|
||||
@abstractmethod
|
||||
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
|
||||
pass
|
||||
optimizer.step()
|
||||
|
||||
@abstractmethod
|
||||
def setup_distributed(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
pass
|
||||
|
@ -51,12 +51,13 @@ class Strategy(ABC):
|
|||
def model_init_context(self):
|
||||
return nullcontext()
|
||||
|
||||
def prepare(
|
||||
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
|
||||
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
|
||||
"""Prepare models or model-optimizer-pairs based on each strategy.
|
||||
def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
|
||||
"""Prepare [model | (model, optimizer) | Dict] based on each strategy.
|
||||
NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
|
||||
|
||||
Example::
|
||||
>>> # e.g., include lr_scheduler
|
||||
>>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
|
||||
>>> # when fine-tuning actor and critic
|
||||
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
>>> # or when training reward model
|
||||
|
@ -65,25 +66,39 @@ class Strategy(ABC):
|
|||
>>> actor, critic = strategy.prepare(actor, critic)
|
||||
|
||||
Returns:
|
||||
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
||||
Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
|
||||
"""
|
||||
|
||||
rets = []
|
||||
for arg in models_or_model_optim_pairs:
|
||||
if isinstance(arg, tuple):
|
||||
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
||||
model, optimizer = arg
|
||||
model = self.setup_model(model)
|
||||
optimizer = self.setup_optimizer(optimizer, model)
|
||||
for arg in boost_args:
|
||||
if isinstance(arg, nn.Module):
|
||||
model, *_ = self.booster.boost(arg)
|
||||
rets.append(model)
|
||||
elif isinstance(arg, tuple):
|
||||
try:
|
||||
model, optimizer = arg
|
||||
except ValueError:
|
||||
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
|
||||
model, optimizer, *_ = self.booster.boost(model=model,
|
||||
optimizer=optimizer)
|
||||
rets.append((model, optimizer))
|
||||
elif isinstance(arg, nn.Module):
|
||||
rets.append(self.setup_model(model))
|
||||
elif isinstance(arg, Dict):
|
||||
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
|
||||
boost_result = dict(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
# remove None values
|
||||
boost_result = {
|
||||
key: value
|
||||
for key, value in boost_result.items() if value is not None
|
||||
}
|
||||
rets.append(boost_result)
|
||||
else:
|
||||
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
||||
raise RuntimeError(f'Type {type(arg)} is not supported')
|
||||
|
||||
if len(rets) == 1:
|
||||
return rets[0]
|
||||
return rets
|
||||
return rets[0] if len(rets) == 1 else rets
|
||||
|
||||
@staticmethod
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
|
@ -97,23 +112,30 @@ class Strategy(ABC):
|
|||
"""
|
||||
return model
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
pass
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
pass
|
||||
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
|
||||
self.booster.load_model(model, path, strict)
|
||||
|
||||
@abstractmethod
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
pass
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||
pass
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
|
||||
self.booster.load_optimizer(optimizer, path)
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
|
||||
return DistributedSampler(dataset, 1, 0)
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -1,24 +1,23 @@
|
|||
import functools
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.optim import Optimizer
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.gemini_plugin import GeminiModel
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero import ColoInitContext
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
|
||||
from .ddp import DDPStrategy
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
class ColossalAIStrategy(DDPStrategy):
|
||||
"""
|
||||
|
@ -62,7 +61,6 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
placement_policy: str = 'cuda',
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
scatter_after_inference: bool = False, # only for stage 3
|
||||
search_range_mb: int = 32, # only for stage 3
|
||||
hidden_dim: Optional[int] = None, # only for stage 3
|
||||
min_chunk_size_mb: float = 32, # only for stage 3
|
||||
|
@ -78,50 +76,76 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0) -> None:
|
||||
super().__init__(seed)
|
||||
|
||||
assert stage in (1, 2, 3), f'Unsupported stage "{stage}"'
|
||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
||||
self.stage = stage
|
||||
|
||||
# TODO(ver217): support shard_init when using from_pretrained()
|
||||
if shard_init:
|
||||
warnings.warn(
|
||||
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
|
||||
f'Shard init is not supported model.from_pretrained() yet. '
|
||||
'Please load weights after strategy.prepare()'
|
||||
)
|
||||
if stage == 3 and precision == 'fp32':
|
||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
||||
precision = 'fp16'
|
||||
self.precision = precision
|
||||
self.shard_init = shard_init
|
||||
self.gemini_config = dict(device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=shard_init,
|
||||
search_range_mb=search_range_mb,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_mb=min_chunk_size_mb,
|
||||
scatter_after_inference=scatter_after_inference)
|
||||
|
||||
optim_kwargs = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type
|
||||
)
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
if stage == 3:
|
||||
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
# gemini_config
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
precision=precision,
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=shard_init,
|
||||
search_range_mb=search_range_mb,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_mb=min_chunk_size_mb,
|
||||
# zero_optim_config
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
# optim_config
|
||||
**optim_kwargs
|
||||
)
|
||||
else:
|
||||
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'))
|
||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type)
|
||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||
# zero_config
|
||||
stage=stage,
|
||||
precision=precision,
|
||||
# zero_optim_config
|
||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'),
|
||||
# optim_config
|
||||
**optim_kwargs
|
||||
)
|
||||
|
||||
super().__init__(seed, plugin_initializer)
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert isinstance(self.plugin, (LowLevelZeroPlugin, GeminiPlugin)), \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
|
||||
def model_init_context(self):
|
||||
if self.stage == 3:
|
||||
if isinstance(self.plugin, GeminiPlugin):
|
||||
world_size = dist.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
||||
|
@ -131,61 +155,29 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
default_dist_spec=default_dist_spec)
|
||||
return super().model_init_context()
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
|
||||
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
||||
|
||||
if self.stage != 3 and self.precision == 'fp16':
|
||||
model = model.half().cuda()
|
||||
return model
|
||||
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
|
||||
return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.backward(loss)
|
||||
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.step()
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
|
||||
return
|
||||
if self.stage == 3:
|
||||
assert isinstance(model, ZeroDDP)
|
||||
# for stage 3, state_dict() method should be called on every rank
|
||||
state_dict = model.state_dict(only_rank_0=only_rank0)
|
||||
else:
|
||||
# only_rank0 is false or rank == 0
|
||||
state_dict = model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0:
|
||||
raise RuntimeError(
|
||||
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
if self.stage == 3:
|
||||
assert isinstance(model, ZeroDDP)
|
||||
if isinstance(self.plugin, GeminiPlugin):
|
||||
assert isinstance(model, GeminiModel)
|
||||
ddp_model = model.unwrap()
|
||||
assert isinstance(ddp_model, GeminiDDP)
|
||||
return ddp_model.module
|
||||
elif isinstance(self.plugin, LowLevelZeroPlugin):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
return model.module
|
||||
return model
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported plugin {type(self.plugin)}')
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if self.stage == 3:
|
||||
if isinstance(self.plugin, GeminiPlugin):
|
||||
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
|
||||
super().save_pretrained(model, path, only_rank0, tokenizer)
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
if self.stage != 3:
|
||||
if not isinstance(self.plugin, GeminiPlugin):
|
||||
yield from super().get_model_state_dict_shard(model, **config)
|
||||
else:
|
||||
# unwrapped_model = self._unwrap_model(model)
|
||||
|
@ -193,5 +185,5 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
# if isinstance(module, LoraLinear):
|
||||
# module.merge_weights = True
|
||||
# module.eval()
|
||||
assert isinstance(model, ZeroDDP)
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel
|
||||
|
||||
from .naive import NaiveStrategy
|
||||
from .sampler import DistributedSampler
|
||||
|
||||
|
@ -21,9 +22,16 @@ class DDPStrategy(NaiveStrategy):
|
|||
Strategy for distributed training using torch.distributed.
|
||||
"""
|
||||
|
||||
def __init__(self, seed: int = 42) -> None:
|
||||
def __init__(self,
|
||||
seed: int = 42,
|
||||
plugin_initializer: Callable = TorchDDPPlugin
|
||||
) -> None:
|
||||
self.seed = seed
|
||||
super().__init__()
|
||||
super().__init__(plugin_initializer)
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert isinstance(self.plugin, TorchDDPPlugin), \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
self._try_init_dist(force=True)
|
||||
|
@ -34,43 +42,24 @@ class DDPStrategy(NaiveStrategy):
|
|||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
device = torch.cuda.current_device()
|
||||
return DDP(model, device_ids=[device])
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
||||
self.booster.backward(loss, optimizer)
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
# DDP only mode, replay buffers on each rank are different.
|
||||
# sampler = DistributedSampler(replay_buffer,
|
||||
# num_replicas=dist.get_world_size(),
|
||||
# rank=dist.get_rank(),
|
||||
# shuffle=True,
|
||||
# seed=self.seed,
|
||||
# drop_last=True)
|
||||
return DataLoader(
|
||||
replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
# sampler=sampler,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
super().save_model(model, path, only_rank0)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
super().save_optimizer(optimizer, path, only_rank0)
|
||||
return self.plugin.prepare_dataloader(replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
|
||||
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
assert isinstance(model, DDP)
|
||||
return model.module
|
||||
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
|
||||
return model.unwrap()
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
|
|
|
@ -1,16 +1,10 @@
|
|||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import get_base_model
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -34,20 +28,18 @@ class NaiveStrategy(Strategy):
|
|||
Strategy for single GPU. No parallelism is used.
|
||||
"""
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
loss.backward()
|
||||
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.step()
|
||||
def _post_init(self) -> None:
|
||||
assert self.plugin is None, \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
self._try_init_dist(force=False)
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
return optimizer
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
||||
# HACK: self.booster.backward(loss, optimizer) can't work if plugin is None,
|
||||
# it would run `optimizer.backward(loss)`, which is not compatible with torch.optim.Optimizer
|
||||
assert self.plugin is None, "DO NOT call this method if plugin is not None"
|
||||
loss.backward()
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return DataLoader(replay_buffer,
|
||||
|
@ -57,22 +49,6 @@ class NaiveStrategy(Strategy):
|
|||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
optimizer.load_state_dict(state_dict)
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
|
|
|
@ -1,5 +1,22 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
set -xue
|
||||
|
||||
if [ -z "$SFT_DATASET" ]; then
|
||||
|
@ -26,109 +43,137 @@ pip install -r ${BASE}/requirements.txt
|
|||
|
||||
wandb init -m offline
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working (tested at commit b3ab7fbabf)
|
||||
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# - llama-*: Repository Not Found for url: https://huggingface.co/{...}/resolve/main/tokenizer.model.
|
||||
# - roberta-*: RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-naive" "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2"
|
||||
"roberta-naive" "roberta-ddp" "roberta-colossalai_gemini" "roberta-colossalai_zero2"
|
||||
)
|
||||
|
||||
# These tests are quick and do not have any dependencies
|
||||
for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do
|
||||
for strategy in 'naive' 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " ${model}-${strategy} " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy"
|
||||
continue
|
||||
fi
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy $strategy --model $model \
|
||||
--num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2
|
||||
done
|
||||
done
|
||||
|
||||
# train sft
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \
|
||||
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4\
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
||||
--model 'gpt2' --strategy colossalai_zero2 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
--model 'gpt2' --strategy colossalai_zero2 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
|
||||
--model 'opt' --strategy colossalai_zero2 --lora_rank 4\
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
--model 'opt' --strategy colossalai_zero2 --lora_rank 4 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
||||
--model 'gpt2' --strategy ddp --lora_rank 4\
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
--model 'gpt2' --strategy ddp --lora_rank 4 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
|
||||
#torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
|
||||
# --model 'opt' --strategy naive \
|
||||
# --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
# --save_path ${BASE}/output
|
||||
# torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
|
||||
# --model 'opt' --strategy naive \
|
||||
# --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
# --save_path ${BASE}/output
|
||||
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
# train rm
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'facebook/opt-350m' --model 'opt' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_opt.pt
|
||||
--pretrain 'facebook/opt-350m' --model 'opt' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_opt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_exp' \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_gpt.pt
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_exp' \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy ddp --loss_fn 'log_exp' \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy ddp --loss_fn 'log_exp' \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig' \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'roberta-base' --model 'roberta' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_exp'\
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
--pretrain 'roberta-base' --model 'roberta' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_exp' \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--pretrain 'facebook/opt-350m' --model opt \
|
||||
--rm_pretrain 'facebook/opt-350m' \
|
||||
--rm_path ${BASE}/rm_ckpt_opt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--pretrain 'facebook/opt-350m' --model opt \
|
||||
--rm_pretrain 'facebook/opt-350m' \
|
||||
--rm_path ${BASE}/rm_ckpt_opt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/rm_ckpt_opt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
||||
|
@ -51,7 +50,7 @@ def main(args):
|
|||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if args.rm_model == None:
|
||||
if args.rm_model is None:
|
||||
rm_model_name = args.model
|
||||
else:
|
||||
rm_model_name = args.rm_model
|
||||
|
@ -163,7 +162,9 @@ def main(args):
|
|||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator)
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
|
||||
strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
|
@ -185,6 +186,7 @@ def main(args):
|
|||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
offload_inference_models=args.strategy != 'colossalai_gemini'
|
||||
)
|
||||
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
|
|
|
@ -18,6 +18,7 @@ from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrat
|
|||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
|
||||
|
@ -165,10 +166,17 @@ def train(args):
|
|||
batch_size=args.batch_size,
|
||||
pin_memory=True)
|
||||
|
||||
(model, optim) = strategy.prepare((model, optim))
|
||||
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
|
||||
strategy_dict = strategy.prepare(
|
||||
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
|
||||
)
|
||||
model = strategy_dict['model']
|
||||
optim = strategy_dict['optimizer']
|
||||
lr_scheduler = strategy_dict['lr_scheduler']
|
||||
trainer = RewardModelTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
loss_fn=loss_fn,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
import math
|
||||
import os
|
||||
|
||||
import loralib as lora
|
||||
|
@ -19,6 +20,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
|||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -152,10 +154,22 @@ def train(args):
|
|||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
(model, optim) = strategy.prepare((model, optim))
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
lr_scheduler = get_scheduler("cosine",
|
||||
optim,
|
||||
num_warmup_steps=math.ceil(max_steps * 0.03),
|
||||
num_training_steps=max_steps)
|
||||
strategy_dict = strategy.prepare(
|
||||
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
|
||||
)
|
||||
model = strategy_dict['model']
|
||||
optim = strategy_dict['optimizer']
|
||||
lr_scheduler = strategy_dict['lr_scheduler']
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
max_epochs=args.max_epochs,
|
||||
|
|
|
@ -60,10 +60,15 @@ def run_test_checkpoint(strategy):
|
|||
rank0_dirname = rank0_dirname[0]
|
||||
|
||||
model_path = os.path.join(rank0_dirname, 'model.pt')
|
||||
optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
|
||||
|
||||
strategy.save_model(actor, model_path, only_rank0=True)
|
||||
strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
|
||||
|
||||
optim_path = os.path.join(rank0_dirname, f'optim.pt')
|
||||
strategy.save_optimizer(actor_optim, optim_path, only_rank0=True)
|
||||
|
||||
# FIXME(cwher): Sharded optimizer checkpoint is not supported yet.
|
||||
# at "ColossalAI/colossalai/checkpoint_io/general_checkpoint_io.py", line 62
|
||||
# optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
|
||||
# strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
|
Loading…
Reference in New Issue