Merge pull request #4056 from Fridge003/hotfix/fix_gemini_chunk_config_searching

[gemini] Rename arguments in chunk configuration searching
pull/4046/merge
Baizhou Zhang 2023-06-25 17:37:37 +08:00 committed by GitHub
commit 2c8ae37f61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 62 additions and 64 deletions

View File

@ -181,11 +181,11 @@ class GeminiPlugin(DPPluginBase):
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
@ -214,9 +214,9 @@ class GeminiPlugin(DPPluginBase):
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
@ -238,9 +238,9 @@ class GeminiPlugin(DPPluginBase):
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
search_range_mb=search_range_mb,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
min_chunk_size_m=min_chunk_size_m,
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
@ -295,10 +295,7 @@ class GeminiPlugin(DPPluginBase):
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -114,9 +114,9 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
search_range_m: float,
search_interval: int, # hidden size is the best value for the interval
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
@ -126,9 +126,9 @@ def search_chunk_configuration(
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
min_chunk_size_mb (float, optional): the minimum size of a distributed chunk.
search_range_m (float): searching range divided by 2^20.
search_interval (int): searching interval.
min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20..
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode.
@ -145,9 +145,9 @@ def search_chunk_configuration(
for p in model.parameters():
param_order.append(p)
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
search_range = round(search_range_m * 1024**2)
min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
size_lcm = np.lcm.reduce(list(params_dict.keys()))
@ -162,7 +162,7 @@ def search_chunk_configuration(
total_param_size += group_acc_size
# let small parameters keep gathered in CUDA all the time
if group_acc_size < min_chunk_size_byte:
if group_acc_size < min_chunk_size:
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
else:
size_dict[dp_degree] = size_list
@ -170,15 +170,15 @@ def search_chunk_configuration(
if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)
max_size = min_chunk_size_byte
max_size = min_chunk_size
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
start_size = int(math.ceil(max_size / search_interval) * search_interval)
min_chunk_waste = float('+inf')
best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)

View File

@ -23,10 +23,10 @@ def init_chunk_manager(model: nn.Module,
verbose: bool = False,
**kwargs) -> ChunkManager:
if hidden_dim:
search_interval_byte = hidden_dim
search_interval = hidden_dim
else:
search_interval_byte = 1024 # defaults to 1kb
kwargs["search_interval_byte"] = search_interval_byte
search_interval = 1024 # defaults to 1024
kwargs["search_interval"] = search_interval
dist.barrier()
begin = time()
@ -36,13 +36,13 @@ def init_chunk_manager(model: nn.Module,
dist.barrier()
end = time()
span_s = end - begin
mb_size = 1024**2
total_size /= mb_size
wasted_size /= mb_size
mega_unit = 1024**2
total_size /= mega_unit
wasted_size /= mega_unit
if verbose and dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)

View File

@ -739,9 +739,9 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
search_range_mb: int = 32,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
@ -763,24 +763,24 @@ class GeminiDDP(ZeroDDP):
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
if search_range_m is None:
search_range_m = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)

View File

@ -60,7 +60,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
placement_policy='cpu',
pin_memory=True,
hidden_dim=8192,
search_range_mb=128)
search_range_m=128)
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)

View File

@ -75,7 +75,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
search_range_mb=128)
search_range_m=128)
post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)

View File

@ -30,7 +30,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
bert_model.config.save_pretrained(save_directory=pretrained_path)
# TODO(ver217): use boost api
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)

View File

@ -79,7 +79,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda':

View File

@ -52,7 +52,7 @@ def exam_gpt_fwd_bwd(
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
@ -113,7 +113,7 @@ def exam_gpt_inference(
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)

View File

@ -56,7 +56,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
assert len(step_list) == 4
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)

View File

@ -51,7 +51,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':

View File

@ -34,7 +34,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
world_size = dist.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':

View File

@ -73,7 +73,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
@ -130,7 +130,7 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
optimizer = HybridAdam(model.parameters(), lr=1e-3)

View File

@ -30,9 +30,9 @@ def exam_search_chunk_size():
model = model_builder()
init_1d_row_spec(model, pg_tp)
config_dict, *_ = search_chunk_configuration(model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True)
for key in config_dict:
@ -54,9 +54,9 @@ def exam_search_strict_ddp():
with ColoInitContext(device=get_current_device()):
ddp_model = model_builder()
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=False)
# get the chunk configuration over sharded ddp models
@ -64,9 +64,9 @@ def exam_search_strict_ddp():
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
assert re_dict == sh_dict
@ -91,8 +91,8 @@ def exam_chunk_manager():
chunk_manager = init_chunk_manager(sharded_ddp_model,
get_current_device(),
hidden_dim=16,
search_range_mb=1,
min_chunk_size_mb=0,
search_range_m=1,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
config_dict = chunk_manager.dp_degree_chunk_size_dict

View File

@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered

View File

@ -22,7 +22,7 @@ def exam_state_dict(placement_policy, model_name: str):
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
@ -38,6 +38,7 @@ def exam_state_dict(placement_policy, model_name: str):
assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

View File

@ -27,7 +27,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered