ColossalAI/colossalai/legacy/utils/profiler/stateful_tensor_mem_extenti...

135 lines
3.9 KiB
Python

import os
import threading
import time
from enum import Enum
from typing import List
import torch
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.legacy.engine import Engine
from colossalai.legacy.utils.profiler.extention import ProfilerExtension
class DeviceType(Enum):
CPU = 0
CUDA = 1
def get_timestamp_us():
return int(time.time() * 1e6)
def generic_instant_event(name, pid, tid, timestamp, args):
return {"ph": "i", "s": "t", "name": name, "pid": pid, "tid": tid, "ts": timestamp, "args": args}
class StatefulTensorMemoryEvent:
EVENT_NAME = "[statefulTensorMemory]"
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
self.pid = os.getpid()
self.tid = threading.get_ident()
self.timestamp = timestamp
self.device_type = device_type
self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1
self.bytes = bytes_
def state_dict(self):
return generic_instant_event(
StatefulTensorMemoryEvent.EVENT_NAME,
self.pid,
self.tid,
self.timestamp,
{"Device Type": self.device_type.value, "Device Id": self.device_id, "Bytes": self.bytes},
)
class StatefulTensorMemoryTracer:
def __init__(self) -> None:
self.events: List[StatefulTensorMemoryEvent] = []
self._tracing = False
def sample(self):
cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"]
cpu_mem = StatefulTensor.GST_MGR.total_mem["cpu"]
timestamp = get_timestamp_us()
if self._tracing:
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem))
def start_trace(self):
self.events.clear()
self._tracing = True
def stop_trace(self):
self._tracing = False
def state_dict(self):
return [event.state_dict() for event in self.events]
class StatefulTensorMemoryTracerHook(BaseOpHook):
def __init__(self, tracer: StatefulTensorMemoryTracer):
super().__init__()
self.tracer = tracer
self._enable = False
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if self._enable:
self.tracer.sample()
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._enable:
self.tracer.sample()
def pre_bwd_exec(self, module: torch.nn.Module, input_, output):
if self._enable:
self.tracer.sample()
def post_bwd_exec(self, module: torch.nn.Module, input_):
if self._enable:
self.tracer.sample()
def post_iter(self):
if self._enable:
self.tracer.sample()
def enable(self):
self._enable = True
def disable(self):
self._enable = False
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.tracer = StatefulTensorMemoryTracer()
self.hook = StatefulTensorMemoryTracerHook(self.tracer)
self.hook_registered = False
def prepare_trace(self):
self.hook.enable()
if not self.hook_registered:
self.engine.add_hook(self.hook)
self.hook_registered = True
def start_trace(self):
self.prepare_trace()
self.tracer.start_trace()
def stop_trace(self):
self.tracer.stop_trace()
self.hook.disable()
if self.hook_registered:
self.engine.remove_hook(self.hook)
# remove_hook is not implemented now
# FIXME(ver217): uncomment below line when remove_hook is implemented
# self.hook_registered = False
def extend_chrome_trace(self, trace: dict) -> dict:
trace["traceEvents"].extend(self.tracer.state_dict())
return trace