from typing import Callable, Any, Dict, Tuple import torch from torch.fx.node import Argument, Target from . import meta_profiler_function, meta_profiler_module from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS __all__ = ['profile_function', 'profile_module', 'profile_method'] CALL_FUNCTION_MSG = \ """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n from colossalai.fx.profiler.experimental import meta_profiler_function @meta_profiler_function.register(YOUR_FUNCTION) def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: flops = ... macs = ... return flops, macs """ CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' CALL_MODULE_MSG = \ """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n from colossalai.fx.profiler.experimental import meta_profiler_module @meta_profiler_module.register(YOUR_MODULE) def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: flops = ... macs = ... return flops, macs """ def profile_function(target: 'Target') -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. Unfortunately, backward memory cost and FLOPs are estimated results. Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn.functional` are available. Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: assert meta_profiler_function.has(target) or meta_profiler_function.has( target.__name__), CALL_FUNCTION_MSG.format(target) fwd_tmp = 0 fwd_out = 0 out = func(*args, **kwargs) if target not in INPLACE_OPS and not kwargs.get('inplace', False): fwd_out = activation_size(out) if meta_profiler_function.has(target): profiler = meta_profiler_function.get(target) else: profiler = meta_profiler_function.get(target.__name__) fwd_flop, _ = profiler(*args, **kwargs) return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = target.__name__ func = target return f def profile_method(target: 'Target') -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. Warnings: This is not fully implemented and you may follow the error message to debug. """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # args[0] is the `self` object for this method call self_obj, *args_tail = args # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' out = getattr(self_obj, target)(*args_tail, **kwargs) assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( target, INPLACE_METHOD, NON_INPLACE_METHOD) # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) return f def profile_module(module: torch.nn.Module) -> Callable: """ Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn` are available. Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) >>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input) """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module)) fwd_tmp = 0 fwd_out = 0 out = func(*args, **kwargs) if getattr(module, 'inplace', False): fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0) f.__name__ = module.__class__.__name__ func = module.forward return f