ColossalAI/colossalai/gemini/chunk/search_utils.py

109 lines
3.4 KiB
Python
Raw Normal View History

import math
from typing import Dict, List, Tuple
import numpy as np
import torch.nn as nn
from colossalai.tensor import ColoParameter
def in_ddp(param: nn.Parameter) -> bool:
return not getattr(param, '_ddp_to_ignore', False)
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
"""
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
params_size_arr = np.array(params_size)
std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
upper_limit = mean + 3 * std
for key in size_dict:
org_list = size_dict[key]
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size.
"""
acc = 0
left = 0
for s in size_list:
if s > left:
acc += left
left = chunk_size
left -= s
return left + acc
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
"""Clasify each parameter by its size of DP group.
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
continue
param_key = param.process_group.dp_world_size()
if param_key not in params_dict:
params_dict[param_key] = []
params_dict[param_key].append(param)
return params_dict
def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
params_dict = clasify_params(model)
config_dict: Dict[int, Dict] = dict()
size_dict: Dict[int, List[int]] = dict()
for key in params_dict:
params_list = params_dict[key]
size_list = [p.numel() for p in params_list]
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
if total_size < min_chunk_size_byte:
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
else:
size_dict[key] = size_list
if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)
max_size = min_chunk_size_byte
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
min_chunk_waste = float('+inf')
best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
if temp_waste < min_chunk_waste:
min_chunk_waste = temp_waste
best_chunk_size = chunk_size
for key in params_dict:
if key in config_dict:
continue
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict, min_chunk_waste