import math from typing import Dict, List, Optional, Tuple import numpy as np import torch.nn as nn from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator from colossalai.tensor import ColoParameter def in_ddp(param: nn.Parameter) -> bool: return not getattr(param, '_ddp_to_ignore', False) def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: """ Filter those parameters whose size is too large (more than 3x standard deviations) from others. """ params_size = [p.numel() for p in model.parameters() if in_ddp(p)] params_size_arr = np.array(params_size) std = np.std(params_size_arr) mean = np.mean(params_size_arr) upper_limit = mean + 3 * std for key in size_dict: org_list = size_dict[key] size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: """Get unused byte for a certain chunk size. """ acc = 0 left = 0 for s in size_list: if s > left: acc += left left = chunk_size left -= s return left + acc def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]: """classify_params_by_dp_degree Classify the parameters by their dp degree Args: param_order (OrderedParamGenerator): the order of param be visied Returns: Dict[int, List[ColoParameter]]: a dict contains the classification results. The keys are dp_degrees and the values are parameters. """ params_dict: Dict[int, List[ColoParameter]] = dict() for param in param_order.generate(): assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" if not in_ddp(param): continue param_key = param.process_group.dp_world_size() if param_key not in params_dict: params_dict[param_key] = [] params_dict[param_key].append(param) return params_dict def search_chunk_configuration( model: nn.Module, search_range_mb: float, search_interval_byte: int, # hidden size is the best value for the interval min_chunk_size_mb: float = 32, filter_exlarge_params: bool = True, memstas: Optional[MemStats] = None) -> Tuple[Dict, int]: """search_chunk_configuration Args: model (nn.Module): torch module search_range_mb (float): searching range in mega byte. search_interval_byte (int): searching interval in byte. filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. Returns: Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. """ if memstas is not None: param_order = memstas.param_order() else: # build the param visited order right now param_order = OrderedParamGenerator() for p in model.parameters(): param_order.append(p) search_range_byte = round(search_range_mb * 1024**2) min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) assert search_range_byte >= 0 params_dict = classify_params_by_dp_degree(param_order) config_dict: Dict[int, Dict] = dict() size_dict: Dict[int, List[int]] = dict() for dp_degree in params_dict: params_list = params_dict[dp_degree] size_list = [p.numel() for p in params_list] # let small parameters keep gathered in CUDA all the time total_size = sum(size_list) if total_size < min_chunk_size_byte: config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True) else: size_dict[dp_degree] = size_list if filter_exlarge_params: _filter_exlarge_params(model, size_dict) max_size = min_chunk_size_byte for key in size_dict: max_size = max(max_size, max(size_dict[key])) start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) min_chunk_waste = float('+inf') best_chunk_size = start_size for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): temp_waste = 0 for key in size_dict: temp_waste += _get_unused_byte(size_dict[key], chunk_size) if temp_waste < min_chunk_waste: min_chunk_waste = temp_waste best_chunk_size = chunk_size for dp_degree in params_dict: if dp_degree in config_dict: continue config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False) return config_dict, min_chunk_waste