mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #4056 from Fridge003/hotfix/fix_gemini_chunk_config_searching
[gemini] Rename arguments in chunk configuration searchingpull/4046/merge
commit
2c8ae37f61
|
@ -181,11 +181,11 @@ class GeminiPlugin(DPPluginBase):
|
||||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||||
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
||||||
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
|
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
|
||||||
hidden_dim (int, optional): the hidden dimension of DNN.
|
hidden_dim (int, optional): the hidden dimension of DNN.
|
||||||
Users can provide this argument to speed up searching.
|
Users can provide this argument to speed up searching.
|
||||||
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
||||||
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
|
||||||
If the aggregate size of parameters is still smaller than the minimum chunk size,
|
If the aggregate size of parameters is still smaller than the minimum chunk size,
|
||||||
all parameters will be compacted into one small chunk.
|
all parameters will be compacted into one small chunk.
|
||||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||||
|
@ -214,9 +214,9 @@ class GeminiPlugin(DPPluginBase):
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
force_outputs_fp32: bool = False,
|
force_outputs_fp32: bool = False,
|
||||||
strict_ddp_mode: bool = False,
|
strict_ddp_mode: bool = False,
|
||||||
search_range_mb: int = 32,
|
search_range_m: int = 32,
|
||||||
hidden_dim: Optional[int] = None,
|
hidden_dim: Optional[int] = None,
|
||||||
min_chunk_size_mb: float = 32,
|
min_chunk_size_m: float = 32,
|
||||||
memstats: Optional[MemStats] = None,
|
memstats: Optional[MemStats] = None,
|
||||||
gpu_margin_mem_ratio: float = 0.0,
|
gpu_margin_mem_ratio: float = 0.0,
|
||||||
initial_scale: float = 2**32,
|
initial_scale: float = 2**32,
|
||||||
|
@ -238,9 +238,9 @@ class GeminiPlugin(DPPluginBase):
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
force_outputs_fp32=force_outputs_fp32,
|
force_outputs_fp32=force_outputs_fp32,
|
||||||
strict_ddp_mode=strict_ddp_mode,
|
strict_ddp_mode=strict_ddp_mode,
|
||||||
search_range_mb=search_range_mb,
|
search_range_m=search_range_m,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
min_chunk_size_mb=min_chunk_size_mb,
|
min_chunk_size_m=min_chunk_size_m,
|
||||||
memstats=memstats,
|
memstats=memstats,
|
||||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||||
)
|
)
|
||||||
|
@ -295,10 +295,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
|
|
||||||
if optimizer is not None and \
|
if optimizer is not None and \
|
||||||
not isinstance(optimizer, OptimizerWrapper):
|
not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer = GeminiOptimizer(model.unwrap(),
|
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||||
optimizer,
|
|
||||||
self.zero_optim_config,
|
|
||||||
self.optim_kwargs,
|
|
||||||
self.verbose)
|
self.verbose)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
|
@ -114,9 +114,9 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
||||||
|
|
||||||
def search_chunk_configuration(
|
def search_chunk_configuration(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
search_range_mb: float,
|
search_range_m: float,
|
||||||
search_interval_byte: int, # hidden size is the best value for the interval
|
search_interval: int, # hidden size is the best value for the interval
|
||||||
min_chunk_size_mb: float = 32,
|
min_chunk_size_m: float = 32,
|
||||||
filter_exlarge_params: bool = True,
|
filter_exlarge_params: bool = True,
|
||||||
strict_ddp_flag: bool = False,
|
strict_ddp_flag: bool = False,
|
||||||
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
||||||
|
@ -126,9 +126,9 @@ def search_chunk_configuration(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): torch module
|
model (nn.Module): torch module
|
||||||
search_range_mb (float): searching range in mega byte.
|
search_range_m (float): searching range divided by 2^20.
|
||||||
search_interval_byte (int): searching interval in byte.
|
search_interval (int): searching interval.
|
||||||
min_chunk_size_mb (float, optional): the minimum size of a distributed chunk.
|
min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20..
|
||||||
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
|
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
|
||||||
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
|
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
|
||||||
all parameters keep replicated in this mode.
|
all parameters keep replicated in this mode.
|
||||||
|
@ -145,9 +145,9 @@ def search_chunk_configuration(
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
param_order.append(p)
|
param_order.append(p)
|
||||||
|
|
||||||
search_range_byte = round(search_range_mb * 1024**2)
|
search_range = round(search_range_m * 1024**2)
|
||||||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
min_chunk_size = round(min_chunk_size_m * 1024**2)
|
||||||
assert search_range_byte >= 0
|
assert search_range >= 0
|
||||||
|
|
||||||
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
|
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
|
||||||
size_lcm = np.lcm.reduce(list(params_dict.keys()))
|
size_lcm = np.lcm.reduce(list(params_dict.keys()))
|
||||||
|
@ -162,7 +162,7 @@ def search_chunk_configuration(
|
||||||
total_param_size += group_acc_size
|
total_param_size += group_acc_size
|
||||||
|
|
||||||
# let small parameters keep gathered in CUDA all the time
|
# let small parameters keep gathered in CUDA all the time
|
||||||
if group_acc_size < min_chunk_size_byte:
|
if group_acc_size < min_chunk_size:
|
||||||
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
|
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
|
||||||
else:
|
else:
|
||||||
size_dict[dp_degree] = size_list
|
size_dict[dp_degree] = size_list
|
||||||
|
@ -170,15 +170,15 @@ def search_chunk_configuration(
|
||||||
if filter_exlarge_params:
|
if filter_exlarge_params:
|
||||||
_filter_exlarge_params(model, size_dict)
|
_filter_exlarge_params(model, size_dict)
|
||||||
|
|
||||||
max_size = min_chunk_size_byte
|
max_size = min_chunk_size
|
||||||
for key in size_dict:
|
for key in size_dict:
|
||||||
max_size = max(max_size, max(size_dict[key]))
|
max_size = max(max_size, max(size_dict[key]))
|
||||||
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
|
start_size = int(math.ceil(max_size / search_interval) * search_interval)
|
||||||
|
|
||||||
min_chunk_waste = float('+inf')
|
min_chunk_waste = float('+inf')
|
||||||
best_chunk_size = start_size
|
best_chunk_size = start_size
|
||||||
|
|
||||||
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
|
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
|
||||||
temp_waste = 0
|
temp_waste = 0
|
||||||
for key in size_dict:
|
for key in size_dict:
|
||||||
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
||||||
|
|
|
@ -23,10 +23,10 @@ def init_chunk_manager(model: nn.Module,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
**kwargs) -> ChunkManager:
|
**kwargs) -> ChunkManager:
|
||||||
if hidden_dim:
|
if hidden_dim:
|
||||||
search_interval_byte = hidden_dim
|
search_interval = hidden_dim
|
||||||
else:
|
else:
|
||||||
search_interval_byte = 1024 # defaults to 1kb
|
search_interval = 1024 # defaults to 1024
|
||||||
kwargs["search_interval_byte"] = search_interval_byte
|
kwargs["search_interval"] = search_interval
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
begin = time()
|
begin = time()
|
||||||
|
@ -36,13 +36,13 @@ def init_chunk_manager(model: nn.Module,
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
end = time()
|
end = time()
|
||||||
span_s = end - begin
|
span_s = end - begin
|
||||||
mb_size = 1024**2
|
mega_unit = 1024**2
|
||||||
total_size /= mb_size
|
total_size /= mega_unit
|
||||||
wasted_size /= mb_size
|
wasted_size /= mega_unit
|
||||||
|
|
||||||
if verbose and dist.get_rank() == 0:
|
if verbose and dist.get_rank() == 0:
|
||||||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||||
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
|
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
|
||||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||||
sep='',
|
sep='',
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
|
@ -739,9 +739,9 @@ class GeminiDDP(ZeroDDP):
|
||||||
force_outputs_fp32: bool = False,
|
force_outputs_fp32: bool = False,
|
||||||
strict_ddp_mode: bool = False,
|
strict_ddp_mode: bool = False,
|
||||||
scatter_after_inference: bool = True,
|
scatter_after_inference: bool = True,
|
||||||
search_range_mb: int = 32,
|
search_range_m: int = 32,
|
||||||
hidden_dim: Optional[int] = None,
|
hidden_dim: Optional[int] = None,
|
||||||
min_chunk_size_mb: float = 32,
|
min_chunk_size_m: float = 32,
|
||||||
memstats: Optional[MemStats] = None,
|
memstats: Optional[MemStats] = None,
|
||||||
mixed_precision: torch.dtype = torch.float16,
|
mixed_precision: torch.dtype = torch.float16,
|
||||||
verbose: bool = False) -> None:
|
verbose: bool = False) -> None:
|
||||||
|
@ -763,24 +763,24 @@ class GeminiDDP(ZeroDDP):
|
||||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||||
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
|
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
|
||||||
hidden_dim (int, optional): the hidden dimension of DNN.
|
hidden_dim (int, optional): the hidden dimension of DNN.
|
||||||
Users can provide this argument to speed up searching.
|
Users can provide this argument to speed up searching.
|
||||||
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
||||||
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
|
||||||
If the aggregate size of parameters is still smaller than the minimum chunk size,
|
If the aggregate size of parameters is still smaller than the minimum chunk size,
|
||||||
all parameters will be compacted into one small chunk.
|
all parameters will be compacted into one small chunk.
|
||||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||||
"""
|
"""
|
||||||
# some ugly hotfix for the compatibility with Lightning
|
# some ugly hotfix for the compatibility with Lightning
|
||||||
if search_range_mb is None:
|
if search_range_m is None:
|
||||||
search_range_mb = 32
|
search_range_m = 32
|
||||||
|
|
||||||
chunk_manager = init_chunk_manager(model=module,
|
chunk_manager = init_chunk_manager(model=module,
|
||||||
init_device=device,
|
init_device=device,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
search_range_mb=search_range_mb,
|
search_range_m=search_range_m,
|
||||||
min_chunk_size_mb=min_chunk_size_mb,
|
min_chunk_size_m=min_chunk_size_m,
|
||||||
strict_ddp_flag=strict_ddp_mode,
|
strict_ddp_flag=strict_ddp_mode,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||||
|
|
|
@ -60,7 +60,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
|
||||||
placement_policy='cpu',
|
placement_policy='cpu',
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
hidden_dim=8192,
|
hidden_dim=8192,
|
||||||
search_range_mb=128)
|
search_range_m=128)
|
||||||
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
|
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
|
||||||
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
||||||
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)
|
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)
|
||||||
|
|
|
@ -75,7 +75,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
placement_policy='cpu',
|
placement_policy='cpu',
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
search_range_mb=128)
|
search_range_m=128)
|
||||||
|
|
||||||
post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
|
post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
|
||||||
gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)
|
gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)
|
||||||
|
|
|
@ -30,7 +30,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
||||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||||
|
|
||||||
# TODO(ver217): use boost api
|
# TODO(ver217): use boost api
|
||||||
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100)
|
||||||
chunk_manager = ChunkManager(config_dict)
|
chunk_manager = ChunkManager(config_dict)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
bert_model = ZeroDDP(bert_model, gemini_manager)
|
bert_model = ZeroDDP(bert_model, gemini_manager)
|
||||||
|
|
|
@ -79,7 +79,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
|
||||||
tp_init_spec_func(model, pg)
|
tp_init_spec_func(model, pg)
|
||||||
|
|
||||||
dp_world_size = pg.dp_world_size()
|
dp_world_size = pg.dp_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[dp_world_size]['chunk_size'] = 5000
|
config_dict[dp_world_size]['chunk_size'] = 5000
|
||||||
config_dict[dp_world_size]['keep_gathered'] = False
|
config_dict[dp_world_size]['keep_gathered'] = False
|
||||||
if placement_policy != 'cuda':
|
if placement_policy != 'cuda':
|
||||||
|
|
|
@ -52,7 +52,7 @@ def exam_gpt_fwd_bwd(
|
||||||
torch_p.data.copy_(p.data)
|
torch_p.data.copy_(p.data)
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||||
chunk_manager = ChunkManager(config_dict)
|
chunk_manager = ChunkManager(config_dict)
|
||||||
|
@ -113,7 +113,7 @@ def exam_gpt_inference(
|
||||||
torch_p.data.copy_(p.data)
|
torch_p.data.copy_(p.data)
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||||
chunk_manager = ChunkManager(config_dict)
|
chunk_manager = ChunkManager(config_dict)
|
||||||
|
|
|
@ -56,7 +56,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||||
assert len(step_list) == 4
|
assert len(step_list) == 4
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||||
chunk_manager = ChunkManager(config_dict)
|
chunk_manager = ChunkManager(config_dict)
|
||||||
|
|
|
@ -51,7 +51,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
|
||||||
p.data.copy_(torch_p.data)
|
p.data.copy_(torch_p.data)
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = False
|
config_dict[world_size]['keep_gathered'] = False
|
||||||
if placement_policy != 'cuda':
|
if placement_policy != 'cuda':
|
||||||
|
|
|
@ -34,7 +34,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||||
|
|
||||||
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
|
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = False
|
config_dict[world_size]['keep_gathered'] = False
|
||||||
if placement_policy != 'cuda':
|
if placement_policy != 'cuda':
|
||||||
|
|
|
@ -73,7 +73,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
|
||||||
p.data.copy_(torch_p.data)
|
p.data.copy_(torch_p.data)
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = False
|
config_dict[world_size]['keep_gathered'] = False
|
||||||
if placement_policy != 'cuda':
|
if placement_policy != 'cuda':
|
||||||
|
@ -130,7 +130,7 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
p.data.copy_(torch_p.data)
|
p.data.copy_(torch_p.data)
|
||||||
|
|
||||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
|
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
|
|
@ -30,9 +30,9 @@ def exam_search_chunk_size():
|
||||||
model = model_builder()
|
model = model_builder()
|
||||||
init_1d_row_spec(model, pg_tp)
|
init_1d_row_spec(model, pg_tp)
|
||||||
config_dict, *_ = search_chunk_configuration(model,
|
config_dict, *_ = search_chunk_configuration(model,
|
||||||
search_range_mb=1,
|
search_range_m=1,
|
||||||
search_interval_byte=16,
|
search_interval=16,
|
||||||
min_chunk_size_mb=0,
|
min_chunk_size_m=0,
|
||||||
filter_exlarge_params=True)
|
filter_exlarge_params=True)
|
||||||
|
|
||||||
for key in config_dict:
|
for key in config_dict:
|
||||||
|
@ -54,9 +54,9 @@ def exam_search_strict_ddp():
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
ddp_model = model_builder()
|
ddp_model = model_builder()
|
||||||
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
|
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
|
||||||
search_range_mb=1,
|
search_range_m=1,
|
||||||
search_interval_byte=16,
|
search_interval=16,
|
||||||
min_chunk_size_mb=0,
|
min_chunk_size_m=0,
|
||||||
filter_exlarge_params=True,
|
filter_exlarge_params=True,
|
||||||
strict_ddp_flag=False)
|
strict_ddp_flag=False)
|
||||||
# get the chunk configuration over sharded ddp models
|
# get the chunk configuration over sharded ddp models
|
||||||
|
@ -64,9 +64,9 @@ def exam_search_strict_ddp():
|
||||||
default_dist_spec=default_shard_spec):
|
default_dist_spec=default_shard_spec):
|
||||||
sharded_ddp_model = model_builder()
|
sharded_ddp_model = model_builder()
|
||||||
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
|
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
|
||||||
search_range_mb=1,
|
search_range_m=1,
|
||||||
search_interval_byte=16,
|
search_interval=16,
|
||||||
min_chunk_size_mb=0,
|
min_chunk_size_m=0,
|
||||||
filter_exlarge_params=True,
|
filter_exlarge_params=True,
|
||||||
strict_ddp_flag=True)
|
strict_ddp_flag=True)
|
||||||
assert re_dict == sh_dict
|
assert re_dict == sh_dict
|
||||||
|
@ -91,8 +91,8 @@ def exam_chunk_manager():
|
||||||
chunk_manager = init_chunk_manager(sharded_ddp_model,
|
chunk_manager = init_chunk_manager(sharded_ddp_model,
|
||||||
get_current_device(),
|
get_current_device(),
|
||||||
hidden_dim=16,
|
hidden_dim=16,
|
||||||
search_range_mb=1,
|
search_range_m=1,
|
||||||
min_chunk_size_mb=0,
|
min_chunk_size_m=0,
|
||||||
filter_exlarge_params=True,
|
filter_exlarge_params=True,
|
||||||
strict_ddp_flag=True)
|
strict_ddp_flag=True)
|
||||||
config_dict = chunk_manager.dp_degree_chunk_size_dict
|
config_dict = chunk_manager.dp_degree_chunk_size_dict
|
||||||
|
|
|
@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||||
torch_p.data.copy_(p.data)
|
torch_p.data.copy_(p.data)
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||||
chunk_manager = ChunkManager(config_dict)
|
chunk_manager = ChunkManager(config_dict)
|
||||||
|
@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||||
torch_model = model_builder() # get a different model
|
torch_model = model_builder() # get a different model
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ def exam_state_dict(placement_policy, model_name: str):
|
||||||
|
|
||||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||||
|
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
chunk_manager = ChunkManager(config_dict)
|
chunk_manager = ChunkManager(config_dict)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
model = ZeroDDP(model, gemini_manager)
|
model = ZeroDDP(model, gemini_manager)
|
||||||
|
@ -38,6 +38,7 @@ def exam_state_dict(placement_policy, model_name: str):
|
||||||
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||||
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = {}
|
config = {}
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
|
@ -27,7 +27,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
|
||||||
torch_model = model_builder() # get a different model
|
torch_model = model_builder() # get a different model
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue