Merge pull request #4056 from Fridge003/hotfix/fix_gemini_chunk_config_searching

[gemini] Rename arguments in chunk configuration searching
pull/4046/merge
Baizhou Zhang 2023-06-25 17:37:37 +08:00 committed by GitHub
commit 2c8ae37f61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 62 additions and 64 deletions

View File

@ -181,11 +181,11 @@ class GeminiPlugin(DPPluginBase):
pin_memory (bool, optional): use pin memory on CPU. Defaults to False. pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN. hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching. Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024. If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size, If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk. all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
@ -214,9 +214,9 @@ class GeminiPlugin(DPPluginBase):
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False, strict_ddp_mode: bool = False,
search_range_mb: int = 32, search_range_m: int = 32,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32, min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None, memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0, gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32, initial_scale: float = 2**32,
@ -238,9 +238,9 @@ class GeminiPlugin(DPPluginBase):
pin_memory=pin_memory, pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32, force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode, strict_ddp_mode=strict_ddp_mode,
search_range_mb=search_range_mb, search_range_m=search_range_m,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb, min_chunk_size_m=min_chunk_size_m,
memstats=memstats, memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision], mixed_precision=PRECISION_STR_TO_DTYPE[precision],
) )
@ -295,10 +295,7 @@ class GeminiPlugin(DPPluginBase):
if optimizer is not None and \ if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper): not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
optimizer,
self.zero_optim_config,
self.optim_kwargs,
self.verbose) self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -114,9 +114,9 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
def search_chunk_configuration( def search_chunk_configuration(
model: nn.Module, model: nn.Module,
search_range_mb: float, search_range_m: float,
search_interval_byte: int, # hidden size is the best value for the interval search_interval: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32, min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True, filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False, strict_ddp_flag: bool = False,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
@ -126,9 +126,9 @@ def search_chunk_configuration(
Args: Args:
model (nn.Module): torch module model (nn.Module): torch module
search_range_mb (float): searching range in mega byte. search_range_m (float): searching range divided by 2^20.
search_interval_byte (int): searching interval in byte. search_interval (int): searching interval.
min_chunk_size_mb (float, optional): the minimum size of a distributed chunk. min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20..
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode. all parameters keep replicated in this mode.
@ -145,9 +145,9 @@ def search_chunk_configuration(
for p in model.parameters(): for p in model.parameters():
param_order.append(p) param_order.append(p)
search_range_byte = round(search_range_mb * 1024**2) search_range = round(search_range_m * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range_byte >= 0 assert search_range >= 0
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag) params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
size_lcm = np.lcm.reduce(list(params_dict.keys())) size_lcm = np.lcm.reduce(list(params_dict.keys()))
@ -162,7 +162,7 @@ def search_chunk_configuration(
total_param_size += group_acc_size total_param_size += group_acc_size
# let small parameters keep gathered in CUDA all the time # let small parameters keep gathered in CUDA all the time
if group_acc_size < min_chunk_size_byte: if group_acc_size < min_chunk_size:
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True) config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
else: else:
size_dict[dp_degree] = size_list size_dict[dp_degree] = size_list
@ -170,15 +170,15 @@ def search_chunk_configuration(
if filter_exlarge_params: if filter_exlarge_params:
_filter_exlarge_params(model, size_dict) _filter_exlarge_params(model, size_dict)
max_size = min_chunk_size_byte max_size = min_chunk_size
for key in size_dict: for key in size_dict:
max_size = max(max_size, max(size_dict[key])) max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) start_size = int(math.ceil(max_size / search_interval) * search_interval)
min_chunk_waste = float('+inf') min_chunk_waste = float('+inf')
best_chunk_size = start_size best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
temp_waste = 0 temp_waste = 0
for key in size_dict: for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size) temp_waste += _get_unused_byte(size_dict[key], chunk_size)

View File

@ -23,10 +23,10 @@ def init_chunk_manager(model: nn.Module,
verbose: bool = False, verbose: bool = False,
**kwargs) -> ChunkManager: **kwargs) -> ChunkManager:
if hidden_dim: if hidden_dim:
search_interval_byte = hidden_dim search_interval = hidden_dim
else: else:
search_interval_byte = 1024 # defaults to 1kb search_interval = 1024 # defaults to 1024
kwargs["search_interval_byte"] = search_interval_byte kwargs["search_interval"] = search_interval
dist.barrier() dist.barrier()
begin = time() begin = time()
@ -36,13 +36,13 @@ def init_chunk_manager(model: nn.Module,
dist.barrier() dist.barrier()
end = time() end = time()
span_s = end - begin span_s = end - begin
mb_size = 1024**2 mega_unit = 1024**2
total_size /= mb_size total_size /= mega_unit
wasted_size /= mb_size wasted_size /= mega_unit
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), "used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='', sep='',
flush=True) flush=True)

View File

@ -739,9 +739,9 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False, strict_ddp_mode: bool = False,
scatter_after_inference: bool = True, scatter_after_inference: bool = True,
search_range_mb: int = 32, search_range_m: int = 32,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32, min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None, memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16, mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None: verbose: bool = False) -> None:
@ -763,24 +763,24 @@ class GeminiDDP(ZeroDDP):
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False. pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN. hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching. Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024. If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size, If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk. all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
""" """
# some ugly hotfix for the compatibility with Lightning # some ugly hotfix for the compatibility with Lightning
if search_range_mb is None: if search_range_m is None:
search_range_mb = 32 search_range_m = 32
chunk_manager = init_chunk_manager(model=module, chunk_manager = init_chunk_manager(model=module,
init_device=device, init_device=device,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
search_range_mb=search_range_mb, search_range_m=search_range_m,
min_chunk_size_mb=min_chunk_size_mb, min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode, strict_ddp_flag=strict_ddp_mode,
verbose=verbose) verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)

View File

@ -60,7 +60,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
placement_policy='cpu', placement_policy='cpu',
pin_memory=True, pin_memory=True,
hidden_dim=8192, hidden_dim=8192,
search_range_mb=128) search_range_m=128)
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config) gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config) gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)

View File

@ -75,7 +75,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
device=get_current_device(), device=get_current_device(),
placement_policy='cpu', placement_policy='cpu',
pin_memory=True, pin_memory=True,
search_range_mb=128) search_range_m=128)
post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group) post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)

View File

@ -30,7 +30,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
bert_model.config.save_pretrained(save_directory=pretrained_path) bert_model.config.save_pretrained(save_directory=pretrained_path)
# TODO(ver217): use boost api # TODO(ver217): use boost api
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100)
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager) bert_model = ZeroDDP(bert_model, gemini_manager)

View File

@ -79,7 +79,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg) tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size() dp_world_size = pg.dp_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[dp_world_size]['chunk_size'] = 5000 config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':

View File

@ -52,7 +52,7 @@ def exam_gpt_fwd_bwd(
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
@ -113,7 +113,7 @@ def exam_gpt_inference(
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)

View File

@ -56,7 +56,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
assert len(step_list) == 4 assert len(step_list) == 4
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)

View File

@ -51,7 +51,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':

View File

@ -34,7 +34,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
def multi_chunk_init(model: torch.nn.Module, placement_policy: str): def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
world_size = dist.get_world_size() world_size = dist.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':

View File

@ -73,7 +73,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':
@ -130,7 +130,7 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)

View File

@ -30,9 +30,9 @@ def exam_search_chunk_size():
model = model_builder() model = model_builder()
init_1d_row_spec(model, pg_tp) init_1d_row_spec(model, pg_tp)
config_dict, *_ = search_chunk_configuration(model, config_dict, *_ = search_chunk_configuration(model,
search_range_mb=1, search_range_m=1,
search_interval_byte=16, search_interval=16,
min_chunk_size_mb=0, min_chunk_size_m=0,
filter_exlarge_params=True) filter_exlarge_params=True)
for key in config_dict: for key in config_dict:
@ -54,9 +54,9 @@ def exam_search_strict_ddp():
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
ddp_model = model_builder() ddp_model = model_builder()
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model, re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
search_range_mb=1, search_range_m=1,
search_interval_byte=16, search_interval=16,
min_chunk_size_mb=0, min_chunk_size_m=0,
filter_exlarge_params=True, filter_exlarge_params=True,
strict_ddp_flag=False) strict_ddp_flag=False)
# get the chunk configuration over sharded ddp models # get the chunk configuration over sharded ddp models
@ -64,9 +64,9 @@ def exam_search_strict_ddp():
default_dist_spec=default_shard_spec): default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder() sharded_ddp_model = model_builder()
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model, sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
search_range_mb=1, search_range_m=1,
search_interval_byte=16, search_interval=16,
min_chunk_size_mb=0, min_chunk_size_m=0,
filter_exlarge_params=True, filter_exlarge_params=True,
strict_ddp_flag=True) strict_ddp_flag=True)
assert re_dict == sh_dict assert re_dict == sh_dict
@ -91,8 +91,8 @@ def exam_chunk_manager():
chunk_manager = init_chunk_manager(sharded_ddp_model, chunk_manager = init_chunk_manager(sharded_ddp_model,
get_current_device(), get_current_device(),
hidden_dim=16, hidden_dim=16,
search_range_mb=1, search_range_m=1,
min_chunk_size_mb=0, min_chunk_size_m=0,
filter_exlarge_params=True, filter_exlarge_params=True,
strict_ddp_flag=True) strict_ddp_flag=True)
config_dict = chunk_manager.dp_degree_chunk_size_dict config_dict = chunk_manager.dp_degree_chunk_size_dict

View File

@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
torch_model = model_builder() # get a different model torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered

View File

@ -22,7 +22,7 @@ def exam_state_dict(placement_policy, model_name: str):
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager) model = ZeroDDP(model, gemini_manager)
@ -38,6 +38,7 @@ def exam_state_dict(placement_policy, model_name: str):
assert key in zero_dict, f"{key} not in ZeRO dictionary." assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal." assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

View File

@ -27,7 +27,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
torch_model = model_builder() # get a different model torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered