mirror of https://github.com/hpcaitech/ColossalAI
[utils] refactor profiler (#837)
* add model data profiler * add a subclass of torch.profiler.profile * refactor folder structure * remove redundant codes * polish code * use GeminiMemoryManager * fix import path * fix stm profiler ext * polish comments * remove useless filepull/853/head
parent
62f059251b
commit
232142f402
|
@ -1,6 +1,2 @@
|
|||
from .comm_profiler import CommProfiler
|
||||
from .pcie_profiler import PcieProfiler
|
||||
from .prof_utils import ProfilerContext, BaseProfiler
|
||||
from .mem_profiler import MemProfiler
|
||||
|
||||
__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
|
||||
from .legacy import *
|
||||
from .profiler import profile
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ProfilerExtension(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prepare_trace(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_trace(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_trace(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extend_chrome_trace(self, trace: dict) -> dict:
|
||||
pass
|
|
@ -0,0 +1,6 @@
|
|||
from .comm_profiler import CommProfiler
|
||||
from .pcie_profiler import PcieProfiler
|
||||
from .prof_utils import ProfilerContext, BaseProfiler
|
||||
from .mem_profiler import MemProfiler
|
||||
|
||||
__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
|
|
@ -3,7 +3,7 @@ from typing import Union
|
|||
from colossalai.engine import Engine
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from colossalai.engine.ophooks import MemTracerOpHook
|
||||
from colossalai.utils.profiler import BaseProfiler
|
||||
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler
|
||||
|
||||
|
||||
class MemProfiler(BaseProfiler):
|
|
@ -0,0 +1,201 @@
|
|||
import os
|
||||
from typing import List
|
||||
from colossalai.engine import Engine
|
||||
from torch.profiler import profile as torch_profile
|
||||
from torch.profiler.profiler import ProfilerAction
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
from torch.autograd import ProfilerActivity
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import gzip
|
||||
from colossalai.utils.profiler.extention import ProfilerExtension
|
||||
from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class profile(torch_profile):
|
||||
"""Profiler context manager.
|
||||
|
||||
Args:
|
||||
activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
|
||||
``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``.
|
||||
Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.
|
||||
schedule (callable): callable that takes step (int) as a single parameter and returns
|
||||
``ProfilerAction`` value that specifies the profiler action to perform at each step.
|
||||
on_trace_ready (callable): callable that is called at each step when ``schedule``
|
||||
returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
|
||||
engine (Optional[Engine], optional): An ``Engine`` instance. Defaults to None.
|
||||
record_shapes (bool): save information about operator's input shapes.
|
||||
profile_memory (bool): track tensor memory allocation/deallocation.
|
||||
with_stack (bool): record source information (file and line number) for the ops.
|
||||
with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
|
||||
(matrix multiplication and 2D convolution).
|
||||
with_modules (bool): record module hierarchy (including function names)
|
||||
corresponding to the callstack of the op. e.g. If module A's forward call's
|
||||
module B's forward which contains an aten::add op,
|
||||
then aten::add's module hierarchy is A.B
|
||||
Note that this support exist, at the moment, only for TorchScript models
|
||||
and not eager mode models.
|
||||
profile_stateful_tensor_memory (bool): track stateful tensor memory usage. ``engine`` must not be None if you enable this.
|
||||
|
||||
.. note::
|
||||
Use :func:`~torch.profiler.schedule` to generate the callable schedule.
|
||||
Non-default schedules are useful when profiling long training jobs
|
||||
and allow the user to obtain multiple traces at the different iterations
|
||||
of the training process.
|
||||
The default schedule simply records all the events continuously for the
|
||||
duration of the context manager.
|
||||
|
||||
.. note::
|
||||
Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
|
||||
|
||||
``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
|
||||
|
||||
After profiling, result files can be found in the specified directory. Use the command:
|
||||
|
||||
``tensorboard --logdir dir_name``
|
||||
|
||||
to see the results in TensorBoard.
|
||||
For more information, see
|
||||
`PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
|
||||
|
||||
.. note::
|
||||
Enabling shape and stack tracing results in additional overhead.
|
||||
When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
|
||||
that may further prevent certain optimizations that depend on the reference count and introduce
|
||||
extra tensor copies.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
]
|
||||
) as p:
|
||||
code_to_profile()
|
||||
print(p.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
|
||||
Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Non-default profiler schedule allows user to turn profiler on and off
|
||||
# on different iterations of the training loop;
|
||||
# trace_handler is called every time a new trace becomes available
|
||||
def trace_handler(prof):
|
||||
print(prof.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
|
||||
# In this example with wait=1, warmup=1, active=2,
|
||||
# profiler will skip the first step/iteration,
|
||||
# start warming up on the second, record
|
||||
# the third and the forth iterations,
|
||||
# after which the trace will become available
|
||||
# and on_trace_ready (when set) is called;
|
||||
# the cycle repeats starting with the next step
|
||||
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=1,
|
||||
warmup=1,
|
||||
active=2),
|
||||
on_trace_ready=trace_handler
|
||||
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
|
||||
# used when outputting for tensorboard
|
||||
) as p:
|
||||
for iter in range(N):
|
||||
code_iteration_to_profile(iter)
|
||||
# send a signal to the profiler that the next iteration has started
|
||||
p.step()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
activities: Optional[Iterable[ProfilerActivity]] = None,
|
||||
schedule: Optional[Callable[[int], ProfilerAction]] = None,
|
||||
on_trace_ready: Optional[Callable[..., Any]] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
with_stack: bool = False,
|
||||
with_flops: bool = False,
|
||||
with_modules: bool = False,
|
||||
profile_stateful_tensor_memory: bool = False) -> None:
|
||||
super().__init__(activities=activities,
|
||||
schedule=schedule,
|
||||
on_trace_ready=on_trace_ready,
|
||||
record_shapes=record_shapes,
|
||||
profile_memory=profile_memory,
|
||||
with_stack=with_stack,
|
||||
with_flops=with_flops,
|
||||
with_modules=with_modules)
|
||||
self._logger = get_dist_logger()
|
||||
self.extentions: List[ProfilerExtension] = []
|
||||
if profile_stateful_tensor_memory:
|
||||
if engine is None:
|
||||
self._logger.warning('Ignore "profile_model_data" since engine is None', ranks=[0])
|
||||
else:
|
||||
self.extentions.append(StatefulTensorMemoryProfilerExtention(engine))
|
||||
|
||||
def prepare_trace(self) -> None:
|
||||
if hasattr(super(), 'prepare_trace'):
|
||||
super().prepare_trace()
|
||||
elif hasattr(super(), '_start_warmup'):
|
||||
super()._start_warmup()
|
||||
for ext in self.extentions:
|
||||
ext.prepare_trace()
|
||||
|
||||
def _start_warmup(self):
|
||||
self.prepare_trace()
|
||||
|
||||
def start_trace(self):
|
||||
if hasattr(super(), '_start_trace'):
|
||||
super()._start_trace()
|
||||
elif hasattr(super(), 'start_trace'):
|
||||
super().start_trace()
|
||||
for ext in self.extentions:
|
||||
ext.start_trace()
|
||||
|
||||
def _start_trace(self):
|
||||
self.start_trace()
|
||||
|
||||
def stop_trace(self):
|
||||
if hasattr(super(), '_stop_trace'):
|
||||
super()._stop_trace()
|
||||
elif hasattr(super(), 'stop_trace'):
|
||||
super().stop_trace()
|
||||
for ext in self.extentions:
|
||||
ext.stop_trace()
|
||||
|
||||
def _stop_trace(self):
|
||||
self.stop_trace()
|
||||
|
||||
def export_chrome_trace(self, path: str):
|
||||
"""
|
||||
Exports the collected trace in Chrome JSON format.
|
||||
"""
|
||||
assert self.profiler
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
fp.close()
|
||||
retvalue = self.profiler.export_chrome_trace(fp.name)
|
||||
with open(fp.name) as fin:
|
||||
trace = json.load(fin)
|
||||
for ext in self.extentions:
|
||||
trace = ext.extend_chrome_trace(trace)
|
||||
open_func = gzip.open if path.endswith('.gz') else open
|
||||
with open_func(path, 'wt') as fout:
|
||||
json.dump(trace, fout)
|
||||
|
||||
os.remove(fp.name)
|
||||
return retvalue
|
|
@ -0,0 +1,133 @@
|
|||
import os
|
||||
import threading
|
||||
import time
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.utils.profiler.extention import ProfilerExtension
|
||||
|
||||
|
||||
class DeviceType(Enum):
|
||||
CPU = 0
|
||||
CUDA = 1
|
||||
|
||||
|
||||
def get_timestamp_us():
|
||||
return int(time.time() * 1e6)
|
||||
|
||||
|
||||
def generic_instant_event(name, pid, tid, timestamp, args):
|
||||
return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args}
|
||||
|
||||
|
||||
class StatefulTensorMemoryEvent:
|
||||
EVENT_NAME = '[statefulTensorMemory]'
|
||||
|
||||
def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None:
|
||||
self.pid = os.getpid()
|
||||
self.tid = threading.get_ident()
|
||||
self.timestamp = timestamp
|
||||
self.device_type = device_type
|
||||
self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1
|
||||
self.bytes = bytes_
|
||||
|
||||
def state_dict(self):
|
||||
return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, {
|
||||
'Device Type': self.device_type.value,
|
||||
'Device Id': self.device_id,
|
||||
'Bytes': self.bytes
|
||||
})
|
||||
|
||||
|
||||
class StatefulTensorMemoryTracer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.events: List[StatefulTensorMemoryEvent] = []
|
||||
self._tracing = False
|
||||
|
||||
def sample(self):
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||
timestamp = get_timestamp_us()
|
||||
if self._tracing:
|
||||
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem))
|
||||
self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem))
|
||||
|
||||
def start_trace(self):
|
||||
self.events.clear()
|
||||
self._tracing = True
|
||||
|
||||
def stop_trace(self):
|
||||
self._tracing = False
|
||||
|
||||
def state_dict(self):
|
||||
return [event.state_dict() for event in self.events]
|
||||
|
||||
|
||||
class StatefulTensorMemoryTracerHook(BaseOpHook):
|
||||
|
||||
def __init__(self, tracer: StatefulTensorMemoryTracer):
|
||||
super().__init__()
|
||||
self.tracer = tracer
|
||||
self._enable = False
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input_, output):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input_):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def post_iter(self):
|
||||
if self._enable:
|
||||
self.tracer.sample()
|
||||
|
||||
def enable(self):
|
||||
self._enable = True
|
||||
|
||||
def disable(self):
|
||||
self._enable = False
|
||||
|
||||
|
||||
class StatefulTensorMemoryProfilerExtention(ProfilerExtension):
|
||||
|
||||
def __init__(self, engine: Engine) -> None:
|
||||
self.engine = engine
|
||||
self.tracer = StatefulTensorMemoryTracer()
|
||||
self.hook = StatefulTensorMemoryTracerHook(self.tracer)
|
||||
self.hook_registered = False
|
||||
|
||||
def prepare_trace(self):
|
||||
self.hook.enable()
|
||||
if not self.hook_registered:
|
||||
self.engine.add_hook(self.hook)
|
||||
self.hook_registered = True
|
||||
|
||||
def start_trace(self):
|
||||
self.prepare_trace()
|
||||
self.tracer.start_trace()
|
||||
|
||||
def stop_trace(self):
|
||||
self.tracer.stop_trace()
|
||||
self.hook.disable()
|
||||
if self.hook_registered:
|
||||
self.engine.remove_hook(self.hook)
|
||||
# remove_hook is not implemented now
|
||||
# FIXME(ver217): uncomment below line when remove_hook is implemented
|
||||
# self.hook_registered = False
|
||||
|
||||
def extend_chrome_trace(self, trace: dict) -> dict:
|
||||
trace['traceEvents'].extend(self.tracer.state_dict())
|
||||
return trace
|
Loading…
Reference in New Issue