[gemini] update ddp strict mode (#2518)

* [zero] add strict ddp mode for chunk init

* [gemini] update gpt example
pull/2520/head
HELSON 2 years ago committed by GitHub
parent 0af793836c
commit 707b11d4a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@ import math
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
@ -13,8 +14,14 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) ->
"""
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
"""
params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
params_size_arr = np.array(params_size)
agg_size_list = []
for key in size_dict:
agg_size_list.extend(size_dict[key])
if len(agg_size_list) == 0:
return
params_size_arr = np.array(agg_size_list)
std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
@ -38,7 +45,15 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc
def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]:
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
if strict_ddp_flag:
return local_param.numel_global()
else:
return local_param.numel()
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
@ -56,7 +71,10 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int
if is_ddp_ignored(param):
continue
param_key = param.process_group.dp_world_size()
if strict_ddp_flag:
param_key = dist.get_world_size()
else:
param_key = param.process_group.dp_world_size()
if param_key not in params_dict:
params_dict[param_key] = []
@ -71,14 +89,18 @@ def search_chunk_configuration(
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int]:
strict_ddp_flag: bool = False,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
min_chunk_size_mb (float, optional): the minimum size of a distributed chunk.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode.
Returns:
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
@ -96,17 +118,20 @@ def search_chunk_configuration(
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
params_dict = classify_params_by_dp_degree(param_order)
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
config_dict: Dict[int, Dict] = dict()
total_param_size = 0
size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict:
params_list = params_dict[dp_degree]
size_list = [p.numel() for p in params_list]
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
group_acc_size = sum(size_list)
total_param_size += group_acc_size
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
if total_size < min_chunk_size_byte:
config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True)
if group_acc_size < min_chunk_size_byte:
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
else:
size_dict[dp_degree] = size_list
@ -134,4 +159,4 @@ def search_chunk_configuration(
continue
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict, min_chunk_waste
return config_dict, total_param_size, min_chunk_waste

@ -19,38 +19,24 @@ def safe_div(a, b):
def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
kwargs_dict = dict()
**kwargs) -> ChunkManager:
if hidden_dim:
search_interval_byte = hidden_dim
else:
search_interval_byte = 1024 # 1kb
kwargs_dict["search_interval_byte"] = search_interval_byte
if search_range_mb:
kwargs_dict["search_range_mb"] = search_range_mb
if min_chunk_size_mb:
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb
if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
params_sizes = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
total_size = sum(params_sizes) / 1024**2
search_interval_byte = 1024 # defaults to 1kb
kwargs["search_interval_byte"] = search_interval_byte
dist.barrier()
begin = time()
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
config_dict, total_size, wasted_size = search_chunk_configuration(model, **kwargs)
dist.barrier()
end = time()
span_s = end - begin
wasted_size /= 1024**2
mb_size = 1024**2
total_size /= mb_size
wasted_size /= mb_size
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),

@ -234,11 +234,14 @@ class ZeroDDP(ColoDDP):
for p in module.parameters():
param_order.append(p)
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
if strict_ddp_mode and not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)

@ -20,7 +20,7 @@ class GeminiDDP(ZeroDDP):
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
@ -53,6 +53,7 @@ class GeminiDDP(ZeroDDP):
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
min_chunk_size_mb=min_chunk_size_mb,
strict_ddp_flag=strict_ddp_mode)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)

@ -1,3 +1,4 @@
import math
from copy import copy
from functools import lru_cache
from typing import Callable, Optional, Set
@ -303,6 +304,11 @@ class ColoTensor(torch.Tensor):
else:
return size_list[args[0]]
def numel_global(self):
"""Returns the number of elements in the tensor when it's replicated.
"""
return math.prod(self.size_global())
# Some API for dist spec check
def is_replicate(self):

@ -263,7 +263,7 @@ def main():
if args.distplan == "colossalai":
# all param must use the same process group.
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size)
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
# build GPT model

@ -35,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config, _ = search_chunk_configuration(module, 4, 1024)
chunk_config, *_ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)

@ -58,7 +58,7 @@ def exam_gpt_fwd_bwd(placement_policy,
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)

@ -62,7 +62,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
assert len(step_list) == 4
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)

@ -58,7 +58,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':

@ -57,7 +57,7 @@ def exam_inference(placement_policy, model_name: str):
p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':

@ -63,7 +63,7 @@ def exam_model_step(placement_policy, model_name: str):
p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':

@ -6,7 +6,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.gemini.chunk import init_chunk_manager, search_chunk_configuration
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
@ -23,7 +23,6 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def exam_search_chunk_size():
world_size = torch.distributed.get_world_size()
pg_tp = ProcessGroup(tp_degree=world_size)
@ -34,11 +33,11 @@ def exam_search_chunk_size():
with ColoInitContext(device=get_current_device()):
model = model_builder()
init_1d_row_spec(model, pg_tp)
config_dict, _ = search_chunk_configuration(model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True)
config_dict, *_ = search_chunk_configuration(model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True)
for key in config_dict:
chunk_size = config_dict[key]['chunk_size']
@ -48,9 +47,68 @@ def exam_search_chunk_size():
assert chunk_size == 1024
def exam_search_strict_ddp():
world_size = torch.distributed.get_world_size()
default_shard_pg = ProcessGroup(tp_degree=world_size)
default_shard_spec = ShardSpec([-1], [world_size])
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# get the chunk configuration over replicated models
with ColoInitContext(device=get_current_device()):
ddp_model = model_builder()
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True,
strict_ddp_flag=False)
# get the chunk configuration over sharded ddp models
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
assert re_dict == sh_dict
for key in re_dict:
assert re_dict[key] == sh_dict[key]
assert re_total == sh_total
assert re_wasted == sh_wasted
def exam_chunk_manager():
world_size = torch.distributed.get_world_size()
default_shard_pg = ProcessGroup(tp_degree=world_size)
default_shard_spec = ShardSpec([-1], [world_size])
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager(sharded_ddp_model,
get_current_device(),
hidden_dim=16,
search_range_mb=1,
min_chunk_size_mb=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
config_dict = chunk_manager.dp_degree_chunk_size_dict
assert len(config_dict) == 1
assert config_dict[world_size] == 31616
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_search_chunk_size()
exam_search_strict_ddp()
exam_chunk_manager()
@pytest.mark.dist

@ -41,7 +41,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
@ -73,7 +73,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered

@ -33,7 +33,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered

@ -85,7 +85,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda':

Loading…
Cancel
Save