ColossalAI/applications/Chat/coati/trainer/strategies/colossalai.py

216 lines
9.6 KiB
Python
Raw Normal View History

2023-03-28 12:25:36 +00:00
import warnings
from typing import Optional
2023-03-28 12:25:36 +00:00
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.gemini_plugin import GeminiModel
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
2023-03-28 12:25:36 +00:00
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
2023-03-28 12:25:36 +00:00
from .ddp import DDPStrategy
class LowLevelZeroStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2)
precision(str): The precision to use. Choose in ('fp32', 'fp16').
seed(int): The seed for the random number generator.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is cpu, parameters, gradients and optimizer states will be offloaded to CPU,
If it is cuda, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
growth_interval(int): The growth interval for the optimizer.
hysteresis(int): The hysteresis for the optimizer.
min_scale(float): The minimum scale for the optimizer.
max_scale(float): The maximum scale for the optimizer.
max_norm(float): The maximum norm for the optimizer.
norm_type(float): The norm type for the optimizer.
"""
def __init__(self,
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
stage: int = 2,
precision: str = 'fp16',
seed: int = 42,
placement_policy: str = 'cuda',
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0
) -> None:
assert stage in (1, 2), f'Unsupported stage "{stage}"'
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
stage=stage,
precision=precision,
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, LowLevelZeroPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, LowLevelZeroModel)
return model.module
def get_model_state_dict_shard(self, model: nn.Module, **config):
assert isinstance(model, LowLevelZeroModel)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
class GeminiStrategy(DDPStrategy):
2023-03-28 12:25:36 +00:00
"""
The strategy for training with ColossalAI.
Args:
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
2023-03-28 12:25:36 +00:00
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is cpu, parameters, gradients and optimizer states will be offloaded to CPU,
If it is cuda, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
2023-03-28 12:25:36 +00:00
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
2023-03-28 12:25:36 +00:00
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
growth_interval(int): The growth interval for the optimizer.
hysteresis(int): The hysteresis for the optimizer.
min_scale(float): The minimum scale for the optimizer.
max_scale(float): The maximum scale for the optimizer.
max_norm(float): The maximum norm for the optimizer.
norm_type(float): The norm type for the optimizer.
"""
def __init__(self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0
) -> None:
2023-03-28 12:25:36 +00:00
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
2023-03-28 12:25:36 +00:00
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()'
)
2023-03-28 12:25:36 +00:00
self.shard_init = shard_init
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision='fp16',
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, GeminiPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
2023-03-28 12:25:36 +00:00
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_pg=shard_pg,
default_dist_spec=default_dist_spec)
2023-03-28 12:25:36 +00:00
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiModel)
ddp_model = model.unwrap()
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module