|
|
|
@ -2,6 +2,7 @@ import math
|
|
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
|
|
|
@ -13,8 +14,14 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) ->
|
|
|
|
|
"""
|
|
|
|
|
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
|
|
|
|
|
"""
|
|
|
|
|
params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
|
|
|
|
|
params_size_arr = np.array(params_size)
|
|
|
|
|
agg_size_list = []
|
|
|
|
|
for key in size_dict:
|
|
|
|
|
agg_size_list.extend(size_dict[key])
|
|
|
|
|
|
|
|
|
|
if len(agg_size_list) == 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
params_size_arr = np.array(agg_size_list)
|
|
|
|
|
|
|
|
|
|
std = np.std(params_size_arr)
|
|
|
|
|
mean = np.mean(params_size_arr)
|
|
|
|
@ -38,7 +45,15 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
|
|
|
|
return left + acc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]:
|
|
|
|
|
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
|
|
|
|
|
if strict_ddp_flag:
|
|
|
|
|
return local_param.numel_global()
|
|
|
|
|
else:
|
|
|
|
|
return local_param.numel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
|
|
|
|
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
|
|
|
|
|
"""classify_params_by_dp_degree
|
|
|
|
|
|
|
|
|
|
Classify the parameters by their dp degree
|
|
|
|
@ -56,7 +71,10 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int
|
|
|
|
|
if is_ddp_ignored(param):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
param_key = param.process_group.dp_world_size()
|
|
|
|
|
if strict_ddp_flag:
|
|
|
|
|
param_key = dist.get_world_size()
|
|
|
|
|
else:
|
|
|
|
|
param_key = param.process_group.dp_world_size()
|
|
|
|
|
|
|
|
|
|
if param_key not in params_dict:
|
|
|
|
|
params_dict[param_key] = []
|
|
|
|
@ -71,14 +89,18 @@ def search_chunk_configuration(
|
|
|
|
|
search_interval_byte: int, # hidden size is the best value for the interval
|
|
|
|
|
min_chunk_size_mb: float = 32,
|
|
|
|
|
filter_exlarge_params: bool = True,
|
|
|
|
|
memstas: Optional[MemStats] = None) -> Tuple[Dict, int]:
|
|
|
|
|
strict_ddp_flag: bool = False,
|
|
|
|
|
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
|
|
|
|
"""search_chunk_configuration
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model (nn.Module): torch module
|
|
|
|
|
search_range_mb (float): searching range in mega byte.
|
|
|
|
|
search_interval_byte (int): searching interval in byte.
|
|
|
|
|
min_chunk_size_mb (float, optional): the minimum size of a distributed chunk.
|
|
|
|
|
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
|
|
|
|
|
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
|
|
|
|
|
all parameters keep replicated in this mode.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
|
|
|
|
@ -96,17 +118,20 @@ def search_chunk_configuration(
|
|
|
|
|
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
|
|
|
|
assert search_range_byte >= 0
|
|
|
|
|
|
|
|
|
|
params_dict = classify_params_by_dp_degree(param_order)
|
|
|
|
|
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
|
|
|
|
|
config_dict: Dict[int, Dict] = dict()
|
|
|
|
|
total_param_size = 0
|
|
|
|
|
|
|
|
|
|
size_dict: Dict[int, List[int]] = dict()
|
|
|
|
|
for dp_degree in params_dict:
|
|
|
|
|
params_list = params_dict[dp_degree]
|
|
|
|
|
size_list = [p.numel() for p in params_list]
|
|
|
|
|
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
|
|
|
|
|
group_acc_size = sum(size_list)
|
|
|
|
|
total_param_size += group_acc_size
|
|
|
|
|
|
|
|
|
|
# let small parameters keep gathered in CUDA all the time
|
|
|
|
|
total_size = sum(size_list)
|
|
|
|
|
if total_size < min_chunk_size_byte:
|
|
|
|
|
config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True)
|
|
|
|
|
if group_acc_size < min_chunk_size_byte:
|
|
|
|
|
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
|
|
|
|
|
else:
|
|
|
|
|
size_dict[dp_degree] = size_list
|
|
|
|
|
|
|
|
|
@ -134,4 +159,4 @@ def search_chunk_configuration(
|
|
|
|
|
continue
|
|
|
|
|
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
|
|
|
|
|
|
|
|
|
|
return config_dict, min_chunk_waste
|
|
|
|
|
return config_dict, total_param_size, min_chunk_waste
|
|
|
|
|