ColossalAI/colossalai/utils/profiler/stateful_tensor_mem_extenti...

134 lines
3.9 KiB
Python

import os
import threading
import time
import torch
from enum import Enum
from typing import List
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.engine import Engine
from colossalai.utils.profiler.extention import ProfilerExtension
class DeviceType(Enum):
CPU = 0
CUDA = 1
def get_timestamp_us():
return int(time.time() * 1e6)
def generic_instant_event(name, pid, tid, timestamp, args):
return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args}
class StatefulTensorMemoryEvent:
EVENT_NAME = '[statefulTensorMemory]'
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
self.pid = os.getpid()
self.tid = threading.get_ident()
self.timestamp = timestamp
self.device_type = device_type
self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1
self.bytes = bytes_
def state_dict(self):
return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, {
'Device Type': self.device_type.value,
'Device Id': self.device_id,
'Bytes': self.bytes
})
class StatefulTensorMemoryTracer:
def __init__(self) -> None:
self.events: List[StatefulTensorMemoryEvent] = []
self._tracing = False
def sample(self):
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
timestamp = get_timestamp_us()
if self._tracing:
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem))
def start_trace(self):
self.events.clear()
self._tracing = True
def stop_trace(self):
self._tracing = False
def state_dict(self):
return [event.state_dict() for event in self.events]
class StatefulTensorMemoryTracerHook(BaseOpHook):
def __init__(self, tracer: StatefulTensorMemoryTracer):
super().__init__()
self.tracer = tracer
self._enable = False
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if self._enable:
self.tracer.sample()
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._enable:
self.tracer.sample()
def pre_bwd_exec(self, module: torch.nn.Module, input_, output):
if self._enable:
self.tracer.sample()
def post_bwd_exec(self, module: torch.nn.Module, input_):
if self._enable:
self.tracer.sample()
def post_iter(self):
if self._enable:
self.tracer.sample()
def enable(self):
self._enable = True
def disable(self):
self._enable = False
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.tracer = StatefulTensorMemoryTracer()
self.hook = StatefulTensorMemoryTracerHook(self.tracer)
self.hook_registered = False
def prepare_trace(self):
self.hook.enable()
if not self.hook_registered:
self.engine.add_hook(self.hook)
self.hook_registered = True
def start_trace(self):
self.prepare_trace()
self.tracer.start_trace()
def stop_trace(self):
self.tracer.stop_trace()
self.hook.disable()
if self.hook_registered:
self.engine.remove_hook(self.hook)
# remove_hook is not implemented now
# FIXME(ver217): uncomment below line when remove_hook is implemented
# self.hook_registered = False
def extend_chrome_trace(self, trace: dict) -> dict:
trace['traceEvents'].extend(self.tracer.state_dict())
return trace