import os import threading import time import torch from enum import Enum from typing import List from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.engine.ophooks import BaseOpHook from colossalai.engine import Engine from colossalai.utils.profiler.extention import ProfilerExtension class DeviceType(Enum): CPU = 0 CUDA = 1 def get_timestamp_us(): return int(time.time() * 1e6) def generic_instant_event(name, pid, tid, timestamp, args): return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args} class StatefulTensorMemoryEvent: EVENT_NAME = '[statefulTensorMemory]' def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None: self.pid = os.getpid() self.tid = threading.get_ident() self.timestamp = timestamp self.device_type = device_type self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1 self.bytes = bytes_ def state_dict(self): return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, { 'Device Type': self.device_type.value, 'Device Id': self.device_id, 'Bytes': self.bytes }) class StatefulTensorMemoryTracer: def __init__(self) -> None: self.events: List[StatefulTensorMemoryEvent] = [] self._tracing = False def sample(self): cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] timestamp = get_timestamp_us() if self._tracing: self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem)) self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem)) def start_trace(self): self.events.clear() self._tracing = True def stop_trace(self): self._tracing = False def state_dict(self): return [event.state_dict() for event in self.events] class StatefulTensorMemoryTracerHook(BaseOpHook): def __init__(self, tracer: StatefulTensorMemoryTracer): super().__init__() self.tracer = tracer self._enable = False def pre_fwd_exec(self, module: torch.nn.Module, *args): if self._enable: self.tracer.sample() def post_fwd_exec(self, module: torch.nn.Module, *args): if self._enable: self.tracer.sample() def pre_bwd_exec(self, module: torch.nn.Module, input_, output): if self._enable: self.tracer.sample() def post_bwd_exec(self, module: torch.nn.Module, input_): if self._enable: self.tracer.sample() def post_iter(self): if self._enable: self.tracer.sample() def enable(self): self._enable = True def disable(self): self._enable = False class StatefulTensorMemoryProfilerExtention(ProfilerExtension): def __init__(self, engine: Engine) -> None: self.engine = engine self.tracer = StatefulTensorMemoryTracer() self.hook = StatefulTensorMemoryTracerHook(self.tracer) self.hook_registered = False def prepare_trace(self): self.hook.enable() if not self.hook_registered: self.engine.add_hook(self.hook) self.hook_registered = True def start_trace(self): self.prepare_trace() self.tracer.start_trace() def stop_trace(self): self.tracer.stop_trace() self.hook.disable() if self.hook_registered: self.engine.remove_hook(self.hook) # remove_hook is not implemented now # FIXME(ver217): uncomment below line when remove_hook is implemented # self.hook_registered = False def extend_chrome_trace(self, trace: dict) -> dict: trace['traceEvents'].extend(self.tracer.state_dict()) return trace