2023-06-25 09:36:21 +00:00
|
|
|
import functools
|
2023-03-28 12:25:36 +00:00
|
|
|
import warnings
|
2023-06-25 09:36:21 +00:00
|
|
|
from typing import Optional
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
|
|
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
|
|
|
|
|
|
import colossalai
|
2023-06-25 09:36:21 +00:00
|
|
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
|
|
|
from colossalai.booster.plugin.gemini_plugin import GeminiModel
|
|
|
|
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
2023-03-28 12:25:36 +00:00
|
|
|
from colossalai.tensor import ProcessGroup, ShardSpec
|
|
|
|
from colossalai.utils import get_current_device
|
2023-06-25 09:36:21 +00:00
|
|
|
from colossalai.zero import ColoInitContext
|
|
|
|
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
from .ddp import DDPStrategy
|
|
|
|
|
|
|
|
|
|
|
|
class ColossalAIStrategy(DDPStrategy):
|
|
|
|
"""
|
|
|
|
The strategy for training with ColossalAI.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
|
|
|
|
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
|
|
|
seed(int): The seed for the random number generator.
|
|
|
|
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
2023-06-15 02:43:11 +00:00
|
|
|
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
2023-03-28 12:25:36 +00:00
|
|
|
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
|
|
|
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
|
|
|
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
|
|
|
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
|
|
|
|
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
|
|
|
|
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
|
|
|
|
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
|
|
|
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
|
|
|
|
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
2023-06-15 02:43:11 +00:00
|
|
|
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
2023-03-28 12:25:36 +00:00
|
|
|
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
|
|
|
initial_scale(float): The initial scale for the optimizer.
|
|
|
|
growth_factor(float): The growth factor for the optimizer.
|
|
|
|
backoff_factor(float): The backoff factor for the optimizer.
|
|
|
|
growth_interval(int): The growth interval for the optimizer.
|
|
|
|
hysteresis(int): The hysteresis for the optimizer.
|
|
|
|
min_scale(float): The minimum scale for the optimizer.
|
|
|
|
max_scale(float): The maximum scale for the optimizer.
|
|
|
|
max_norm(float): The maximum norm for the optimizer.
|
|
|
|
norm_type(float): The norm type for the optimizer.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
stage: int = 3,
|
|
|
|
precision: str = 'fp16',
|
|
|
|
seed: int = 42,
|
|
|
|
shard_init: bool = False, # only for stage 3
|
|
|
|
placement_policy: str = 'cuda',
|
|
|
|
pin_memory: bool = True, # only for stage 3
|
|
|
|
force_outputs_fp32: bool = False, # only for stage 3
|
|
|
|
search_range_mb: int = 32, # only for stage 3
|
|
|
|
hidden_dim: Optional[int] = None, # only for stage 3
|
|
|
|
min_chunk_size_mb: float = 32, # only for stage 3
|
|
|
|
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
|
|
|
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
|
|
|
overlap_communication: bool = True, # only for stage 1&2
|
|
|
|
initial_scale: float = 2**16,
|
|
|
|
growth_factor: float = 2,
|
|
|
|
backoff_factor: float = 0.5,
|
|
|
|
growth_interval: int = 1000,
|
|
|
|
hysteresis: int = 2,
|
|
|
|
min_scale: float = 1,
|
|
|
|
max_scale: float = 2**32,
|
|
|
|
max_norm: float = 0.0,
|
|
|
|
norm_type: float = 2.0) -> None:
|
2023-06-25 09:36:21 +00:00
|
|
|
|
|
|
|
assert stage in (1, 2, 3), f'Unsupported stage "{stage}"'
|
2023-03-28 12:25:36 +00:00
|
|
|
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
|
|
|
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
2023-06-25 09:36:21 +00:00
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
# TODO(ver217): support shard_init when using from_pretrained()
|
|
|
|
if shard_init:
|
|
|
|
warnings.warn(
|
2023-06-25 09:36:21 +00:00
|
|
|
f'Shard init is not supported model.from_pretrained() yet. '
|
|
|
|
'Please load weights after strategy.prepare()'
|
2023-03-28 12:25:36 +00:00
|
|
|
)
|
|
|
|
if stage == 3 and precision == 'fp32':
|
|
|
|
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
|
|
|
precision = 'fp16'
|
|
|
|
self.precision = precision
|
|
|
|
self.shard_init = shard_init
|
2023-06-25 09:36:21 +00:00
|
|
|
|
|
|
|
optim_kwargs = dict(
|
|
|
|
initial_scale=initial_scale,
|
|
|
|
growth_factor=growth_factor,
|
|
|
|
backoff_factor=backoff_factor,
|
|
|
|
growth_interval=growth_interval,
|
|
|
|
hysteresis=hysteresis,
|
|
|
|
min_scale=min_scale,
|
|
|
|
max_scale=max_scale,
|
|
|
|
max_norm=max_norm,
|
|
|
|
norm_type=norm_type
|
|
|
|
)
|
|
|
|
# NOTE: dist should be initialized before calling get_current_device()
|
2023-03-28 12:25:36 +00:00
|
|
|
if stage == 3:
|
2023-06-25 09:36:21 +00:00
|
|
|
plugin_initializer = lambda: GeminiPlugin(
|
|
|
|
# gemini_config
|
|
|
|
device=get_current_device(),
|
|
|
|
placement_policy=placement_policy,
|
|
|
|
precision=precision,
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
force_outputs_fp32=force_outputs_fp32,
|
|
|
|
strict_ddp_mode=shard_init,
|
|
|
|
search_range_mb=search_range_mb,
|
|
|
|
hidden_dim=hidden_dim,
|
|
|
|
min_chunk_size_mb=min_chunk_size_mb,
|
|
|
|
# zero_optim_config
|
|
|
|
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
|
|
|
# optim_config
|
|
|
|
**optim_kwargs
|
|
|
|
)
|
2023-03-28 12:25:36 +00:00
|
|
|
else:
|
2023-06-25 09:36:21 +00:00
|
|
|
plugin_initializer = lambda: LowLevelZeroPlugin(
|
|
|
|
# zero_config
|
|
|
|
stage=stage,
|
|
|
|
precision=precision,
|
|
|
|
# zero_optim_config
|
|
|
|
reduce_bucket_size_in_m=reduce_bucket_size,
|
|
|
|
overlap_communication=overlap_communication,
|
|
|
|
cpu_offload=(placement_policy == 'cpu'),
|
|
|
|
# optim_config
|
|
|
|
**optim_kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
super().__init__(seed, plugin_initializer)
|
|
|
|
|
|
|
|
def _post_init(self) -> None:
|
|
|
|
assert isinstance(self.plugin, (LowLevelZeroPlugin, GeminiPlugin)), \
|
|
|
|
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def setup_distributed(self) -> None:
|
|
|
|
colossalai.launch_from_torch({}, seed=self.seed)
|
|
|
|
|
|
|
|
def model_init_context(self):
|
2023-06-25 09:36:21 +00:00
|
|
|
if isinstance(self.plugin, GeminiPlugin):
|
2023-03-28 12:25:36 +00:00
|
|
|
world_size = dist.get_world_size()
|
|
|
|
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
|
|
|
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
|
|
|
return ColoInitContext(device=get_current_device(),
|
|
|
|
dtype=torch.half,
|
|
|
|
default_pg=shard_pg,
|
|
|
|
default_dist_spec=default_dist_spec)
|
|
|
|
return super().model_init_context()
|
|
|
|
|
2023-04-27 10:41:49 +00:00
|
|
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
2023-06-25 09:36:21 +00:00
|
|
|
if isinstance(self.plugin, GeminiPlugin):
|
|
|
|
assert isinstance(model, GeminiModel)
|
|
|
|
ddp_model = model.unwrap()
|
|
|
|
assert isinstance(ddp_model, GeminiDDP)
|
|
|
|
return ddp_model.module
|
|
|
|
elif isinstance(self.plugin, LowLevelZeroPlugin):
|
|
|
|
assert isinstance(model, LowLevelZeroModel)
|
2023-06-13 05:31:56 +00:00
|
|
|
return model.module
|
2023-06-25 09:36:21 +00:00
|
|
|
else:
|
|
|
|
raise RuntimeError(f'Unsupported plugin {type(self.plugin)}')
|
2023-04-27 10:41:49 +00:00
|
|
|
|
|
|
|
def save_pretrained(self,
|
|
|
|
model: nn.Module,
|
|
|
|
path: str,
|
|
|
|
only_rank0: bool = True,
|
|
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
2023-06-25 09:36:21 +00:00
|
|
|
if isinstance(self.plugin, GeminiPlugin):
|
2023-04-27 10:41:49 +00:00
|
|
|
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
|
|
|
|
super().save_pretrained(model, path, only_rank0, tokenizer)
|
2023-06-07 02:41:16 +00:00
|
|
|
|
|
|
|
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
2023-06-25 09:36:21 +00:00
|
|
|
if not isinstance(self.plugin, GeminiPlugin):
|
2023-06-07 02:41:16 +00:00
|
|
|
yield from super().get_model_state_dict_shard(model, **config)
|
|
|
|
else:
|
|
|
|
# unwrapped_model = self._unwrap_model(model)
|
|
|
|
# for module in unwrapped_model.modules():
|
|
|
|
# if isinstance(module, LoraLinear):
|
|
|
|
# module.merge_weights = True
|
|
|
|
# module.eval()
|
2023-06-25 09:36:21 +00:00
|
|
|
assert isinstance(model, LowLevelZeroModel)
|
2023-06-13 05:31:56 +00:00
|
|
|
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|