diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py index 572c3d945..57a708135 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -2,6 +2,7 @@ import math from typing import Dict, List, Optional, Tuple import numpy as np +import torch.distributed as dist import torch.nn as nn from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator @@ -13,8 +14,14 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> """ Filter those parameters whose size is too large (more than 3x standard deviations) from others. """ - params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)] - params_size_arr = np.array(params_size) + agg_size_list = [] + for key in size_dict: + agg_size_list.extend(size_dict[key]) + + if len(agg_size_list) == 0: + return + + params_size_arr = np.array(agg_size_list) std = np.std(params_size_arr) mean = np.mean(params_size_arr) @@ -38,7 +45,15 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: return left + acc -def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]: +def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool): + if strict_ddp_flag: + return local_param.numel_global() + else: + return local_param.numel() + + +def classify_params_by_dp_degree(param_order: OrderedParamGenerator, + strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]: """classify_params_by_dp_degree Classify the parameters by their dp degree @@ -56,7 +71,10 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int if is_ddp_ignored(param): continue - param_key = param.process_group.dp_world_size() + if strict_ddp_flag: + param_key = dist.get_world_size() + else: + param_key = param.process_group.dp_world_size() if param_key not in params_dict: params_dict[param_key] = [] @@ -71,14 +89,18 @@ def search_chunk_configuration( search_interval_byte: int, # hidden size is the best value for the interval min_chunk_size_mb: float = 32, filter_exlarge_params: bool = True, - memstas: Optional[MemStats] = None) -> Tuple[Dict, int]: + strict_ddp_flag: bool = False, + memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: """search_chunk_configuration Args: model (nn.Module): torch module search_range_mb (float): searching range in mega byte. search_interval_byte (int): searching interval in byte. + min_chunk_size_mb (float, optional): the minimum size of a distributed chunk. filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. + strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. + all parameters keep replicated in this mode. Returns: Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. @@ -96,17 +118,20 @@ def search_chunk_configuration( min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) assert search_range_byte >= 0 - params_dict = classify_params_by_dp_degree(param_order) + params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag) config_dict: Dict[int, Dict] = dict() + total_param_size = 0 size_dict: Dict[int, List[int]] = dict() for dp_degree in params_dict: params_list = params_dict[dp_degree] - size_list = [p.numel() for p in params_list] + size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list] + group_acc_size = sum(size_list) + total_param_size += group_acc_size + # let small parameters keep gathered in CUDA all the time - total_size = sum(size_list) - if total_size < min_chunk_size_byte: - config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True) + if group_acc_size < min_chunk_size_byte: + config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True) else: size_dict[dp_degree] = size_list @@ -134,4 +159,4 @@ def search_chunk_configuration( continue config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False) - return config_dict, min_chunk_waste + return config_dict, total_param_size, min_chunk_waste diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/gemini/chunk/utils.py index ebfdee778..83512b8e0 100644 --- a/colossalai/gemini/chunk/utils.py +++ b/colossalai/gemini/chunk/utils.py @@ -19,38 +19,24 @@ def safe_div(a, b): def init_chunk_manager(model: nn.Module, init_device: Optional[torch.device] = None, hidden_dim: Optional[int] = None, - search_range_mb: Optional[float] = None, - min_chunk_size_mb: Optional[float] = None, - filter_exlarge_params: Optional[bool] = None) -> ChunkManager: - kwargs_dict = dict() - + **kwargs) -> ChunkManager: if hidden_dim: search_interval_byte = hidden_dim else: - search_interval_byte = 1024 # 1kb - kwargs_dict["search_interval_byte"] = search_interval_byte - - if search_range_mb: - kwargs_dict["search_range_mb"] = search_range_mb - - if min_chunk_size_mb: - kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb - - if filter_exlarge_params: - kwargs_dict["filter_exlarge_params"] = filter_exlarge_params - - params_sizes = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)] - total_size = sum(params_sizes) / 1024**2 + search_interval_byte = 1024 # defaults to 1kb + kwargs["search_interval_byte"] = search_interval_byte dist.barrier() begin = time() - config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict) + config_dict, total_size, wasted_size = search_chunk_configuration(model, **kwargs) dist.barrier() end = time() span_s = end - begin - wasted_size /= 1024**2 + mb_size = 1024**2 + total_size /= mb_size + wasted_size /= mb_size if dist.get_rank() == 0: print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index a742946f4..24d59e177 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -234,11 +234,14 @@ class ZeroDDP(ColoDDP): for p in module.parameters(): param_order.append(p) + ddp_pg = ColoProcessGroup() for p in param_order.generate(): assert isinstance(p, ColoParameter) - if strict_ddp_mode and not p.is_replicate(): - p.set_dist_spec(ReplicaSpec()) + if strict_ddp_mode: + if not p.is_replicate(): + p.set_dist_spec(ReplicaSpec()) + p.set_process_group(pg=ddp_pg) if is_ddp_ignored(p): p.data = p.data.to(device=get_current_device(), dtype=torch.float16) diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py index 868a3960f..636f1ec74 100644 --- a/colossalai/nn/parallel/gemini_parallel.py +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -20,7 +20,7 @@ class GeminiDDP(ZeroDDP): strict_ddp_mode: bool = False, search_range_mb: int = 32, hidden_dim: Optional[int] = None, - min_chunk_size_mb: Optional[float] = None, + min_chunk_size_mb: float = 32, memstats: Optional[MemStats] = None) -> None: """ A torch.Module warpper using ZeRO-DP and Genimi. @@ -53,6 +53,7 @@ class GeminiDDP(ZeroDDP): init_device=device, hidden_dim=hidden_dim, search_range_mb=search_range_mb, - min_chunk_size_mb=min_chunk_size_mb) + min_chunk_size_mb=min_chunk_size_mb, + strict_ddp_flag=strict_ddp_mode) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 3712d6a0a..b27f5dea7 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,3 +1,4 @@ +import math from copy import copy from functools import lru_cache from typing import Callable, Optional, Set @@ -303,6 +304,11 @@ class ColoTensor(torch.Tensor): else: return size_list[args[0]] + def numel_global(self): + """Returns the number of elements in the tensor when it's replicated. + """ + return math.prod(self.size_global()) + # Some API for dist spec check def is_replicate(self): diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 285706596..02857ae9c 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -263,7 +263,7 @@ def main(): if args.distplan == "colossalai": # all param must use the same process group. world_size = torch.distributed.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) + shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None # build GPT model diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index 2be962e1a..679c8b0f6 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -35,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: - chunk_config, _ = search_chunk_configuration(module, 4, 1024) + chunk_config, *_ = search_chunk_configuration(module, 4, 1024) chunk_manager = ChunkManager(chunk_config) gemini_manager = GeminiManager('cuda', chunk_manager) return ZeroDDP(module, gemini_manager) diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index af98878e9..0d35ba83d 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -58,7 +58,7 @@ def exam_gpt_fwd_bwd(placement_policy, torch_p.data.copy_(p.data) world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gather chunk_manager = ChunkManager(config_dict) diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py index 7fce84a50..8cf17a0a7 100644 --- a/tests/test_gemini/update/test_gemini_use_rmt.py +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -62,7 +62,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ assert len(step_list) == 4 world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gather chunk_manager = ChunkManager(config_dict) diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_gemini/update/test_grad_clip.py index fda1cf8cf..d97ba9439 100644 --- a/tests/test_gemini/update/test_grad_clip.py +++ b/tests/test_gemini/update/test_grad_clip.py @@ -58,7 +58,7 @@ def exam_grad_clipping(placement_policy, model_name: str): p.data.copy_(torch_p.data) world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = False if placement_policy != 'cuda': diff --git a/tests/test_gemini/update/test_inference.py b/tests/test_gemini/update/test_inference.py index aec945fc9..443155865 100644 --- a/tests/test_gemini/update/test_inference.py +++ b/tests/test_gemini/update/test_inference.py @@ -57,7 +57,7 @@ def exam_inference(placement_policy, model_name: str): p.data.copy_(torch_p.data) world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = False if placement_policy != 'cuda': diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index 07e6e65f2..cd3aa6051 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -63,7 +63,7 @@ def exam_model_step(placement_policy, model_name: str): p.data.copy_(torch_p.data) world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = False if placement_policy != 'cuda': diff --git a/tests/test_gemini/update/test_search.py b/tests/test_gemini/update/test_search.py index e0b4e207f..2fcdd5380 100644 --- a/tests/test_gemini/update/test_search.py +++ b/tests/test_gemini/update/test_search.py @@ -6,7 +6,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import search_chunk_configuration +from colossalai.gemini.chunk import init_chunk_manager, search_chunk_configuration from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device @@ -23,7 +23,6 @@ def init_1d_row_spec(model, pg: ProcessGroup): def exam_search_chunk_size(): - world_size = torch.distributed.get_world_size() pg_tp = ProcessGroup(tp_degree=world_size) @@ -34,11 +33,11 @@ def exam_search_chunk_size(): with ColoInitContext(device=get_current_device()): model = model_builder() init_1d_row_spec(model, pg_tp) - config_dict, _ = search_chunk_configuration(model, - search_range_mb=1, - search_interval_byte=16, - min_chunk_size_mb=0, - filter_exlarge_params=True) + config_dict, *_ = search_chunk_configuration(model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True) for key in config_dict: chunk_size = config_dict[key]['chunk_size'] @@ -48,9 +47,68 @@ def exam_search_chunk_size(): assert chunk_size == 1024 +def exam_search_strict_ddp(): + world_size = torch.distributed.get_world_size() + default_shard_pg = ProcessGroup(tp_degree=world_size) + default_shard_spec = ShardSpec([-1], [world_size]) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + # get the chunk configuration over replicated models + with ColoInitContext(device=get_current_device()): + ddp_model = model_builder() + re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True, + strict_ddp_flag=False) + # get the chunk configuration over sharded ddp models + with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg, + default_dist_spec=default_shard_spec): + sharded_ddp_model = model_builder() + sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True, + strict_ddp_flag=True) + assert re_dict == sh_dict + for key in re_dict: + assert re_dict[key] == sh_dict[key] + + assert re_total == sh_total + assert re_wasted == sh_wasted + + +def exam_chunk_manager(): + world_size = torch.distributed.get_world_size() + default_shard_pg = ProcessGroup(tp_degree=world_size) + default_shard_spec = ShardSpec([-1], [world_size]) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg, + default_dist_spec=default_shard_spec): + sharded_ddp_model = model_builder() + chunk_manager = init_chunk_manager(sharded_ddp_model, + get_current_device(), + hidden_dim=16, + search_range_mb=1, + min_chunk_size_mb=0, + filter_exlarge_params=True, + strict_ddp_flag=True) + config_dict = chunk_manager.dp_degree_chunk_size_dict + assert len(config_dict) == 1 + assert config_dict[world_size] == 31616 + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_search_chunk_size() + exam_search_strict_ddp() + exam_chunk_manager() @pytest.mark.dist diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index 266b8eab1..00d835842 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -41,7 +41,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): torch_p.data.copy_(p.data) world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gathered chunk_manager = ChunkManager(config_dict) @@ -73,7 +73,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gathered diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py index 7f53415bf..dc3dda9d6 100644 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -33,7 +33,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered): torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gathered diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 83645bc6e..1a6d23f6a 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -85,7 +85,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None): tp_init_spec_func(model, pg) dp_world_size = pg.dp_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[dp_world_size]['chunk_size'] = 5000 config_dict[dp_world_size]['keep_gathered'] = False if placement_policy != 'cuda':