diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/utils/profiler/__init__.py index 810a43394..90eab67c4 100644 --- a/colossalai/utils/profiler/__init__.py +++ b/colossalai/utils/profiler/__init__.py @@ -1,6 +1,2 @@ -from .comm_profiler import CommProfiler -from .pcie_profiler import PcieProfiler -from .prof_utils import ProfilerContext, BaseProfiler -from .mem_profiler import MemProfiler - -__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] \ No newline at end of file +from .legacy import * +from .profiler import profile diff --git a/colossalai/utils/profiler/extention.py b/colossalai/utils/profiler/extention.py new file mode 100644 index 000000000..6726a683c --- /dev/null +++ b/colossalai/utils/profiler/extention.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class ProfilerExtension(ABC): + + @abstractmethod + def prepare_trace(self): + pass + + @abstractmethod + def start_trace(self): + pass + + @abstractmethod + def stop_trace(self): + pass + + @abstractmethod + def extend_chrome_trace(self, trace: dict) -> dict: + pass diff --git a/colossalai/utils/profiler/legacy/__init__.py b/colossalai/utils/profiler/legacy/__init__.py new file mode 100644 index 000000000..849c7fca3 --- /dev/null +++ b/colossalai/utils/profiler/legacy/__init__.py @@ -0,0 +1,6 @@ +from .comm_profiler import CommProfiler +from .pcie_profiler import PcieProfiler +from .prof_utils import ProfilerContext, BaseProfiler +from .mem_profiler import MemProfiler + +__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] diff --git a/colossalai/utils/profiler/comm_profiler.py b/colossalai/utils/profiler/legacy/comm_profiler.py similarity index 100% rename from colossalai/utils/profiler/comm_profiler.py rename to colossalai/utils/profiler/legacy/comm_profiler.py diff --git a/colossalai/utils/profiler/mem_profiler.py b/colossalai/utils/profiler/legacy/mem_profiler.py similarity index 95% rename from colossalai/utils/profiler/mem_profiler.py rename to colossalai/utils/profiler/legacy/mem_profiler.py index 662417dfd..c4d7ca2ef 100644 --- a/colossalai/utils/profiler/mem_profiler.py +++ b/colossalai/utils/profiler/legacy/mem_profiler.py @@ -3,7 +3,7 @@ from typing import Union from colossalai.engine import Engine from torch.utils.tensorboard import SummaryWriter from colossalai.engine.ophooks import MemTracerOpHook -from colossalai.utils.profiler import BaseProfiler +from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler class MemProfiler(BaseProfiler): diff --git a/colossalai/utils/profiler/pcie_profiler.py b/colossalai/utils/profiler/legacy/pcie_profiler.py similarity index 100% rename from colossalai/utils/profiler/pcie_profiler.py rename to colossalai/utils/profiler/legacy/pcie_profiler.py diff --git a/colossalai/utils/profiler/prof_utils.py b/colossalai/utils/profiler/legacy/prof_utils.py similarity index 100% rename from colossalai/utils/profiler/prof_utils.py rename to colossalai/utils/profiler/legacy/prof_utils.py diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/utils/profiler/profiler.py new file mode 100644 index 000000000..8f43a0b96 --- /dev/null +++ b/colossalai/utils/profiler/profiler.py @@ -0,0 +1,201 @@ +import os +from typing import List +from colossalai.engine import Engine +from torch.profiler import profile as torch_profile +from torch.profiler.profiler import ProfilerAction +from typing import Any, Callable, Iterable, Optional +from torch.autograd import ProfilerActivity +import json +import os +import tempfile +import gzip +from colossalai.utils.profiler.extention import ProfilerExtension +from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention +from colossalai.logging import get_dist_logger + + +class profile(torch_profile): + """Profiler context manager. + + Args: + activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA. + schedule (callable): callable that takes step (int) as a single parameter and returns + ``ProfilerAction`` value that specifies the profiler action to perform at each step. + on_trace_ready (callable): callable that is called at each step when ``schedule`` + returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. + engine (Optional[Engine], optional): An ``Engine`` instance. Defaults to None. + record_shapes (bool): save information about operator's input shapes. + profile_memory (bool): track tensor memory allocation/deallocation. + with_stack (bool): record source information (file and line number) for the ops. + with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators + (matrix multiplication and 2D convolution). + with_modules (bool): record module hierarchy (including function names) + corresponding to the callstack of the op. e.g. If module A's forward call's + module B's forward which contains an aten::add op, + then aten::add's module hierarchy is A.B + Note that this support exist, at the moment, only for TorchScript models + and not eager mode models. + profile_stateful_tensor_memory (bool): track stateful tensor memory usage. ``engine`` must not be None if you enable this. + + .. note:: + Use :func:`~torch.profiler.schedule` to generate the callable schedule. + Non-default schedules are useful when profiling long training jobs + and allow the user to obtain multiple traces at the different iterations + of the training process. + The default schedule simply records all the events continuously for the + duration of the context manager. + + .. note:: + Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard: + + ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)`` + + After profiling, result files can be found in the specified directory. Use the command: + + ``tensorboard --logdir dir_name`` + + to see the results in TensorBoard. + For more information, see + `PyTorch Profiler TensorBoard Plugin `__ + + .. note:: + Enabling shape and stack tracing results in additional overhead. + When record_shapes=True is specified, profiler will temporarily hold references to the tensors; + that may further prevent certain optimizations that depend on the reference count and introduce + extra tensor copies. + + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + code_to_profile() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: + + .. code-block:: python + + # Non-default profiler schedule allows user to turn profiler on and off + # on different iterations of the training loop; + # trace_handler is called every time a new trace becomes available + def trace_handler(prof): + print(prof.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + + # In this example with wait=1, warmup=1, active=2, + # profiler will skip the first step/iteration, + # start warming up on the second, record + # the third and the forth iterations, + # after which the trace will become available + # and on_trace_ready (when set) is called; + # the cycle repeats starting with the next step + + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2), + on_trace_ready=trace_handler + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') + # used when outputting for tensorboard + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() + """ + + def __init__(self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + engine: Optional[Engine] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + profile_stateful_tensor_memory: bool = False) -> None: + super().__init__(activities=activities, + schedule=schedule, + on_trace_ready=on_trace_ready, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules) + self._logger = get_dist_logger() + self.extentions: List[ProfilerExtension] = [] + if profile_stateful_tensor_memory: + if engine is None: + self._logger.warning('Ignore "profile_model_data" since engine is None', ranks=[0]) + else: + self.extentions.append(StatefulTensorMemoryProfilerExtention(engine)) + + def prepare_trace(self) -> None: + if hasattr(super(), 'prepare_trace'): + super().prepare_trace() + elif hasattr(super(), '_start_warmup'): + super()._start_warmup() + for ext in self.extentions: + ext.prepare_trace() + + def _start_warmup(self): + self.prepare_trace() + + def start_trace(self): + if hasattr(super(), '_start_trace'): + super()._start_trace() + elif hasattr(super(), 'start_trace'): + super().start_trace() + for ext in self.extentions: + ext.start_trace() + + def _start_trace(self): + self.start_trace() + + def stop_trace(self): + if hasattr(super(), '_stop_trace'): + super()._stop_trace() + elif hasattr(super(), 'stop_trace'): + super().stop_trace() + for ext in self.extentions: + ext.stop_trace() + + def _stop_trace(self): + self.stop_trace() + + def export_chrome_trace(self, path: str): + """ + Exports the collected trace in Chrome JSON format. + """ + assert self.profiler + fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False) + fp.close() + retvalue = self.profiler.export_chrome_trace(fp.name) + with open(fp.name) as fin: + trace = json.load(fin) + for ext in self.extentions: + trace = ext.extend_chrome_trace(trace) + open_func = gzip.open if path.endswith('.gz') else open + with open_func(path, 'wt') as fout: + json.dump(trace, fout) + + os.remove(fp.name) + return retvalue diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/utils/profiler/stateful_tensor_mem_extention.py new file mode 100644 index 000000000..749823553 --- /dev/null +++ b/colossalai/utils/profiler/stateful_tensor_mem_extention.py @@ -0,0 +1,133 @@ +import os +import threading +import time +import torch +from enum import Enum +from typing import List +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.engine.ophooks import BaseOpHook +from colossalai.engine import Engine +from colossalai.utils.profiler.extention import ProfilerExtension + + +class DeviceType(Enum): + CPU = 0 + CUDA = 1 + + +def get_timestamp_us(): + return int(time.time() * 1e6) + + +def generic_instant_event(name, pid, tid, timestamp, args): + return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args} + + +class StatefulTensorMemoryEvent: + EVENT_NAME = '[statefulTensorMemory]' + + def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None: + self.pid = os.getpid() + self.tid = threading.get_ident() + self.timestamp = timestamp + self.device_type = device_type + self.device_id = torch.cuda.current_device() if device_type == DeviceType.CUDA else -1 + self.bytes = bytes_ + + def state_dict(self): + return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, { + 'Device Type': self.device_type.value, + 'Device Id': self.device_id, + 'Bytes': self.bytes + }) + + +class StatefulTensorMemoryTracer: + + def __init__(self) -> None: + self.events: List[StatefulTensorMemoryEvent] = [] + self._tracing = False + + def sample(self): + cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] + timestamp = get_timestamp_us() + if self._tracing: + self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem)) + self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CPU, cpu_mem)) + + def start_trace(self): + self.events.clear() + self._tracing = True + + def stop_trace(self): + self._tracing = False + + def state_dict(self): + return [event.state_dict() for event in self.events] + + +class StatefulTensorMemoryTracerHook(BaseOpHook): + + def __init__(self, tracer: StatefulTensorMemoryTracer): + super().__init__() + self.tracer = tracer + self._enable = False + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + if self._enable: + self.tracer.sample() + + def post_fwd_exec(self, module: torch.nn.Module, *args): + if self._enable: + self.tracer.sample() + + def pre_bwd_exec(self, module: torch.nn.Module, input_, output): + if self._enable: + self.tracer.sample() + + def post_bwd_exec(self, module: torch.nn.Module, input_): + if self._enable: + self.tracer.sample() + + def post_iter(self): + if self._enable: + self.tracer.sample() + + def enable(self): + self._enable = True + + def disable(self): + self._enable = False + + +class StatefulTensorMemoryProfilerExtention(ProfilerExtension): + + def __init__(self, engine: Engine) -> None: + self.engine = engine + self.tracer = StatefulTensorMemoryTracer() + self.hook = StatefulTensorMemoryTracerHook(self.tracer) + self.hook_registered = False + + def prepare_trace(self): + self.hook.enable() + if not self.hook_registered: + self.engine.add_hook(self.hook) + self.hook_registered = True + + def start_trace(self): + self.prepare_trace() + self.tracer.start_trace() + + def stop_trace(self): + self.tracer.stop_trace() + self.hook.disable() + if self.hook_registered: + self.engine.remove_hook(self.hook) + # remove_hook is not implemented now + # FIXME(ver217): uncomment below line when remove_hook is implemented + # self.hook_registered = False + + def extend_chrome_trace(self, trace: dict) -> dict: + trace['traceEvents'].extend(self.tracer.state_dict()) + return trace