mirror of https://github.com/hpcaitech/ColossalAI
[chat] remove naive strategy and split colossalai strategy (#4094)
* feat: remove on_learn_epoch fn as not used * revert: add _on_learn_epoch fn * to: remove the use of NaiveStrategy * test: remove NaiveStrategy tests * feat: remove NaiveStrategy * style: modify comments and params * feat: split ColossalAIStrategy into LowLevelZeroStrategy and GeminiStrategy * fix: remove naive * fix: align with modified colossal strategy * fix: fix ddp _try_init_dist argpull/4122/head
parent
b03d64d010
commit
edd75a59ea
|
@ -287,7 +287,7 @@ If you only have a single 24G GPU, you can use the following script. `batch_size
|
|||
torchrun --standalone --nproc_per_node=1 train_sft.py \
|
||||
--pretrain "/path/to/LLaMa-7B/" \
|
||||
--model 'llama' \
|
||||
--strategy naive \
|
||||
--strategy ddp \
|
||||
--log_interval 10 \
|
||||
--save_path /path/to/Coati-7B \
|
||||
--dataset /path/to/data.json \
|
||||
|
|
|
@ -8,7 +8,7 @@ from coati.models.base import RewardModel
|
|||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import PerformanceEvaluator
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -19,10 +19,8 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
if isinstance(strategy.plugin, GeminiPlugin) and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
if isinstance(strategy, GeminiStrategy) and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
|
@ -78,17 +76,17 @@ def main(args):
|
|||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
|
|
|
@ -83,8 +83,8 @@ def main(args):
|
|||
env_info=env_info_maker,
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
|
@ -153,10 +153,10 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
|
|
|
@ -87,8 +87,8 @@ def main(args):
|
|||
env_info=env_info_maker,
|
||||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
|
@ -164,10 +164,10 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
|
|
|
@ -6,7 +6,7 @@ from coati.experience_maker import Experience, NaiveExperienceMaker
|
|||
from coati.models.base import Actor, Critic
|
||||
from coati.models.loss import PolicyLoss, ValueLoss
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -85,7 +85,7 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||
evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
|
||||
callbacks = callbacks + [evaluator]
|
||||
|
||||
if isinstance(self.strategy, ColossalAIStrategy):
|
||||
if isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)):
|
||||
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
|
||||
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
|||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
|
||||
|
||||
|
@ -76,18 +76,16 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
|||
|
||||
|
||||
def get_strategy_from_args(strategy: str):
|
||||
if strategy == 'naive':
|
||||
strategy_ = NaiveStrategy()
|
||||
elif strategy == 'ddp':
|
||||
if strategy == 'ddp':
|
||||
strategy_ = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif strategy == 'colossalai_gemini_cpu':
|
||||
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2_cpu':
|
||||
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
return strategy_
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
from coati.trainer.strategies import ColossalAIStrategy, Strategy
|
||||
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from coati.trainer.utils import is_rank_0
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -69,7 +69,7 @@ class SaveCheckpoint(Callback):
|
|||
# save optimizer
|
||||
if self.model_dict[model][1] is None:
|
||||
continue
|
||||
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
|
||||
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
|
||||
rank = 0 if is_rank_0() else dist.get_rank()
|
||||
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
|
||||
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
|
||||
|
|
|
@ -15,7 +15,7 @@ from colossalai.utils import get_current_device
|
|||
|
||||
from .base import OnPolicyTrainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import ColossalAIStrategy, Strategy
|
||||
from .strategies import GeminiStrategy, Strategy
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
|
@ -82,9 +82,8 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs
|
||||
) -> None:
|
||||
if isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
|
||||
if isinstance(strategy, GeminiStrategy):
|
||||
assert not offload_inference_models, \
|
||||
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||
|
||||
buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
|
|||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import SLTrainer
|
||||
from .strategies import ColossalAIStrategy, Strategy
|
||||
from .strategies import GeminiStrategy, Strategy
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
|
@ -38,9 +38,8 @@ class SFTTrainer(SLTrainer):
|
|||
max_epochs: int = 2,
|
||||
accumulation_steps: int = 8,
|
||||
) -> None:
|
||||
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
assert not isinstance(strategy.plugin, GeminiPlugin), \
|
||||
if accumulation_steps > 1:
|
||||
assert not isinstance(strategy, GeminiStrategy), \
|
||||
"Accumulation steps are not supported in stage 3 of ColossalAI"
|
||||
|
||||
super().__init__(strategy, max_epochs, model, optim)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from .base import Strategy
|
||||
from .colossalai import ColossalAIStrategy
|
||||
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
|
||||
from .ddp import DDPStrategy
|
||||
from .naive import NaiveStrategy
|
||||
|
||||
__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']
|
||||
__all__ = [
|
||||
'Strategy', 'DDPStrategy',
|
||||
'LowLevelZeroStrategy', 'GeminiStrategy'
|
||||
]
|
||||
|
|
|
@ -18,25 +18,17 @@ from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
|||
from .ddp import DDPStrategy
|
||||
|
||||
|
||||
class ColossalAIStrategy(DDPStrategy):
|
||||
class LowLevelZeroStrategy(DDPStrategy):
|
||||
"""
|
||||
The strategy for training with ColossalAI.
|
||||
|
||||
Args:
|
||||
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
|
||||
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
||||
stage(int): The stage to use in ZeRO. Choose in (1, 2)
|
||||
precision(str): The precision to use. Choose in ('fp32', 'fp16').
|
||||
seed(int): The seed for the random number generator.
|
||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
||||
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
||||
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
||||
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
||||
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
||||
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
|
||||
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
|
||||
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
|
||||
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
||||
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
|
||||
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
||||
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
||||
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
||||
initial_scale(float): The initial scale for the optimizer.
|
||||
|
@ -51,132 +43,185 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stage: int = 3,
|
||||
precision: str = 'fp16',
|
||||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = 'cuda',
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_m: int = 32, # only for stage 3
|
||||
hidden_dim: Optional[int] = None, # only for stage 3
|
||||
min_chunk_size_m: float = 32, # only for stage 3
|
||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
||||
overlap_communication: bool = True, # only for stage 1&2
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0) -> None:
|
||||
def __init__(self,
|
||||
stage: int = 3,
|
||||
precision: str = 'fp16',
|
||||
seed: int = 42,
|
||||
placement_policy: str = 'cuda',
|
||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
||||
overlap_communication: bool = True, # only for stage 1&2
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0
|
||||
) -> None:
|
||||
|
||||
assert stage in (1, 2, 3), f'Unsupported stage "{stage}"'
|
||||
assert stage in (1, 2), f'Unsupported stage "{stage}"'
|
||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
||||
|
||||
# TODO(ver217): support shard_init when using from_pretrained()
|
||||
if shard_init:
|
||||
warnings.warn(f'Shard init is not supported model.from_pretrained() yet. '
|
||||
'Please load weights after strategy.prepare()')
|
||||
if stage == 3 and precision == 'fp32':
|
||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
||||
precision = 'fp16'
|
||||
self.precision = precision
|
||||
self.shard_init = shard_init
|
||||
|
||||
optim_kwargs = dict(initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type)
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
if stage == 3:
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
# gemini_config
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
precision=precision,
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=shard_init,
|
||||
search_range_m=search_range_m,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
# zero_optim_config
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
# optim_config
|
||||
**optim_kwargs)
|
||||
else:
|
||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||
# zero_config
|
||||
stage=stage,
|
||||
precision=precision,
|
||||
# zero_optim_config
|
||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'),
|
||||
# optim_config
|
||||
**optim_kwargs)
|
||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||
# zero_config
|
||||
stage=stage,
|
||||
precision=precision,
|
||||
# zero_optim_config
|
||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'),
|
||||
# optim_config
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type
|
||||
)
|
||||
|
||||
super().__init__(seed, plugin_initializer)
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert isinstance(self.plugin, (LowLevelZeroPlugin, GeminiPlugin)), \
|
||||
assert isinstance(self.plugin, LowLevelZeroPlugin), \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
return model.module
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||
|
||||
|
||||
class GeminiStrategy(DDPStrategy):
|
||||
"""
|
||||
The strategy for training with ColossalAI.
|
||||
|
||||
Args:
|
||||
seed(int): The seed for the random number generator.
|
||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
||||
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
||||
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
||||
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
||||
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
||||
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
|
||||
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
|
||||
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
|
||||
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
||||
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
|
||||
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
||||
initial_scale(float): The initial scale for the optimizer.
|
||||
growth_factor(float): The growth factor for the optimizer.
|
||||
backoff_factor(float): The backoff factor for the optimizer.
|
||||
growth_interval(int): The growth interval for the optimizer.
|
||||
hysteresis(int): The hysteresis for the optimizer.
|
||||
min_scale(float): The minimum scale for the optimizer.
|
||||
max_scale(float): The maximum scale for the optimizer.
|
||||
max_norm(float): The maximum norm for the optimizer.
|
||||
norm_type(float): The norm type for the optimizer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = 'cuda',
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_m: int = 32, # only for stage 3
|
||||
hidden_dim: Optional[int] = None, # only for stage 3
|
||||
min_chunk_size_m: float = 32, # only for stage 3
|
||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0
|
||||
) -> None:
|
||||
|
||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||
|
||||
# TODO(ver217): support shard_init when using from_pretrained()
|
||||
if shard_init:
|
||||
warnings.warn(
|
||||
f'Shard init is not supported model.from_pretrained() yet. '
|
||||
'Please load weights after strategy.prepare()'
|
||||
)
|
||||
self.shard_init = shard_init
|
||||
|
||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
||||
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
# gemini_config
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
precision='fp16',
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=shard_init,
|
||||
search_range_m=search_range_m,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
# zero_optim_config
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
# optim_config
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type
|
||||
)
|
||||
|
||||
super().__init__(seed, plugin_initializer)
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert isinstance(self.plugin, GeminiPlugin), \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
|
||||
def model_init_context(self):
|
||||
if isinstance(self.plugin, GeminiPlugin):
|
||||
world_size = dist.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
||||
return ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_pg=shard_pg,
|
||||
default_dist_spec=default_dist_spec)
|
||||
return super().model_init_context()
|
||||
world_size = dist.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
||||
return ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_pg=shard_pg,
|
||||
default_dist_spec=default_dist_spec)
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
if isinstance(self.plugin, GeminiPlugin):
|
||||
assert isinstance(model, GeminiModel)
|
||||
ddp_model = model.unwrap()
|
||||
assert isinstance(ddp_model, GeminiDDP)
|
||||
return ddp_model.module
|
||||
elif isinstance(self.plugin, LowLevelZeroPlugin):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
return model.module
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported plugin {type(self.plugin)}')
|
||||
assert isinstance(model, GeminiModel)
|
||||
ddp_model = model.unwrap()
|
||||
assert isinstance(ddp_model, GeminiDDP)
|
||||
return ddp_model.module
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if isinstance(self.plugin, GeminiPlugin):
|
||||
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
|
||||
super().save_pretrained(model, path, only_rank0, tokenizer)
|
||||
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
if not isinstance(self.plugin, GeminiPlugin):
|
||||
yield from super().get_model_state_dict_shard(model, **config)
|
||||
else:
|
||||
# unwrapped_model = self._unwrap_model(model)
|
||||
# for module in unwrapped_model.modules():
|
||||
# if isinstance(module, LoraLinear):
|
||||
# module.merge_weights = True
|
||||
# module.eval()
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||
assert isinstance(self.plugin, GeminiPlugin)
|
||||
yield from super().get_model_state_dict_shard(model, **config)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -6,18 +8,27 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel
|
||||
|
||||
from .naive import NaiveStrategy
|
||||
from .base import Strategy
|
||||
from .sampler import DistributedSampler
|
||||
|
||||
|
||||
class DDPStrategy(NaiveStrategy):
|
||||
# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
|
||||
def get_grad_required_state_dict(model: nn.Module):
|
||||
state_dict = OrderedDict()
|
||||
for name, parameter in model.named_parameters():
|
||||
if parameter.requires_grad:
|
||||
state_dict[name] = parameter.detach()
|
||||
return state_dict
|
||||
|
||||
|
||||
class DDPStrategy(Strategy):
|
||||
"""
|
||||
Strategy for distributed training using torch.distributed.
|
||||
"""
|
||||
|
@ -29,6 +40,24 @@ class DDPStrategy(NaiveStrategy):
|
|||
self.seed = seed
|
||||
super().__init__(plugin_initializer)
|
||||
|
||||
def _try_init_dist(self, force: bool = False) -> None:
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
except KeyError as e:
|
||||
if force:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||
)
|
||||
except Exception as e:
|
||||
if force:
|
||||
raise e
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert isinstance(self.plugin, TorchDDPPlugin), \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
|
@ -42,9 +71,6 @@ class DDPStrategy(NaiveStrategy):
|
|||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
||||
self.booster.backward(loss, optimizer)
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return self.plugin.prepare_dataloader(replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
|
@ -68,4 +94,32 @@ class DDPStrategy(NaiveStrategy):
|
|||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
super().save_pretrained(model, path, only_rank0, tokenizer)
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, PreTrainedModel)
|
||||
unwrapped_model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
model = self.unwrap_model(model)
|
||||
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
|
||||
state_dict = get_grad_required_state_dict(model)
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if 'shard_size' in config:
|
||||
shard_size = config['shard_size']
|
||||
accumulate_size = 0
|
||||
state_dict_shard = OrderedDict()
|
||||
for name, param in state_dict.items():
|
||||
state_dict_shard[name] = param
|
||||
accumulate_size += param.numel() * param.element_size()
|
||||
if accumulate_size >= shard_size:
|
||||
accumulate_size = 0
|
||||
yield state_dict_shard
|
||||
state_dict_shard = OrderedDict()
|
||||
if accumulate_size > 0:
|
||||
yield state_dict_shard
|
||||
else:
|
||||
yield state_dict
|
||||
|
|
|
@ -1,103 +0,0 @@
|
|||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .base import Strategy
|
||||
|
||||
|
||||
# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
|
||||
def get_grad_required_state_dict(model: nn.Module):
|
||||
state_dict = OrderedDict()
|
||||
for name, parameter in model.named_parameters():
|
||||
if parameter.requires_grad:
|
||||
state_dict[name] = parameter.detach()
|
||||
return state_dict
|
||||
|
||||
|
||||
class NaiveStrategy(Strategy):
|
||||
"""
|
||||
Strategy for single GPU. No parallelism is used.
|
||||
"""
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert self.plugin is None, \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
self._try_init_dist(force=False)
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
||||
# HACK: self.booster.backward(loss, optimizer) can't work if plugin is None,
|
||||
# it would run `optimizer.backward(loss)`, which is not compatible with torch.optim.Optimizer
|
||||
assert self.plugin is None, "DO NOT call this method if plugin is not None"
|
||||
loss.backward()
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return DataLoader(replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, PreTrainedModel)
|
||||
unwrapped_model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
model = self.unwrap_model(model)
|
||||
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
|
||||
state_dict = get_grad_required_state_dict(model)
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if 'shard_size' in config:
|
||||
shard_size = config['shard_size']
|
||||
accumulate_size = 0
|
||||
state_dict_shard = OrderedDict()
|
||||
for name, param in state_dict.items():
|
||||
state_dict_shard[name] = param
|
||||
accumulate_size += param.numel() * param.element_size()
|
||||
if accumulate_size >= shard_size:
|
||||
accumulate_size = 0
|
||||
yield state_dict_shard
|
||||
state_dict_shard = OrderedDict()
|
||||
if accumulate_size > 0:
|
||||
yield state_dict_shard
|
||||
else:
|
||||
yield state_dict
|
||||
|
||||
def _try_init_dist(self, force: bool = False) -> None:
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
except KeyError as e:
|
||||
if force:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||
)
|
||||
except Exception as e:
|
||||
if force:
|
||||
raise e
|
|
@ -69,7 +69,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
|
|||
--grad_checkpoint
|
||||
```
|
||||
### Arg List
|
||||
- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
|
||||
- --strategy: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
|
||||
- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
|
||||
- --pretrain: pretrain model, type=str, default=None
|
||||
- --max_datasets_size: the max size of dataset, type=int, default=None
|
||||
|
@ -118,7 +118,7 @@ Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):
|
|||
<div align=left>We also train the reward model based on LLaMA-7B, which reaches the ACC of 72.06% after 1 epoch, performing almost the same as Anthropic's best RM.
|
||||
|
||||
### Arg List
|
||||
- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
|
||||
- --strategy: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
|
||||
- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
|
||||
- --pretrain: pretrain model, type=str, default=None
|
||||
- --model_path: the path of rm model(if continue to train), type=str, default=None
|
||||
|
@ -160,7 +160,7 @@ Prompt dataset: the instruction dataset mentioned in the above figure which incl
|
|||
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
|
||||
|
||||
### Arg List
|
||||
- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
|
||||
- --strategy: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
|
||||
- --model: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
|
||||
- --pretrain: pretrain model, type=str, default=None
|
||||
- --rm_model: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None
|
||||
|
|
|
@ -9,7 +9,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
|||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
|
||||
from easy_models import BLOOMActor
|
||||
|
@ -24,14 +24,12 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
|
@ -202,8 +200,8 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
|
||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp',
|
||||
help='strategy to use')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
|
|
|
@ -11,7 +11,7 @@ from coati.models.gpt import GPTLM
|
|||
from coati.models.llama import LlamaLM
|
||||
from coati.models.opt import OPTLM
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from easy_dataset import EasyDataset
|
||||
|
@ -30,14 +30,12 @@ from colossalai.tensor import ColoParameter
|
|||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
|
@ -45,15 +43,15 @@ def train(args):
|
|||
with strategy.model_init_context():
|
||||
print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
|
||||
#if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \
|
||||
and os.path.exists(args.save_path+'/adapter_model.bin'):
|
||||
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \
|
||||
and os.path.exists(args.save_path + '/adapter_model.bin'):
|
||||
print("loading from saved peft model ", args.save_path)
|
||||
model = PeftModel.from_pretrained(model, args.save_path)
|
||||
else:
|
||||
#we'll use peft lora library to do the lora
|
||||
# we'll use peft lora library to do the lora
|
||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
||||
#config lora with rank of lora_rank
|
||||
# config lora with rank of lora_rank
|
||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=lora_rank,
|
||||
|
@ -170,8 +168,8 @@ def train(args):
|
|||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
|
|
|
@ -15,7 +15,7 @@ from coati.models.lora import LoRAModule
|
|||
from coati.models.loss import PolicyLoss, ValueLoss
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.models.utils import compute_reward
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from torch.optim import Adam
|
||||
|
@ -99,19 +99,17 @@ class BasePPORole(DistributedTorchRayActor):
|
|||
|
||||
def _init_strategy(self, strategy: str):
|
||||
# configure strategy
|
||||
if strategy == 'naive':
|
||||
self._strategy = NaiveStrategy()
|
||||
elif strategy == 'ddp':
|
||||
if strategy == 'ddp':
|
||||
self._strategy = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
self._strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
self._strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
|
||||
def _init_optimizer(self):
|
||||
if isinstance(self._strategy, ColossalAIStrategy):
|
||||
if isinstance(self._strategy, (GeminiStrategy, LowLevelZeroStrategy)):
|
||||
self._optimizer = HybridAdam(self._model.parameters(), lr=5e-6)
|
||||
else:
|
||||
self._optimizer = Adam(self._model.parameters(), lr=5e-6)
|
||||
|
@ -534,8 +532,8 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_csv_url', type=str)
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default='gpt2')
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||
|
|
|
@ -103,8 +103,8 @@ def main(args):
|
|||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
|
@ -150,10 +150,10 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
|
|
|
@ -87,8 +87,8 @@ def main(args):
|
|||
kl_coef=0.1,
|
||||
debug=args.debug,
|
||||
update_lora_weights=not (args.lora_rank == 0),
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
# sync_models_from_trainers=True,
|
||||
# generation kwargs:
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
|
@ -163,10 +163,10 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='naive')
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
|
|
|
@ -49,13 +49,13 @@ wandb init -m offline
|
|||
# - roberta-*: RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
|
||||
SKIPPED_TESTS=(
|
||||
"gpt2-ddp"
|
||||
"llama-naive" "llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2"
|
||||
"roberta-naive" "roberta-ddp" "roberta-colossalai_gemini" "roberta-colossalai_zero2"
|
||||
"llama-ddp" "llama-colossalai_gemini" "llama-colossalai_zero2"
|
||||
"roberta-ddp" "roberta-colossalai_gemini" "roberta-colossalai_zero2"
|
||||
)
|
||||
|
||||
# These tests are quick and do not have any dependencies
|
||||
for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do
|
||||
for strategy in 'naive' 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do
|
||||
for strategy in 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " ${model}-${strategy} " ]]; then
|
||||
echo "[Test]: Skipped $model-$strategy"
|
||||
continue
|
||||
|
@ -91,12 +91,6 @@ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2'
|
|||
--model 'gpt2' --strategy ddp --lora_rank 4 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
|
||||
# torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
|
||||
# --model 'opt' --strategy naive \
|
||||
# --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
# --save_path ${BASE}/output
|
||||
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
# train rm
|
||||
|
@ -144,9 +138,9 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
# train rl
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 \
|
||||
|
|
|
@ -9,7 +9,7 @@ from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
|||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -21,14 +21,12 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
|
@ -208,7 +206,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
|
||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='colossalai_zero2',
|
||||
help='strategy to use')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta'])
|
||||
|
|
|
@ -14,7 +14,7 @@ from coati.models.llama import LlamaRM
|
|||
from coati.models.opt import OPTRM
|
||||
from coati.models.roberta import RoBERTaRM
|
||||
from coati.trainer import RewardModelTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
|
@ -29,14 +29,12 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
|
@ -195,7 +193,7 @@ def train(args):
|
|||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='colossalai_zero2')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.distributed as dist
|
|||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
|
@ -29,18 +29,16 @@ from colossalai.tensor import ColoParameter
|
|||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
raise NotImplementedError(
|
||||
'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
|
@ -66,7 +64,7 @@ def train(args):
|
|||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
@ -190,7 +188,7 @@ def train(args):
|
|||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
|
||||
default='colossalai_zero2')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.utils import calc_action_log_probs
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -28,9 +28,9 @@ def run_test_checkpoint(strategy):
|
|||
if strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from coati.experience_maker import NaiveExperienceMaker
|
|||
from coati.models.base import RewardModel
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
@ -39,7 +39,7 @@ def run_test_data(strategy):
|
|||
if strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif strategy == 'colossalai':
|
||||
strategy = ColossalAIStrategy(placement_policy='cuda')
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
|
||||
|
|
Loading…
Reference in New Issue