You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/zero/gemini/memory_tracer/memory_stats.py

127 lines
4.0 KiB

from typing import List, Optional
import torch
from .param_runtime_order import OrderedParamGenerator
class MemStats(object):
def __init__(self) -> None:
"""
Store the non model data statistics used for Gemini and GeminiOptimizer.
"""
# (preop_step, List[param])
self._step_param_dict = dict()
# (param, List[preop_step])
self._param_step_dict = dict()
# (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1)
self._step_nmd_dict = dict()
self._param_runtime_order = OrderedParamGenerator()
self._preop_step = 0
self._prev_overall_cuda = -1
self._max_overall_cuda = 0
self._prev_md_cuda = -1
# old version
self._model_data_cuda_list = []
self._model_data_cpu_list = []
self._overall_cuda_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
def calc_max_cuda_non_model_data(self):
if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1:
max_cuda_non_model_data = self._prev_overall_cuda - self._prev_md_cuda
self._step_nmd_dict[self._preop_step - 1] = max_cuda_non_model_data
# compatibility of the old version.
self._non_model_data_cuda_list.append(max_cuda_non_model_data)
def record_max_cuda_model_data(self, val):
self._prev_md_cuda = val
def record_max_cuda_overall_data(self, val):
self._prev_overall_cuda = val
self._max_overall_cuda = max(self._max_overall_cuda, val)
@property
def max_overall_cuda(self):
return self._max_overall_cuda
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch parameters.
"""
for p in param_list:
if p not in self._param_step_dict:
self._param_step_dict[p] = [self._preop_step]
else:
self._param_step_dict[p].append(self._preop_step)
self._param_runtime_order.append(p)
self._step_param_dict[self._preop_step] = param_list
self._preop_step += 1
def param_used_step(self, param: torch.nn.Parameter) -> Optional[List[int]]:
"""param_used_step
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if param not in self._param_step_dict:
return None
else:
return self._param_step_dict[param]
def param_order(self):
if self._param_runtime_order.is_empty():
raise RuntimeError
else:
return self._param_runtime_order
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == "cuda":
return self._non_model_data_cuda_list
elif device_type == "cpu":
return self._non_model_data_cpu_list
else:
raise TypeError
def max_non_model_data(self, device_type: str) -> float:
if device_type == "cuda":
return max(self._non_model_data_cuda_list)
elif device_type == "cpu":
return max(self._non_model_data_cpu_list)
else:
raise TypeError
def clear(self):
self._model_data_cuda_list = []
self._overall_cuda_list = []
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._param_runtime_order.clear()
self._step_param_dict.clear()
self._param_step_dict.clear()
self._step_nmd_dict.clear()
self._preop_step = 0
self._prev_overall_cuda = -1
self._prev_md_cuda = -1