ColossalAI/colossalai/legacy/utils/profiler/stateful_tensor_mem_extenti...

135 lines
3.9 KiB
Python
Raw Normal View History

import os
import threading
import time
from enum import Enum
from typing import List
import torch
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.legacy.engine import Engine
from colossalai.legacy.utils.profiler.extention import ProfilerExtension
class DeviceType(Enum):
CPU = 0
CUDA = 1
def get_timestamp_us():
return int(time.time() * 1e6)
def generic_instant_event(name, pid, tid, timestamp, args):
return {"ph": "i", "s": "t", "name": name, "pid": pid, "tid": tid, "ts": timestamp, "args": args}
class StatefulTensorMemoryEvent:
EVENT_NAME = "[statefulTensorMemory]"
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
self.pid = os.getpid()
self.tid = threading.get_ident()
self.timestamp = timestamp
self.device_type = device_type
self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1
self.bytes = bytes_
def state_dict(self):
return generic_instant_event(
StatefulTensorMemoryEvent.EVENT_NAME,
self.pid,
self.tid,
self.timestamp,
{"Device Type": self.device_type.value, "Device Id": self.device_id, "Bytes": self.bytes},
)
class StatefulTensorMemoryTracer:
def __init__(self) -> None:
self.events: List[StatefulTensorMemoryEvent] = []
self._tracing = False
def sample(self):
cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"]
cpu_mem = StatefulTensor.GST_MGR.total_mem["cpu"]
timestamp = get_timestamp_us()
if self._tracing:
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem))
def start_trace(self):
self.events.clear()
self._tracing = True
def stop_trace(self):
self._tracing = False
def state_dict(self):
return [event.state_dict() for event in self.events]
class StatefulTensorMemoryTracerHook(BaseOpHook):
def __init__(self, tracer: StatefulTensorMemoryTracer):
super().__init__()
self.tracer = tracer
self._enable = False
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if self._enable:
self.tracer.sample()
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._enable:
self.tracer.sample()
def pre_bwd_exec(self, module: torch.nn.Module, input_, output):
if self._enable:
self.tracer.sample()
def post_bwd_exec(self, module: torch.nn.Module, input_):
if self._enable:
self.tracer.sample()
def post_iter(self):
if self._enable:
self.tracer.sample()
def enable(self):
self._enable = True
def disable(self):
self._enable = False
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.tracer = StatefulTensorMemoryTracer()
self.hook = StatefulTensorMemoryTracerHook(self.tracer)
self.hook_registered = False
def prepare_trace(self):
self.hook.enable()
if not self.hook_registered:
self.engine.add_hook(self.hook)
self.hook_registered = True
def start_trace(self):
self.prepare_trace()
self.tracer.start_trace()
def stop_trace(self):
self.tracer.stop_trace()
self.hook.disable()
if self.hook_registered:
self.engine.remove_hook(self.hook)
# remove_hook is not implemented now
# FIXME(ver217): uncomment below line when remove_hook is implemented
# self.hook_registered = False
def extend_chrome_trace(self, trace: dict) -> dict:
trace["traceEvents"].extend(self.tracer.state_dict())
return trace