|
|
|
@ -11,6 +11,9 @@ from colossalai.engine.ophooks import register_ophooks_recursively
|
|
|
|
|
from colossalai.engine.ophooks.zero_hook import ZeroHook |
|
|
|
|
from colossalai.engine.paramhooks import BaseParamHookMgr |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.utils import get_current_device |
|
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER |
|
|
|
|
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity |
|
|
|
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector |
|
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import \ |
|
|
|
|
GLOBAL_MODEL_DATA_TRACER |
|
|
|
@ -131,6 +134,29 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
def cpu_offload(self): |
|
|
|
|
return self._cpu_offload |
|
|
|
|
|
|
|
|
|
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None: |
|
|
|
|
""" |
|
|
|
|
dummy memory tracer collected infomation to a file. |
|
|
|
|
try: |
|
|
|
|
# forward: model(inputs) |
|
|
|
|
# backward: optimizer.backward() |
|
|
|
|
except Exception as e: |
|
|
|
|
model.dump_memory_stats() |
|
|
|
|
exit(0) |
|
|
|
|
""" |
|
|
|
|
if self._use_memory_tracer: |
|
|
|
|
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0]) |
|
|
|
|
if gpc.get_global_rank() == 0: |
|
|
|
|
with open(filename, 'w+') as f: |
|
|
|
|
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n') |
|
|
|
|
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n') |
|
|
|
|
f.write('model data\n') |
|
|
|
|
f.write(str(self._memstats_collector.model_data_cuda_GB)) |
|
|
|
|
f.write('\n') |
|
|
|
|
f.write('non model data\n') |
|
|
|
|
f.write(str(self._memstats_collector.non_model_data_cuda_GB)) |
|
|
|
|
f.write('\n') |
|
|
|
|
|
|
|
|
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
|
|
|
if self._iter_cnter == 0 and self._memstats_collector: |
|
|
|
|
# the opeartion will affect the flag in ZeroHook |
|
|
|
@ -147,6 +173,7 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor, grad): |
|
|
|
|
torch.autograd.backward(tensors=tensor, grad_tensors=grad) |
|
|
|
|
|
|
|
|
|
self._post_backward_operations() |
|
|
|
|
for ophook in self._ophook_list: |
|
|
|
|
ophook.post_iter() |
|
|
|
@ -154,9 +181,6 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
def _update_memstats(self): |
|
|
|
|
if self._iter_cnter == 0 and self._memstats_collector: |
|
|
|
|
self._memstats_collector.finish_collection() |
|
|
|
|
self.logger.debug(f'model data cuda, {self._memstats_collector.model_data_cuda}') |
|
|
|
|
self.logger.debug(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}') |
|
|
|
|
|
|
|
|
|
if self._memstats_collector: |
|
|
|
|
self._memstats_collector.reset_sampling_cnter() |
|
|
|
|
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. |
|
|
|
|