mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
3.9 KiB
134 lines
3.9 KiB
import os
|
|
import threading
|
|
import time
|
|
import torch
|
|
from enum import Enum
|
|
from typing import List
|
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
|
from colossalai.gemini.ophooks import BaseOpHook
|
|
from colossalai.engine import Engine
|
|
from colossalai.utils.profiler.extention import ProfilerExtension
|
|
|
|
|
|
class DeviceType(Enum):
|
|
CPU = 0
|
|
CUDA = 1
|
|
|
|
|
|
def get_timestamp_us():
|
|
return int(time.time() * 1e6)
|
|
|
|
|
|
def generic_instant_event(name, pid, tid, timestamp, args):
|
|
return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args}
|
|
|
|
|
|
class StatefulTensorMemoryEvent:
|
|
EVENT_NAME = '[statefulTensorMemory]'
|
|
|
|
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
|
|
self.pid = os.getpid()
|
|
self.tid = threading.get_ident()
|
|
self.timestamp = timestamp
|
|
self.device_type = device_type
|
|
self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1
|
|
self.bytes = bytes_
|
|
|
|
def state_dict(self):
|
|
return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, {
|
|
'Device Type': self.device_type.value,
|
|
'Device Id': self.device_id,
|
|
'Bytes': self.bytes
|
|
})
|
|
|
|
|
|
class StatefulTensorMemoryTracer:
|
|
|
|
def __init__(self) -> None:
|
|
self.events: List[StatefulTensorMemoryEvent] = []
|
|
self._tracing = False
|
|
|
|
def sample(self):
|
|
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
|
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
|
timestamp = get_timestamp_us()
|
|
if self._tracing:
|
|
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
|
|
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem))
|
|
|
|
def start_trace(self):
|
|
self.events.clear()
|
|
self._tracing = True
|
|
|
|
def stop_trace(self):
|
|
self._tracing = False
|
|
|
|
def state_dict(self):
|
|
return [event.state_dict() for event in self.events]
|
|
|
|
|
|
class StatefulTensorMemoryTracerHook(BaseOpHook):
|
|
|
|
def __init__(self, tracer: StatefulTensorMemoryTracer):
|
|
super().__init__()
|
|
self.tracer = tracer
|
|
self._enable = False
|
|
|
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
|
if self._enable:
|
|
self.tracer.sample()
|
|
|
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
|
if self._enable:
|
|
self.tracer.sample()
|
|
|
|
def pre_bwd_exec(self, module: torch.nn.Module, input_, output):
|
|
if self._enable:
|
|
self.tracer.sample()
|
|
|
|
def post_bwd_exec(self, module: torch.nn.Module, input_):
|
|
if self._enable:
|
|
self.tracer.sample()
|
|
|
|
def post_iter(self):
|
|
if self._enable:
|
|
self.tracer.sample()
|
|
|
|
def enable(self):
|
|
self._enable = True
|
|
|
|
def disable(self):
|
|
self._enable = False
|
|
|
|
|
|
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
|
|
|
|
def __init__(self, engine: Engine) -> None:
|
|
self.engine = engine
|
|
self.tracer = StatefulTensorMemoryTracer()
|
|
self.hook = StatefulTensorMemoryTracerHook(self.tracer)
|
|
self.hook_registered = False
|
|
|
|
def prepare_trace(self):
|
|
self.hook.enable()
|
|
if not self.hook_registered:
|
|
self.engine.add_hook(self.hook)
|
|
self.hook_registered = True
|
|
|
|
def start_trace(self):
|
|
self.prepare_trace()
|
|
self.tracer.start_trace()
|
|
|
|
def stop_trace(self):
|
|
self.tracer.stop_trace()
|
|
self.hook.disable()
|
|
if self.hook_registered:
|
|
self.engine.remove_hook(self.hook)
|
|
# remove_hook is not implemented now
|
|
# FIXME(ver217): uncomment below line when remove_hook is implemented
|
|
# self.hook_registered = False
|
|
|
|
def extend_chrome_trace(self, trace: dict) -> dict:
|
|
trace['traceEvents'].extend(self.tracer.state_dict())
|
|
return trace
|