mirror of https://github.com/hpcaitech/ColossalAI
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device. * [siu] add experimental submodules to main branch. * [siu] * [siu] * [analyzer] init. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [test] add test. * Update symbolic_trace.py * mark skip tests. * try except. * try except. * try except. * s * init * init * fix * skip * skip --------- Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com> Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>pull/3094/head
parent
5d5f475d75
commit
fff98f06ed
|
@ -0,0 +1,306 @@
|
|||
# Analyzer
|
||||
|
||||
# Overview
|
||||
The Analyzer is a collection of static graph utils including Colossal-AI FX. Features include:
|
||||
- MetaTensor -- enabling:
|
||||
- Ahead-of-time Profiling
|
||||
- Shape Propagation
|
||||
- Ideal Flop Counter
|
||||
- symbolic_trace()
|
||||
- Robust Control-flow Tracing / Recompile
|
||||
- Robust Activation Checkpoint Tracing / CodeGen
|
||||
- Easy-to-define Bias-Addition Split
|
||||
- symbolic_profile()
|
||||
- Support ``MetaTensorMode``, where all Tensor operations are executed symbolically.
|
||||
- Shape Inference Across Device and Unified ``MetaInfo``
|
||||
- Ideal Flop Counter https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
|
||||
|
||||
# Quickstart
|
||||
## Analyzer.FX
|
||||
**Reference:**
|
||||
|
||||
https://pytorch.org/docs/stable/fx.html [[paper](https://arxiv.org/pdf/2112.08429)]
|
||||
|
||||
|
||||
torch.FX is a toolkit for developers to use to transform nn.Module instances. FX consists of three main components: a symbolic tracer, an intermediate representation, and Python code generation. FX.Tracer hacks _\_\_torch_function\_\__ and use a Proxy object to propagate through any forward function of torch.nn.Module.
|
||||

|
||||
ColossalAI FX is modified from torch.FX, with the extra capability of ahead-of-time profiling enabled by the subclass of ``MetaTensor``.
|
||||
|
||||
### Analyzer.FX.symbolic_trace()
|
||||
A drawback of the original torch.FX implementation is that it is poor at handling control flow. All control flow is not PyTorch native operands and requires actual instances that specify the branches to execute on. For example,
|
||||
|
||||
```python
|
||||
class MyModule(nn.Module):
|
||||
def forward(self, x):
|
||||
if x.dim() == 3:
|
||||
return x * 2 + 1
|
||||
else:
|
||||
return x - 5
|
||||
```
|
||||
|
||||
The above function has the computation graph of
|
||||
|
||||

|
||||
|
||||
However, since Proxy does not have concrete data, applying ``x.dim()`` will return nothing. In the context of the auto-parallel system, at least the control-flow dependencies for tensor shape should be removed, since any searched strategy could only auto-parallelize a specific computation graph with the same tensor shape. It is native to attach concrete data onto a Proxy, and propagate them through control flow.
|
||||
|
||||

|
||||
|
||||
|
||||
With ``MetaTensor``, the computation during shape propagation can be virtualized. This speeds up tracing by avoiding allocating actual memory on devices.
|
||||
|
||||
#### Remarks
|
||||
There is no free lunch for PyTorch to unify all operands in both its repo and other repos in its eco-system. For example, the einops library currently has no intention to support torch.FX (See https://github.com/arogozhnikov/einops/issues/188). To support different PyTorch-based libraries without modifying source code, good practices can be to allow users to register their implementation to substitute the functions not supported by torch.FX, or to avoid entering incompatible submodules.
|
||||
|
||||
### Analyzer.FX.symbolic_profile()
|
||||
|
||||
``symbolic_profile`` is another important feature of Colossal-AI's auto-parallel system. Profiling DNN can be costly, as you need to allocate memory and execute on real devices. However, since the profiling requirements for auto-parallel is enough if we can detect when and where the intermediate activations (i.e. Tensor) are generated, we can profile the whole procedure without actually executing it. ``symbolic_profile``, as its name infers, profiles the whole network with symbolic information only.
|
||||
|
||||
```python
|
||||
with MetaTensorMode():
|
||||
model = MyModule().cuda()
|
||||
sample = torch.rand(100, 3, 224, 224).cuda()
|
||||
meta_args = dict(
|
||||
x = sample,
|
||||
)
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
gm = symbolic_profile(gm, sample)
|
||||
```
|
||||
|
||||
``symbolic_profile`` is enabled by ``ShapeProp`` and ``GraphProfile``.
|
||||
|
||||
#### ShapeProp
|
||||
Both Tensor Parallel and Activation Checkpoint solvers need to know the shape information ahead of time. Unlike PyTorch's implementation, this ``ShapeProp`` can be executed under MetaTensorMode. With this, all the preparation for auto-parallel solvers can be done in milliseconds.
|
||||
|
||||
Meanwhile, it is easy to keep track of the memory usage of each node when doing shape propagation. However, the drawbacks of FX is that not every ``call_function`` saves its input for backward, and different tensor that flows within one FX.Graph can actually have the same layout. This raises problems for fine-grained profiling.
|
||||
|
||||

|
||||
|
||||
To address this problem, I came up with a simulated environment enabled by ``torch.autograd.graph.saved_tensor_hooks`` and fake ``data_ptr`` (check ``_subclasses/meta_tensor.py`` for more details of ``data_ptr`` updates).
|
||||
|
||||
```python
|
||||
class sim_env(saved_tensors_hooks):
|
||||
"""
|
||||
A simulation of memory allocation and deallocation in the forward pass
|
||||
using ``saved_tensor_hooks``.
|
||||
|
||||
Attributes:
|
||||
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
||||
data pointer of a tensor to the tensor itself. This is used
|
||||
to track the memory allocation and deallocation.
|
||||
|
||||
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
||||
data pointer of all model parameters to the parameter itself.
|
||||
This avoids overestimating the memory usage of the intermediate activations.
|
||||
"""
|
||||
|
||||
def __init__(self, module: Optional[torch.nn.Module] = None):
|
||||
super().__init__(self.pack_hook, self.unpack_hook)
|
||||
self.ctx = {}
|
||||
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
|
||||
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
|
||||
|
||||
def pack_hook(self, tensor: torch.Tensor):
|
||||
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
|
||||
self.ctx[tensor.data_ptr()] = tensor
|
||||
return tensor
|
||||
|
||||
def unpack_hook(self, tensor):
|
||||
return tensor
|
||||
```
|
||||
The ``ctx`` variable will keep track of all saved tensors with a unique identifier. It is likely that ``nn.Parameter`` is also counted in the ``ctx``, which is not desired. To avoid this, we can use ``param_ctx`` to keep track of all parameters in the model. The ``buffer_ctx`` is used to keep track of all buffers in the model. The ``local_ctx`` that is attached to each ``Node`` marks the memory usage of the stage to which the node belongs. With simple ``intersect``, ``union`` and ``subtract`` operations, we can get any memory-related information. For non-profileable nodes, you might add your customized profile rules to simulate the memory allocation. If a ``Graph`` is modified with some non-PyTorch functions, such as fused operands, you can register the shape propagation rule with the decorator.
|
||||
|
||||
```python
|
||||
@register_shape_impl(fuse_conv_bn)
|
||||
def fuse_conv_bn_shape_impl(*args, **kwargs):
|
||||
# infer output shape here
|
||||
return torch.empty(output_shape, device=output_device)
|
||||
```
|
||||
|
||||
An important notice is that ``ShapeProp`` will attach additional information to the graph, which will be exactly the input of ``Profiler``.
|
||||
|
||||
#### GraphProfiler
|
||||
``GraphProfiler`` executes at the node level, and profiles both forward and backward within one node. For example, ``FlopProfiler`` will profile the forward and backward FLOPs of a node, and ``CommunicationProfiler`` will profile the forward and backward communication cost of a node. The ``GraphProfiler`` will attach the profiling results to the ``Node``. These procedures are decoupled for better extensibility.
|
||||
|
||||
To provide a general insight of the profiled results, you can set ``verbose=True`` to print the summary as well.
|
||||
```python
|
||||
model = tm.resnet18()
|
||||
sample = torch.rand(100, 3, 224, 224)
|
||||
meta_args = dict(x=sample)
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
gm = symbolic_profile(gm, sample, verbose=True)
|
||||
|
||||
============================================================ Results =====================================================================
|
||||
Op type Op Accumulate size Incremental size Output size Temp size Param size Backward size Fwd FLOPs Bwd FLOPs
|
||||
------------- ---------------------------------------------- ----------------- ------------------ ------------- ----------- ------------ --------------- ------------- -------------
|
||||
placeholder x 4.59 Mb 0 b 4.59 Mb 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module conv_proj 4.59 Mb 0 b 0 b 4.59 Mb 2.25 Mb 4.59 Mb 924.84 MFLOPs 924.84 MFLOPs
|
||||
call_method reshape 4.59 Mb 0 b 0 b 4.59 Mb 0 b 4.59 Mb 0 FLOPs 0 FLOPs
|
||||
call_method permute 4.59 Mb 0 b 0 b 4.59 Mb 0 b 4.59 Mb 0 FLOPs 0 FLOPs
|
||||
get_attr class_token 4.59 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_method expand 4.59 Mb 0 b 0 b 24.00 Kb 3.00 Kb 0 b 0 FLOPs 6.14 kFLOPs
|
||||
call_function cat 4.59 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
get_attr encoder_pos_embedding 4.59 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function add 9.21 Mb 4.62 Mb 4.62 Mb 0 b 591.00 Kb 4.62 Mb 1.21 MFLOPs 1.21 MFLOPs
|
||||
call_module encoder_dropout 9.21 Mb 0 b 4.62 Mb 0 b 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_0_ln_1 9.22 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_0_self_attention 46.52 Mb 37.30 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem 46.52 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_1 46.52 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_0_dropout 46.52 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_1 51.14 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_0_ln_2 51.15 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_0_mlp_0 74.24 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_0_mlp_1 92.71 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_0_mlp_2 92.71 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_0_mlp_3 92.71 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_0_mlp_4 92.71 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_2 97.32 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_1_ln_1 101.95 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_1_self_attention 134.63 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_2 134.63 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_3 134.63 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_1_dropout 134.63 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_3 139.25 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_1_ln_2 139.26 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_1_mlp_0 162.35 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_1_mlp_1 180.82 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_1_mlp_2 180.82 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_1_mlp_3 180.82 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_1_mlp_4 180.82 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_4 185.43 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_2_ln_1 190.06 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_2_self_attention 222.74 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_4 222.74 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_5 222.74 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_2_dropout 222.74 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_5 227.36 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_2_ln_2 227.37 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_2_mlp_0 250.46 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_2_mlp_1 268.93 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_2_mlp_2 268.93 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_2_mlp_3 268.93 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_2_mlp_4 268.93 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_6 273.54 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_3_ln_1 278.17 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_3_self_attention 310.86 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_6 310.86 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_7 310.86 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_3_dropout 310.86 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_7 315.47 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_3_ln_2 315.48 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_3_mlp_0 338.57 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_3_mlp_1 357.04 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_3_mlp_2 357.04 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_3_mlp_3 357.04 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_3_mlp_4 357.04 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_8 361.66 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_4_ln_1 366.29 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_4_self_attention 398.97 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_8 398.97 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_9 398.97 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_4_dropout 398.97 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_9 403.58 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_4_ln_2 403.60 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_4_mlp_0 426.68 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_4_mlp_1 445.15 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_4_mlp_2 445.15 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_4_mlp_3 445.15 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_4_mlp_4 445.15 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_10 449.77 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_5_ln_1 454.40 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_5_self_attention 487.08 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_10 487.08 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_11 487.08 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_5_dropout 487.08 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_11 491.70 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_5_ln_2 491.71 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_5_mlp_0 514.79 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_5_mlp_1 533.26 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_5_mlp_2 533.26 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_5_mlp_3 533.26 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_5_mlp_4 533.26 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_12 537.88 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_6_ln_1 542.51 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_6_self_attention 575.19 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_12 575.19 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_13 575.19 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_6_dropout 575.19 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_13 579.81 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_6_ln_2 579.82 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_6_mlp_0 602.90 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_6_mlp_1 621.37 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_6_mlp_2 621.37 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_6_mlp_3 621.37 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_6_mlp_4 621.37 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_14 625.99 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_7_ln_1 630.62 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_7_self_attention 663.30 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_14 663.30 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_15 663.30 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_7_dropout 663.30 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_15 667.92 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_7_ln_2 667.93 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_7_mlp_0 691.02 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_7_mlp_1 709.48 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_7_mlp_2 709.48 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_7_mlp_3 709.48 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_7_mlp_4 709.48 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_16 714.10 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_8_ln_1 718.73 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_8_self_attention 751.41 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_16 751.41 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_17 751.41 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_8_dropout 751.41 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_17 756.03 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_8_ln_2 756.04 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_8_mlp_0 779.13 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_8_mlp_1 797.60 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_8_mlp_2 797.60 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_8_mlp_3 797.60 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_8_mlp_4 797.60 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_18 802.21 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_9_ln_1 806.84 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_9_self_attention 839.52 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_18 839.52 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_19 839.52 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_9_dropout 839.52 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_19 844.14 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_9_ln_2 844.15 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_9_mlp_0 867.24 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_9_mlp_1 885.71 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_9_mlp_2 885.71 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_9_mlp_3 885.71 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_9_mlp_4 885.71 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_20 890.32 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_10_ln_1 894.95 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_10_self_attention 927.63 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_20 927.63 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_21 927.63 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_10_dropout 927.63 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_21 932.25 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_10_ln_2 932.26 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_10_mlp_0 955.35 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_10_mlp_1 973.82 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_10_mlp_2 973.82 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_10_mlp_3 973.82 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_10_mlp_4 973.82 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_22 978.44 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_11_ln_1 983.06 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_11_self_attention 1015.75 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
|
||||
call_function getitem_22 1015.75 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_function getitem_23 1015.75 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_11_dropout 1015.75 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_23 1020.36 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_11_ln_2 1020.38 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_11_mlp_0 1.02 Gb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_11_mlp_1 1.04 Gb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
|
||||
call_module encoder_layers_encoder_layer_11_mlp_2 1.04 Gb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
|
||||
call_module encoder_layers_encoder_layer_11_mlp_3 1.04 Gb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
|
||||
call_module encoder_layers_encoder_layer_11_mlp_4 1.04 Gb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_function add_24 1.04 Gb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
|
||||
call_module encoder_ln 1.04 Gb 36.31 Kb 24.00 Kb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
|
||||
call_function getitem_24 1.04 Gb 0 b 24.00 Kb 0 b 0 b 4.62 Mb 0 FLOPs 0 FLOPs
|
||||
call_module heads_head 1.04 Gb 0 b 0 b 31.25 Kb 2.93 Mb 24.00 Kb 6.14 MFLOPs 12.30 MFLOPs
|
||||
output output 1.04 Gb 0 b 0 b 31.25 Kb 0 b 31.25 Kb 0 FLOPs 0 FLOPs
|
||||
```
|
|
@ -0,0 +1,4 @@
|
|||
from ._meta_registration import *
|
||||
from ._monkey_patch import *
|
||||
from .flop_tensor import flop_count, flop_mapping
|
||||
from .meta_tensor import MetaTensor, MetaTensorMode
|
|
@ -0,0 +1,481 @@
|
|||
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
|
||||
# should be activated for PyTorch version 1.12.0 and below
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
# for more meta_registrations
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||
|
||||
meta_table = {}
|
||||
|
||||
orig_empty = torch.empty
|
||||
orig_empty_strided = torch.empty_strided
|
||||
orig_empty_like = torch.empty_like
|
||||
|
||||
|
||||
def new(*args, **kwargs):
|
||||
return orig_empty(*args, **kwargs, device=torch.device('meta'))
|
||||
|
||||
|
||||
def new_strided(*args, **kwargs):
|
||||
return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
|
||||
|
||||
|
||||
def new_like(*args, **kwargs):
|
||||
return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
|
||||
|
||||
|
||||
def register_meta(op, register_dispatcher=True):
|
||||
|
||||
def wrapper(f):
|
||||
|
||||
def add_func(op):
|
||||
meta_table[op] = f
|
||||
if register_dispatcher:
|
||||
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
|
||||
try:
|
||||
meta_lib.impl(name, f)
|
||||
except:
|
||||
pass
|
||||
|
||||
tree_map(add_func, op)
|
||||
return f
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# ============================== Convolutions ======================================
|
||||
# https://github.com/pytorch/pytorch/pull/79834
|
||||
@register_meta(aten.convolution.default)
|
||||
def meta_conv(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
):
|
||||
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
Args:
|
||||
ln: length of the dimension
|
||||
p: padding in that dim
|
||||
d: dilation in that dim
|
||||
k: kernel size in that dim
|
||||
s: stride in that dim
|
||||
Returns:
|
||||
The output length
|
||||
"""
|
||||
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
|
||||
|
||||
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
if transposed convolution is used.
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
Args:
|
||||
ln: length of the dimension
|
||||
p: padding in that dim
|
||||
d: dilation in that dim
|
||||
k: kernel size in that dim
|
||||
s: stride in that dim
|
||||
op: output padding in that dim
|
||||
Returns:
|
||||
The output length
|
||||
"""
|
||||
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
|
||||
|
||||
def calc_conv_nd_return_shape(
|
||||
dims: torch.Size,
|
||||
kernel_size: torch.Size,
|
||||
stride: Union[List[int], int],
|
||||
padding: Union[List[int], int],
|
||||
dilation: Union[List[int], int],
|
||||
output_padding: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
ret_shape = []
|
||||
if isinstance(stride, int):
|
||||
stride = [stride] * len(dims)
|
||||
elif len(stride) == 1:
|
||||
stride = [stride[0]] * len(dims)
|
||||
|
||||
if isinstance(padding, int):
|
||||
padding = [padding] * len(dims)
|
||||
elif len(padding) == 1:
|
||||
padding = [padding[0]] * len(dims)
|
||||
|
||||
if isinstance(dilation, int):
|
||||
dilation = [dilation] * len(dims)
|
||||
elif len(dilation) == 1:
|
||||
dilation = [dilation[0]] * len(dims)
|
||||
|
||||
output_padding_list: Optional[List[int]] = None
|
||||
if output_padding:
|
||||
if isinstance(output_padding, int):
|
||||
output_padding_list = [output_padding] * len(dims)
|
||||
elif len(output_padding) == 1:
|
||||
output_padding_list = [output_padding[0]] * len(dims)
|
||||
else:
|
||||
output_padding_list = output_padding
|
||||
|
||||
for i in range(len(dims)):
|
||||
# If output_padding is present, we are dealing with a transposed convolution
|
||||
if output_padding_list:
|
||||
ret_shape.append(
|
||||
_formula_transposed(
|
||||
dims[i],
|
||||
padding[i],
|
||||
dilation[i],
|
||||
kernel_size[i],
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
))
|
||||
else:
|
||||
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||
return ret_shape
|
||||
|
||||
def pick_memory_format():
|
||||
if input_tensor.is_contiguous(memory_format=torch.channels_last):
|
||||
return torch.channels_last
|
||||
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
||||
return torch.contiguous_format
|
||||
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
||||
return torch.preserve_format
|
||||
|
||||
kernel_size = weight.shape[2:]
|
||||
dims = input_tensor.shape[2:]
|
||||
if is_transposed:
|
||||
out_channels = groups * weight.shape[1]
|
||||
|
||||
shape_out = calc_conv_nd_return_shape(
|
||||
dims,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
output_padding,
|
||||
)
|
||||
|
||||
else:
|
||||
out_channels = weight.shape[0]
|
||||
if weight.shape[1] != input_tensor.shape[1] / groups:
|
||||
raise RuntimeError("Invalid channel dimensions")
|
||||
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||
mem_fmt = pick_memory_format()
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten._convolution.default)
|
||||
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||
*extra_args):
|
||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
return new_like(input), new_like(weight), new((bias_sizes))
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta_adaptive_avg_pool2d_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
):
|
||||
return new_like(input)
|
||||
|
||||
|
||||
# ================================ RNN =============================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn.default)
|
||||
def meta_cuda_rnn(
|
||||
input,
|
||||
weight,
|
||||
weight_stride0,
|
||||
weight_buf,
|
||||
hx,
|
||||
cx,
|
||||
mode,
|
||||
hidden_size,
|
||||
proj_size,
|
||||
num_layers,
|
||||
batch_first,
|
||||
dropout,
|
||||
train,
|
||||
bidirectional,
|
||||
batch_sizes,
|
||||
dropout_state,
|
||||
):
|
||||
|
||||
is_input_packed = len(batch_sizes) != 0
|
||||
if is_input_packed:
|
||||
seq_length = len(batch_sizes)
|
||||
mini_batch = batch_sizes[0]
|
||||
batch_sizes_sum = input.shape[0]
|
||||
else:
|
||||
seq_length = input.shape[1] if batch_first else input.shape[0]
|
||||
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
||||
batch_sizes_sum = -1
|
||||
|
||||
num_directions = 2 if bidirectional else 1
|
||||
out_size = proj_size if proj_size != 0 else hidden_size
|
||||
if is_input_packed:
|
||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||
else:
|
||||
out_shape = ([mini_batch, seq_length, out_size *
|
||||
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||
output = input.new_empty(out_shape)
|
||||
|
||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||
cy = new(0) if cx is None else cx.new_empty(cell_shape)
|
||||
|
||||
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
||||
|
||||
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
||||
reserve_shape = 0 if train else 0
|
||||
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
||||
|
||||
return output, hy, cy, reserve, weight_buf
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn_backward.default)
|
||||
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
||||
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||
# ============================== Activations =======================================
|
||||
_unregistered_ewise = [
|
||||
aten.relu.default,
|
||||
aten.prelu.default,
|
||||
aten.hardswish.default,
|
||||
aten.hardtanh.default,
|
||||
aten.prelu_backward.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh_backward.default,
|
||||
]
|
||||
|
||||
|
||||
@register_meta(_unregistered_ewise)
|
||||
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
||||
return new_like(input)
|
||||
|
||||
|
||||
# ============================== Normalization =====================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm.default)
|
||||
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
return new_like(input), new((n_input)), new((n_input))
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
||||
save_invstd, train, eps, output_mask):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.cudnn_batch_norm.default)
|
||||
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
return new_like(input), new((n_input)), new((n_input)), new(
|
||||
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
# which is why this doesn't accept a 'training' parameter.
|
||||
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, eps, reserve):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
bs, n_input = input.size(0), input.size(1)
|
||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||
grad_input_mask):
|
||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||
|
||||
|
||||
# ================================== Misc ==========================================
|
||||
# Maybe incorrect
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
|
||||
@register_meta(aten.im2col.default)
|
||||
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
||||
return new_like(input)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
return out
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
return input
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
|
||||
@register_meta(aten._local_scalar_dense.default)
|
||||
def meta_local_scalar_dense(self: torch.Tensor):
|
||||
return 0
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||
@register_meta(aten.where.self)
|
||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||
result_type = torch.result_type(self, other)
|
||||
return new_like(condition + self + other, dtype=result_type)
|
||||
|
||||
|
||||
@register_meta(aten.index.Tensor)
|
||||
def meta_index_Tensor(self, indices):
|
||||
assert indices, "at least one index must be provided"
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
else:
|
||||
result.append(index)
|
||||
indices = result
|
||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs
|
||||
|
||||
indices = list(refs._maybe_broadcast(*indices))
|
||||
# add missing null tensors
|
||||
while len(indices) < self.ndim:
|
||||
indices.append(None)
|
||||
|
||||
# hasContiguousSubspace
|
||||
# true if all non-null tensors are adjacent
|
||||
# See:
|
||||
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
||||
state = 0
|
||||
has_contiguous_subspace = False
|
||||
for index in indices:
|
||||
if state == 0:
|
||||
if index is not None:
|
||||
state = 1
|
||||
elif state == 1:
|
||||
if index is None:
|
||||
state = 2
|
||||
else:
|
||||
if index is not None:
|
||||
break
|
||||
else:
|
||||
has_contiguous_subspace = True
|
||||
|
||||
# transposeToFront
|
||||
# This is the logic that causes the newly inserted dimensions to show up
|
||||
# at the beginning of the tensor, if they're not contiguous
|
||||
if not has_contiguous_subspace:
|
||||
dims = []
|
||||
transposed_indices = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
for i, index in enumerate(indices):
|
||||
if index is None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
self = self.permute(dims)
|
||||
indices = transposed_indices
|
||||
|
||||
# AdvancedIndex::AdvancedIndex
|
||||
# Now we can assume the indices have contiguous subspace
|
||||
# This is simplified from AdvancedIndex which goes to more effort
|
||||
# to put the input and indices in a form so that TensorIterator can
|
||||
# take them. If we write a ref for this, probably that logic should
|
||||
# get implemented
|
||||
before_shape: List[int] = []
|
||||
after_shape: List[int] = []
|
||||
replacement_shape: List[int] = []
|
||||
for dim, index in enumerate(indices):
|
||||
if index is None:
|
||||
if replacement_shape:
|
||||
after_shape.append(self.shape[dim])
|
||||
else:
|
||||
before_shape.append(self.shape[dim])
|
||||
else:
|
||||
replacement_shape = list(index.shape)
|
||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
|
||||
|
||||
# ============================== Embedding =========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
return new((num_weights, grad_output.size(-1)),
|
||||
dtype=grad_output.dtype,
|
||||
device=grad_output.device,
|
||||
layout=grad_output.layout)
|
||||
|
||||
|
||||
# ============================== Dropout ===========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||
# notice that mask is bool
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return new_like(grad) # (grad_in)
|
|
@ -0,0 +1,88 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
__all__ = [
|
||||
"_TorchFactoryMethod",
|
||||
"_TorchOverrideableFactoryMethod",
|
||||
"_TorchNonOverrideableFactoryMethod",
|
||||
"_TensorPropertyMethod",
|
||||
"_DistCommMethod",
|
||||
"_AliasATen",
|
||||
"_InplaceATen",
|
||||
"_MaybeInplaceATen",
|
||||
]
|
||||
|
||||
_TorchOverrideableFactoryMethod = [
|
||||
"empty",
|
||||
"eye",
|
||||
"full",
|
||||
"ones",
|
||||
"rand",
|
||||
"randn",
|
||||
"zeros",
|
||||
]
|
||||
|
||||
_TorchNonOverrideableFactoryMethod = [
|
||||
"arange",
|
||||
"finfo",
|
||||
"linspace",
|
||||
"logspace",
|
||||
"randint",
|
||||
"randperm",
|
||||
"tensor",
|
||||
]
|
||||
|
||||
_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod
|
||||
|
||||
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
|
||||
|
||||
_DistCommMethod = [
|
||||
"all_gather",
|
||||
"all_reduce",
|
||||
"all_to_all",
|
||||
"broadcast",
|
||||
"gather",
|
||||
"reduce",
|
||||
"reduce_scatter",
|
||||
"scatter",
|
||||
]
|
||||
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
_AliasATen = [
|
||||
aten.detach.default,
|
||||
aten.detach_.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
|
||||
_InplaceATen = [
|
||||
aten.add_.Tensor,
|
||||
aten.add_.Scalar,
|
||||
aten.sub_.Tensor,
|
||||
aten.sub_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.mul_.Scalar,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.pow_.Tensor,
|
||||
aten.pow_.Scalar,
|
||||
]
|
||||
|
||||
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
||||
_MaybeInplaceATen = [
|
||||
aten.diagonal.default,
|
||||
aten.expand.default,
|
||||
aten.select.int,
|
||||
aten.slice.Tensor,
|
||||
aten.split.Tensor,
|
||||
aten.squeeze.default,
|
||||
aten.permute.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.as_strided.default,
|
||||
]
|
|
@ -0,0 +1,536 @@
|
|||
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
|
||||
# ideas from https://pastebin.com/AkvAyJBw
|
||||
# and https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
|
||||
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from functools import partial, reduce
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .meta_tensor import MetaTensor
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
FWD = auto()
|
||||
BWD = auto()
|
||||
|
||||
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def _format_flops(flop):
|
||||
K = 1e3
|
||||
M = 1e6
|
||||
B = 1e9
|
||||
T = 1e12
|
||||
if flop < K:
|
||||
return f'{flop:.2f}'
|
||||
elif flop < M:
|
||||
return f'{flop / K:.2f}K'
|
||||
elif flop < B:
|
||||
return f'{flop / M:.2f}M'
|
||||
elif flop < T:
|
||||
return f'{flop / B:.2f}B'
|
||||
else:
|
||||
return f'{flop / T:.2f}T'
|
||||
|
||||
|
||||
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
|
||||
"""
|
||||
Count the number of floating point operations in a model.
|
||||
Ideas from https://pastebin.com/AkvAyJBw.
|
||||
Args:
|
||||
module (torch.nn.Module): A PyTorch model.
|
||||
*args: Input arguments to the model.
|
||||
verbose (bool): If True, print the number of flops for each module.
|
||||
**kwargs: Input keyword arguments to the model.
|
||||
Returns:
|
||||
Number: The total number of floating point operations (FWD + BWD).
|
||||
"""
|
||||
maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
|
||||
or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, func):
|
||||
super().__init__()
|
||||
self.func = func
|
||||
self.__name__ = func.__name__
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
|
||||
flop_counts = defaultdict(lambda: defaultdict(int))
|
||||
parents = ['Global']
|
||||
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
|
||||
|
||||
class FlopTensor(MetaTensor):
|
||||
_tensor: torch.Tensor
|
||||
|
||||
def __repr__(self):
|
||||
name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
|
||||
if self.grad_fn:
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
# no_dispatch is only needed if you use enable_python_mode.
|
||||
# It prevents infinite recursion.
|
||||
rs = super().__torch_dispatch__(func, types, args, kwargs)
|
||||
|
||||
outs = normalize_tuple(rs)
|
||||
|
||||
if func in flop_mapping:
|
||||
nonlocal flop_counts, total_flop_count
|
||||
flop_count = flop_mapping[func](args, outs)
|
||||
for par in parents:
|
||||
flop_counts[par][func.__name__] += flop_count
|
||||
total_flop_count[cur_phase] += flop_count
|
||||
|
||||
def wrap(x):
|
||||
if isinstance(x, MetaTensor):
|
||||
x = FlopTensor(x)
|
||||
return x
|
||||
|
||||
rs = tree_map(wrap, rs)
|
||||
|
||||
return rs
|
||||
|
||||
def is_autogradable(x):
|
||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
def create_backwards_push(name):
|
||||
|
||||
class PushState(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
nonlocal parents
|
||||
parents.append(name)
|
||||
return grad_outs
|
||||
|
||||
return PushState.apply
|
||||
|
||||
def create_backwards_pop(name):
|
||||
|
||||
class PopState(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
nonlocal parents
|
||||
assert (parents[-1] == name)
|
||||
parents.pop()
|
||||
return grad_outs
|
||||
|
||||
return PopState.apply
|
||||
|
||||
def enter_module(name):
|
||||
|
||||
def f(module, inputs):
|
||||
nonlocal parents
|
||||
parents.append(name)
|
||||
inputs = normalize_tuple(inputs)
|
||||
out = create_backwards_pop(name)(*inputs)
|
||||
return out
|
||||
|
||||
return f
|
||||
|
||||
def exit_module(name):
|
||||
|
||||
def f(module, inputs, outputs):
|
||||
nonlocal parents
|
||||
assert (parents[-1] == name)
|
||||
parents.pop()
|
||||
outputs = normalize_tuple(outputs)
|
||||
return create_backwards_push(name)(*outputs)
|
||||
|
||||
return f
|
||||
|
||||
@contextmanager
|
||||
def instrument_module(mod):
|
||||
registered = []
|
||||
for name, module in dict(mod.named_children()).items():
|
||||
registered.append(module.register_forward_pre_hook(enter_module(name)))
|
||||
registered.append(module.register_forward_hook(exit_module(name)))
|
||||
yield
|
||||
for handle in registered:
|
||||
handle.remove()
|
||||
|
||||
def display_flops():
|
||||
for mod in flop_counts.keys():
|
||||
print(f"Module: ", mod)
|
||||
for k, v in flop_counts[mod].items():
|
||||
print('\t', k, _format_flops(v))
|
||||
print()
|
||||
|
||||
def detach_variables(r):
|
||||
if isinstance(r, torch.Tensor):
|
||||
requires_grad = r.requires_grad
|
||||
r = r.detach()
|
||||
r.requires_grad = requires_grad
|
||||
return r
|
||||
|
||||
def wrap(r):
|
||||
if isinstance(r, torch.Tensor):
|
||||
data_ptr_fn = getattr(r, '_tensor', r).data_ptr
|
||||
r = FlopTensor(detach_variables(r))
|
||||
if maybe_inplace:
|
||||
r = r + 0
|
||||
r._tensor.data_ptr = data_ptr_fn
|
||||
return r
|
||||
|
||||
with instrument_module(module):
|
||||
cur_phase = Phase.FWD
|
||||
rst = module(*tree_map(wrap, args), **tree_map(wrap, kwargs))
|
||||
rst = tuple(r for r in normalize_tuple(rst) if is_autogradable(r) and r.requires_grad)
|
||||
cur_phase = Phase.BWD
|
||||
|
||||
if rst:
|
||||
grad = [torch.zeros_like(t) for t in rst]
|
||||
torch.autograd.backward(
|
||||
rst,
|
||||
grad,
|
||||
)
|
||||
|
||||
if verbose:
|
||||
display_flops()
|
||||
|
||||
return total_flop_count[Phase.FWD], total_flop_count[Phase.BWD]
|
||||
|
||||
|
||||
def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
"""
|
||||
Count flops for matmul.
|
||||
"""
|
||||
# Inputs should be a list of length 2.
|
||||
# Inputs contains the shapes of two matrices.
|
||||
input_shapes = [v.shape for v in inputs]
|
||||
assert len(input_shapes) == 2, input_shapes
|
||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
||||
return flops
|
||||
|
||||
|
||||
def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
"""
|
||||
Count flops for fully connected layers.
|
||||
"""
|
||||
# Count flop for nn.Linear
|
||||
# inputs is a list of length 3.
|
||||
input_shapes = [v.shape for v in inputs[1:3]]
|
||||
# input_shapes[0]: [batch size, input feature dimension]
|
||||
# input_shapes[1]: [input feature dimension, output feature dimension]
|
||||
assert len(input_shapes[0]) == 2, input_shapes[0]
|
||||
assert len(input_shapes[1]) == 2, input_shapes[1]
|
||||
batch_size, input_dim = input_shapes[0]
|
||||
output_dim = input_shapes[1][1]
|
||||
flops = batch_size * input_dim * output_dim
|
||||
return flops
|
||||
|
||||
|
||||
def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
"""
|
||||
Count flops for the aten::linear operator.
|
||||
"""
|
||||
# Inputs is a list of length 3; unlike aten::addmm, it is the first
|
||||
# two elements that are relevant.
|
||||
input_shapes = [v.shape for v in inputs[0:2]]
|
||||
# input_shapes[0]: [dim0, dim1, ..., input_feature_dim]
|
||||
# input_shapes[1]: [output_feature_dim, input_feature_dim]
|
||||
assert input_shapes[0][-1] == input_shapes[1][-1]
|
||||
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0]
|
||||
return flops
|
||||
|
||||
|
||||
def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
"""
|
||||
Count flops for the bmm operation.
|
||||
"""
|
||||
# Inputs should be a list of length 2.
|
||||
# Inputs contains the shapes of two tensor.
|
||||
assert len(inputs) == 2, len(inputs)
|
||||
input_shapes = [v.shape for v in inputs]
|
||||
n, c, t = input_shapes[0]
|
||||
d = input_shapes[-1][-1]
|
||||
flops = n * c * t * d
|
||||
return flops
|
||||
|
||||
|
||||
def conv_flop_count(
|
||||
x_shape: List[int],
|
||||
w_shape: List[int],
|
||||
out_shape: List[int],
|
||||
transposed: bool = False,
|
||||
) -> Number:
|
||||
"""
|
||||
Count flops for convolution. Note only multiplication is
|
||||
counted. Computation for addition and bias is ignored.
|
||||
Flops for a transposed convolution are calculated as
|
||||
flops = (x_shape[2:] * prod(w_shape) * batch_size).
|
||||
Args:
|
||||
x_shape (list(int)): The input shape before convolution.
|
||||
w_shape (list(int)): The filter shape.
|
||||
out_shape (list(int)): The output shape after convolution.
|
||||
transposed (bool): is the convolution transposed
|
||||
Returns:
|
||||
int: the number of flops
|
||||
"""
|
||||
batch_size = x_shape[0]
|
||||
conv_shape = (x_shape if transposed else out_shape)[2:]
|
||||
flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape)
|
||||
return flops
|
||||
|
||||
|
||||
def conv_flop_jit(inputs: List[Any], outputs: List[Any]):
|
||||
"""
|
||||
Count flops for convolution.
|
||||
"""
|
||||
x, w = inputs[:2]
|
||||
x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape)
|
||||
transposed = inputs[6]
|
||||
|
||||
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
|
||||
|
||||
|
||||
def transpose_shape(shape):
|
||||
return [shape[1], shape[0]] + list(shape[2:])
|
||||
|
||||
|
||||
def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]):
|
||||
grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
|
||||
output_mask = inputs[-1]
|
||||
fwd_transposed = inputs[7]
|
||||
flop_count = 0
|
||||
|
||||
if output_mask[0]:
|
||||
grad_input_shape = outputs[0].shape
|
||||
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
|
||||
if output_mask[1]:
|
||||
grad_weight_shape = outputs[1].shape
|
||||
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
|
||||
|
||||
return flop_count
|
||||
|
||||
|
||||
def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
|
||||
"""
|
||||
Args:
|
||||
affine_arg_index: index of the affine argument in inputs
|
||||
"""
|
||||
|
||||
def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
"""
|
||||
Count flops for norm layers.
|
||||
"""
|
||||
# Inputs[0] contains the shape of the input.
|
||||
input_shape = inputs[input_arg_index].shape
|
||||
|
||||
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
|
||||
'shape') else inputs[affine_arg_index]
|
||||
assert 2 <= len(input_shape) <= 5, input_shape
|
||||
# 5 is just a rough estimate
|
||||
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
|
||||
return flop
|
||||
|
||||
return norm_flop_jit
|
||||
|
||||
|
||||
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number:
|
||||
if training is None:
|
||||
training = inputs[-3]
|
||||
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
|
||||
if training:
|
||||
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
|
||||
has_affine = inputs[1].shape is not None
|
||||
input_shape = reduce(operator.mul, inputs[0].shape)
|
||||
return input_shape * (2 if has_affine else 1)
|
||||
|
||||
|
||||
def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable:
|
||||
"""
|
||||
Count flops by
|
||||
input_tensor.numel() * input_scale + output_tensor.numel() * output_scale
|
||||
Args:
|
||||
input_scale: scale of the input tensor (first argument)
|
||||
output_scale: scale of the output tensor (first element in outputs)
|
||||
"""
|
||||
|
||||
def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
ret = 0
|
||||
if input_scale != 0:
|
||||
shape = inputs[0].shape
|
||||
ret += input_scale * reduce(operator.mul, shape) if shape else 0
|
||||
if output_scale != 0:
|
||||
shape = outputs[0].shape
|
||||
ret += output_scale * reduce(operator.mul, shape) if shape else 0
|
||||
return ret
|
||||
|
||||
return ewise_flop
|
||||
|
||||
|
||||
def zero_flop_jit(*args):
|
||||
"""
|
||||
Count flops for zero flop layers.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
flop_mapping = {
|
||||
# gemm
|
||||
aten.mm.default: matmul_flop_jit,
|
||||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
|
||||
# convolution
|
||||
aten.convolution.default: conv_flop_jit,
|
||||
aten._convolution.default: conv_flop_jit,
|
||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||
|
||||
# normalization
|
||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
|
||||
# pooling
|
||||
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding.default: ewise_flop_counter(1, 0),
|
||||
}
|
||||
|
||||
ewise_flop_aten = [
|
||||
# basic op
|
||||
aten.add.Tensor,
|
||||
aten.add_.Tensor,
|
||||
aten.div.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div.Scalar,
|
||||
aten.div_.Scalar,
|
||||
aten.mul.Tensor,
|
||||
aten.mul.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.neg.default,
|
||||
aten.pow.Tensor_Scalar,
|
||||
aten.rsub.Scalar,
|
||||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
|
||||
# activation op
|
||||
aten.hardswish.default,
|
||||
aten.hardswish_.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh.default,
|
||||
aten.hardtanh_.default,
|
||||
aten.hardtanh_backward.default,
|
||||
aten.hardsigmoid_backward.default,
|
||||
aten.hardsigmoid.default,
|
||||
aten.gelu.default,
|
||||
aten.gelu_backward.default,
|
||||
aten.silu.default,
|
||||
aten.silu_.default,
|
||||
aten.silu_backward.default,
|
||||
aten.sigmoid.default,
|
||||
aten.sigmoid_backward.default,
|
||||
aten._softmax.default,
|
||||
aten._softmax_backward_data.default,
|
||||
aten.relu_.default,
|
||||
aten.relu.default,
|
||||
aten.tanh.default,
|
||||
aten.tanh_backward.default,
|
||||
aten.threshold_backward.default,
|
||||
|
||||
# dropout
|
||||
aten.native_dropout.default,
|
||||
aten.native_dropout_backward.default,
|
||||
|
||||
# distribution
|
||||
aten.bernoulli_.float,
|
||||
|
||||
# where
|
||||
aten.where.self,
|
||||
]
|
||||
for op in ewise_flop_aten:
|
||||
flop_mapping[op] = ewise_flop_counter(1, 0)
|
||||
|
||||
# fix-me: this will be removed in future
|
||||
zero_flop_aten = [
|
||||
aten.as_strided.default,
|
||||
aten.as_strided_.default,
|
||||
aten.cat.default,
|
||||
aten.clone.default,
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.expand.default,
|
||||
aten.empty_like.default,
|
||||
aten.new_empty.default,
|
||||
aten.new_empty_strided.default,
|
||||
aten.ones_like.default,
|
||||
aten._reshape_alias.default,
|
||||
aten.select.int,
|
||||
aten.select_backward.default,
|
||||
aten.squeeze.dim,
|
||||
aten.slice.Tensor,
|
||||
aten.slice_backward.default,
|
||||
aten.split.Tensor,
|
||||
aten.permute.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten._to_copy.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.unbind.int,
|
||||
aten._unsafe_view.default,
|
||||
aten.view.default,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
|
||||
for op in zero_flop_aten:
|
||||
flop_mapping[op] = zero_flop_jit
|
|
@ -0,0 +1,207 @@
|
|||
import uuid
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.types import _bool, _device, _dtype
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
|
||||
|
||||
__all__ = ['MetaTensor', 'MetaTensorMode']
|
||||
|
||||
|
||||
def register_storage(r, data_ptr_fn=None):
|
||||
if isinstance(r, torch.Tensor):
|
||||
if data_ptr_fn is not None:
|
||||
r.data_ptr = data_ptr_fn
|
||||
elif not r.data_ptr():
|
||||
data_ptr = uuid.uuid1()
|
||||
r.data_ptr = lambda: data_ptr
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
# a hack of inplace execution in PyTorch
|
||||
def _assert_alias(func):
|
||||
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
|
||||
)
|
||||
|
||||
|
||||
class MetaTensor(torch.Tensor):
|
||||
"""
|
||||
A wrapping tensor that hacks ``torch.autograd`` without patching more ``torch.ops.aten`` ops.
|
||||
`device` is the device that ``MetaTensor`` is supposed to run on. Meta tensors give you the
|
||||
ability to run PyTorch code without having to actually do computation through tensors
|
||||
allocated on a `meta` device. Because the device is `meta`, meta tensors do not model
|
||||
device propagation. ``MetaTensor`` extends its usage by carrying an additional `device`
|
||||
which tracks devices that would have been used.
|
||||
|
||||
Reference:
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/_subclasses/fake_tensor.py
|
||||
"""
|
||||
|
||||
_tensor: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, device=None, data_ptr_fn=None):
|
||||
requires_grad = elem.requires_grad
|
||||
# Avoid multiple wrapping
|
||||
while isinstance(elem, MetaTensor):
|
||||
device = elem.device if device is None else device
|
||||
elem = elem._tensor
|
||||
|
||||
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||||
# memory for the class in question, but it should still
|
||||
# advertise the same device as before
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
|
||||
requires_grad=requires_grad) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
val = elem.data_ptr()
|
||||
data_ptr_fn = lambda: val
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
register_storage(r._tensor, data_ptr_fn)
|
||||
if isinstance(elem, torch.nn.Parameter):
|
||||
r = torch.nn.Parameter(r)
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
|
||||
if self.grad_fn:
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
device = None
|
||||
|
||||
def unwrap(x):
|
||||
nonlocal device
|
||||
if isinstance(x, MetaTensor):
|
||||
device = x.device
|
||||
x = x._tensor
|
||||
elif isinstance(x, torch.Tensor):
|
||||
device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
if 'device' in kwargs:
|
||||
device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
# here we detect whether or not the execution generates a physical copy
|
||||
# of the input tensor
|
||||
ret = func(*args, **kwargs)
|
||||
|
||||
if _assert_alias(func):
|
||||
val = args[0].data_ptr()
|
||||
tree_map(partial(register_storage, data_ptr_fn=lambda: val), _normalize_tuple(ret))
|
||||
|
||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
def wrap(x):
|
||||
return MetaTensor(x, device=device) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, ret)
|
||||
|
||||
def to(self, *args, **kwargs) -> torch.Tensor:
|
||||
"""An extension of `torch.Tensor.to()` to MetaTensor
|
||||
Returns:
|
||||
result (MetaTensor): MetaTensor
|
||||
Usage:
|
||||
>>> tensor = MetaTensor(torch.rand(10), device='cuda:100')
|
||||
>>> tensor.to(torch.uint8)
|
||||
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), device='cuda:100')
|
||||
>>> tensor.to(torch.device('cuda:42'))
|
||||
MetaTensor(tensor(..., device='meta', size=(10,)), device='cuda:42')
|
||||
>>> tensor.to('vulkan')
|
||||
MetaTensor(tensor(..., device='meta', size=(10,)), device='vulkan')
|
||||
"""
|
||||
# this imitates c++ function in the way of @overload
|
||||
device = None
|
||||
|
||||
def replace(x):
|
||||
nonlocal device
|
||||
if isinstance(x, str) or isinstance(x, _device):
|
||||
device = x
|
||||
return torch.device('meta')
|
||||
return x
|
||||
|
||||
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
return MetaTensor(elem, device=device)
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
if self.device.type == 'cpu':
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cpu', **kwargs)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False):
|
||||
if device is not None:
|
||||
return self.to(device=device, non_blocking=non_blocking)
|
||||
return self.to(device='cuda:0', non_blocking=non_blocking)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._tensor.data_ptr()
|
||||
|
||||
|
||||
class MetaTensorMode(object):
|
||||
"""
|
||||
A context manager that enables MetaTensor mode.
|
||||
|
||||
Usage:
|
||||
>>> with MetaTensorMode():
|
||||
>>> # all torch.xxx and torch.distributed.xxx will be replaced by patched functions
|
||||
>>> # and the actual execution will be on torch.device('meta')
|
||||
>>> a = torch.rand(100000, 100000)
|
||||
>>> b = torch.rand(100000, 100000)
|
||||
>>> c = torch.mm(a, b)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.torch_overrides = {} # override torch.xxx
|
||||
self.dist_overrides = {} # override torch.distributed.xxx
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def _dummy(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def _new(*args, orig_new=torch.empty, **kwargs):
|
||||
return MetaTensor(orig_new(*args, **{
|
||||
**kwargs, 'device': 'meta'
|
||||
}),
|
||||
device=kwargs.get('device', torch.device('cpu')))
|
||||
|
||||
for func in _TorchOverrideableFactoryMethod:
|
||||
self.torch_overrides[func] = getattr(torch, func)
|
||||
setattr(torch, func, partial(_new, orig_new=getattr(torch, func)))
|
||||
|
||||
for func in _DistCommMethod:
|
||||
self.dist_overrides[func] = getattr(dist, func)
|
||||
setattr(dist, func, _dummy)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
for func, func_impl in self.torch_overrides.items():
|
||||
setattr(torch, func, func_impl)
|
||||
|
||||
for func, func_impl in self.dist_overrides.items():
|
||||
setattr(dist, func, func_impl)
|
|
@ -0,0 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshConfig:
|
||||
TFLOPS: float = 1.9e12
|
||||
BANDWIDTH = 1.2e9
|
|
@ -0,0 +1,4 @@
|
|||
from .bias_addition import *
|
||||
from .node_util import MetaInfo
|
||||
from .symbolic_profile import symbolic_profile
|
||||
from .symbolic_trace import symbolic_trace
|
|
@ -0,0 +1,155 @@
|
|||
"""
|
||||
If FX.Graph is traced for auto-parallel module, some extra node will be added during
|
||||
graph construction to deal with the compatibility between bias-addition and all-reduce.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.utils import _pair, _single, _triple
|
||||
|
||||
from .symbolic_trace import register_tracer_impl
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
@register_tracer_impl(F.linear, name='_bias_addition_impl')
|
||||
def linear_impl(input, weight, bias=None):
|
||||
if bias is None:
|
||||
return F.linear(input, weight)
|
||||
else:
|
||||
return F.linear(input, weight) + bias
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
|
||||
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
|
||||
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
|
||||
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
|
||||
def conv_transpose1d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_single(1),
|
||||
padding=_single(0),
|
||||
output_padding=_single(0),
|
||||
groups=1,
|
||||
dilation=_single(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
|
||||
def conv_transpose2d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_pair(1),
|
||||
padding=_pair(0),
|
||||
output_padding=_pair(0),
|
||||
groups=1,
|
||||
dilation=_pair(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
|
||||
def conv_transpose3d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_triple(1),
|
||||
padding=_triple(0),
|
||||
output_padding=_triple(0),
|
||||
groups=1,
|
||||
dilation=_triple(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
|
||||
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
|
||||
if alpha != 1 and beta != 1:
|
||||
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
|
||||
elif alpha != 1:
|
||||
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input
|
||||
elif beta != 1:
|
||||
return F.linear(mat1, mat2.transpose(0, 1)) + input * beta
|
||||
else:
|
||||
return F.linear(mat1, mat2.transpose(0, 1)) + input
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
|
||||
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
|
||||
if alpha != 1 and beta != 1:
|
||||
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
|
||||
elif alpha != 1:
|
||||
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input
|
||||
elif beta != 1:
|
||||
return torch.bmm(batch1, batch2.transpose(1, 2)) + input * beta
|
||||
else:
|
||||
return torch.bmm(batch1, batch2.transpose(1, 2)) + input
|
|
@ -0,0 +1,456 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
_register_custom_builtin,
|
||||
inplace_methods,
|
||||
magic_methods,
|
||||
)
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
import colossalai
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
|
||||
_register_custom_builtin('colossalai', 'import colossalai', colossalai)
|
||||
|
||||
|
||||
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
||||
"""
|
||||
Generate the checkpoint function definition
|
||||
"""
|
||||
return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
|
||||
|
||||
|
||||
def _gen_ckpt_output(output_vars: List[str]) -> str:
|
||||
"""
|
||||
Generate the return statement for checkpoint region
|
||||
"""
|
||||
return f"return {', '.join(output_vars)}"
|
||||
|
||||
|
||||
def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
|
||||
"""
|
||||
Generate the checkpoint function call code text
|
||||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
|
||||
|
||||
|
||||
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
|
||||
"""
|
||||
Check if the node could end the ckpt region at `ckpt_level`
|
||||
"""
|
||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
||||
return node.meta['info'].to_recompute[ckpt_level] is not None
|
||||
return True
|
||||
|
||||
|
||||
def _find_input_and_output_nodes(nodes: List[Node]):
|
||||
"""
|
||||
Find the input and output node names which are not found in the given list of nodes.
|
||||
"""
|
||||
input_nodes = []
|
||||
output_nodes = []
|
||||
|
||||
# if a node has an input node which is not in the node list
|
||||
# we treat that input node as the input of the checkpoint function
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
node_repr = repr(input_node)
|
||||
if input_node not in nodes and node_repr not in input_nodes:
|
||||
input_nodes.append(node_repr)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
node_repr = repr(node)
|
||||
if output_node not in nodes and node_repr not in output_nodes:
|
||||
output_nodes.append(node_repr)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
||||
"""
|
||||
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
|
||||
will be list of tuples, each tuple is in the form of (start_index, end_index).
|
||||
"""
|
||||
ckpt_regions = []
|
||||
start = -1
|
||||
end = -1
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
if current_region is None:
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
|
||||
# if activation checkpoint has changed
|
||||
# we restart the tracking
|
||||
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
|
||||
if act_ckpt_label != current_region:
|
||||
assert start != -1
|
||||
ckpt_regions.append((start, idx - 1))
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
end = -1
|
||||
|
||||
elif current_region is not None and _end_of_ckpt(node, ckpt_level):
|
||||
# used to check the case below
|
||||
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
||||
end = idx - 1
|
||||
assert start != -1 and end != -1
|
||||
ckpt_regions.append((start, end))
|
||||
start = end = -1
|
||||
current_region = None
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
if current_region is not None:
|
||||
end = len(node_list) - 1
|
||||
ckpt_regions.append((start, end))
|
||||
return ckpt_regions
|
||||
|
||||
|
||||
def emit_ckpt_func(body,
|
||||
ckpt_func,
|
||||
node_list: List[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
ckpt_level=0,
|
||||
in_ckpt=False):
|
||||
"""Emit ckpt fuction in nested way
|
||||
|
||||
Args:
|
||||
body: forward code - in recursive calls, this part will be checkpoint
|
||||
functions code
|
||||
ckpt_func: checkpoint functions code - in recursive calls, this part
|
||||
will be a buffer
|
||||
node_list (List[Node]): list of torch.fx.Node
|
||||
emit_node_func: function to emit a node
|
||||
delete_unused_value_func: function to delete unused value
|
||||
level (int, optional): checkpoint level. Defaults to 0.
|
||||
in_ckpt (bool, optional): indicates wether the func is in recursive
|
||||
call. Defaults to False.
|
||||
"""
|
||||
inputs, outputs = _find_input_and_output_nodes(node_list)
|
||||
|
||||
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
|
||||
# if there is more level to fetch
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
# use ckpt_func_buffer to store nested checkpoint functions
|
||||
ckpt_func_buffer = []
|
||||
node_idx = 0
|
||||
while 1:
|
||||
if node_idx >= len(node_list):
|
||||
break
|
||||
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
|
||||
ckpt_level + 1, True)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
node_idx += 1
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
ckpt_func += ckpt_func_buffer
|
||||
|
||||
# last level
|
||||
else:
|
||||
for node in node_list:
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
|
||||
usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
body.append(usage)
|
||||
|
||||
|
||||
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
||||
"""Emit code with nested activation checkpoint
|
||||
When we detect some of the annotation is a , we will use
|
||||
this function to emit the activation checkpoint codes.
|
||||
|
||||
Args:
|
||||
body: forward code
|
||||
ckpt_func: checkpoint functions code
|
||||
nodes: graph.nodes
|
||||
emit_node_func: function to emit node
|
||||
delete_unused_value_func: function to remove the unused value
|
||||
"""
|
||||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
node_idx = 0
|
||||
while 1:
|
||||
# break if we finish the processing all the nodes
|
||||
if node_idx >= len(node_list):
|
||||
break
|
||||
|
||||
# process ckpt_regions
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
# process node in forward function
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
emit_node_func(node, body)
|
||||
delete_unused_value_func(node, body)
|
||||
node_idx += 1
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ActivationCheckpointCodeGen(CodeGen):
|
||||
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
globals_: Dict[str, Any] = {}
|
||||
wrapped_fns: Dict[str, None] = {}
|
||||
|
||||
# Wrap string in list to pass by reference
|
||||
maybe_return_annotation: List[str] = ['']
|
||||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
We call this for names that reference objects external to the
|
||||
Graph, like functions or types.
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
# HACK: workaround for how torch custom ops are registered. We
|
||||
# can't import them like normal modules so they must retain their
|
||||
# fully qualified name.
|
||||
return _get_qualified_name(obj)
|
||||
|
||||
# normalize the name hint to get a proper identifier
|
||||
global_name = namespace.create_name(name_hint, obj)
|
||||
|
||||
if global_name in globals_:
|
||||
assert globals_[global_name] is obj
|
||||
return global_name
|
||||
globals_[global_name] = obj
|
||||
return global_name
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
add_global(name, obj)
|
||||
|
||||
def type_repr(o: Any):
|
||||
if o == ():
|
||||
# Empty tuple is used for empty tuple type annotation Tuple[()]
|
||||
return '()'
|
||||
|
||||
typename = _type_repr(o)
|
||||
|
||||
if hasattr(o, '__origin__'):
|
||||
# This is a generic type, e.g. typing.List[torch.Tensor]
|
||||
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
||||
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
||||
|
||||
if hasattr(o, '__args__'):
|
||||
# Assign global names for each of the inner type variables.
|
||||
args = [type_repr(arg) for arg in o.__args__]
|
||||
|
||||
if len(args) == 0:
|
||||
# Bare type, such as `typing.Tuple` with no subscript
|
||||
# This code-path used in Python < 3.9
|
||||
return origin_typename
|
||||
|
||||
return f'{origin_typename}[{",".join(args)}]'
|
||||
else:
|
||||
# Bare type, such as `typing.Tuple` with no subscript
|
||||
# This code-path used in Python 3.9+
|
||||
return origin_typename
|
||||
|
||||
# Common case: this is a regular module name like 'foo.bar.baz'
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
|
||||
qualified_name = _get_qualified_name(type(arg))
|
||||
global_name = add_global(qualified_name, type(arg))
|
||||
return f"{global_name}{repr(tuple(arg))}"
|
||||
return repr(arg)
|
||||
|
||||
args_s = ', '.join(_get_repr(a) for a in args)
|
||||
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
|
||||
if args_s and kwargs_s:
|
||||
return f'{args_s}, {kwargs_s}'
|
||||
return args_s or kwargs_s
|
||||
|
||||
# Run through reverse nodes and record the first instance of a use
|
||||
# of a given node. This represents the *last* use of the node in the
|
||||
# execution order of the program, which we will use to free unused
|
||||
# values
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if n not in node_to_last_use:
|
||||
node_to_last_use[n] = user
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body):
|
||||
"""
|
||||
Delete values after their last use. This ensures that values that are
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
of the code is optimal.
|
||||
"""
|
||||
if user.op == 'placeholder':
|
||||
return
|
||||
if user.op == 'output':
|
||||
body.append('\n')
|
||||
return
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
||||
body.append(f'; {to_delete_str}\n')
|
||||
else:
|
||||
body.append('\n')
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
if node.op == 'placeholder':
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
|
||||
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
|
||||
raw_name = node.target.replace('*', '')
|
||||
if raw_name != repr(node):
|
||||
body.append(f'{repr(node)} = {raw_name}\n')
|
||||
return
|
||||
elif node.op == 'call_method':
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
|
||||
f'({_format_args(node.args[1:], node.kwargs)})')
|
||||
return
|
||||
elif node.op == 'call_function':
|
||||
assert callable(node.target)
|
||||
# pretty print operators
|
||||
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
||||
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
|
||||
return
|
||||
|
||||
# pretty print inplace operators; required for jit.script to work properly
|
||||
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
|
||||
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
|
||||
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
|
||||
return
|
||||
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||
if global_name == 'getattr' and \
|
||||
isinstance(node.args, tuple) and \
|
||||
isinstance(node.args[1], str) and \
|
||||
node.args[1].isidentifier() and \
|
||||
len(node.args) == 2:
|
||||
body.append(
|
||||
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
|
||||
return
|
||||
body.append(
|
||||
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
|
||||
if node.meta.get('is_wrapped', False):
|
||||
wrapped_fns.setdefault(global_name)
|
||||
return
|
||||
elif node.op == 'call_module':
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
||||
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
|
||||
return
|
||||
elif node.op == 'get_attr':
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
|
||||
return
|
||||
elif node.op == 'output':
|
||||
if node.type is not None:
|
||||
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
|
||||
body.append(self.generate_output(node.args[0]))
|
||||
return
|
||||
raise NotImplementedError(f'node: {node.op} {node.target}')
|
||||
|
||||
# Modified for activation checkpointing
|
||||
ckpt_func = []
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
# have been emitted. To continue to have valid Python code, emit a
|
||||
# single pass statement
|
||||
body.append('pass\n')
|
||||
|
||||
if len(wrapped_fns) > 0:
|
||||
wrap_name = add_global('wrap', torch.fx.wrap)
|
||||
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
||||
else:
|
||||
wrap_stmts = ''
|
||||
|
||||
if self._body_transformer:
|
||||
body = self._body_transformer(body)
|
||||
|
||||
for name, value in self.additional_globals():
|
||||
add_global(name, value)
|
||||
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = ''.join(ckpt_func) + prologue
|
||||
prologue = prologue
|
||||
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
{prologue}
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_)
|
|
@ -0,0 +1,173 @@
|
|||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.nn as nn
|
||||
from torch.fx.graph import PythonCode, _PyTreeCodeGen
|
||||
from torch.fx.graph_module import _exec_with_source, _forward_from_src, _WrappedCall
|
||||
from torch.nn.modules.module import _addindent
|
||||
|
||||
|
||||
class ColoGraphModule(torch.fx.GraphModule):
|
||||
"""
|
||||
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
|
||||
ColoGraphmodule has a ``graph`` attribute, as well as ``code`` and ``forward``
|
||||
attributes generated from that ``graph``.
|
||||
|
||||
The difference between ``ColoGraphModule`` and ``torch.fx.GraphModule`` is that
|
||||
``ColoGraphModule`` has a ``bind()`` function to bind customized functions
|
||||
(i.e. activation checkpoint) to ``code`` of ``nn.Module``. If you want to use
|
||||
specific features in Colossal-AI that are not supported by ``torch.fx.GraphModule``,
|
||||
you can use ``ColoGraphModule`` instead.
|
||||
|
||||
``colossalai.fx.symbolic_trace()`` will return a ``ColoGraphModule`` as default.
|
||||
|
||||
.. warning::
|
||||
|
||||
When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
|
||||
regenerated. However, if you edit the contents of the ``graph`` without reassigning
|
||||
the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
|
||||
code.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: torch.fx.Graph,
|
||||
class_name: str = 'GraphModule'):
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
"""Bind function needed for correctly execute ``GraphModule.forward()``
|
||||
|
||||
We need to bind checkpoint functions to ``ColoGraphModule`` so that we could
|
||||
correctly execute ``GraphModule.forward()``
|
||||
|
||||
Args:
|
||||
ckpt_def (List[str]): definition before the forward function
|
||||
globals (Dict[str, Any]): global variables
|
||||
"""
|
||||
|
||||
ckpt_code = "\n".join(ckpt_def)
|
||||
globals_copy = globals.copy()
|
||||
_exec_with_source(ckpt_code, globals_copy)
|
||||
func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func]
|
||||
for func in func_list:
|
||||
tmp_func = globals_copy[func]
|
||||
setattr(self, func, tmp_func.__get__(self, self.__class__))
|
||||
del globals_copy[func]
|
||||
|
||||
def recompile(self) -> PythonCode:
|
||||
"""
|
||||
Recompile this GraphModule from its ``graph`` attribute. This should be
|
||||
called after editing the contained ``graph``, otherwise the generated
|
||||
code of this ``GraphModule`` will be out of date.
|
||||
"""
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
python_code = self._graph.python_code(root_module='self')
|
||||
self._code = python_code.src
|
||||
|
||||
# To split ckpt functions code and forward code
|
||||
_code_list = self._code.split("\n")
|
||||
_fwd_def = [item for item in _code_list if "def forward" in item][0]
|
||||
_fwd_idx = _code_list.index(_fwd_def)
|
||||
ckpt_def = _code_list[:_fwd_idx]
|
||||
self._code = "\n".join(_code_list[_fwd_idx:])
|
||||
|
||||
self.bind(ckpt_def, python_code.globals)
|
||||
|
||||
cls = type(self)
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
# to wrap. If it does, save it in order to have wrapped_call invoke it.
|
||||
# If it does not, wrapped_call can use a dynamic call to super() instead.
|
||||
# In most cases, super().__call__ should be torch.nn.Module.__call__.
|
||||
# We do not want to hold a reference to Module.__call__ here; doing so will
|
||||
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
|
||||
cls_call = cls.__call__ if "__call__" in vars(cls) else None
|
||||
|
||||
if '_wrapped_call' not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
return self._wrapped_call(self, *args, **kwargs)
|
||||
|
||||
cls.__call__ = call_wrapped
|
||||
|
||||
# reset self._code to original src, otherwise to_folder will be wrong
|
||||
self._code = python_code.src
|
||||
return python_code
|
||||
|
||||
def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
|
||||
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
|
||||
imported with ``from <folder> import <module_name>``
|
||||
|
||||
Args:
|
||||
|
||||
folder (Union[str, os.PathLike]): The folder to write the code out to
|
||||
|
||||
module_name (str): Top-level name to use for the ``Module`` while
|
||||
writing out the code
|
||||
"""
|
||||
folder = Path(folder)
|
||||
Path(folder).mkdir(exist_ok=True)
|
||||
torch.save(self.state_dict(), folder / 'state_dict.pt')
|
||||
tab = " " * 4
|
||||
|
||||
# we add import colossalai here
|
||||
model_str = f"""
|
||||
import torch
|
||||
from torch.nn import *
|
||||
import colossalai
|
||||
|
||||
|
||||
class {module_name}(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
"""
|
||||
|
||||
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
|
||||
safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
|
||||
if type(module) in safe_reprs:
|
||||
return f"{module.__repr__()}"
|
||||
else:
|
||||
return None
|
||||
|
||||
blobified_modules = []
|
||||
for module_name, module in self.named_children():
|
||||
module_str = _gen_model_repr(module_name, module)
|
||||
if module_str is None:
|
||||
module_file = folder / f'{module_name}.pt'
|
||||
torch.save(module, module_file)
|
||||
blobified_modules.append(module_name)
|
||||
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
|
||||
module_str = f"torch.load(r'{module_file}') # {module_repr}"
|
||||
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||||
|
||||
for buffer_name, buffer in self._buffers.items():
|
||||
if buffer is None:
|
||||
continue
|
||||
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
|
||||
|
||||
for param_name, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
|
||||
|
||||
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||||
model_str += f"{_addindent(self.code, 4)}\n"
|
||||
|
||||
module_file = folder / 'module.py'
|
||||
module_file.write_text(model_str)
|
||||
|
||||
init_file = folder / '__init__.py'
|
||||
init_file.write_text('from .module import *')
|
||||
|
||||
if len(blobified_modules) > 0:
|
||||
warnings.warn("Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}")
|
|
@ -0,0 +1,211 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
from colossalai._analyzer.envs import MeshConfig
|
||||
|
||||
|
||||
def intersect(a, b):
|
||||
return {k: a[k] for k in a if k in b}
|
||||
|
||||
|
||||
def subtract(a, b):
|
||||
return {k: a[k] for k in a if k not in b}
|
||||
|
||||
|
||||
def union(a, b):
|
||||
return {**a, **b}
|
||||
|
||||
|
||||
def compute_size_in_bytes(elem: torch.Tensor | Dict | List | Tuple | int) -> int:
|
||||
"""Compute the size of a tensor or a collection of tensors in bytes.
|
||||
|
||||
Args:
|
||||
elem (torch.Tensor | Dict | List | Tuple | int): Arbitrary nested ``torch.Tensor`` data structure.
|
||||
|
||||
Returns:
|
||||
int: The size of the tensor or the collection of tensors in bytes.
|
||||
"""
|
||||
nbytes = 0
|
||||
if isinstance(elem, torch.Tensor):
|
||||
if elem.is_quantized:
|
||||
nbytes += elem.numel() * torch._empty_affine_quantized([], dtype=elem.dtype).element_size()
|
||||
else:
|
||||
nbytes += elem.numel() * torch.tensor([], dtype=elem.dtype).element_size()
|
||||
elif isinstance(elem, dict):
|
||||
value_list = [v for _, v in elem.items()]
|
||||
nbytes += compute_size_in_bytes(value_list)
|
||||
elif isinstance(elem, tuple) or isinstance(elem, list) or isinstance(elem, set):
|
||||
for e in elem:
|
||||
nbytes += compute_size_in_bytes(e)
|
||||
return nbytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetaInfo:
|
||||
r"""
|
||||
The base class to store all profiling and static graph analysis information
|
||||
needed for auto-parallel system in Colossal-AI.
|
||||
============================================================================
|
||||
-------------------------------
|
||||
| FX.Node | <-----
|
||||
[input/param] are ---> |[input/param] [grad_inp]| [grad_inp] contributes to the
|
||||
placeholders (might be | | \__________ | | profiled peak memory in backward
|
||||
saved for backward. | | \ | | pass. [grad_param] is calculated
|
||||
| | \ | | separately.
|
||||
| [interm] -------> [grad_int]| <-----
|
||||
| | \_________ | | [grad_interm] marks the peak
|
||||
| / \ \ | | memory in backward pass.
|
||||
[x] is not counted ---> | [x] [interm] --> [grad_int]| <-----
|
||||
in [interm] because | | \_____ | |
|
||||
it is not saved for | | \ | |
|
||||
backward. | [output] \ | | <----- [output] is potentially
|
||||
------------------------------- [input] for the next node.
|
||||
============================================================================
|
||||
|
||||
Accumulate Size = ALL_PREVIOUS_CTX U {Interm Size + Output Size}
|
||||
Output Size = ([output] in global_ctx and not is_alias)
|
||||
Temp Size = ([output] not in global_ctx and not is_alias)
|
||||
Backward Size = ([grad_inp])
|
||||
|
||||
Usage:
|
||||
>>> for node in graph.nodes:
|
||||
>>> n_info = MetaInfo(node) # will create a new MetaInfo instance and store in node.meta['info']
|
||||
>>> # if not exist, otherwise return the existing one
|
||||
>>> n_info.to_recompute = ... # set the to_recompute attribute
|
||||
|
||||
Remarks:
|
||||
This feature is experimental and all the entries are subject to change.
|
||||
"""
|
||||
|
||||
# reference
|
||||
node: Node
|
||||
|
||||
# directory
|
||||
mod_dir: str = ''
|
||||
|
||||
# ctx[data_ptr] = Tensor
|
||||
# mark the storage for ctx.save_for_backward
|
||||
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
|
||||
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
|
||||
|
||||
# should be updated after each graph manipulation
|
||||
# ============================== Update ====================================
|
||||
# parameter and buffer within ``Node``
|
||||
parameters: Dict[str, torch.nn.Parameter] = field(default_factory=lambda: {})
|
||||
buffers: Dict[str, torch.Tensor] = field(default_factory=lambda: {})
|
||||
|
||||
inputs: Tuple[torch.Tensor] = ()
|
||||
outputs: Tuple[torch.Tensor] = ()
|
||||
is_alias: Tuple[bool] = () # whether the output is an alias of input
|
||||
|
||||
# compute cost
|
||||
fwd_flop: Optional[int] = 0
|
||||
bwd_flop: Optional[int] = 0
|
||||
|
||||
# communication cost (should be the size in bytes of communication)
|
||||
fwd_comm: Optional[int] = 0
|
||||
bwd_comm: Optional[int] = 0
|
||||
|
||||
# should keep the same whenever manipulated
|
||||
# ============================= Invariant ==================================
|
||||
to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
to_offload: Optional[bool] = False
|
||||
sharding_spec: str = 'RR'
|
||||
|
||||
def __new__(cls, node: Node, **kwargs):
|
||||
orig_init = cls.__init__
|
||||
|
||||
# if initialized, return the existing one
|
||||
# should disable the __init__ function
|
||||
if node.meta.get('info', None) is not None:
|
||||
|
||||
def _dummy(self, *args, **kwargs):
|
||||
if getattr(self, '_is_init', False):
|
||||
self._is_init = True
|
||||
orig_init(self, *args, **kwargs)
|
||||
cls.__init__ = orig_init
|
||||
|
||||
cls.__init__ = _dummy
|
||||
return node.meta['info']
|
||||
return super().__new__(cls)
|
||||
|
||||
def __post_init__(self):
|
||||
self.node.meta['info'] = self
|
||||
|
||||
@property
|
||||
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
|
||||
return self.fwd_flop / tflops + self.fwd_comm / bandwidth
|
||||
|
||||
@property
|
||||
def bwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
|
||||
return self.bwd_flop / tflops + self.bwd_comm / bandwidth
|
||||
|
||||
@property
|
||||
def param_size(self):
|
||||
return compute_size_in_bytes(self.parameters)
|
||||
|
||||
@property
|
||||
def buffer_size(self):
|
||||
return compute_size_in_bytes(self.buffers)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
"""Used in CheckpointSolver"""
|
||||
output_ctx = {
|
||||
o.data_ptr(): o
|
||||
for o, is_alias in zip(self.outputs, self.is_alias)
|
||||
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
|
||||
}
|
||||
return compute_size_in_bytes(intersect(self.global_ctx, output_ctx))
|
||||
|
||||
@property
|
||||
def accumulate_size(self):
|
||||
"""Used in CheckpointSolver"""
|
||||
output_ctx = {
|
||||
o.data_ptr(): o
|
||||
for o, is_alias in zip(self.outputs, self.is_alias)
|
||||
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
|
||||
}
|
||||
return compute_size_in_bytes(union(self.curr_ctx, intersect(self.global_ctx, output_ctx)))
|
||||
|
||||
@property
|
||||
def temp_size(self):
|
||||
"""Used in CheckpointSolver"""
|
||||
output_ctx = {
|
||||
o.data_ptr(): o
|
||||
for o, is_alias in zip(self.outputs, self.is_alias)
|
||||
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
|
||||
}
|
||||
return compute_size_in_bytes(subtract(output_ctx, self.global_ctx))
|
||||
|
||||
@property
|
||||
def backward_size(self):
|
||||
"""Used in CheckpointSolver"""
|
||||
return compute_size_in_bytes(self.inputs)
|
||||
|
||||
def __repr__(self):
|
||||
s = f'Node {self.node.name}'
|
||||
if self.parameters:
|
||||
s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
|
||||
if self.buffers:
|
||||
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
||||
if self.output_size:
|
||||
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
||||
if self.total_size:
|
||||
s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
||||
if self.temp_size:
|
||||
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
||||
if self.backward_size:
|
||||
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
|
||||
s += f'\n\tfwd_flop = {self.fwd_flop}'\
|
||||
f'\n\tbwd_flop = {self.bwd_flop}'\
|
||||
f'\n\tfwd_comm = {self.fwd_comm}'\
|
||||
f'\n\tbwd_comm = {self.bwd_comm}'\
|
||||
f'\n\tto_recompute = {self.to_recompute}'\
|
||||
f'\n\tto_offload = {self.to_offload}'\
|
||||
f'\n\tsharding_spec = {self.sharding_spec}'
|
||||
return s
|
|
@ -0,0 +1,2 @@
|
|||
from .graph_profile import graph_profile_pass
|
||||
from .shape_prop import ShapeProp, shape_prop_pass, sim_env
|
|
@ -0,0 +1,347 @@
|
|||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
|
||||
from colossalai._analyzer._subclasses import flop_count
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
|
||||
|
||||
def _format_flops(flops: float) -> str:
|
||||
"""Returns a formatted FLOP size string"""
|
||||
if flops > 1e12:
|
||||
return f'{flops / 1e12:.2f} TFLOPs'
|
||||
elif flops > 1e9:
|
||||
return f'{flops / 1e9:.2f} GFLOPs'
|
||||
elif flops > 1e6:
|
||||
return f'{flops / 1e6:.2f} MFLOPs'
|
||||
elif flops > 1e3:
|
||||
return f'{flops / 1e3:.2f} kFLOPs'
|
||||
return f'{flops} FLOPs'
|
||||
|
||||
|
||||
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
return t[0] if len(t) == 1 else t
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def _current_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
|
||||
class GraphProfiler(torch.fx.Interpreter):
|
||||
"""
|
||||
Fetch shape argument from ``ShapeProp`` without re-executing
|
||||
the ``GraphModule`` from scratch.
|
||||
"""
|
||||
_profileable = [
|
||||
'call_function',
|
||||
'call_module',
|
||||
'call_method',
|
||||
]
|
||||
|
||||
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
|
||||
super().__init__(module, garbage_collect_values)
|
||||
|
||||
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
|
||||
"""
|
||||
Run `module` via interpretation and return the result.
|
||||
|
||||
Args:
|
||||
*args: The arguments to the Module to run, in positional order
|
||||
initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
|
||||
This is a dict mapping `Node` to any value. This can be used, for example, to
|
||||
pre-populate results for certain `Nodes` so as to do only partial evaluation within
|
||||
the interpreter.
|
||||
enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
|
||||
process_outputs function first before using them.
|
||||
|
||||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
self.env = initial_env if initial_env else {}
|
||||
|
||||
# Positional function args are consumed left-to-right by
|
||||
# `placeholder` nodes. Use an iterator to keep track of
|
||||
# position and extract those values.
|
||||
if enable_io_processing:
|
||||
args = self.module.graph.process_inputs(*args)
|
||||
self.args_iter: Iterator[Any] = iter(args)
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
|
||||
self.run_node(node) # No need to store.
|
||||
|
||||
if self.garbage_collect_values:
|
||||
for to_delete in self.user_to_last_uses.get(node, []):
|
||||
del self.env[to_delete]
|
||||
|
||||
if node.op == 'output':
|
||||
output_val = self.env[node]
|
||||
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
|
||||
|
||||
def fetch_initial_env(self, device=None) -> Dict[Node, Any]:
|
||||
"""
|
||||
Fetch ``initial_env`` for execution. This is because ``ShapeProp``
|
||||
has already attached outputs of each ``Node`` to its ``MetaInfo``.
|
||||
|
||||
Args:
|
||||
device (torch.device): The device to place the execution, default to ``None``
|
||||
|
||||
Returns:
|
||||
Dict[Node, Any]: The initial environment for execution
|
||||
"""
|
||||
initial_env = {}
|
||||
for n in self.module.graph.nodes:
|
||||
initial_env[n] = _denormalize_tuple(MetaInfo(n).outputs)
|
||||
return initial_env
|
||||
|
||||
def propagate(self, *args, device=None):
|
||||
"""
|
||||
Run `module` via interpretation and profile the execution
|
||||
of each ``Node``.
|
||||
|
||||
Args:
|
||||
*args (Tensor): The sample input, not used
|
||||
device (torch.device): The device to place the execution, default to ``None``
|
||||
|
||||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
initial_env = self.fetch_initial_env(device)
|
||||
|
||||
return self.run(initial_env=initial_env)
|
||||
|
||||
def summary(self) -> str:
|
||||
"""
|
||||
Summarizes the profiled statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
to be installed.
|
||||
|
||||
Returns:
|
||||
str: The summary of the profiled statistics
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
|
||||
# Build up a list of summary information for each node
|
||||
node_summaries: List[List[Any]] = []
|
||||
last_n_info = None
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
n_info = MetaInfo(node)
|
||||
last_n_info = last_n_info or n_info
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
_format_memory(n_info.accumulate_size),
|
||||
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
|
||||
_format_memory(n_info.output_size),
|
||||
_format_memory(n_info.temp_size),
|
||||
_format_memory(n_info.param_size),
|
||||
_format_memory(n_info.backward_size),
|
||||
_format_flops(n_info.fwd_flop),
|
||||
_format_flops(n_info.bwd_flop),
|
||||
])
|
||||
last_n_info = n_info
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Accumulate size',
|
||||
'Incremental size',
|
||||
'Output size',
|
||||
'Temp size',
|
||||
'Param size',
|
||||
'Backward size',
|
||||
'Fwd FLOPs',
|
||||
'Bwd FLOPs',
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
|
||||
|
||||
class CommunicationProfiler(GraphProfiler):
|
||||
"""
|
||||
TODO(lyl): Add this for all comm nodes
|
||||
"""
|
||||
|
||||
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlopProfiler(GraphProfiler):
|
||||
"""
|
||||
Execute an FX graph Node-by-Node and record the meta data of the result
|
||||
into the corresponding node.
|
||||
|
||||
Usage:
|
||||
>>> model = MyModule()
|
||||
>>> x = torch.rand(10, 10)
|
||||
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x}})
|
||||
>>> shape_interp = ShapeProp(gm) # must do this first
|
||||
>>> shape_interp.propagate(x)
|
||||
>>> profiler = FlopProfiler(gm)
|
||||
>>> profiler.propagate(x)
|
||||
|
||||
Args:
|
||||
module (GraphModule): The module to be executed
|
||||
|
||||
Hints:
|
||||
If you want to add a new flop count rule, you can first
|
||||
check the existing files in ``../_subclasses/flop_tensor.py``.
|
||||
If your flop count rules are incompatible with the existing
|
||||
ones, you can do so by adding a new method to this class
|
||||
with the ``@register_flop_count_impl`` decorator. The method
|
||||
should take (*args, **kwargs) instance as its input and
|
||||
generate flop count for both forward and backward as its
|
||||
output.
|
||||
|
||||
For example, if you want to add a flop count rule for
|
||||
``my_fn``, which is a hand-written operand not detected by
|
||||
PyTorch, you can do so by adding a new method to this
|
||||
class with the ``@register_flop_count_impl`` decorator:
|
||||
|
||||
>>> @register_flop_count_impl(my_fn)
|
||||
>>> def my_fn_flop_count_impl(*args, **kwargs):
|
||||
>>> return 0, 0
|
||||
"""
|
||||
_custom_flop_count_impl = {}
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
"""
|
||||
Run a specific node ``n`` and profile its execution time and memory usage.
|
||||
Calls into call_function, call_method, and call_module only.
|
||||
|
||||
Args:
|
||||
n (Node): The Node to profile
|
||||
|
||||
Returns:
|
||||
Any: The output of the node
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the node is not profileable.
|
||||
"""
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
n_info = MetaInfo(n)
|
||||
|
||||
if n.op in self._profileable:
|
||||
try:
|
||||
(
|
||||
n_info.fwd_flop,
|
||||
n_info.bwd_flop,
|
||||
) = getattr(self, n.op)(n.target, args, kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
|
||||
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
|
||||
) from e
|
||||
|
||||
# retain the autograd graph
|
||||
for param in self.module.parameters():
|
||||
param.grad = None
|
||||
|
||||
return _denormalize_tuple(n_info.outputs)
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the profiling result.
|
||||
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
|
||||
profiled in a user-defined behavior.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
|
||||
"""
|
||||
assert not isinstance(target, str)
|
||||
|
||||
# Dispatch the impl for profiling, default will be ``flop_count``
|
||||
if target in self._custom_flop_count_impl:
|
||||
return self._custom_flop_count_impl[target](*args, **kwargs)
|
||||
else:
|
||||
return flop_count(target, *args, **kwargs)
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the profiling result.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
|
||||
"""
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
|
||||
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node and return the profiling result.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
|
||||
"""
|
||||
# Retrieve executed args and kwargs values from the environment
|
||||
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
submod = self.fetch_attr(target)
|
||||
return flop_count(submod, *args, **kwargs)
|
||||
|
||||
|
||||
def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule:
|
||||
"""
|
||||
Run ``module`` via interpretation and profile the execution
|
||||
of each ``Node``.
|
||||
|
||||
Args:
|
||||
module (GraphModule): The GraphModule to profile
|
||||
*args (Any): The sample input, not used
|
||||
verbose (bool): Whether to print the profiling summary
|
||||
|
||||
Returns:
|
||||
GraphModule: The same GraphModule with profiling information
|
||||
"""
|
||||
for profiler_cls in (FlopProfiler,
|
||||
# CommunicationProfiler, # TODO: add communication profiling
|
||||
):
|
||||
profiler = profiler_cls(module)
|
||||
profiler.propagate(*args, device=_current_device(module))
|
||||
|
||||
if verbose:
|
||||
print(profiler.summary())
|
||||
return module
|
|
@ -0,0 +1,194 @@
|
|||
"""``torch.fx.ShapeProp``, but with ``MetaTensor``"""
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.autograd.graph import saved_tensors_hooks
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
|
||||
Target = Union[Callable[..., Any], str]
|
||||
|
||||
|
||||
class sim_env(saved_tensors_hooks):
|
||||
"""
|
||||
A simulation of memory allocation and deallocation in the forward pass
|
||||
using ``saved_tensor_hooks``.
|
||||
|
||||
Attributes:
|
||||
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
||||
data pointer of a tensor to the tensor itself. This is used
|
||||
to track the memory allocation and deallocation.
|
||||
|
||||
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
||||
data pointer of all model parameters to the parameter itself.
|
||||
This avoids overestimating the memory usage of the intermediate activations.
|
||||
"""
|
||||
|
||||
def __init__(self, module: Optional[torch.nn.Module] = None):
|
||||
super().__init__(self.pack_hook, self.unpack_hook)
|
||||
self.ctx = {}
|
||||
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
|
||||
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
|
||||
|
||||
def pack_hook(self, tensor: torch.Tensor):
|
||||
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
|
||||
self.ctx[tensor.data_ptr()] = tensor
|
||||
return tensor
|
||||
|
||||
def unpack_hook(self, tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def _current_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class ShapeProp(torch.fx.Interpreter):
|
||||
"""
|
||||
Execute an FX graph Node-by-Node and record the meta data of the result
|
||||
into the corresponding node.
|
||||
|
||||
Usage:
|
||||
>>> model = MyModule()
|
||||
>>> x = torch.rand(10, 10)
|
||||
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x})
|
||||
>>> interp = ShapeProp(gm)
|
||||
>>> interp.propagate(x)
|
||||
|
||||
Args:
|
||||
module (GraphModule): The module to be executed
|
||||
|
||||
Hints:
|
||||
If you want to add a new shape propagation rule, you can do so by
|
||||
adding a new method to this class with the ``@register_shape_impl``
|
||||
decorator. The method should take (*args, **kwargs) instance as its
|
||||
input and generate output.
|
||||
|
||||
For example, if you want to add a shape propagation rule for
|
||||
``torch.nn.functional.linear``, you can do so by adding a new method
|
||||
to this class with the ``@register_shape_impl`` decorator (Since the
|
||||
``MetaTensorMode`` is compatible with ``torch.nn.functional.linear``,
|
||||
in practice you don't have to do as follows):
|
||||
|
||||
>>> @register_shape_impl(torch.nn.functional.linear)
|
||||
>>> def linear_shape_impl(*args, **kwargs):
|
||||
>>> # do something here
|
||||
>>> return torch.empty(output_shape, device=output_device)
|
||||
"""
|
||||
_custom_dispatch_func = {}
|
||||
_mode = MetaTensorMode()
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule, garbage_collect_values: bool = True):
|
||||
super().__init__(module, garbage_collect_values)
|
||||
self.global_hook = sim_env(module=self.module)
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
"""
|
||||
Run a specific node ``n`` and return the result. Attach
|
||||
(
|
||||
``inputs``, ``outputs``, ``parameters``, ``buffers``
|
||||
) to ``n``.
|
||||
|
||||
Args:
|
||||
n (Node): The ``Node`` to execute
|
||||
|
||||
Returns:
|
||||
Any: The result of executing ``n``
|
||||
"""
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
with self.global_hook:
|
||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
||||
n_info = MetaInfo(n)
|
||||
n_info.outputs = _normalize_tuple(r)
|
||||
|
||||
if n.op == 'call_module':
|
||||
submod = self.fetch_attr(n.target)
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
|
||||
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
|
||||
|
||||
else:
|
||||
n_info.parameters.update({
|
||||
k.name: MetaTensor(v)
|
||||
for k, v in zip(n.args, args)
|
||||
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
||||
})
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
|
||||
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
|
||||
tuple(v for v in kwargs.values() if is_pure_tensor(v))
|
||||
|
||||
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
|
||||
|
||||
n_info.global_ctx = self.global_hook.ctx
|
||||
n_info.curr_ctx = self.global_hook.ctx.copy()
|
||||
|
||||
crit = lambda x: x.data_ptr() in self.global_hook.ctx if isinstance(x, torch.Tensor) else False
|
||||
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
|
||||
return r
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the result.
|
||||
If the target of ``Node`` is registered with ``@register_shape_impl``,
|
||||
the registered function will be used to execute the node. This is common
|
||||
if we insert some customized kernels.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
Any: The value returned by the function invocation
|
||||
"""
|
||||
if target in self._custom_dispatch_func:
|
||||
return self._custom_dispatch_func[target](*args, **kwargs)
|
||||
else:
|
||||
return super().call_function(target, args, kwargs)
|
||||
|
||||
def propagate(self, *args, device=None):
|
||||
"""
|
||||
Run `module` via interpretation and return the result and record the
|
||||
shape of each node.
|
||||
Args:
|
||||
*args (Tensor): The sample input.
|
||||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
||||
with self._mode:
|
||||
return super().run(*tree_map(wrap_fn, args))
|
||||
|
||||
|
||||
def shape_prop_pass(module: torch.fx.GraphModule, *args) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Run ``module`` via interpretation and return the result and record the
|
||||
shape of each ``Node``.
|
||||
|
||||
Args:
|
||||
module (GraphModule): The GraphModule to profile
|
||||
*args (Any): The sample input
|
||||
|
||||
Returns:
|
||||
GraphModule: The same GraphModule with shape information
|
||||
"""
|
||||
|
||||
ShapeProp(module).propagate(*args, device=_current_device(module))
|
||||
return module
|
|
@ -0,0 +1,40 @@
|
|||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
|
||||
from .passes.graph_profile import FlopProfiler
|
||||
|
||||
|
||||
def register_flop_count_impl(func):
|
||||
|
||||
def wrapper(impl):
|
||||
FlopProfiler._custom_flop_count_impl[func] = impl
|
||||
return impl
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def register_shape_impl(func):
|
||||
|
||||
def wrapper(impl):
|
||||
ShapeProp._custom_dispatch_func[func] = impl
|
||||
return impl
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def symbolic_profile(module: GraphModule, *args, verbose=False) -> GraphModule:
|
||||
"""Symbolically profile a model with sample inputs.
|
||||
|
||||
Args:
|
||||
module (GraphModule): The module to be profiled
|
||||
args (Tuple): The sample inputs
|
||||
verbose (bool): Whether to print the profiling result
|
||||
|
||||
Returns:
|
||||
GraphModule: The profiled module
|
||||
"""
|
||||
module = shape_prop_pass(module, *args)
|
||||
module = graph_profile_pass(module, *args, verbose=verbose)
|
||||
return module
|
|
@ -0,0 +1,620 @@
|
|||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import Graph, Node, Proxy, Tracer
|
||||
from torch.fx.graph import _Namespace
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor, _TensorPropertyMethod, _TorchFactoryMethod
|
||||
|
||||
from .codegen import ActivationCheckpointCodeGen
|
||||
from .graph_module import ColoGraphModule
|
||||
from .node_util import MetaInfo
|
||||
|
||||
Target = Union[Callable[..., Any], str]
|
||||
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
|
||||
List[Any], # actually Argument
|
||||
Dict[str, Any], # actually Argument
|
||||
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
||||
'Node',]]
|
||||
zeros = torch.zeros
|
||||
|
||||
|
||||
def _truncate_suffix(s: str):
|
||||
import re
|
||||
|
||||
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
|
||||
return re.sub(r'_\d+$', '', s)
|
||||
|
||||
|
||||
def _default_device():
|
||||
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
|
||||
def _current_device(module):
|
||||
try:
|
||||
return next(module.parameters()).device
|
||||
except:
|
||||
return _default_device()
|
||||
|
||||
|
||||
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
||||
|
||||
def wrapper(impl):
|
||||
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
|
||||
getattr(ColoTracer, name)[func] = impl
|
||||
return impl
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def register_leaf_module_impl(module: nn.Module):
|
||||
|
||||
def wrapper(impl):
|
||||
ColoTracer._custom_leaf_module_impl[module] = impl
|
||||
return impl
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def register_leaf_module(module: nn.Module):
|
||||
ColoTracer._custom_leaf_module.add(module)
|
||||
|
||||
|
||||
def register_non_leaf_module(module: nn.Module):
|
||||
ColoTracer._custom_non_leaf_module.add(module)
|
||||
|
||||
|
||||
class ColoProxy(Proxy):
|
||||
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._meta_data = data
|
||||
|
||||
@property
|
||||
def meta_data(self):
|
||||
return self._meta_data
|
||||
|
||||
@meta_data.setter
|
||||
def meta_data(self, args):
|
||||
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
self._meta_data = tree_map(wrap_fn, args)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
if orig_method in cls._func_dispatch:
|
||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||
proxy = impl(*args, **kwargs)
|
||||
cls._func_dispatch[orig_method] = impl
|
||||
return proxy
|
||||
else:
|
||||
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
|
||||
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
||||
if proxy.meta_data is None:
|
||||
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||
return proxy
|
||||
|
||||
@classmethod
|
||||
def from_torch_proxy(cls, proxy: Proxy):
|
||||
return cls(proxy.node, proxy.tracer)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_data)
|
||||
|
||||
def __int__(self):
|
||||
return int(self.meta_data)
|
||||
|
||||
def __index__(self):
|
||||
try:
|
||||
return int(self.meta_data)
|
||||
except:
|
||||
return zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
|
||||
|
||||
def __float__(self):
|
||||
return float(self.meta_data)
|
||||
|
||||
def __bool__(self):
|
||||
return self.meta_data
|
||||
|
||||
def __getattr__(self, k):
|
||||
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
||||
proxy.meta_data = self._meta_data
|
||||
return proxy
|
||||
|
||||
def __contains__(self, key):
|
||||
if self.node.op == "placeholder":
|
||||
# this is used to handle like
|
||||
# if x in kwargs
|
||||
# we don't handle this case for now
|
||||
return False
|
||||
return super().__contains__(key)
|
||||
|
||||
def __isinstancecheck__(self, type):
|
||||
return isinstance(self.meta_data, type)
|
||||
|
||||
def size(self, dim=None):
|
||||
if self._meta_data is None:
|
||||
return self._meta_data.size(*[dim] if dim else [])
|
||||
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
|
||||
|
||||
def dim(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.dim()
|
||||
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.shape
|
||||
return self.tracer.create_proxy('call_function', getattr, (self, 'shape'), {})
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.ndim
|
||||
return self.tracer.create_proxy('call_function', getattr, (self, 'ndim'), {})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.device
|
||||
return self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.dtype
|
||||
return self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
def __init__(self, root, attr: str, data=None):
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
self.tracer = root.tracer
|
||||
self._meta_data = data
|
||||
self._node: Optional[Node] = None
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
||||
|
||||
|
||||
class ColoTracer(Tracer):
|
||||
_custom_leaf_module: Set[Type[nn.Module]] = set()
|
||||
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
|
||||
_custom_non_leaf_module: Set[Type[nn.Module]] = set()
|
||||
_custom_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
|
||||
_bias_addition_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
|
||||
_bias_addition_module = [
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d,
|
||||
torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d,
|
||||
]
|
||||
|
||||
def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.disable_module_getattr = False
|
||||
self.proxy_buffer_attributes = True
|
||||
|
||||
# whether the tracer will record the usage of torch.utils.checkpoint
|
||||
self.trace_act_ckpt = trace_act_ckpt
|
||||
self.ckpt_regions = []
|
||||
self.ckpt_idx = 0
|
||||
|
||||
self.mod_dir = ''
|
||||
|
||||
# whether the tracer should split the bias_add ops into two ops
|
||||
self.bias_addition_split = bias_addition_split
|
||||
|
||||
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
|
||||
# if bias-addiction split is enabled, and module has bias, then it is not a leaf module
|
||||
# we will enter the module and split the bias-addition ops
|
||||
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
||||
return False
|
||||
|
||||
# user can specify which modules are leaf modules and which are not
|
||||
return (type(m) not in self._custom_non_leaf_module
|
||||
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
||||
|
||||
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any]) -> Any:
|
||||
curr_dir = self.mod_dir
|
||||
self.mod_dir = 'self.' + self.path_of_module(m)
|
||||
rst = super().call_module(m, forward, args, kwargs)
|
||||
self.mod_dir = curr_dir
|
||||
return rst
|
||||
|
||||
def proxy(self, node: Node) -> 'ColoProxy':
|
||||
return ColoProxy(node, self)
|
||||
|
||||
def create_proxy(self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
|
||||
|
||||
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
||||
if kind == 'placeholder':
|
||||
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
|
||||
_truncate_suffix(target), None)
|
||||
elif kind == 'get_attr':
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
attr_itr = self.root
|
||||
atoms = target.split(".")
|
||||
for atom in atoms:
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
proxy.meta_data = attr_itr
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
elif kind == 'call_function':
|
||||
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||
elif kind == 'call_method':
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
if target == '__call__':
|
||||
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
|
||||
else:
|
||||
if target not in _TensorPropertyMethod:
|
||||
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
elif kind == 'call_module':
|
||||
mod = self.root.get_submodule(target)
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
proxy.meta_data = self._custom_leaf_module_impl.get(type(mod),
|
||||
mod.forward)(*tree_map(unwrap_fn, args),
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
return proxy
|
||||
|
||||
def create_node(self, *args, **kwargs) -> Node:
|
||||
node = super().create_node(*args, **kwargs)
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
|
||||
return node
|
||||
|
||||
def trace(self,
|
||||
root: torch.nn.Module,
|
||||
concrete_args: Optional[Dict[str, torch.Tensor]] = {},
|
||||
meta_args: Optional[Dict[str, torch.Tensor]] = {}) -> Graph:
|
||||
|
||||
# check concrete and meta args have valid names
|
||||
sig = inspect.signature(root.forward)
|
||||
sig_names = set(sig.parameters.keys())
|
||||
meta_arg_names = set(meta_args.keys())
|
||||
concrete_arg_names = set(concrete_args.keys())
|
||||
|
||||
# update concrete args with default values
|
||||
for k, v in sig.parameters.items():
|
||||
if k in sig_names - meta_arg_names and \
|
||||
k not in concrete_args and \
|
||||
v.default is not inspect.Parameter.empty:
|
||||
concrete_args[k] = v.default
|
||||
|
||||
def _check_arg_name_valid(names: Iterable[str]):
|
||||
for name in names:
|
||||
if name not in sig_names:
|
||||
raise ValueError(f"Argument {name} is not in the signature of {root.__class__.__name__}.forward")
|
||||
|
||||
_check_arg_name_valid(meta_arg_names)
|
||||
_check_arg_name_valid(concrete_arg_names)
|
||||
|
||||
self.concrete_args = concrete_args
|
||||
self.meta_args = meta_args
|
||||
|
||||
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
|
||||
self.mod_dir = 'self'
|
||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||
self.mod_dir = ''
|
||||
self.graph.lint()
|
||||
return self.graph
|
||||
|
||||
@contextmanager
|
||||
def _tracer_override(self):
|
||||
# override the tracer to support custom modules and checkpointing
|
||||
if self.trace_act_ckpt:
|
||||
orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply
|
||||
orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant
|
||||
|
||||
def checkpoint(run_function, preserve_rng_state=False, *args):
|
||||
self.ckpt_regions.append(self.ckpt_idx)
|
||||
out = run_function(*args)
|
||||
self.ckpt_idx = self.ckpt_regions.pop(-1) + 1
|
||||
return out
|
||||
|
||||
# override the checkpoint function
|
||||
torch.utils.checkpoint.CheckpointFunction.apply = checkpoint
|
||||
torch.utils.checkpoint._checkpoint_without_reentrant = checkpoint
|
||||
|
||||
# override the custom functions
|
||||
ColoProxy._func_dispatch.update({k: v for k, v in self._custom_impl.items()})
|
||||
|
||||
# override the bias addition functions
|
||||
if self.bias_addition_split:
|
||||
ColoProxy._func_dispatch.update({k: v for k, v in self._bias_addition_impl.items()})
|
||||
|
||||
yield
|
||||
|
||||
if self.trace_act_ckpt:
|
||||
# recover the checkpoint function upon exit
|
||||
torch.utils.checkpoint.CheckpointFunction.apply = orig_ckpt_func_apply
|
||||
torch.utils.checkpoint._checkpoint_reentrant = orig_ckpt_func_without_reentrant
|
||||
|
||||
ColoProxy._func_dispatch = {}
|
||||
|
||||
@contextmanager
|
||||
def _torch_factory_override(self):
|
||||
# override the torch factory functions to create a proxy when the method
|
||||
# is called during ``symbolic_trace()``.
|
||||
def wrap_factory_method(target):
|
||||
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args, **kwargs):
|
||||
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
|
||||
isinstance(p, ColoProxy) for p in kwargs.values())
|
||||
if is_proxy:
|
||||
# if the arg is a proxy, then need to record this function called on this proxy
|
||||
# e.g. torch.ones(size) where size is an input proxy
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
proxy = self.create_proxy('call_function', target, args, kwargs)
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
return proxy
|
||||
else:
|
||||
return target(*args, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
overrides = {
|
||||
target: wrap_factory_method(getattr(torch, target))
|
||||
for target in _TorchFactoryMethod
|
||||
if callable(getattr(torch, target))
|
||||
}
|
||||
for name, (wrapper, orig) in overrides.items():
|
||||
setattr(torch, name, wrapper)
|
||||
|
||||
yield
|
||||
|
||||
# recover the torch factory functions upon exit
|
||||
for name, (wrapper, orig) in overrides.items():
|
||||
setattr(torch, name, orig)
|
||||
|
||||
def _post_check(self, non_concrete_arg_names: Set[str]):
|
||||
# This is necessary because concrete args are added as input to the traced module since
|
||||
# https://github.com/pytorch/pytorch/pull/55888.
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
# Removing default values for inputs as the forward pass will fail with them.
|
||||
if node.target in non_concrete_arg_names:
|
||||
node.args = ()
|
||||
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
||||
# It cannot infer on the attributes and methods the input should have, and fails.
|
||||
node.type = torch.Tensor
|
||||
# It is a concrete arg so it is not used and should be removed.
|
||||
else:
|
||||
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
||||
# Newer versions of torch.fx emit an assert statement
|
||||
# for concrete arguments; delete those before we delete
|
||||
# the concrete arg.
|
||||
to_delete = []
|
||||
for user in node.users:
|
||||
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
||||
to_delete.append(user)
|
||||
for user in to_delete:
|
||||
self.graph.erase_node(user)
|
||||
|
||||
self.graph.erase_node(node)
|
||||
|
||||
if node.op == "output":
|
||||
node.type = None
|
||||
self.graph.lint()
|
||||
|
||||
def getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
|
||||
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
if getattr(self, "disable_module_getattr", False):
|
||||
return attr_val
|
||||
|
||||
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
||||
for n, p in collection_to_search:
|
||||
if attr_val is p:
|
||||
if n not in parameter_proxy_cache:
|
||||
kwargs = {}
|
||||
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
|
||||
lambda node: ColoProxy(self, node, n, attr_val))
|
||||
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
parameter_proxy_cache[n] = val_proxy
|
||||
return parameter_proxy_cache[n]
|
||||
return None
|
||||
|
||||
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
||||
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache)
|
||||
if maybe_buffer_proxy is not None:
|
||||
return maybe_buffer_proxy
|
||||
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
||||
parameter_proxy_cache)
|
||||
if maybe_parameter_proxy is not None:
|
||||
return maybe_parameter_proxy
|
||||
|
||||
return attr_val
|
||||
|
||||
|
||||
def symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = {},
|
||||
meta_args: Optional[Dict[str, Any]] = {},
|
||||
trace_act_ckpt: bool = False,
|
||||
bias_addition_split: bool = False,
|
||||
) -> ColoGraphModule:
|
||||
"""
|
||||
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
|
||||
attached to the ``Node``s.
|
||||
|
||||
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
|
||||
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
|
||||
|
||||
This tracer is able to trace basic control flow and for loops.
|
||||
|
||||
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
|
||||
(See ./bias_addition.py for more details).
|
||||
|
||||
Examples:
|
||||
1. Tracing a ``torch.nn.Module`` with control flow.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
if x.size(0) > 1:
|
||||
x = x.sum(dim=0)
|
||||
return self.linear(x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# linear_1 = self.linear(x)
|
||||
# return linear_1
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# sum = x.sum(dim=0); x = None
|
||||
# linear = self.linear(sum); sum = None
|
||||
# return linear
|
||||
|
||||
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
def custom_forward(x):
|
||||
return self.linear(x)
|
||||
return torch.utils.checkpoint.checkpoint(custom_forward, x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
|
||||
|
||||
# traced code like:
|
||||
# def checkpoint_0(self, x):
|
||||
# linear = self.linear(x); x = None
|
||||
# return linear
|
||||
#
|
||||
# def forward(self, x):
|
||||
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
|
||||
# return linear
|
||||
|
||||
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# linear_bias = self.linear.bias
|
||||
# linear_weight = self.linear.weight
|
||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
||||
# add = linear + linear_bias; linear = linear_bias = None
|
||||
# return add
|
||||
|
||||
Args:
|
||||
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
|
||||
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
|
||||
Defaults to {}.
|
||||
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
|
||||
for tracing control flow. Defaults to {}.
|
||||
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
|
||||
Defaults to False.
|
||||
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
|
||||
|
||||
Remarks:
|
||||
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
|
||||
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
|
||||
repo. We welcome any feedback and contributions to enhance the extensibility of
|
||||
Colossal-AI.
|
||||
"""
|
||||
if meta_args:
|
||||
device, orig_device = _default_device(), _current_device(root)
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
||||
bias_addition_split=bias_addition_split).trace(root.to(device),
|
||||
concrete_args=concrete_args,
|
||||
meta_args=tree_map(wrap_fn, meta_args))
|
||||
if trace_act_ckpt:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
root.to(orig_device)
|
||||
else:
|
||||
graph = Tracer().trace(root, concrete_args=concrete_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return ColoGraphModule(root, graph, name)
|
|
@ -226,7 +226,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
|||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
return super().run(*args)
|
||||
return self.run(*args)
|
||||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
"""
|
||||
|
|
|
@ -288,13 +288,16 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
def flops_repr(flop: int) -> str:
|
||||
return f"{flop:,} FLOPs"
|
||||
|
||||
accumulate_size = 0
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
flops_repr(node.meta['fwd_flop']),
|
||||
flops_repr(node.meta['bwd_flop']),
|
||||
mem_repr(accumulate_size),
|
||||
mem_repr(calculate_fwd_in(node)),
|
||||
mem_repr(calculate_fwd_out(node)),
|
||||
mem_repr(calculate_fwd_tmp(node)),
|
||||
|
@ -309,6 +312,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
'Op',
|
||||
'Forward FLOPs',
|
||||
'Backward FLOPs',
|
||||
'Accumulated Memory',
|
||||
'FWD_IN',
|
||||
'FWD_OUT',
|
||||
'FWD_TMP',
|
||||
|
|
|
@ -347,6 +347,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.squeeze.dim,
|
||||
aten.slice.Tensor,
|
||||
aten.slice_backward.default,
|
||||
aten.stack.default,
|
||||
aten.split.Tensor,
|
||||
aten.permute.default,
|
||||
aten.t.default,
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, bias):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channel,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=bias,
|
||||
padding=1,
|
||||
stride=2,
|
||||
dilation=2,
|
||||
groups=3)
|
||||
self.conv_transpose = torch.nn.ConvTranspose2d(in_channel,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=bias,
|
||||
padding=1,
|
||||
stride=2,
|
||||
dilation=2,
|
||||
groups=3)
|
||||
|
||||
def forward(self, x, select=0):
|
||||
if select == 0:
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.conv_transpose(x)
|
||||
return x
|
||||
|
||||
|
||||
class SiuModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, bias) -> None:
|
||||
super().__init__()
|
||||
self.linear = LinearModel(3, 3, bias)
|
||||
self.conv = ConvModel(3, 6, 3, bias)
|
||||
|
||||
def forward(self, x, select=0):
|
||||
x = self.linear(x)
|
||||
x = checkpoint(self.conv, x, select)
|
||||
return x
|
||||
|
||||
|
||||
class AddmmModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, alpha, beta) -> None:
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
@pytest.mark.parametrize("select", [0, 1])
|
||||
def test_siu_model(bias, bias_addition_split, shape, select):
|
||||
model = SiuModel(bias=bias)
|
||||
x = torch.rand(shape)
|
||||
gm = symbolic_trace(model,
|
||||
meta_args={'x': x},
|
||||
concrete_args={'select': select},
|
||||
trace_act_ckpt=True,
|
||||
bias_addition_split=bias_addition_split)
|
||||
assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!'
|
||||
if bias and bias_addition_split:
|
||||
assert '+' in gm.code, 'bias addition should be split!'
|
||||
else:
|
||||
assert '+' not in gm.code, 'bias addition should not be split!'
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize("alpha", [1, 2])
|
||||
@pytest.mark.parametrize("beta", [1, 2])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3), (5, 5)])
|
||||
def test_addmm_model(alpha, beta, bias_addition_split, shape):
|
||||
model = AddmmModel(alpha=alpha, beta=beta)
|
||||
x = torch.rand(shape)
|
||||
gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split)
|
||||
assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!'
|
||||
if (alpha == 1 and beta == 1) or not bias_addition_split:
|
||||
assert '*' not in gm.code, 'bias addition should not be split!'
|
||||
elif bias_addition_split:
|
||||
assert '+' in gm.code, 'bias addition should be split!'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_siu_model(True, True, (3, 3, 3))
|
|
@ -0,0 +1,78 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, bias):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channel,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=bias,
|
||||
padding=1,
|
||||
stride=2,
|
||||
dilation=2,
|
||||
groups=3)
|
||||
self.conv_transpose = torch.nn.ConvTranspose2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=bias,
|
||||
padding=1,
|
||||
stride=2,
|
||||
dilation=2,
|
||||
groups=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.conv_transpose(x)
|
||||
return x
|
||||
|
||||
|
||||
class AModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, bias) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = LinearModel(3, 3, bias)
|
||||
self.linear_2 = LinearModel(3, 3, bias)
|
||||
self.conv = ConvModel(3, 6, 3, bias)
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(x.shape[0]):
|
||||
x = self.linear_1(x)
|
||||
x = self.linear_2(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
def test_mod_dir(bias, bias_addition_split, shape):
|
||||
model = AModel(bias=bias)
|
||||
x = torch.rand(shape)
|
||||
gm = symbolic_trace(model, meta_args={'x': x}, bias_addition_split=bias_addition_split)
|
||||
for node in gm.graph.nodes:
|
||||
assert len(node.meta['info'].mod_dir), f"{node} should have non-trivial ``mod_dir``."
|
||||
print(node, node.meta['info'].mod_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mod_dir(True, True, (3, 3, 3))
|
|
@ -0,0 +1,55 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = nn.Linear(10, 10)
|
||||
self.b = nn.Linear(10, 10)
|
||||
self.c = nn.Linear(10, 10)
|
||||
self.d = nn.Linear(10, 10)
|
||||
self.e = nn.Linear(10, 10)
|
||||
|
||||
def checkpoint_0(self, x):
|
||||
return checkpoint(self.checkpoint_0_0, x) + checkpoint(self.checkpoint_0_1, x) + self.e(x)
|
||||
|
||||
def checkpoint_0_0(self, x):
|
||||
return checkpoint(self.checkpoint_0_0_0, x) + checkpoint(self.checkpoint_0_0_1, x)
|
||||
|
||||
def checkpoint_0_0_0(self, x):
|
||||
return self.a(x) + checkpoint(self.checkpoint_0_0_0_0, x, use_reentrant=False)
|
||||
|
||||
def checkpoint_0_0_0_0(self, x):
|
||||
return self.b(x)
|
||||
|
||||
def checkpoint_0_0_1(self, x):
|
||||
return self.b(x) + self.c(x)
|
||||
|
||||
def checkpoint_0_1(self, x):
|
||||
return self.d(x)
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self.checkpoint_0, x)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
def test_nested_ckpt():
|
||||
model = MyModule()
|
||||
x = torch.rand(10, 10)
|
||||
gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True)
|
||||
assert torch.allclose(gm(x), model(x)), "The traced model should generate the same output as the original model."
|
||||
for ckpt_def in filter(lambda s: s.startswith('checkpoint'), dir(model)):
|
||||
assert ckpt_def in gm.code, f"Checkpoint {ckpt_def} should be in the traced code.\n Traced code = {gm.code}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nested_ckpt()
|
|
@ -0,0 +1,63 @@
|
|||
import pytest
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from .zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensorMode
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.symbolic_profile import register_shape_impl
|
||||
|
||||
|
||||
@register_shape_impl(torch.nn.functional.linear)
|
||||
def linear_impl(*args, **kwargs):
|
||||
assert True
|
||||
return torch.nn.functional.linear(*args, **kwargs)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def _check_gm_validity(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
|
||||
if node.op in [
|
||||
# 'call_module', # can apply to params
|
||||
# 'call_function', # can apply to params
|
||||
# 'call_method', # can apply to params
|
||||
]:
|
||||
assert node.meta['info'].inputs, f'In {gm.__class__.__name__}, {node} has no input shape.'
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models)
|
||||
def test_torchvision_shape_prop(m):
|
||||
with MetaTensorMode():
|
||||
model = m()
|
||||
data = torch.rand(100, 3, 224, 224)
|
||||
meta_args = {
|
||||
"x": data,
|
||||
}
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
shape_prop_pass(gm, data)
|
||||
_check_gm_validity(gm)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tmm_models)
|
||||
def test_timm_shape_prop(m):
|
||||
with MetaTensorMode():
|
||||
model = m()
|
||||
data = torch.rand(100, 3, 224, 224)
|
||||
meta_args = {
|
||||
"x": data,
|
||||
}
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
shape_prop_pass(gm, data)
|
||||
_check_gm_validity(gm)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torchvision_shape_prop(tm.resnet18)
|
||||
test_timm_shape_prop(tmm.vgg11)
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from .zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensorMode
|
||||
from colossalai._analyzer.fx import symbolic_profile, symbolic_trace
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def _check_gm_validity(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models)
|
||||
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
|
||||
with MetaTensorMode():
|
||||
model = m()
|
||||
data = torch.rand(8, 3, 224, 224)
|
||||
meta_args = {
|
||||
"x": data,
|
||||
}
|
||||
gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split)
|
||||
symbolic_profile(gm, data, verbose=verbose)
|
||||
_check_gm_validity(gm)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tmm_models)
|
||||
def test_timm_profile(m, verbose=False, bias_addition_split=False):
|
||||
with MetaTensorMode():
|
||||
model = m()
|
||||
data = torch.rand(8, 3, 224, 224)
|
||||
meta_args = {
|
||||
"x": data,
|
||||
}
|
||||
gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split)
|
||||
symbolic_profile(gm, data, verbose=verbose)
|
||||
_check_gm_validity(gm)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torchvision_profile(tm.vit_b_16, verbose=True, bias_addition_split=False)
|
||||
test_timm_profile(tmm.gmlp_b16_224, verbose=True, bias_addition_split=False)
|
|
@ -0,0 +1,53 @@
|
|||
import timm.models as tmm
|
||||
import torchvision.models as tm
|
||||
|
||||
# input shape: (batch_size, 3, 224, 224)
|
||||
tm_models = [
|
||||
tm.alexnet,
|
||||
tm.convnext_base,
|
||||
tm.densenet121,
|
||||
# tm.efficientnet_v2_s,
|
||||
# tm.googlenet, # output bad case
|
||||
# tm.inception_v3, # bad case
|
||||
tm.mobilenet_v2,
|
||||
tm.mobilenet_v3_small,
|
||||
tm.mnasnet0_5,
|
||||
tm.resnet18,
|
||||
tm.regnet_x_16gf,
|
||||
tm.resnext50_32x4d,
|
||||
tm.shufflenet_v2_x0_5,
|
||||
tm.squeezenet1_0,
|
||||
# tm.swin_s, # fx bad case
|
||||
tm.vgg11,
|
||||
tm.vit_b_16,
|
||||
tm.wide_resnet50_2,
|
||||
]
|
||||
|
||||
tmm_models = [
|
||||
tmm.beit_base_patch16_224,
|
||||
tmm.beitv2_base_patch16_224,
|
||||
tmm.cait_s24_224,
|
||||
tmm.coat_lite_mini,
|
||||
tmm.convit_base,
|
||||
tmm.deit3_base_patch16_224,
|
||||
tmm.dm_nfnet_f0,
|
||||
tmm.eca_nfnet_l0,
|
||||
tmm.efficientformer_l1,
|
||||
tmm.ese_vovnet19b_dw,
|
||||
tmm.gmixer_12_224,
|
||||
tmm.gmlp_b16_224,
|
||||
tmm.hardcorenas_a,
|
||||
tmm.hrnet_w18_small,
|
||||
tmm.inception_v3,
|
||||
tmm.mixer_b16_224,
|
||||
tmm.nf_ecaresnet101,
|
||||
tmm.nf_regnet_b0,
|
||||
# tmm.pit_b_224, # pretrained only
|
||||
tmm.regnetv_040,
|
||||
tmm.skresnet18,
|
||||
# tmm.swin_base_patch4_window7_224, # fx bad case
|
||||
# tmm.tnt_b_patch16_224, # bad case
|
||||
tmm.vgg11,
|
||||
tmm.vit_base_patch16_18x2_224,
|
||||
tmm.wide_resnet50_2,
|
||||
]
|
|
@ -0,0 +1,82 @@
|
|||
from typing import Any, Callable, Union
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
except:
|
||||
pass
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
registered_meta = {
|
||||
('aten.convolution.default', True): [ # (aten ops, requires_backward)
|
||||
(nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
|
||||
(nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)),
|
||||
(nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)),
|
||||
(nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
|
||||
(nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
|
||||
dilation=2), torch.rand(2, 3, 4, 4)),
|
||||
(nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
|
||||
dilation=2), torch.rand(2, 3, 4, 4, 4)),
|
||||
],
|
||||
('aten.native_batch_norm.default', True): [
|
||||
(nn.BatchNorm1d(4), torch.rand(2, 4)),
|
||||
(nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)),
|
||||
(nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)),
|
||||
],
|
||||
('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),],
|
||||
('aten.avg_pool1d.default', True): [
|
||||
(nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)),
|
||||
(nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)),
|
||||
(nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)),
|
||||
(nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)),
|
||||
],
|
||||
('aten.avg_pool2d.default', True): [
|
||||
(nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
|
||||
],
|
||||
('aten.relu.default', True): [
|
||||
(nn.ReLU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.LeakyReLU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.SiLU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.GELU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.ELU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.Sigmoid(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.Tanh(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.Hardswish(), torch.rand(4, 3, 1, 2)),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
|
||||
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
|
||||
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
|
||||
assert tensor.stride() == meta_tensor.stride(
|
||||
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
|
||||
|
||||
|
||||
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
|
||||
x.requires_grad = requires_backward
|
||||
meta_x = MetaTensor(x)
|
||||
x_out, meta_out = f(x), f(meta_x)
|
||||
compare_all(x_out, meta_out)
|
||||
if requires_backward:
|
||||
x_out.sum().backward()
|
||||
meta_out.sum().backward()
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
def test_meta_aten():
|
||||
for (aten_op, requires_backward), v in registered_meta.items():
|
||||
for f, x in v:
|
||||
run_and_compare(f, x, requires_backward)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_meta_aten()
|
|
@ -0,0 +1,50 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as tm
|
||||
from .zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||
def test_flop_count_module(m):
|
||||
x = torch.rand(2, 3, 224, 224)
|
||||
with MetaTensorMode(): # save time for testing
|
||||
module = m()
|
||||
rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}'
|
||||
assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}'
|
||||
|
||||
|
||||
odd_cases = [
|
||||
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
||||
'inplace': True
|
||||
}),
|
||||
(F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
||||
'kernel_size': 3,
|
||||
'stride': 2,
|
||||
'padding': 1,
|
||||
'dilation': 2
|
||||
}),
|
||||
(torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True),
|
||||
torch.rand(2, 3, 224, 224, requires_grad=True)), {}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
|
||||
def test_flop_count_function(func, args, kwargs):
|
||||
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
|
||||
assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224))
|
||||
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
|
|
@ -0,0 +1,38 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchvision.models as tm
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
||||
except:
|
||||
pass
|
||||
from .zoo import tm_models, tmm_models
|
||||
|
||||
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
|
||||
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
|
||||
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
|
||||
assert tensor.stride() == meta_tensor.stride(
|
||||
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
|
||||
|
||||
|
||||
def run_and_compare(model):
|
||||
x = torch.rand(2, 3, 224, 224, requires_grad=True)
|
||||
x_out = model(x)
|
||||
with MetaTensorMode():
|
||||
meta_x = torch.rand(2, 3, 224, 224, requires_grad=True)
|
||||
meta_out = model(meta_x)
|
||||
compare_all(x_out, meta_out)
|
||||
x_out.sum().backward()
|
||||
meta_out.sum().backward()
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||
def test_meta_mode_shape(m):
|
||||
run_and_compare(m())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_meta_mode_shape(tm.resnet18)
|
|
@ -0,0 +1,53 @@
|
|||
import timm.models as tmm
|
||||
import torchvision.models as tm
|
||||
|
||||
# input shape: (batch_size, 3, 224, 224)
|
||||
tm_models = [
|
||||
tm.alexnet,
|
||||
tm.convnext_base,
|
||||
tm.densenet121,
|
||||
# tm.efficientnet_v2_s,
|
||||
# tm.googlenet, # output bad case
|
||||
# tm.inception_v3, # bad case
|
||||
tm.mobilenet_v2,
|
||||
tm.mobilenet_v3_small,
|
||||
tm.mnasnet0_5,
|
||||
tm.resnet18,
|
||||
tm.regnet_x_16gf,
|
||||
tm.resnext50_32x4d,
|
||||
tm.shufflenet_v2_x0_5,
|
||||
tm.squeezenet1_0,
|
||||
# tm.swin_s, # fx bad case
|
||||
tm.vgg11,
|
||||
tm.vit_b_16,
|
||||
tm.wide_resnet50_2,
|
||||
]
|
||||
|
||||
tmm_models = [
|
||||
tmm.beit_base_patch16_224,
|
||||
tmm.beitv2_base_patch16_224,
|
||||
tmm.cait_s24_224,
|
||||
tmm.coat_lite_mini,
|
||||
tmm.convit_base,
|
||||
tmm.deit3_base_patch16_224,
|
||||
tmm.dm_nfnet_f0,
|
||||
tmm.eca_nfnet_l0,
|
||||
tmm.efficientformer_l1,
|
||||
tmm.ese_vovnet19b_dw,
|
||||
tmm.gmixer_12_224,
|
||||
tmm.gmlp_b16_224,
|
||||
tmm.hardcorenas_a,
|
||||
tmm.hrnet_w18_small,
|
||||
tmm.inception_v3,
|
||||
tmm.mixer_b16_224,
|
||||
tmm.nf_ecaresnet101,
|
||||
tmm.nf_regnet_b0,
|
||||
# tmm.pit_b_224, # pretrained only
|
||||
tmm.regnetv_040,
|
||||
tmm.skresnet18,
|
||||
# tmm.swin_base_patch4_window7_224, # fx bad case
|
||||
# tmm.tnt_b_patch16_224, # bad case
|
||||
tmm.vgg11,
|
||||
tmm.vit_base_patch16_18x2_224,
|
||||
tmm.wide_resnet50_2,
|
||||
]
|
Loading…
Reference in New Issue