2022-10-18 08:31:22 +00:00
|
|
|
import math
|
2022-12-12 10:06:16 +00:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
2022-10-18 08:31:22 +00:00
|
|
|
|
|
|
|
import numpy as np
|
2023-01-28 06:35:25 +00:00
|
|
|
import torch.distributed as dist
|
2022-10-18 08:31:22 +00:00
|
|
|
import torch.nn as nn
|
2023-08-24 01:29:25 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
2022-10-18 08:31:22 +00:00
|
|
|
|
|
|
|
from colossalai.tensor import ColoParameter
|
2023-01-11 04:22:45 +00:00
|
|
|
from colossalai.utils import is_ddp_ignored
|
2023-04-04 05:48:16 +00:00
|
|
|
from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
2022-10-18 08:31:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
2023-04-17 04:44:17 +00:00
|
|
|
"""_filter_exlarge_params
|
|
|
|
|
2022-12-09 07:00:39 +00:00
|
|
|
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
|
2023-04-17 04:44:17 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): the model.
|
|
|
|
size_dict (Dict[int, List[int]]): the size dict of parameters.
|
2022-10-18 08:31:22 +00:00
|
|
|
"""
|
2023-01-28 06:35:25 +00:00
|
|
|
agg_size_list = []
|
|
|
|
for key in size_dict:
|
|
|
|
agg_size_list.extend(size_dict[key])
|
|
|
|
|
|
|
|
if len(agg_size_list) == 0:
|
|
|
|
return
|
|
|
|
|
|
|
|
params_size_arr = np.array(agg_size_list)
|
2022-10-18 08:31:22 +00:00
|
|
|
|
|
|
|
std = np.std(params_size_arr)
|
|
|
|
mean = np.mean(params_size_arr)
|
|
|
|
upper_limit = mean + 3 * std
|
|
|
|
|
|
|
|
for key in size_dict:
|
|
|
|
org_list = size_dict[key]
|
|
|
|
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
|
|
|
|
|
|
|
|
|
|
|
|
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
2023-04-17 04:44:17 +00:00
|
|
|
"""_get_unused_byte
|
|
|
|
|
|
|
|
Get unused byte for a certain chunk size.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
size_list (List[int]): the size list of parameters.
|
|
|
|
chunk_size (int): the chunk size.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: the unused byte.
|
2022-10-18 08:31:22 +00:00
|
|
|
"""
|
|
|
|
acc = 0
|
|
|
|
left = 0
|
|
|
|
for s in size_list:
|
|
|
|
if s > left:
|
|
|
|
acc += left
|
|
|
|
left = chunk_size
|
|
|
|
left -= s
|
|
|
|
return left + acc
|
|
|
|
|
|
|
|
|
2023-08-24 01:29:25 +00:00
|
|
|
def _tensor_numel(local_param: ColoParameter) -> int:
|
2023-04-17 04:44:17 +00:00
|
|
|
"""_tensor_numel
|
|
|
|
|
|
|
|
Get the number of elements of a tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
local_param (ColoParameter): The local parameter.
|
|
|
|
strict_ddp_flag (bool): whether to enable the strict ddp mode.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: the number of elements.
|
|
|
|
"""
|
2023-08-24 01:29:25 +00:00
|
|
|
# TODO(ver217): support dtensor here
|
|
|
|
return local_param.numel()
|
2023-01-28 06:35:25 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def classify_params_by_dp_degree(
|
|
|
|
param_order: OrderedParamGenerator, process_group: ProcessGroup
|
|
|
|
) -> Dict[int, List[ColoParameter]]:
|
2022-12-09 07:00:39 +00:00
|
|
|
"""classify_params_by_dp_degree
|
|
|
|
|
|
|
|
Classify the parameters by their dp degree
|
|
|
|
|
|
|
|
Args:
|
2023-05-23 07:28:20 +00:00
|
|
|
param_order (OrderedParamGenerator): the order of param be vised
|
2023-04-17 04:44:17 +00:00
|
|
|
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
|
2022-12-09 07:00:39 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[int, List[ColoParameter]]: a dict contains the classification results.
|
|
|
|
The keys are dp_degrees and the values are parameters.
|
2022-10-18 08:31:22 +00:00
|
|
|
"""
|
|
|
|
params_dict: Dict[int, List[ColoParameter]] = dict()
|
2022-12-11 13:41:13 +00:00
|
|
|
for param in param_order.generate():
|
2023-04-12 08:03:25 +00:00
|
|
|
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
2023-01-11 04:22:45 +00:00
|
|
|
if is_ddp_ignored(param):
|
2022-10-18 08:31:22 +00:00
|
|
|
continue
|
2023-08-24 01:29:25 +00:00
|
|
|
param_key = dist.get_world_size(process_group)
|
2022-10-18 08:31:22 +00:00
|
|
|
|
|
|
|
if param_key not in params_dict:
|
|
|
|
params_dict[param_key] = []
|
|
|
|
params_dict[param_key].append(param)
|
|
|
|
|
|
|
|
return params_dict
|
|
|
|
|
|
|
|
|
|
|
|
def search_chunk_configuration(
|
2023-09-19 06:20:26 +00:00
|
|
|
model: nn.Module,
|
|
|
|
search_range_m: float,
|
|
|
|
search_interval: int, # hidden size is the best value for the interval
|
|
|
|
min_chunk_size_m: float = 32,
|
|
|
|
filter_exlarge_params: bool = True,
|
|
|
|
strict_ddp_flag: bool = False,
|
|
|
|
process_group: Optional[ProcessGroup] = None,
|
|
|
|
memstas: Optional[MemStats] = None,
|
|
|
|
) -> Tuple[Dict, int, int]:
|
2022-12-09 07:00:39 +00:00
|
|
|
"""search_chunk_configuration
|
|
|
|
|
2023-04-17 04:44:17 +00:00
|
|
|
Search the chunk configuration for a model.
|
|
|
|
|
2022-12-09 07:00:39 +00:00
|
|
|
Args:
|
|
|
|
model (nn.Module): torch module
|
2023-06-25 05:34:15 +00:00
|
|
|
search_range_m (float): searching range divided by 2^20.
|
|
|
|
search_interval (int): searching interval.
|
|
|
|
min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20..
|
2022-12-09 07:00:39 +00:00
|
|
|
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
|
2023-01-28 06:35:25 +00:00
|
|
|
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
|
|
|
|
all parameters keep replicated in this mode.
|
2022-12-09 07:00:39 +00:00
|
|
|
|
|
|
|
Returns:
|
2022-12-12 08:57:22 +00:00
|
|
|
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
|
2022-12-09 07:00:39 +00:00
|
|
|
"""
|
|
|
|
|
2022-12-12 10:06:16 +00:00
|
|
|
if memstas is not None:
|
|
|
|
param_order = memstas.param_order()
|
|
|
|
else:
|
|
|
|
# build the param visited order right now
|
|
|
|
param_order = OrderedParamGenerator()
|
|
|
|
for p in model.parameters():
|
|
|
|
param_order.append(p)
|
2022-12-11 13:41:13 +00:00
|
|
|
|
2023-06-25 05:34:15 +00:00
|
|
|
search_range = round(search_range_m * 1024**2)
|
|
|
|
min_chunk_size = round(min_chunk_size_m * 1024**2)
|
|
|
|
assert search_range >= 0
|
2022-10-18 08:31:22 +00:00
|
|
|
|
2023-08-24 01:29:25 +00:00
|
|
|
params_dict = classify_params_by_dp_degree(param_order, process_group)
|
2023-02-22 07:04:46 +00:00
|
|
|
size_lcm = np.lcm.reduce(list(params_dict.keys()))
|
2022-10-18 08:31:22 +00:00
|
|
|
config_dict: Dict[int, Dict] = dict()
|
2023-01-28 06:35:25 +00:00
|
|
|
total_param_size = 0
|
2022-10-18 08:31:22 +00:00
|
|
|
|
|
|
|
size_dict: Dict[int, List[int]] = dict()
|
2022-12-09 07:00:39 +00:00
|
|
|
for dp_degree in params_dict:
|
|
|
|
params_list = params_dict[dp_degree]
|
2023-08-24 01:29:25 +00:00
|
|
|
size_list = [_tensor_numel(p) for p in params_list]
|
2023-01-28 06:35:25 +00:00
|
|
|
group_acc_size = sum(size_list)
|
|
|
|
total_param_size += group_acc_size
|
|
|
|
|
2022-10-18 08:31:22 +00:00
|
|
|
# let small parameters keep gathered in CUDA all the time
|
2023-06-25 05:34:15 +00:00
|
|
|
if group_acc_size < min_chunk_size:
|
2023-01-28 06:35:25 +00:00
|
|
|
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
|
2022-10-18 08:31:22 +00:00
|
|
|
else:
|
2022-12-09 07:00:39 +00:00
|
|
|
size_dict[dp_degree] = size_list
|
2022-10-18 08:31:22 +00:00
|
|
|
|
|
|
|
if filter_exlarge_params:
|
|
|
|
_filter_exlarge_params(model, size_dict)
|
|
|
|
|
2023-06-25 05:34:15 +00:00
|
|
|
max_size = min_chunk_size
|
2022-10-18 08:31:22 +00:00
|
|
|
for key in size_dict:
|
|
|
|
max_size = max(max_size, max(size_dict[key]))
|
2023-06-25 05:34:15 +00:00
|
|
|
start_size = int(math.ceil(max_size / search_interval) * search_interval)
|
2022-10-18 08:31:22 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
min_chunk_waste = float("+inf")
|
2022-10-18 08:31:22 +00:00
|
|
|
best_chunk_size = start_size
|
|
|
|
|
2023-06-25 05:34:15 +00:00
|
|
|
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
|
2022-10-18 08:31:22 +00:00
|
|
|
temp_waste = 0
|
|
|
|
for key in size_dict:
|
|
|
|
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
|
|
|
if temp_waste < min_chunk_waste:
|
|
|
|
min_chunk_waste = temp_waste
|
|
|
|
best_chunk_size = chunk_size
|
|
|
|
|
2023-02-22 07:04:46 +00:00
|
|
|
# the chunk size needs to be divided by each groups sizes
|
|
|
|
best_chunk_size = best_chunk_size + (-best_chunk_size % size_lcm)
|
2022-12-09 07:00:39 +00:00
|
|
|
for dp_degree in params_dict:
|
|
|
|
if dp_degree in config_dict:
|
2022-10-18 08:31:22 +00:00
|
|
|
continue
|
2022-12-09 07:00:39 +00:00
|
|
|
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
|
2022-10-18 08:31:22 +00:00
|
|
|
|
2023-01-28 06:35:25 +00:00
|
|
|
return config_dict, total_param_size, min_chunk_waste
|