|
|
|
@ -11,6 +11,9 @@ from colossalai.engine.ophooks import register_ophooks_recursively
|
|
|
|
|
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
|
|
|
|
from colossalai.engine.paramhooks import BaseParamHookMgr
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
|
|
|
|
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
|
|
|
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
|
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
|
|
|
|
GLOBAL_MODEL_DATA_TRACER
|
|
|
|
@ -131,6 +134,29 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
def cpu_offload(self):
|
|
|
|
|
return self._cpu_offload
|
|
|
|
|
|
|
|
|
|
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
|
|
|
|
|
"""
|
|
|
|
|
dummy memory tracer collected infomation to a file.
|
|
|
|
|
try:
|
|
|
|
|
# forward: model(inputs)
|
|
|
|
|
# backward: optimizer.backward()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
model.dump_memory_stats()
|
|
|
|
|
exit(0)
|
|
|
|
|
"""
|
|
|
|
|
if self._use_memory_tracer:
|
|
|
|
|
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
|
|
|
|
|
if gpc.get_global_rank() == 0:
|
|
|
|
|
with open(filename, 'w+') as f:
|
|
|
|
|
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
|
|
|
|
|
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
|
|
|
|
|
f.write('model data\n')
|
|
|
|
|
f.write(str(self._memstats_collector.model_data_cuda_GB))
|
|
|
|
|
f.write('\n')
|
|
|
|
|
f.write('non model data\n')
|
|
|
|
|
f.write(str(self._memstats_collector.non_model_data_cuda_GB))
|
|
|
|
|
f.write('\n')
|
|
|
|
|
|
|
|
|
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
|
|
|
if self._iter_cnter == 0 and self._memstats_collector:
|
|
|
|
|
# the opeartion will affect the flag in ZeroHook
|
|
|
|
@ -147,6 +173,7 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor, grad):
|
|
|
|
|
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
|
|
|
|
|
|
|
|
|
self._post_backward_operations()
|
|
|
|
|
for ophook in self._ophook_list:
|
|
|
|
|
ophook.post_iter()
|
|
|
|
@ -154,9 +181,6 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
def _update_memstats(self):
|
|
|
|
|
if self._iter_cnter == 0 and self._memstats_collector:
|
|
|
|
|
self._memstats_collector.finish_collection()
|
|
|
|
|
self.logger.debug(f'model data cuda, {self._memstats_collector.model_data_cuda}')
|
|
|
|
|
self.logger.debug(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}')
|
|
|
|
|
|
|
|
|
|
if self._memstats_collector:
|
|
|
|
|
self._memstats_collector.reset_sampling_cnter()
|
|
|
|
|
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
|
|
|
|