from typing import Any, Dict, List, Optional import torch from colossalai.gemini.memory_tracer import OrderedParamGenerator class MemStats(object): def __init__(self) -> None: """ Store the non model data statistics used for Gemini and ZeroOptimizer. """ # (preop_step, List[param]) self._step_param_dict = dict() # (param, List[preop_step]) self._param_step_dict = dict() # (preop_step, non_model_data) self._step_nmd_dict = dict() self._param_runtime_order = OrderedParamGenerator() self._preop_step = 0 self._prev_overall_cuda = -1 self._prev_md_cuda = -1 # old version self.param_non_model_data_map: Dict(Any, List[int]) = {} self._model_data_cuda_list = [] self._model_data_cpu_list = [] self._overall_cuda_list = [] self._overall_cpu_list = [] self._non_model_data_cuda_list = [] self._non_model_data_cpu_list = [] def record_max_cuda_non_model_data(self): if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1: self._step_nmd_dict[self._preop_step] = self._prev_overall_cuda - self._prev_md_cuda def record_max_cuda_model_data(self, val): self._prev_md_cuda = val def record_max_cuda_overall_data(self, val): self._prev_overall_cuda = val def param_order(self): if self._param_runtime_order.is_empty(): raise RuntimeError else: return self._param_runtime_order def append_overall_data(self, device_type: str, val: float): if device_type == 'cuda': self._overall_cuda_list.append(val) elif device_type == 'cpu': self._overall_cpu_list.append(val) else: raise TypeError def append_model_data(self, device_type: str, val: float): if device_type == 'cuda': self._model_data_cuda_list.append(val) elif device_type == 'cpu': self._model_data_cpu_list.append(val) else: raise TypeError def last_model_data(self, device_type: str): if len(self._model_data_cuda_list) == 0: return None if device_type == 'cuda': return self._model_data_cuda_list[-1] elif device_type == 'cpu': return self._model_data_cpu_list[-1] else: raise TypeError def append_non_model_data(self, device_type: str, val=None): if device_type == 'cuda': if val is None: if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0: return self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) else: self._non_model_data_cuda_list.append(val) elif device_type == 'cpu': if val is None: if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0: return self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) else: self._non_model_data_cuda_list.append(val) else: raise TypeError def overall_mem_stats(self, device_type: str) -> List[int]: if device_type == 'cuda': return self._overall_cuda_list elif device_type == 'cpu': return self._overall_cpu_list else: raise TypeError def model_data_list(self, device_type: str) -> List[int]: if device_type == 'cuda': return self._model_data_cuda_list elif device_type == 'cpu': return self._model_data_cpu_list else: raise TypeError def non_model_data_list(self, device_type: str) -> List[int]: if device_type == 'cuda': return self._non_model_data_cuda_list elif device_type == 'cpu': return self._non_model_data_cpu_list else: raise TypeError def max_non_model_data(self, device_type: str) -> float: if device_type == 'cuda': return max(self._non_model_data_cuda_list) elif device_type == 'cpu': return max(self._non_model_data_cpu_list) else: raise TypeError def max_overall_cuda(self, device_type: str) -> float: if device_type == 'cuda': return max(self._overall_cuda_list) elif device_type == 'cpu': return max(self._overall_cpu_list) else: raise TypeError def increase_preop_step(self, param_list: List[torch.nn.Parameter]): """ the time step is increased. param list is used between current and the next time step. Args: param_list (List[torch.nn.Parameter]): a list of torch paramters. """ for p in param_list: if p not in self._param_step_dict: self._param_step_dict[p] = [self._preop_step] else: self._param_step_dict[p].append(self._preop_step) self._param_runtime_order.append(p) self._step_param_dict[self._preop_step] = param_list self._preop_step += 1 def param_used_timestep(self, param: torch.nn.Parameter) -> Optional[List[int]]: """param_used_timestep get the timestep list using the param Args: param (torch.nn.Parameter): a torch param Returns: Optional[List[int]]: a list of int indicates the time step of preop hook. """ if param not in self._param_step_dict: return None else: return self._param_step_dict[param] def clear(self): self._model_data_cuda_list = [] self._overall_cuda_list = [] self._model_data_cpu_list = [] self._overall_cpu_list = [] self._non_model_data_cpu_list = [] self._non_model_data_cuda_list = [] self._param_runtime_order.clear() self._step_param_dict.clear() self._param_step_dict.clear() self._step_nmd_dict.clear() self._preop_step = 0 self._prev_overall_cuda = -1 self._prev_md_cuda = -1