ColossalAI/colossalai/zero/gemini/chunk/search_utils.py

192 lines
6.1 KiB
Python

import math
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""_filter_exlarge_params
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
Args:
model (nn.Module): the model.
size_dict (Dict[int, List[int]]): the size dict of parameters.
"""
agg_size_list = []
for key in size_dict:
agg_size_list.extend(size_dict[key])
if len(agg_size_list) == 0:
return
params_size_arr = np.array(agg_size_list)
std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
upper_limit = mean + 3 * std
for key in size_dict:
org_list = size_dict[key]
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""_get_unused_byte
Get unused byte for a certain chunk size.
Args:
size_list (List[int]): the size list of parameters.
chunk_size (int): the chunk size.
Returns:
int: the unused byte.
"""
acc = 0
left = 0
for s in size_list:
if s > left:
acc += left
left = chunk_size
left -= s
return left + acc
def _tensor_numel(local_param: ColoParameter) -> int:
"""_tensor_numel
Get the number of elements of a tensor.
Args:
local_param (ColoParameter): The local parameter.
strict_ddp_flag (bool): whether to enable the strict ddp mode.
Returns:
int: the number of elements.
"""
# TODO(ver217): support dtensor here
return local_param.numel()
def classify_params_by_dp_degree(
param_order: OrderedParamGenerator, process_group: ProcessGroup
) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
Args:
param_order (OrderedParamGenerator): the order of param be vised
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
The keys are dp_degrees and the values are parameters.
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate():
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param):
continue
param_key = dist.get_world_size(process_group)
if param_key not in params_dict:
params_dict[param_key] = []
params_dict[param_key].append(param)
return params_dict
def search_chunk_configuration(
model: nn.Module,
search_range_m: float,
search_interval: int, # hidden size is the best value for the interval
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
process_group: Optional[ProcessGroup] = None,
memstas: Optional[MemStats] = None,
) -> Tuple[Dict, int, int]:
"""search_chunk_configuration
Search the chunk configuration for a model.
Args:
model (nn.Module): torch module
search_range_m (float): searching range divided by 2^20.
search_interval (int): searching interval.
min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20..
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode.
Returns:
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
"""
if memstas is not None:
param_order = memstas.param_order()
else:
# build the param visited order right now
param_order = OrderedParamGenerator()
for p in model.parameters():
param_order.append(p)
search_range = round(search_range_m * 1024**2)
min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0
params_dict = classify_params_by_dp_degree(param_order, process_group)
size_lcm = np.lcm.reduce(list(params_dict.keys()))
config_dict: Dict[int, Dict] = dict()
total_param_size = 0
size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict:
params_list = params_dict[dp_degree]
size_list = [_tensor_numel(p) for p in params_list]
group_acc_size = sum(size_list)
total_param_size += group_acc_size
# let small parameters keep gathered in CUDA all the time
if group_acc_size < min_chunk_size:
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
else:
size_dict[dp_degree] = size_list
if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)
max_size = min_chunk_size
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval) * search_interval)
min_chunk_waste = float("+inf")
best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
if temp_waste < min_chunk_waste:
min_chunk_waste = temp_waste
best_chunk_size = chunk_size
# the chunk size needs to be divided by each groups sizes
best_chunk_size = best_chunk_size + (-best_chunk_size % size_lcm)
for dp_degree in params_dict:
if dp_degree in config_dict:
continue
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict, total_param_size, min_chunk_waste