import math from typing import Dict, List, Tuple import numpy as np import torch.nn as nn from colossalai.tensor import ColoParameter def in_ddp(param: nn.Parameter) -> bool: return not getattr(param, '_ddp_to_ignore', False) def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: """Filter those parameters whose size is too large from others. """ params_size = [p.numel() for p in model.parameters() if in_ddp(p)] params_size_arr = np.array(params_size) std = np.std(params_size_arr) mean = np.mean(params_size_arr) upper_limit = mean + 3 * std for key in size_dict: org_list = size_dict[key] size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: """Get unused byte for a certain chunk size. """ acc = 0 left = 0 for s in size_list: if s > left: acc += left left = chunk_size left -= s return left + acc def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: """Clasify each parameter by its size of DP group. """ params_dict: Dict[int, List[ColoParameter]] = dict() for param in model.parameters(): assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" if not in_ddp(param): continue param_key = param.process_group.dp_world_size() if param_key not in params_dict: params_dict[param_key] = [] params_dict[param_key].append(param) return params_dict def search_chunk_configuration( model: nn.Module, search_range_mb: float, search_interval_byte: int, # hidden size is the best value for the interval min_chunk_size_mb: float = 32, filter_exlarge_params: bool = True) -> Tuple[Dict, int]: search_range_byte = round(search_range_mb * 1024**2) min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) assert search_range_byte >= 0 params_dict = clasify_params(model) config_dict: Dict[int, Dict] = dict() size_dict: Dict[int, List[int]] = dict() for key in params_dict: params_list = params_dict[key] size_list = [p.numel() for p in params_list] # let small parameters keep gathered in CUDA all the time total_size = sum(size_list) if total_size < min_chunk_size_byte: config_dict[key] = dict(chunk_size=total_size, keep_gathered=True) else: size_dict[key] = size_list if filter_exlarge_params: _filter_exlarge_params(model, size_dict) max_size = min_chunk_size_byte for key in size_dict: max_size = max(max_size, max(size_dict[key])) start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) min_chunk_waste = float('+inf') best_chunk_size = start_size for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): temp_waste = 0 for key in size_dict: temp_waste += _get_unused_byte(size_dict[key], chunk_size) if temp_waste < min_chunk_waste: min_chunk_waste = temp_waste best_chunk_size = chunk_size for key in params_dict: if key in config_dict: continue config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False) return config_dict, min_chunk_waste