[analyzer] a minimal implementation of static graph analyzer (#2852)

* [hotfix] meta tensor default device.

* [siu] add experimental submodules to main branch.

* [siu]

* [siu]

* [analyzer] init.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [test] add test.

* Update symbolic_trace.py

* mark skip tests.

* try except.

* try except.

* try except.

* s

* init

* init

* fix

* skip

* skip

---------

Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com>
Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
pull/3094/head
Super Daniel 2023-03-10 13:21:05 +08:00 committed by GitHub
parent 5d5f475d75
commit fff98f06ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 4471 additions and 1 deletions

View File

@ -0,0 +1,306 @@
# Analyzer
# Overview
The Analyzer is a collection of static graph utils including Colossal-AI FX. Features include:
- MetaTensor -- enabling:
- Ahead-of-time Profiling
- Shape Propagation
- Ideal Flop Counter
- symbolic_trace()
- Robust Control-flow Tracing / Recompile
- Robust Activation Checkpoint Tracing / CodeGen
- Easy-to-define Bias-Addition Split
- symbolic_profile()
- Support ``MetaTensorMode``, where all Tensor operations are executed symbolically.
- Shape Inference Across Device and Unified ``MetaInfo``
- Ideal Flop Counter https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
# Quickstart
## Analyzer.FX
**Reference:**
https://pytorch.org/docs/stable/fx.html [[paper](https://arxiv.org/pdf/2112.08429)]
torch.FX is a toolkit for developers to use to transform nn.Module instances. FX consists of three main components: a symbolic tracer, an intermediate representation, and Python code generation. FX.Tracer hacks _\_\_torch_function\_\__ and use a Proxy object to propagate through any forward function of torch.nn.Module.
![image](https://user-images.githubusercontent.com/78588128/212531495-bbb934dd-dbbb-4578-8869-6171973f7dd8.png)
ColossalAI FX is modified from torch.FX, with the extra capability of ahead-of-time profiling enabled by the subclass of ``MetaTensor``.
### Analyzer.FX.symbolic_trace()
A drawback of the original torch.FX implementation is that it is poor at handling control flow. All control flow is not PyTorch native operands and requires actual instances that specify the branches to execute on. For example,
```python
class MyModule(nn.Module):
def forward(self, x):
if x.dim() == 3:
return x * 2 + 1
else:
return x - 5
```
The above function has the computation graph of
![image](https://user-images.githubusercontent.com/78588128/212532631-dba30734-577b-4418-8dc9-004d7983abc5.png)
However, since Proxy does not have concrete data, applying ``x.dim()`` will return nothing. In the context of the auto-parallel system, at least the control-flow dependencies for tensor shape should be removed, since any searched strategy could only auto-parallelize a specific computation graph with the same tensor shape. It is native to attach concrete data onto a Proxy, and propagate them through control flow.
![image](https://user-images.githubusercontent.com/78588128/212533403-1b620986-1c3a-420a-87c6-d08c9702135d.png)
With ``MetaTensor``, the computation during shape propagation can be virtualized. This speeds up tracing by avoiding allocating actual memory on devices.
#### Remarks
There is no free lunch for PyTorch to unify all operands in both its repo and other repos in its eco-system. For example, the einops library currently has no intention to support torch.FX (See https://github.com/arogozhnikov/einops/issues/188). To support different PyTorch-based libraries without modifying source code, good practices can be to allow users to register their implementation to substitute the functions not supported by torch.FX, or to avoid entering incompatible submodules.
### Analyzer.FX.symbolic_profile()
``symbolic_profile`` is another important feature of Colossal-AI's auto-parallel system. Profiling DNN can be costly, as you need to allocate memory and execute on real devices. However, since the profiling requirements for auto-parallel is enough if we can detect when and where the intermediate activations (i.e. Tensor) are generated, we can profile the whole procedure without actually executing it. ``symbolic_profile``, as its name infers, profiles the whole network with symbolic information only.
```python
with MetaTensorMode():
model = MyModule().cuda()
sample = torch.rand(100, 3, 224, 224).cuda()
meta_args = dict(
x = sample,
)
gm = symbolic_trace(model, meta_args=meta_args)
gm = symbolic_profile(gm, sample)
```
``symbolic_profile`` is enabled by ``ShapeProp`` and ``GraphProfile``.
#### ShapeProp
Both Tensor Parallel and Activation Checkpoint solvers need to know the shape information ahead of time. Unlike PyTorch's implementation, this ``ShapeProp`` can be executed under MetaTensorMode. With this, all the preparation for auto-parallel solvers can be done in milliseconds.
Meanwhile, it is easy to keep track of the memory usage of each node when doing shape propagation. However, the drawbacks of FX is that not every ``call_function`` saves its input for backward, and different tensor that flows within one FX.Graph can actually have the same layout. This raises problems for fine-grained profiling.
![image](https://user-images.githubusercontent.com/78588128/215312957-7eb6cbc3-61b2-49cf-95a4-6b859149eb8d.png)
To address this problem, I came up with a simulated environment enabled by ``torch.autograd.graph.saved_tensor_hooks`` and fake ``data_ptr`` (check ``_subclasses/meta_tensor.py`` for more details of ``data_ptr`` updates).
```python
class sim_env(saved_tensors_hooks):
"""
A simulation of memory allocation and deallocation in the forward pass
using ``saved_tensor_hooks``.
Attributes:
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
data pointer of a tensor to the tensor itself. This is used
to track the memory allocation and deallocation.
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
data pointer of all model parameters to the parameter itself.
This avoids overestimating the memory usage of the intermediate activations.
"""
def __init__(self, module: Optional[torch.nn.Module] = None):
super().__init__(self.pack_hook, self.unpack_hook)
self.ctx = {}
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
def pack_hook(self, tensor: torch.Tensor):
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
self.ctx[tensor.data_ptr()] = tensor
return tensor
def unpack_hook(self, tensor):
return tensor
```
The ``ctx`` variable will keep track of all saved tensors with a unique identifier. It is likely that ``nn.Parameter`` is also counted in the ``ctx``, which is not desired. To avoid this, we can use ``param_ctx`` to keep track of all parameters in the model. The ``buffer_ctx`` is used to keep track of all buffers in the model. The ``local_ctx`` that is attached to each ``Node`` marks the memory usage of the stage to which the node belongs. With simple ``intersect``, ``union`` and ``subtract`` operations, we can get any memory-related information. For non-profileable nodes, you might add your customized profile rules to simulate the memory allocation. If a ``Graph`` is modified with some non-PyTorch functions, such as fused operands, you can register the shape propagation rule with the decorator.
```python
@register_shape_impl(fuse_conv_bn)
def fuse_conv_bn_shape_impl(*args, **kwargs):
# infer output shape here
return torch.empty(output_shape, device=output_device)
```
An important notice is that ``ShapeProp`` will attach additional information to the graph, which will be exactly the input of ``Profiler``.
#### GraphProfiler
``GraphProfiler`` executes at the node level, and profiles both forward and backward within one node. For example, ``FlopProfiler`` will profile the forward and backward FLOPs of a node, and ``CommunicationProfiler`` will profile the forward and backward communication cost of a node. The ``GraphProfiler`` will attach the profiling results to the ``Node``. These procedures are decoupled for better extensibility.
To provide a general insight of the profiled results, you can set ``verbose=True`` to print the summary as well.
```python
model = tm.resnet18()
sample = torch.rand(100, 3, 224, 224)
meta_args = dict(x=sample)
gm = symbolic_trace(model, meta_args=meta_args)
gm = symbolic_profile(gm, sample, verbose=True)
============================================================ Results =====================================================================
Op type Op Accumulate size Incremental size Output size Temp size Param size Backward size Fwd FLOPs Bwd FLOPs
------------- ---------------------------------------------- ----------------- ------------------ ------------- ----------- ------------ --------------- ------------- -------------
placeholder x 4.59 Mb 0 b 4.59 Mb 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module conv_proj 4.59 Mb 0 b 0 b 4.59 Mb 2.25 Mb 4.59 Mb 924.84 MFLOPs 924.84 MFLOPs
call_method reshape 4.59 Mb 0 b 0 b 4.59 Mb 0 b 4.59 Mb 0 FLOPs 0 FLOPs
call_method permute 4.59 Mb 0 b 0 b 4.59 Mb 0 b 4.59 Mb 0 FLOPs 0 FLOPs
get_attr class_token 4.59 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_method expand 4.59 Mb 0 b 0 b 24.00 Kb 3.00 Kb 0 b 0 FLOPs 6.14 kFLOPs
call_function cat 4.59 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
get_attr encoder_pos_embedding 4.59 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_function add 9.21 Mb 4.62 Mb 4.62 Mb 0 b 591.00 Kb 4.62 Mb 1.21 MFLOPs 1.21 MFLOPs
call_module encoder_dropout 9.21 Mb 0 b 4.62 Mb 0 b 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_0_ln_1 9.22 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_0_self_attention 46.52 Mb 37.30 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem 46.52 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_1 46.52 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_0_dropout 46.52 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_1 51.14 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_0_ln_2 51.15 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_0_mlp_0 74.24 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_0_mlp_1 92.71 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_0_mlp_2 92.71 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_0_mlp_3 92.71 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_0_mlp_4 92.71 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_2 97.32 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_1_ln_1 101.95 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_1_self_attention 134.63 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_2 134.63 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_3 134.63 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_1_dropout 134.63 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_3 139.25 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_1_ln_2 139.26 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_1_mlp_0 162.35 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_1_mlp_1 180.82 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_1_mlp_2 180.82 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_1_mlp_3 180.82 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_1_mlp_4 180.82 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_4 185.43 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_2_ln_1 190.06 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_2_self_attention 222.74 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_4 222.74 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_5 222.74 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_2_dropout 222.74 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_5 227.36 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_2_ln_2 227.37 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_2_mlp_0 250.46 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_2_mlp_1 268.93 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_2_mlp_2 268.93 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_2_mlp_3 268.93 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_2_mlp_4 268.93 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_6 273.54 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_3_ln_1 278.17 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_3_self_attention 310.86 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_6 310.86 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_7 310.86 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_3_dropout 310.86 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_7 315.47 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_3_ln_2 315.48 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_3_mlp_0 338.57 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_3_mlp_1 357.04 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_3_mlp_2 357.04 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_3_mlp_3 357.04 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_3_mlp_4 357.04 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_8 361.66 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_4_ln_1 366.29 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_4_self_attention 398.97 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_8 398.97 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_9 398.97 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_4_dropout 398.97 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_9 403.58 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_4_ln_2 403.60 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_4_mlp_0 426.68 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_4_mlp_1 445.15 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_4_mlp_2 445.15 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_4_mlp_3 445.15 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_4_mlp_4 445.15 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_10 449.77 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_5_ln_1 454.40 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_5_self_attention 487.08 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_10 487.08 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_11 487.08 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_5_dropout 487.08 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_11 491.70 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_5_ln_2 491.71 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_5_mlp_0 514.79 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_5_mlp_1 533.26 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_5_mlp_2 533.26 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_5_mlp_3 533.26 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_5_mlp_4 533.26 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_12 537.88 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_6_ln_1 542.51 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_6_self_attention 575.19 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_12 575.19 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_13 575.19 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_6_dropout 575.19 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_13 579.81 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_6_ln_2 579.82 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_6_mlp_0 602.90 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_6_mlp_1 621.37 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_6_mlp_2 621.37 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_6_mlp_3 621.37 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_6_mlp_4 621.37 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_14 625.99 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_7_ln_1 630.62 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_7_self_attention 663.30 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_14 663.30 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_15 663.30 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_7_dropout 663.30 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_15 667.92 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_7_ln_2 667.93 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_7_mlp_0 691.02 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_7_mlp_1 709.48 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_7_mlp_2 709.48 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_7_mlp_3 709.48 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_7_mlp_4 709.48 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_16 714.10 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_8_ln_1 718.73 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_8_self_attention 751.41 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_16 751.41 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_17 751.41 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_8_dropout 751.41 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_17 756.03 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_8_ln_2 756.04 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_8_mlp_0 779.13 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_8_mlp_1 797.60 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_8_mlp_2 797.60 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_8_mlp_3 797.60 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_8_mlp_4 797.60 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_18 802.21 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_9_ln_1 806.84 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_9_self_attention 839.52 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_18 839.52 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_19 839.52 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_9_dropout 839.52 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_19 844.14 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_9_ln_2 844.15 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_9_mlp_0 867.24 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_9_mlp_1 885.71 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_9_mlp_2 885.71 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_9_mlp_3 885.71 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_9_mlp_4 885.71 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_20 890.32 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_10_ln_1 894.95 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_10_self_attention 927.63 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_20 927.63 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_21 927.63 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_10_dropout 927.63 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_21 932.25 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_10_ln_2 932.26 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_10_mlp_0 955.35 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_10_mlp_1 973.82 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_10_mlp_2 973.82 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_10_mlp_3 973.82 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_10_mlp_4 973.82 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_22 978.44 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_11_ln_1 983.06 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_11_self_attention 1015.75 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
call_function getitem_22 1015.75 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
call_function getitem_23 1015.75 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_11_dropout 1015.75 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_23 1020.36 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_11_ln_2 1020.38 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_module encoder_layers_encoder_layer_11_mlp_0 1.02 Gb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_11_mlp_1 1.04 Gb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
call_module encoder_layers_encoder_layer_11_mlp_2 1.04 Gb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
call_module encoder_layers_encoder_layer_11_mlp_3 1.04 Gb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
call_module encoder_layers_encoder_layer_11_mlp_4 1.04 Gb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_function add_24 1.04 Gb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
call_module encoder_ln 1.04 Gb 36.31 Kb 24.00 Kb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
call_function getitem_24 1.04 Gb 0 b 24.00 Kb 0 b 0 b 4.62 Mb 0 FLOPs 0 FLOPs
call_module heads_head 1.04 Gb 0 b 0 b 31.25 Kb 2.93 Mb 24.00 Kb 6.14 MFLOPs 12.30 MFLOPs
output output 1.04 Gb 0 b 0 b 31.25 Kb 0 b 31.25 Kb 0 FLOPs 0 FLOPs
```

View File

@ -0,0 +1,4 @@
from ._meta_registration import *
from ._monkey_patch import *
from .flop_tensor import flop_count, flop_mapping
from .meta_tensor import MetaTensor, MetaTensorMode

View File

@ -0,0 +1,481 @@
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
# should be activated for PyTorch version 1.12.0 and below
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch.utils._pytree import tree_map
aten = torch.ops.aten
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
meta_table = {}
orig_empty = torch.empty
orig_empty_strided = torch.empty_strided
orig_empty_like = torch.empty_like
def new(*args, **kwargs):
return orig_empty(*args, **kwargs, device=torch.device('meta'))
def new_strided(*args, **kwargs):
return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
def new_like(*args, **kwargs):
return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
try:
meta_lib.impl(name, f)
except:
pass
tree_map(add_func, op)
return f
return wrapper
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
def meta_conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
Returns:
The output length
"""
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
if transposed convolution is used.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
op: output padding in that dim
Returns:
The output length
"""
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
def calc_conv_nd_return_shape(
dims: torch.Size,
kernel_size: torch.Size,
stride: Union[List[int], int],
padding: Union[List[int], int],
dilation: Union[List[int], int],
output_padding: Optional[Union[List[int], int]] = None,
):
ret_shape = []
if isinstance(stride, int):
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
if isinstance(padding, int):
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
if isinstance(dilation, int):
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
output_padding_list: Optional[List[int]] = None
if output_padding:
if isinstance(output_padding, int):
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
else:
output_padding_list = output_padding
for i in range(len(dims)):
# If output_padding is present, we are dealing with a transposed convolution
if output_padding_list:
ret_shape.append(
_formula_transposed(
dims[i],
padding[i],
dilation[i],
kernel_size[i],
stride[i],
output_padding_list[i],
))
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
def pick_memory_format():
if input_tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
kernel_size = weight.shape[2:]
dims = input_tensor.shape[2:]
if is_transposed:
out_channels = groups * weight.shape[1]
shape_out = calc_conv_nd_return_shape(
dims,
kernel_size,
stride,
padding,
dilation,
output_padding,
)
else:
out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
*extra_args):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
return new_like(input), new_like(weight), new((bias_sizes))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
return new_like(input)
# ================================ RNN =============================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn.default)
def meta_cuda_rnn(
input,
weight,
weight_stride0,
weight_buf,
hx,
cx,
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
dropout,
train,
bidirectional,
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
mini_batch = batch_sizes[0]
batch_sizes_sum = input.shape[0]
else:
seq_length = input.shape[1] if batch_first else input.shape[0]
mini_batch = input.shape[0] if batch_first else input.shape[1]
batch_sizes_sum = -1
num_directions = 2 if bidirectional else 1
out_size = proj_size if proj_size != 0 else hidden_size
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = ([mini_batch, seq_length, out_size *
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
cy = new(0) if cx is None else cx.new_empty(cell_shape)
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
reserve_shape = 0 if train else 0
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
return output, hy, cy, reserve, weight_buf
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs):
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
()) # (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
_unregistered_ewise = [
aten.relu.default,
aten.prelu.default,
aten.hardswish.default,
aten.hardtanh.default,
aten.prelu_backward.default,
aten.hardswish_backward.default,
aten.hardtanh_backward.default,
]
@register_meta(_unregistered_ewise)
def meta_unregistered_ewise(input: torch.Tensor, *args):
return new_like(input)
# ============================== Normalization =====================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm.default)
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input)), new(
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1)
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ==========================================
# Maybe incorrect
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
@register_meta(aten.im2col.default)
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
return new_like(input)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
return out
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
@register_meta(aten._local_scalar_dense.default)
def meta_local_scalar_dense(self: torch.Tensor):
return 0
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return new_like(condition + self + other, dtype=result_type)
@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
assert indices, "at least one index must be provided"
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs
indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)
# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
if index is not None:
break
else:
has_contiguous_subspace = True
# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices
# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
else:
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in)

View File

@ -0,0 +1,88 @@
import torch
import torch.distributed as dist
aten = torch.ops.aten
__all__ = [
"_TorchFactoryMethod",
"_TorchOverrideableFactoryMethod",
"_TorchNonOverrideableFactoryMethod",
"_TensorPropertyMethod",
"_DistCommMethod",
"_AliasATen",
"_InplaceATen",
"_MaybeInplaceATen",
]
_TorchOverrideableFactoryMethod = [
"empty",
"eye",
"full",
"ones",
"rand",
"randn",
"zeros",
]
_TorchNonOverrideableFactoryMethod = [
"arange",
"finfo",
"linspace",
"logspace",
"randint",
"randperm",
"tensor",
]
_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
_DistCommMethod = [
"all_gather",
"all_reduce",
"all_to_all",
"broadcast",
"gather",
"reduce",
"reduce_scatter",
"scatter",
]
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [
aten.detach.default,
aten.detach_.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
_InplaceATen = [
aten.add_.Tensor,
aten.add_.Scalar,
aten.sub_.Tensor,
aten.sub_.Scalar,
aten.mul_.Tensor,
aten.mul_.Scalar,
aten.div_.Tensor,
aten.div_.Scalar,
aten.pow_.Tensor,
aten.pow_.Scalar,
]
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
_MaybeInplaceATen = [
aten.diagonal.default,
aten.expand.default,
aten.select.int,
aten.slice.Tensor,
aten.split.Tensor,
aten.squeeze.default,
aten.permute.default,
aten.unsqueeze.default,
aten.as_strided.default,
]

View File

@ -0,0 +1,536 @@
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw
# and https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
import operator
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from functools import partial, reduce
from numbers import Number
from typing import Any, Callable, List, Optional, Union
import torch
from torch.utils._pytree import tree_map
from .meta_tensor import MetaTensor
aten = torch.ops.aten
class Phase(Enum):
FWD = auto()
BWD = auto()
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def _format_flops(flop):
K = 1e3
M = 1e6
B = 1e9
T = 1e12
if flop < K:
return f'{flop:.2f}'
elif flop < M:
return f'{flop / K:.2f}K'
elif flop < B:
return f'{flop / M:.2f}M'
elif flop < T:
return f'{flop / B:.2f}B'
else:
return f'{flop / T:.2f}T'
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
"""
Count the number of floating point operations in a model.
Ideas from https://pastebin.com/AkvAyJBw.
Args:
module (torch.nn.Module): A PyTorch model.
*args: Input arguments to the model.
verbose (bool): If True, print the number of flops for each module.
**kwargs: Input keyword arguments to the model.
Returns:
Number: The total number of floating point operations (FWD + BWD).
"""
maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
class DummyModule(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
self.__name__ = func.__name__
def forward(self, *args, **kwargs):
return self.func(*args, **kwargs)
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
flop_counts = defaultdict(lambda: defaultdict(int))
parents = ['Global']
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
class FlopTensor(MetaTensor):
_tensor: torch.Tensor
def __repr__(self):
name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
rs = super().__torch_dispatch__(func, types, args, kwargs)
outs = normalize_tuple(rs)
if func in flop_mapping:
nonlocal flop_counts, total_flop_count
flop_count = flop_mapping[func](args, outs)
for par in parents:
flop_counts[par][func.__name__] += flop_count
total_flop_count[cur_phase] += flop_count
def wrap(x):
if isinstance(x, MetaTensor):
x = FlopTensor(x)
return x
rs = tree_map(wrap, rs)
return rs
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def create_backwards_push(name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
nonlocal parents
parents.append(name)
return grad_outs
return PushState.apply
def create_backwards_pop(name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
nonlocal parents
assert (parents[-1] == name)
parents.pop()
return grad_outs
return PopState.apply
def enter_module(name):
def f(module, inputs):
nonlocal parents
parents.append(name)
inputs = normalize_tuple(inputs)
out = create_backwards_pop(name)(*inputs)
return out
return f
def exit_module(name):
def f(module, inputs, outputs):
nonlocal parents
assert (parents[-1] == name)
parents.pop()
outputs = normalize_tuple(outputs)
return create_backwards_push(name)(*outputs)
return f
@contextmanager
def instrument_module(mod):
registered = []
for name, module in dict(mod.named_children()).items():
registered.append(module.register_forward_pre_hook(enter_module(name)))
registered.append(module.register_forward_hook(exit_module(name)))
yield
for handle in registered:
handle.remove()
def display_flops():
for mod in flop_counts.keys():
print(f"Module: ", mod)
for k, v in flop_counts[mod].items():
print('\t', k, _format_flops(v))
print()
def detach_variables(r):
if isinstance(r, torch.Tensor):
requires_grad = r.requires_grad
r = r.detach()
r.requires_grad = requires_grad
return r
def wrap(r):
if isinstance(r, torch.Tensor):
data_ptr_fn = getattr(r, '_tensor', r).data_ptr
r = FlopTensor(detach_variables(r))
if maybe_inplace:
r = r + 0
r._tensor.data_ptr = data_ptr_fn
return r
with instrument_module(module):
cur_phase = Phase.FWD
rst = module(*tree_map(wrap, args), **tree_map(wrap, kwargs))
rst = tuple(r for r in normalize_tuple(rst) if is_autogradable(r) and r.requires_grad)
cur_phase = Phase.BWD
if rst:
grad = [torch.zeros_like(t) for t in rst]
torch.autograd.backward(
rst,
grad,
)
if verbose:
display_flops()
return total_flop_count[Phase.FWD], total_flop_count[Phase.BWD]
def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [v.shape for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
return flops
def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for fully connected layers.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [v.shape for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [input feature dimension, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flops = batch_size * input_dim * output_dim
return flops
def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the aten::linear operator.
"""
# Inputs is a list of length 3; unlike aten::addmm, it is the first
# two elements that are relevant.
input_shapes = [v.shape for v in inputs[0:2]]
# input_shapes[0]: [dim0, dim1, ..., input_feature_dim]
# input_shapes[1]: [output_feature_dim, input_feature_dim]
assert input_shapes[0][-1] == input_shapes[1][-1]
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0]
return flops
def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(inputs) == 2, len(inputs)
input_shapes = [v.shape for v in inputs]
n, c, t = input_shapes[0]
d = input_shapes[-1][-1]
flops = n * c * t * d
return flops
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> Number:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape)
return flops
def conv_flop_jit(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape)
transposed = inputs[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0
if output_mask[0]:
grad_input_shape = outputs[0].shape
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
if output_mask[1]:
grad_weight_shape = outputs[1].shape
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
return flop_count
def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
"""
Args:
affine_arg_index: index of the affine argument in inputs
"""
def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for norm layers.
"""
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
'shape') else inputs[affine_arg_index]
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
return flop
return norm_flop_jit
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number:
if training is None:
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable:
"""
Count flops by
input_tensor.numel() * input_scale + output_tensor.numel() * output_scale
Args:
input_scale: scale of the input tensor (first argument)
output_scale: scale of the output tensor (first element in outputs)
"""
def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
ret = 0
if input_scale != 0:
shape = inputs[0].shape
ret += input_scale * reduce(operator.mul, shape) if shape else 0
if output_scale != 0:
shape = outputs[0].shape
ret += output_scale * reduce(operator.mul, shape) if shape else 0
return ret
return ewise_flop
def zero_flop_jit(*args):
"""
Count flops for zero flop layers.
"""
return 0
flop_mapping = {
# gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
# convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
# pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.max_pool1d.default: ewise_flop_counter(1, 0),
aten.max_pool2d.default: ewise_flop_counter(1, 0),
aten.max_pool3d.default: ewise_flop_counter(1, 0),
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
aten.embedding.default: ewise_flop_counter(1, 0),
}
ewise_flop_aten = [
# basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
aten.div_.Tensor,
aten.div.Scalar,
aten.div_.Scalar,
aten.mul.Tensor,
aten.mul.Scalar,
aten.mul_.Tensor,
aten.neg.default,
aten.pow.Tensor_Scalar,
aten.rsub.Scalar,
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
# activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
aten.hardtanh.default,
aten.hardtanh_.default,
aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default,
aten.hardsigmoid.default,
aten.gelu.default,
aten.gelu_backward.default,
aten.silu.default,
aten.silu_.default,
aten.silu_backward.default,
aten.sigmoid.default,
aten.sigmoid_backward.default,
aten._softmax.default,
aten._softmax_backward_data.default,
aten.relu_.default,
aten.relu.default,
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
# distribution
aten.bernoulli_.float,
# where
aten.where.self,
]
for op in ewise_flop_aten:
flop_mapping[op] = ewise_flop_counter(1, 0)
# fix-me: this will be removed in future
zero_flop_aten = [
aten.as_strided.default,
aten.as_strided_.default,
aten.cat.default,
aten.clone.default,
aten.copy_.default,
aten.detach.default,
aten.expand.default,
aten.empty_like.default,
aten.new_empty.default,
aten.new_empty_strided.default,
aten.ones_like.default,
aten._reshape_alias.default,
aten.select.int,
aten.select_backward.default,
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default,
aten.view.default,
aten.zero_.default,
aten.zeros_like.default,
]
for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit

View File

@ -0,0 +1,207 @@
import uuid
from functools import partial
import torch
import torch.distributed as dist
from torch.types import _bool, _device, _dtype
from torch.utils._pytree import tree_flatten, tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
__all__ = ['MetaTensor', 'MetaTensorMode']
def register_storage(r, data_ptr_fn=None):
if isinstance(r, torch.Tensor):
if data_ptr_fn is not None:
r.data_ptr = data_ptr_fn
elif not r.data_ptr():
data_ptr = uuid.uuid1()
r.data_ptr = lambda: data_ptr
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
# a hack of inplace execution in PyTorch
def _assert_alias(func):
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
)
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks ``torch.autograd`` without patching more ``torch.ops.aten`` ops.
`device` is the device that ``MetaTensor`` is supposed to run on. Meta tensors give you the
ability to run PyTorch code without having to actually do computation through tensors
allocated on a `meta` device. Because the device is `meta`, meta tensors do not model
device propagation. ``MetaTensor`` extends its usage by carrying an additional `device`
which tracks devices that would have been used.
Reference:
https://github.com/pytorch/pytorch/blob/master/torch/_subclasses/fake_tensor.py
"""
_tensor: torch.Tensor
@staticmethod
def __new__(cls, elem, device=None, data_ptr_fn=None):
requires_grad = elem.requires_grad
# Avoid multiple wrapping
while isinstance(elem, MetaTensor):
device = elem.device if device is None else device
elem = elem._tensor
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
requires_grad=requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
val = elem.data_ptr()
data_ptr_fn = lambda: val
r._tensor = r._tensor.to(torch.device('meta'))
# only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn)
if isinstance(elem, torch.nn.Parameter):
r = torch.nn.Parameter(r)
return r
def __repr__(self):
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
device = None
def unwrap(x):
nonlocal device
if isinstance(x, MetaTensor):
device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor):
device = x.device
x = x.to(torch.device('meta'))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs:
device = kwargs['device']
kwargs['device'] = torch.device('meta')
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
# of the input tensor
ret = func(*args, **kwargs)
if _assert_alias(func):
val = args[0].data_ptr()
tree_map(partial(register_storage, data_ptr_fn=lambda: val), _normalize_tuple(ret))
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaTensor(x, device=device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, ret)
def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor
Returns:
result (MetaTensor): MetaTensor
Usage:
>>> tensor = MetaTensor(torch.rand(10), device='cuda:100')
>>> tensor.to(torch.uint8)
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), device='cuda:100')
>>> tensor.to(torch.device('cuda:42'))
MetaTensor(tensor(..., device='meta', size=(10,)), device='cuda:42')
>>> tensor.to('vulkan')
MetaTensor(tensor(..., device='meta', size=(10,)), device='vulkan')
"""
# this imitates c++ function in the way of @overload
device = None
def replace(x):
nonlocal device
if isinstance(x, str) or isinstance(x, _device):
device = x
return torch.device('meta')
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking)
def data_ptr(self):
return self._tensor.data_ptr()
class MetaTensorMode(object):
"""
A context manager that enables MetaTensor mode.
Usage:
>>> with MetaTensorMode():
>>> # all torch.xxx and torch.distributed.xxx will be replaced by patched functions
>>> # and the actual execution will be on torch.device('meta')
>>> a = torch.rand(100000, 100000)
>>> b = torch.rand(100000, 100000)
>>> c = torch.mm(a, b)
"""
def __init__(self):
self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self):
def _dummy(*args, **kwargs):
pass
def _new(*args, orig_new=torch.empty, **kwargs):
return MetaTensor(orig_new(*args, **{
**kwargs, 'device': 'meta'
}),
device=kwargs.get('device', torch.device('cpu')))
for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func)
setattr(torch, func, partial(_new, orig_new=getattr(torch, func)))
for func in _DistCommMethod:
self.dist_overrides[func] = getattr(dist, func)
setattr(dist, func, _dummy)
def __exit__(self, exc_type, exc_value, traceback):
for func, func_impl in self.torch_overrides.items():
setattr(torch, func, func_impl)
for func, func_impl in self.dist_overrides.items():
setattr(dist, func, func_impl)

View File

@ -0,0 +1,7 @@
from dataclasses import dataclass
@dataclass
class MeshConfig:
TFLOPS: float = 1.9e12
BANDWIDTH = 1.2e9

View File

@ -0,0 +1,4 @@
from .bias_addition import *
from .node_util import MetaInfo
from .symbolic_profile import symbolic_profile
from .symbolic_trace import symbolic_trace

View File

@ -0,0 +1,155 @@
"""
If FX.Graph is traced for auto-parallel module, some extra node will be added during
graph construction to deal with the compatibility between bias-addition and all-reduce.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _single, _triple
from .symbolic_trace import register_tracer_impl
__all__ = []
@register_tracer_impl(F.linear, name='_bias_addition_impl')
def linear_impl(input, weight, bias=None):
if bias is None:
return F.linear(input, weight)
else:
return F.linear(input, weight) + bias
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
if bias is None:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
if bias is None:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
if bias is None:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
elif alpha != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input
elif beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) + input * beta
else:
return F.linear(mat1, mat2.transpose(0, 1)) + input
@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
elif alpha != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input
elif beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) + input * beta
else:
return torch.bmm(batch1, batch2.transpose(1, 2)) + input

View File

@ -0,0 +1,456 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
_register_custom_builtin,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
import colossalai
from colossalai.fx._compatibility import compatibility
_register_custom_builtin('colossalai', 'import colossalai', colossalai)
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
"""
Generate the checkpoint function definition
"""
return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
def _gen_ckpt_output(output_vars: List[str]) -> str:
"""
Generate the return statement for checkpoint region
"""
return f"return {', '.join(output_vars)}"
def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
if len(node.meta['info'].to_recompute) > ckpt_level:
return node.meta['info'].to_recompute[ckpt_level] is not None
return True
def _find_input_and_output_nodes(nodes: List[Node]):
"""
Find the input and output node names which are not found in the given list of nodes.
"""
input_nodes = []
output_nodes = []
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
node_repr = repr(input_node)
if input_node not in nodes and node_repr not in input_nodes:
input_nodes.append(node_repr)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
node_repr = repr(node)
if output_node not in nodes and node_repr not in output_nodes:
output_nodes.append(node_repr)
return input_nodes, output_nodes
def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
"""
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
will be list of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(node_list):
if len(node.meta['info'].to_recompute) > ckpt_level:
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if current_region is None:
current_region = act_ckpt_label
start = idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if act_ckpt_label != current_region:
assert start != -1
ckpt_regions.append((start, idx - 1))
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and _end_of_ckpt(node, ckpt_level):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
assert start != -1 and end != -1
ckpt_regions.append((start, end))
start = end = -1
current_region = None
else:
pass
if current_region is not None:
end = len(node_list) - 1
ckpt_regions.append((start, end))
return ckpt_regions
def emit_ckpt_func(body,
ckpt_func,
node_list: List[Node],
emit_node_func,
delete_unused_value_func,
ckpt_level=0,
in_ckpt=False):
"""Emit ckpt fuction in nested way
Args:
body: forward code - in recursive calls, this part will be checkpoint
functions code
ckpt_func: checkpoint functions code - in recursive calls, this part
will be a buffer
node_list (List[Node]): list of torch.fx.Node
emit_node_func: function to emit a node
delete_unused_value_func: function to delete unused value
level (int, optional): checkpoint level. Defaults to 0.
in_ckpt (bool, optional): indicates wether the func is in recursive
call. Defaults to False.
"""
inputs, outputs = _find_input_and_output_nodes(node_list)
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
# if there is more level to fetch
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
# use ckpt_func_buffer to store nested checkpoint functions
ckpt_func_buffer = []
node_idx = 0
while 1:
if node_idx >= len(node_list):
break
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
ckpt_level + 1, True)
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func += ckpt_func_buffer
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
body.append(usage)
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
"""Emit code with nested activation checkpoint
When we detect some of the annotation is a , we will use
this function to emit the activation checkpoint codes.
Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
node_list = list(nodes)
node_idx = 0
while 1:
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
break
# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
# process node in forward function
else:
node = node_list[node_idx]
emit_node_func(node, body)
delete_unused_value_func(node, body)
node_idx += 1
@compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = ['']
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return _get_qualified_name(obj)
# normalize the name hint to get a proper identifier
global_name = namespace.create_name(name_hint, obj)
if global_name in globals_:
assert globals_[global_name] is obj
return global_name
globals_[global_name] = obj
return global_name
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return '()'
typename = _type_repr(o)
if hasattr(o, '__origin__'):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, '__args__'):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
if len(args) == 0:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return origin_typename
return f'{origin_typename}[{",".join(args)}]'
else:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == 'placeholder':
return
if user.op == 'output':
body.append('\n')
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
else:
body.append('\n')
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n')
return
elif node.op == 'call_method':
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
return
elif node.op == 'call_function':
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
return
body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
wrapped_fns.setdefault(global_name)
return
elif node.op == 'call_module':
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
return
elif node.op == 'get_attr':
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
return
elif node.op == 'output':
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append('pass\n')
if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ''
if self._body_transformer:
body = self._body_transformer(body)
for name, value in self.additional_globals():
add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue
prologue = prologue
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
return PythonCode(fn_code, globals_)

View File

@ -0,0 +1,173 @@
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
import torch.fx
import torch.nn as nn
from torch.fx.graph import PythonCode, _PyTreeCodeGen
from torch.fx.graph_module import _exec_with_source, _forward_from_src, _WrappedCall
from torch.nn.modules.module import _addindent
class ColoGraphModule(torch.fx.GraphModule):
"""
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
ColoGraphmodule has a ``graph`` attribute, as well as ``code`` and ``forward``
attributes generated from that ``graph``.
The difference between ``ColoGraphModule`` and ``torch.fx.GraphModule`` is that
``ColoGraphModule`` has a ``bind()`` function to bind customized functions
(i.e. activation checkpoint) to ``code`` of ``nn.Module``. If you want to use
specific features in Colossal-AI that are not supported by ``torch.fx.GraphModule``,
you can use ``ColoGraphModule`` instead.
``colossalai.fx.symbolic_trace()`` will return a ``ColoGraphModule`` as default.
.. warning::
When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
regenerated. However, if you edit the contents of the ``graph`` without reassigning
the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
code.
"""
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: torch.fx.Graph,
class_name: str = 'GraphModule'):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
"""Bind function needed for correctly execute ``GraphModule.forward()``
We need to bind checkpoint functions to ``ColoGraphModule`` so that we could
correctly execute ``GraphModule.forward()``
Args:
ckpt_def (List[str]): definition before the forward function
globals (Dict[str, Any]): global variables
"""
ckpt_code = "\n".join(ckpt_def)
globals_copy = globals.copy()
_exec_with_source(ckpt_code, globals_copy)
func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func]
for func in func_list:
tmp_func = globals_copy[func]
setattr(self, func, tmp_func.__get__(self, self.__class__))
del globals_copy[func]
def recompile(self) -> PythonCode:
"""
Recompile this GraphModule from its ``graph`` attribute. This should be
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
self._code = python_code.src
# To split ckpt functions code and forward code
_code_list = self._code.split("\n")
_fwd_def = [item for item in _code_list if "def forward" in item][0]
_fwd_idx = _code_list.index(_fwd_def)
ckpt_def = _code_list[:_fwd_idx]
self._code = "\n".join(_code_list[_fwd_idx:])
self.bind(ckpt_def, python_code.globals)
cls = type(self)
cls.forward = _forward_from_src(self._code, python_code.globals)
# Determine whether this class explicitly defines a __call__ implementation
# to wrap. If it does, save it in order to have wrapped_call invoke it.
# If it does not, wrapped_call can use a dynamic call to super() instead.
# In most cases, super().__call__ should be torch.nn.Module.__call__.
# We do not want to hold a reference to Module.__call__ here; doing so will
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
cls.__call__ = call_wrapped
# reset self._code to original src, otherwise to_folder will be wrong
self._code = python_code.src
return python_code
def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
imported with ``from <folder> import <module_name>``
Args:
folder (Union[str, os.PathLike]): The folder to write the code out to
module_name (str): Top-level name to use for the ``Module`` while
writing out the code
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
tab = " " * 4
# we add import colossalai here
model_str = f"""
import torch
from torch.nn import *
import colossalai
class {module_name}(torch.nn.Module):
def __init__(self):
super().__init__()
"""
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
else:
return None
blobified_modules = []
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in self._buffers.items():
if buffer is None:
continue
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
for param_name, param in self._parameters.items():
if param is None:
continue
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")

View File

@ -0,0 +1,211 @@
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
import torch
from torch.autograd.profiler_util import _format_memory, _format_time
from torch.fx import Graph, GraphModule, Node
from colossalai._analyzer.envs import MeshConfig
def intersect(a, b):
return {k: a[k] for k in a if k in b}
def subtract(a, b):
return {k: a[k] for k in a if k not in b}
def union(a, b):
return {**a, **b}
def compute_size_in_bytes(elem: torch.Tensor | Dict | List | Tuple | int) -> int:
"""Compute the size of a tensor or a collection of tensors in bytes.
Args:
elem (torch.Tensor | Dict | List | Tuple | int): Arbitrary nested ``torch.Tensor`` data structure.
Returns:
int: The size of the tensor or the collection of tensors in bytes.
"""
nbytes = 0
if isinstance(elem, torch.Tensor):
if elem.is_quantized:
nbytes += elem.numel() * torch._empty_affine_quantized([], dtype=elem.dtype).element_size()
else:
nbytes += elem.numel() * torch.tensor([], dtype=elem.dtype).element_size()
elif isinstance(elem, dict):
value_list = [v for _, v in elem.items()]
nbytes += compute_size_in_bytes(value_list)
elif isinstance(elem, tuple) or isinstance(elem, list) or isinstance(elem, set):
for e in elem:
nbytes += compute_size_in_bytes(e)
return nbytes
@dataclass
class MetaInfo:
r"""
The base class to store all profiling and static graph analysis information
needed for auto-parallel system in Colossal-AI.
============================================================================
-------------------------------
| FX.Node | <-----
[input/param] are ---> |[input/param] [grad_inp]| [grad_inp] contributes to the
placeholders (might be | | \__________ | | profiled peak memory in backward
saved for backward. | | \ | | pass. [grad_param] is calculated
| | \ | | separately.
| [interm] -------> [grad_int]| <-----
| | \_________ | | [grad_interm] marks the peak
| / \ \ | | memory in backward pass.
[x] is not counted ---> | [x] [interm] --> [grad_int]| <-----
in [interm] because | | \_____ | |
it is not saved for | | \ | |
backward. | [output] \ | | <----- [output] is potentially
------------------------------- [input] for the next node.
============================================================================
Accumulate Size = ALL_PREVIOUS_CTX U {Interm Size + Output Size}
Output Size = ([output] in global_ctx and not is_alias)
Temp Size = ([output] not in global_ctx and not is_alias)
Backward Size = ([grad_inp])
Usage:
>>> for node in graph.nodes:
>>> n_info = MetaInfo(node) # will create a new MetaInfo instance and store in node.meta['info']
>>> # if not exist, otherwise return the existing one
>>> n_info.to_recompute = ... # set the to_recompute attribute
Remarks:
This feature is experimental and all the entries are subject to change.
"""
# reference
node: Node
# directory
mod_dir: str = ''
# ctx[data_ptr] = Tensor
# mark the storage for ctx.save_for_backward
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
# should be updated after each graph manipulation
# ============================== Update ====================================
# parameter and buffer within ``Node``
parameters: Dict[str, torch.nn.Parameter] = field(default_factory=lambda: {})
buffers: Dict[str, torch.Tensor] = field(default_factory=lambda: {})
inputs: Tuple[torch.Tensor] = ()
outputs: Tuple[torch.Tensor] = ()
is_alias: Tuple[bool] = () # whether the output is an alias of input
# compute cost
fwd_flop: Optional[int] = 0
bwd_flop: Optional[int] = 0
# communication cost (should be the size in bytes of communication)
fwd_comm: Optional[int] = 0
bwd_comm: Optional[int] = 0
# should keep the same whenever manipulated
# ============================= Invariant ==================================
to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False
sharding_spec: str = 'RR'
def __new__(cls, node: Node, **kwargs):
orig_init = cls.__init__
# if initialized, return the existing one
# should disable the __init__ function
if node.meta.get('info', None) is not None:
def _dummy(self, *args, **kwargs):
if getattr(self, '_is_init', False):
self._is_init = True
orig_init(self, *args, **kwargs)
cls.__init__ = orig_init
cls.__init__ = _dummy
return node.meta['info']
return super().__new__(cls)
def __post_init__(self):
self.node.meta['info'] = self
@property
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
return self.fwd_flop / tflops + self.fwd_comm / bandwidth
@property
def bwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
return self.bwd_flop / tflops + self.bwd_comm / bandwidth
@property
def param_size(self):
return compute_size_in_bytes(self.parameters)
@property
def buffer_size(self):
return compute_size_in_bytes(self.buffers)
@property
def output_size(self):
"""Used in CheckpointSolver"""
output_ctx = {
o.data_ptr(): o
for o, is_alias in zip(self.outputs, self.is_alias)
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
}
return compute_size_in_bytes(intersect(self.global_ctx, output_ctx))
@property
def accumulate_size(self):
"""Used in CheckpointSolver"""
output_ctx = {
o.data_ptr(): o
for o, is_alias in zip(self.outputs, self.is_alias)
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
}
return compute_size_in_bytes(union(self.curr_ctx, intersect(self.global_ctx, output_ctx)))
@property
def temp_size(self):
"""Used in CheckpointSolver"""
output_ctx = {
o.data_ptr(): o
for o, is_alias in zip(self.outputs, self.is_alias)
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
}
return compute_size_in_bytes(subtract(output_ctx, self.global_ctx))
@property
def backward_size(self):
"""Used in CheckpointSolver"""
return compute_size_in_bytes(self.inputs)
def __repr__(self):
s = f'Node {self.node.name}'
if self.parameters:
s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
if self.buffers:
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
if self.output_size:
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
if self.total_size:
s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if self.temp_size:
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
if self.backward_size:
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
s += f'\n\tfwd_flop = {self.fwd_flop}'\
f'\n\tbwd_flop = {self.bwd_flop}'\
f'\n\tfwd_comm = {self.fwd_comm}'\
f'\n\tbwd_comm = {self.bwd_comm}'\
f'\n\tto_recompute = {self.to_recompute}'\
f'\n\tto_offload = {self.to_offload}'\
f'\n\tsharding_spec = {self.sharding_spec}'
return s

View File

@ -0,0 +1,2 @@
from .graph_profile import graph_profile_pass
from .shape_prop import ShapeProp, shape_prop_pass, sim_env

View File

@ -0,0 +1,347 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.fx
from torch.autograd.profiler_util import _format_memory, _format_time
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
from colossalai._analyzer._subclasses import flop_count
from colossalai._analyzer.fx.node_util import MetaInfo
def _format_flops(flops: float) -> str:
"""Returns a formatted FLOP size string"""
if flops > 1e12:
return f'{flops / 1e12:.2f} TFLOPs'
elif flops > 1e9:
return f'{flops / 1e9:.2f} GFLOPs'
elif flops > 1e6:
return f'{flops / 1e6:.2f} MFLOPs'
elif flops > 1e3:
return f'{flops / 1e3:.2f} kFLOPs'
return f'{flops} FLOPs'
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
return t[0] if len(t) == 1 else t
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def _current_device(module):
return next(module.parameters()).device
class GraphProfiler(torch.fx.Interpreter):
"""
Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch.
"""
_profileable = [
'call_function',
'call_module',
'call_method',
]
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
super().__init__(module, garbage_collect_values)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Run `module` via interpretation and return the result.
Args:
*args: The arguments to the Module to run, in positional order
initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
This is a dict mapping `Node` to any value. This can be used, for example, to
pre-populate results for certain `Nodes` so as to do only partial evaluation within
the interpreter.
enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
process_outputs function first before using them.
Returns:
Any: The value returned from executing the Module
"""
self.env = initial_env if initial_env else {}
# Positional function args are consumed left-to-right by
# `placeholder` nodes. Use an iterator to keep track of
# position and extract those values.
if enable_io_processing:
args = self.module.graph.process_inputs(*args)
self.args_iter: Iterator[Any] = iter(args)
for node in self.module.graph.nodes:
self.run_node(node) # No need to store.
if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
if node.op == 'output':
output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
def fetch_initial_env(self, device=None) -> Dict[Node, Any]:
"""
Fetch ``initial_env`` for execution. This is because ``ShapeProp``
has already attached outputs of each ``Node`` to its ``MetaInfo``.
Args:
device (torch.device): The device to place the execution, default to ``None``
Returns:
Dict[Node, Any]: The initial environment for execution
"""
initial_env = {}
for n in self.module.graph.nodes:
initial_env[n] = _denormalize_tuple(MetaInfo(n).outputs)
return initial_env
def propagate(self, *args, device=None):
"""
Run `module` via interpretation and profile the execution
of each ``Node``.
Args:
*args (Tensor): The sample input, not used
device (torch.device): The device to place the execution, default to ``None``
Returns:
Any: The value returned from executing the Module
"""
initial_env = self.fetch_initial_env(device)
return self.run(initial_env=initial_env)
def summary(self) -> str:
"""
Summarizes the profiled statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
Returns:
str: The summary of the profiled statistics
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
last_n_info = None
for node in self.module.graph.nodes:
node: Node
n_info = MetaInfo(node)
last_n_info = last_n_info or n_info
node_summaries.append([
node.op,
str(node),
_format_memory(n_info.accumulate_size),
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
_format_memory(n_info.output_size),
_format_memory(n_info.temp_size),
_format_memory(n_info.param_size),
_format_memory(n_info.backward_size),
_format_flops(n_info.fwd_flop),
_format_flops(n_info.bwd_flop),
])
last_n_info = n_info
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Accumulate size',
'Incremental size',
'Output size',
'Temp size',
'Param size',
'Backward size',
'Fwd FLOPs',
'Bwd FLOPs',
]
return tabulate(node_summaries, headers=headers, stralign='right')
class CommunicationProfiler(GraphProfiler):
"""
TODO(lyl): Add this for all comm nodes
"""
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
raise NotImplementedError()
class FlopProfiler(GraphProfiler):
"""
Execute an FX graph Node-by-Node and record the meta data of the result
into the corresponding node.
Usage:
>>> model = MyModule()
>>> x = torch.rand(10, 10)
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x}})
>>> shape_interp = ShapeProp(gm) # must do this first
>>> shape_interp.propagate(x)
>>> profiler = FlopProfiler(gm)
>>> profiler.propagate(x)
Args:
module (GraphModule): The module to be executed
Hints:
If you want to add a new flop count rule, you can first
check the existing files in ``../_subclasses/flop_tensor.py``.
If your flop count rules are incompatible with the existing
ones, you can do so by adding a new method to this class
with the ``@register_flop_count_impl`` decorator. The method
should take (*args, **kwargs) instance as its input and
generate flop count for both forward and backward as its
output.
For example, if you want to add a flop count rule for
``my_fn``, which is a hand-written operand not detected by
PyTorch, you can do so by adding a new method to this
class with the ``@register_flop_count_impl`` decorator:
>>> @register_flop_count_impl(my_fn)
>>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0
"""
_custom_flop_count_impl = {}
def run_node(self, n: torch.fx.Node) -> Any:
"""
Run a specific node ``n`` and profile its execution time and memory usage.
Calls into call_function, call_method, and call_module only.
Args:
n (Node): The Node to profile
Returns:
Any: The output of the node
Raises:
RuntimeError: If the node is not profileable.
"""
args, kwargs = self.fetch_args_kwargs_from_env(n)
n_info = MetaInfo(n)
if n.op in self._profileable:
try:
(
n_info.fwd_flop,
n_info.bwd_flop,
) = getattr(self, n.op)(n.target, args, kwargs)
except Exception as e:
raise RuntimeError(
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
) from e
# retain the autograd graph
for param in self.module.parameters():
param.grad = None
return _denormalize_tuple(n_info.outputs)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
profiled in a user-defined behavior.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
"""
assert not isinstance(target, str)
# Dispatch the impl for profiling, default will be ``flop_count``
if target in self._custom_flop_count_impl:
return self._custom_flop_count_impl[target](*args, **kwargs)
else:
return flop_count(target, *args, **kwargs)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the profiling result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
"""
# Execute the method and return the result
assert isinstance(target, str)
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the profiling result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return flop_count(submod, *args, **kwargs)
def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule:
"""
Run ``module`` via interpretation and profile the execution
of each ``Node``.
Args:
module (GraphModule): The GraphModule to profile
*args (Any): The sample input, not used
verbose (bool): Whether to print the profiling summary
Returns:
GraphModule: The same GraphModule with profiling information
"""
for profiler_cls in (FlopProfiler,
# CommunicationProfiler, # TODO: add communication profiling
):
profiler = profiler_cls(module)
profiler.propagate(*args, device=_current_device(module))
if verbose:
print(profiler.summary())
return module

View File

@ -0,0 +1,194 @@
"""``torch.fx.ShapeProp``, but with ``MetaTensor``"""
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import torch.fx
from torch.autograd.graph import saved_tensors_hooks
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.fx._compatibility import compatibility
Target = Union[Callable[..., Any], str]
class sim_env(saved_tensors_hooks):
"""
A simulation of memory allocation and deallocation in the forward pass
using ``saved_tensor_hooks``.
Attributes:
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
data pointer of a tensor to the tensor itself. This is used
to track the memory allocation and deallocation.
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
data pointer of all model parameters to the parameter itself.
This avoids overestimating the memory usage of the intermediate activations.
"""
def __init__(self, module: Optional[torch.nn.Module] = None):
super().__init__(self.pack_hook, self.unpack_hook)
self.ctx = {}
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
def pack_hook(self, tensor: torch.Tensor):
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
self.ctx[tensor.data_ptr()] = tensor
return tensor
def unpack_hook(self, tensor):
return tensor
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def _current_device(module):
return next(module.parameters()).device
@compatibility(is_backward_compatible=False)
class ShapeProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and record the meta data of the result
into the corresponding node.
Usage:
>>> model = MyModule()
>>> x = torch.rand(10, 10)
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x})
>>> interp = ShapeProp(gm)
>>> interp.propagate(x)
Args:
module (GraphModule): The module to be executed
Hints:
If you want to add a new shape propagation rule, you can do so by
adding a new method to this class with the ``@register_shape_impl``
decorator. The method should take (*args, **kwargs) instance as its
input and generate output.
For example, if you want to add a shape propagation rule for
``torch.nn.functional.linear``, you can do so by adding a new method
to this class with the ``@register_shape_impl`` decorator (Since the
``MetaTensorMode`` is compatible with ``torch.nn.functional.linear``,
in practice you don't have to do as follows):
>>> @register_shape_impl(torch.nn.functional.linear)
>>> def linear_shape_impl(*args, **kwargs):
>>> # do something here
>>> return torch.empty(output_shape, device=output_device)
"""
_custom_dispatch_func = {}
_mode = MetaTensorMode()
def __init__(self, module: torch.fx.GraphModule, garbage_collect_values: bool = True):
super().__init__(module, garbage_collect_values)
self.global_hook = sim_env(module=self.module)
def run_node(self, n: torch.fx.Node) -> Any:
"""
Run a specific node ``n`` and return the result. Attach
(
``inputs``, ``outputs``, ``parameters``, ``buffers``
) to ``n``.
Args:
n (Node): The ``Node`` to execute
Returns:
Any: The result of executing ``n``
"""
args, kwargs = self.fetch_args_kwargs_from_env(n)
with self.global_hook:
r = getattr(self, n.op)(n.target, args, kwargs)
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
if n.op == 'call_module':
submod = self.fetch_attr(n.target)
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
else:
n_info.parameters.update({
k.name: MetaTensor(v)
for k, v in zip(n.args, args)
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
})
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v))
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
n_info.global_ctx = self.global_hook.ctx
n_info.curr_ctx = self.global_hook.ctx.copy()
crit = lambda x: x.data_ptr() in self.global_hook.ctx if isinstance(x, torch.Tensor) else False
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
return r
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``,
the registered function will be used to execute the node. This is common
if we insert some customized kernels.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the function invocation
"""
if target in self._custom_dispatch_func:
return self._custom_dispatch_func[target](*args, **kwargs)
else:
return super().call_function(target, args, kwargs)
def propagate(self, *args, device=None):
"""
Run `module` via interpretation and return the result and record the
shape of each node.
Args:
*args (Tensor): The sample input.
Returns:
Any: The value returned from executing the Module
"""
wrap_fn = lambda elem: MetaTensor(elem, device=device)
with self._mode:
return super().run(*tree_map(wrap_fn, args))
def shape_prop_pass(module: torch.fx.GraphModule, *args) -> torch.fx.GraphModule:
"""
Run ``module`` via interpretation and return the result and record the
shape of each ``Node``.
Args:
module (GraphModule): The GraphModule to profile
*args (Any): The sample input
Returns:
GraphModule: The same GraphModule with shape information
"""
ShapeProp(module).propagate(*args, device=_current_device(module))
return module

View File

@ -0,0 +1,40 @@
import torch
import torch.fx
from torch.fx import GraphModule
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
from .passes.graph_profile import FlopProfiler
def register_flop_count_impl(func):
def wrapper(impl):
FlopProfiler._custom_flop_count_impl[func] = impl
return impl
return wrapper
def register_shape_impl(func):
def wrapper(impl):
ShapeProp._custom_dispatch_func[func] = impl
return impl
return wrapper
def symbolic_profile(module: GraphModule, *args, verbose=False) -> GraphModule:
"""Symbolically profile a model with sample inputs.
Args:
module (GraphModule): The module to be profiled
args (Tuple): The sample inputs
verbose (bool): Whether to print the profiling result
Returns:
GraphModule: The profiled module
"""
module = shape_prop_pass(module, *args)
module = graph_profile_pass(module, *args, verbose=verbose)
return module

View File

@ -0,0 +1,620 @@
import functools
import inspect
import operator
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.fx import Graph, Node, Proxy, Tracer
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor, _TensorPropertyMethod, _TorchFactoryMethod
from .codegen import ActivationCheckpointCodeGen
from .graph_module import ColoGraphModule
from .node_util import MetaInfo
Target = Union[Callable[..., Any], str]
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
List[Any], # actually Argument
Dict[str, Any], # actually Argument
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
'Node',]]
zeros = torch.zeros
def _truncate_suffix(s: str):
import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
return re.sub(r'_\d+$', '', s)
def _default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
def _current_device(module):
try:
return next(module.parameters()).device
except:
return _default_device()
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl
return impl
return wrapper
def register_leaf_module_impl(module: nn.Module):
def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl
return impl
return wrapper
def register_leaf_module(module: nn.Module):
ColoTracer._custom_leaf_module.add(module)
def register_non_leaf_module(module: nn.Module):
ColoTracer._custom_non_leaf_module.add(module)
class ColoProxy(Proxy):
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._meta_data = data
@property
def meta_data(self):
return self._meta_data
@meta_data.setter
def meta_data(self, args):
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
self._meta_data = tree_map(wrap_fn, args)
@classmethod
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
else:
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if proxy.meta_data is None:
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
return proxy
@classmethod
def from_torch_proxy(cls, proxy: Proxy):
return cls(proxy.node, proxy.tracer)
def __repr__(self):
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
def __len__(self):
return len(self.meta_data)
def __int__(self):
return int(self.meta_data)
def __index__(self):
try:
return int(self.meta_data)
except:
return zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
def __float__(self):
return float(self.meta_data)
def __bool__(self):
return self.meta_data
def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return False
return super().__contains__(key)
def __isinstancecheck__(self, type):
return isinstance(self.meta_data, type)
def size(self, dim=None):
if self._meta_data is None:
return self._meta_data.size(*[dim] if dim else [])
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
def dim(self):
if self._meta_data is not None:
return self._meta_data.dim()
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
@property
def shape(self):
if self._meta_data is not None:
return self._meta_data.shape
return self.tracer.create_proxy('call_function', getattr, (self, 'shape'), {})
@property
def ndim(self):
if self._meta_data is not None:
return self._meta_data.ndim
return self.tracer.create_proxy('call_function', getattr, (self, 'ndim'), {})
@property
def device(self):
if self._meta_data is not None:
return self._meta_data.device
return self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
@property
def dtype(self):
if self._meta_data is not None:
return self._meta_data.dtype
return self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
def to(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
def cpu(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
def cuda(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._meta_data = data
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
class ColoTracer(Tracer):
_custom_leaf_module: Set[Type[nn.Module]] = set()
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
_custom_non_leaf_module: Set[Type[nn.Module]] = set()
_custom_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
_bias_addition_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
_bias_addition_module = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
]
def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.disable_module_getattr = False
self.proxy_buffer_attributes = True
# whether the tracer will record the usage of torch.utils.checkpoint
self.trace_act_ckpt = trace_act_ckpt
self.ckpt_regions = []
self.ckpt_idx = 0
self.mod_dir = ''
# whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
# if bias-addiction split is enabled, and module has bias, then it is not a leaf module
# we will enter the module and split the bias-addition ops
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
return (type(m) not in self._custom_non_leaf_module
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
kwargs: Dict[str, Any]) -> Any:
curr_dir = self.mod_dir
self.mod_dir = 'self.' + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir
return rst
def proxy(self, node: Node) -> 'ColoProxy':
return ColoProxy(node, self)
def create_proxy(self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder':
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
self.disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
proxy.meta_data = attr_itr
finally:
self.disable_module_getattr = False
elif kind == 'call_function':
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
self.disable_module_getattr = True
try:
if target == '__call__':
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
finally:
self.disable_module_getattr = False
elif kind == 'call_module':
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
proxy.meta_data = self._custom_leaf_module_impl.get(type(mod),
mod.forward)(*tree_map(unwrap_fn, args),
**tree_map(unwrap_fn, kwargs))
finally:
self.disable_module_getattr = False
return proxy
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
return node
def trace(self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = {},
meta_args: Optional[Dict[str, torch.Tensor]] = {}) -> Graph:
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
concrete_arg_names = set(concrete_args.keys())
# update concrete args with default values
for k, v in sig.parameters.items():
if k in sig_names - meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]):
for name in names:
if name not in sig_names:
raise ValueError(f"Argument {name} is not in the signature of {root.__class__.__name__}.forward")
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
self.concrete_args = concrete_args
self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
self.mod_dir = 'self'
self.graph = super().trace(root, concrete_args=concrete_args)
self.mod_dir = ''
self.graph.lint()
return self.graph
@contextmanager
def _tracer_override(self):
# override the tracer to support custom modules and checkpointing
if self.trace_act_ckpt:
orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply
orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant
def checkpoint(run_function, preserve_rng_state=False, *args):
self.ckpt_regions.append(self.ckpt_idx)
out = run_function(*args)
self.ckpt_idx = self.ckpt_regions.pop(-1) + 1
return out
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction.apply = checkpoint
torch.utils.checkpoint._checkpoint_without_reentrant = checkpoint
# override the custom functions
ColoProxy._func_dispatch.update({k: v for k, v in self._custom_impl.items()})
# override the bias addition functions
if self.bias_addition_split:
ColoProxy._func_dispatch.update({k: v for k, v in self._bias_addition_impl.items()})
yield
if self.trace_act_ckpt:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction.apply = orig_ckpt_func_apply
torch.utils.checkpoint._checkpoint_reentrant = orig_ckpt_func_without_reentrant
ColoProxy._func_dispatch = {}
@contextmanager
def _torch_factory_override(self):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def wrap_factory_method(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values())
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True
try:
proxy = self.create_proxy('call_function', target, args, kwargs)
finally:
self.disable_module_getattr = False
return proxy
else:
return target(*args, **kwargs)
return wrapper, target
overrides = {
target: wrap_factory_method(getattr(torch, target))
for target in _TorchFactoryMethod
if callable(getattr(torch, target))
}
for name, (wrapper, orig) in overrides.items():
setattr(torch, name, wrapper)
yield
# recover the torch factory functions upon exit
for name, (wrapper, orig) in overrides.items():
setattr(torch, name, orig)
def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
if node.op == "output":
node.type = None
self.graph.lint()
def getattr(self, attr, attr_val, parameter_proxy_cache):
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "disable_module_getattr", False):
return attr_val
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
lambda node: ColoProxy(self, node, n, attr_val))
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
return attr_val
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = {},
meta_args: Optional[Dict[str, Any]] = {},
trace_act_ckpt: bool = False,
bias_addition_split: bool = False,
) -> ColoGraphModule:
"""
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
attached to the ``Node``s.
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
This tracer is able to trace basic control flow and for loops.
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
(See ./bias_addition.py for more details).
Examples:
1. Tracing a ``torch.nn.Module`` with control flow.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
if x.size(0) > 1:
x = x.sum(dim=0)
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
# traced code like:
# def forward(self, x):
# linear_1 = self.linear(x)
# return linear_1
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
# traced code like:
# def forward(self, x):
# sum = x.sum(dim=0); x = None
# linear = self.linear(sum); sum = None
# return linear
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
def custom_forward(x):
return self.linear(x)
return torch.utils.checkpoint.checkpoint(custom_forward, x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
# traced code like:
# def checkpoint_0(self, x):
# linear = self.linear(x); x = None
# return linear
#
# def forward(self, x):
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
# return linear
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2, bias=True)
def forward(self, x):
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
# traced code like:
# def forward(self, x):
# linear_bias = self.linear.bias
# linear_weight = self.linear.weight
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# return add
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
Defaults to {}.
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
for tracing control flow. Defaults to {}.
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
Defaults to False.
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
Returns:
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
Remarks:
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
repo. We welcome any feedback and contributions to enhance the extensibility of
Colossal-AI.
"""
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
bias_addition_split=bias_addition_split).trace(root.to(device),
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
if trace_act_ckpt:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)

View File

@ -226,7 +226,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
Returns:
Any: The value returned from executing the Module
"""
return super().run(*args)
return self.run(*args)
def summary(self, unit: str = 'MB') -> str:
"""

View File

@ -288,13 +288,16 @@ class MetaInfoProp(torch.fx.Interpreter):
def flops_repr(flop: int) -> str:
return f"{flop:,} FLOPs"
accumulate_size = 0
for node in self.module.graph.nodes:
node: Node
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
node_summaries.append([
node.op,
str(node),
flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']),
mem_repr(accumulate_size),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
@ -309,6 +312,7 @@ class MetaInfoProp(torch.fx.Interpreter):
'Op',
'Forward FLOPs',
'Backward FLOPs',
'Accumulated Memory',
'FWD_IN',
'FWD_OUT',
'FWD_TMP',

View File

@ -347,6 +347,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
aten.stack.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,

View File

View File

@ -0,0 +1,113 @@
import pytest
import torch
from torch.utils.checkpoint import checkpoint
try:
from colossalai._analyzer.fx import symbolic_trace
except:
pass
class LinearModel(torch.nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
x = self.linear(x)
return x
class ConvModel(torch.nn.Module):
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(in_channel,
out_channels,
kernel_size,
bias=bias,
padding=1,
stride=2,
dilation=2,
groups=3)
self.conv_transpose = torch.nn.ConvTranspose2d(in_channel,
out_channels,
kernel_size,
bias=bias,
padding=1,
stride=2,
dilation=2,
groups=3)
def forward(self, x, select=0):
if select == 0:
x = self.conv(x)
else:
x = self.conv_transpose(x)
return x
class SiuModel(torch.nn.Module):
def __init__(self, bias) -> None:
super().__init__()
self.linear = LinearModel(3, 3, bias)
self.conv = ConvModel(3, 6, 3, bias)
def forward(self, x, select=0):
x = self.linear(x)
x = checkpoint(self.conv, x, select)
return x
class AddmmModel(torch.nn.Module):
def __init__(self, alpha, beta) -> None:
super().__init__()
self.alpha = alpha
self.beta = beta
def forward(self, x):
x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta)
return x
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@pytest.mark.parametrize("select", [0, 1])
def test_siu_model(bias, bias_addition_split, shape, select):
model = SiuModel(bias=bias)
x = torch.rand(shape)
gm = symbolic_trace(model,
meta_args={'x': x},
concrete_args={'select': select},
trace_act_ckpt=True,
bias_addition_split=bias_addition_split)
assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!'
if bias and bias_addition_split:
assert '+' in gm.code, 'bias addition should be split!'
else:
assert '+' not in gm.code, 'bias addition should not be split!'
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("alpha", [1, 2])
@pytest.mark.parametrize("beta", [1, 2])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3), (5, 5)])
def test_addmm_model(alpha, beta, bias_addition_split, shape):
model = AddmmModel(alpha=alpha, beta=beta)
x = torch.rand(shape)
gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split)
assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!'
if (alpha == 1 and beta == 1) or not bias_addition_split:
assert '*' not in gm.code, 'bias addition should not be split!'
elif bias_addition_split:
assert '+' in gm.code, 'bias addition should be split!'
if __name__ == '__main__':
test_siu_model(True, True, (3, 3, 3))

View File

@ -0,0 +1,78 @@
import pytest
import torch
try:
from colossalai._analyzer.fx import symbolic_trace
except:
pass
class LinearModel(torch.nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
x = self.linear(x)
return x
class ConvModel(torch.nn.Module):
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(in_channel,
out_channels,
kernel_size,
bias=bias,
padding=1,
stride=2,
dilation=2,
groups=3)
self.conv_transpose = torch.nn.ConvTranspose2d(out_channels,
out_channels,
kernel_size,
bias=bias,
padding=1,
stride=2,
dilation=2,
groups=3)
def forward(self, x):
x = self.conv(x)
x = self.conv_transpose(x)
return x
class AModel(torch.nn.Module):
def __init__(self, bias) -> None:
super().__init__()
self.linear_1 = LinearModel(3, 3, bias)
self.linear_2 = LinearModel(3, 3, bias)
self.conv = ConvModel(3, 6, 3, bias)
def forward(self, x):
for i in range(x.shape[0]):
x = self.linear_1(x)
x = self.linear_2(x)
x = self.conv(x)
return x
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
def test_mod_dir(bias, bias_addition_split, shape):
model = AModel(bias=bias)
x = torch.rand(shape)
gm = symbolic_trace(model, meta_args={'x': x}, bias_addition_split=bias_addition_split)
for node in gm.graph.nodes:
assert len(node.meta['info'].mod_dir), f"{node} should have non-trivial ``mod_dir``."
print(node, node.meta['info'].mod_dir)
if __name__ == '__main__':
test_mod_dir(True, True, (3, 3, 3))

View File

@ -0,0 +1,55 @@
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import pytest
try:
from colossalai._analyzer.fx import symbolic_trace
except:
pass
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = nn.Linear(10, 10)
self.c = nn.Linear(10, 10)
self.d = nn.Linear(10, 10)
self.e = nn.Linear(10, 10)
def checkpoint_0(self, x):
return checkpoint(self.checkpoint_0_0, x) + checkpoint(self.checkpoint_0_1, x) + self.e(x)
def checkpoint_0_0(self, x):
return checkpoint(self.checkpoint_0_0_0, x) + checkpoint(self.checkpoint_0_0_1, x)
def checkpoint_0_0_0(self, x):
return self.a(x) + checkpoint(self.checkpoint_0_0_0_0, x, use_reentrant=False)
def checkpoint_0_0_0_0(self, x):
return self.b(x)
def checkpoint_0_0_1(self, x):
return self.b(x) + self.c(x)
def checkpoint_0_1(self, x):
return self.d(x)
def forward(self, x):
return checkpoint(self.checkpoint_0, x)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
def test_nested_ckpt():
model = MyModule()
x = torch.rand(10, 10)
gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True)
assert torch.allclose(gm(x), model(x)), "The traced model should generate the same output as the original model."
for ckpt_def in filter(lambda s: s.startswith('checkpoint'), dir(model)):
assert ckpt_def in gm.code, f"Checkpoint {ckpt_def} should be in the traced code.\n Traced code = {gm.code}"
if __name__ == "__main__":
test_nested_ckpt()

View File

@ -0,0 +1,63 @@
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm
from .zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode
from colossalai._analyzer.fx import symbolic_trace
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.symbolic_profile import register_shape_impl
@register_shape_impl(torch.nn.functional.linear)
def linear_impl(*args, **kwargs):
assert True
return torch.nn.functional.linear(*args, **kwargs)
except:
pass
def _check_gm_validity(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
if node.op in [
# 'call_module', # can apply to params
# 'call_function', # can apply to params
# 'call_method', # can apply to params
]:
assert node.meta['info'].inputs, f'In {gm.__class__.__name__}, {node} has no input shape.'
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models)
def test_torchvision_shape_prop(m):
with MetaTensorMode():
model = m()
data = torch.rand(100, 3, 224, 224)
meta_args = {
"x": data,
}
gm = symbolic_trace(model, meta_args=meta_args)
shape_prop_pass(gm, data)
_check_gm_validity(gm)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tmm_models)
def test_timm_shape_prop(m):
with MetaTensorMode():
model = m()
data = torch.rand(100, 3, 224, 224)
meta_args = {
"x": data,
}
gm = symbolic_trace(model, meta_args=meta_args)
shape_prop_pass(gm, data)
_check_gm_validity(gm)
if __name__ == "__main__":
test_torchvision_shape_prop(tm.resnet18)
test_timm_shape_prop(tmm.vgg11)

View File

@ -0,0 +1,49 @@
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm
from .zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode
from colossalai._analyzer.fx import symbolic_profile, symbolic_trace
except:
pass
def _check_gm_validity(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models)
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
model = m()
data = torch.rand(8, 3, 224, 224)
meta_args = {
"x": data,
}
gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split)
symbolic_profile(gm, data, verbose=verbose)
_check_gm_validity(gm)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tmm_models)
def test_timm_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
model = m()
data = torch.rand(8, 3, 224, 224)
meta_args = {
"x": data,
}
gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split)
symbolic_profile(gm, data, verbose=verbose)
_check_gm_validity(gm)
if __name__ == "__main__":
test_torchvision_profile(tm.vit_b_16, verbose=True, bias_addition_split=False)
test_timm_profile(tmm.gmlp_b16_224, verbose=True, bias_addition_split=False)

View File

@ -0,0 +1,53 @@
import timm.models as tmm
import torchvision.models as tm
# input shape: (batch_size, 3, 224, 224)
tm_models = [
tm.alexnet,
tm.convnext_base,
tm.densenet121,
# tm.efficientnet_v2_s,
# tm.googlenet, # output bad case
# tm.inception_v3, # bad case
tm.mobilenet_v2,
tm.mobilenet_v3_small,
tm.mnasnet0_5,
tm.resnet18,
tm.regnet_x_16gf,
tm.resnext50_32x4d,
tm.shufflenet_v2_x0_5,
tm.squeezenet1_0,
# tm.swin_s, # fx bad case
tm.vgg11,
tm.vit_b_16,
tm.wide_resnet50_2,
]
tmm_models = [
tmm.beit_base_patch16_224,
tmm.beitv2_base_patch16_224,
tmm.cait_s24_224,
tmm.coat_lite_mini,
tmm.convit_base,
tmm.deit3_base_patch16_224,
tmm.dm_nfnet_f0,
tmm.eca_nfnet_l0,
tmm.efficientformer_l1,
tmm.ese_vovnet19b_dw,
tmm.gmixer_12_224,
tmm.gmlp_b16_224,
tmm.hardcorenas_a,
tmm.hrnet_w18_small,
tmm.inception_v3,
tmm.mixer_b16_224,
tmm.nf_ecaresnet101,
tmm.nf_regnet_b0,
# tmm.pit_b_224, # pretrained only
tmm.regnetv_040,
tmm.skresnet18,
# tmm.swin_base_patch4_window7_224, # fx bad case
# tmm.tnt_b_patch16_224, # bad case
tmm.vgg11,
tmm.vit_base_patch16_18x2_224,
tmm.wide_resnet50_2,
]

View File

@ -0,0 +1,82 @@
from typing import Any, Callable, Union
import pytest
import torch
import torch.nn as nn
try:
from colossalai._analyzer._subclasses import MetaTensor
except:
pass
aten = torch.ops.aten
registered_meta = {
('aten.convolution.default', True): [ # (aten ops, requires_backward)
(nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
(nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)),
(nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)),
(nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
(nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
dilation=2), torch.rand(2, 3, 4, 4)),
(nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
dilation=2), torch.rand(2, 3, 4, 4, 4)),
],
('aten.native_batch_norm.default', True): [
(nn.BatchNorm1d(4), torch.rand(2, 4)),
(nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)),
(nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)),
],
('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),],
('aten.avg_pool1d.default', True): [
(nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)),
(nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)),
(nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)),
(nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)),
],
('aten.avg_pool2d.default', True): [
(nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
(nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
(nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
(nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
],
('aten.relu.default', True): [
(nn.ReLU(), torch.rand(4, 3, 1, 2)),
(nn.LeakyReLU(), torch.rand(4, 3, 1, 2)),
(nn.SiLU(), torch.rand(4, 3, 1, 2)),
(nn.GELU(), torch.rand(4, 3, 1, 2)),
(nn.ELU(), torch.rand(4, 3, 1, 2)),
(nn.Sigmoid(), torch.rand(4, 3, 1, 2)),
(nn.Tanh(), torch.rand(4, 3, 1, 2)),
(nn.Hardswish(), torch.rand(4, 3, 1, 2)),
]
}
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
assert tensor.stride() == meta_tensor.stride(
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
x.requires_grad = requires_backward
meta_x = MetaTensor(x)
x_out, meta_out = f(x), f(meta_x)
compare_all(x_out, meta_out)
if requires_backward:
x_out.sum().backward()
meta_out.sum().backward()
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v:
run_and_compare(f, x, requires_backward)
if __name__ == '__main__':
test_meta_aten()

View File

@ -0,0 +1,50 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tm
from .zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
except:
pass
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
def test_flop_count_module(m):
x = torch.rand(2, 3, 224, 224)
with MetaTensorMode(): # save time for testing
module = m()
rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}'
assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}'
odd_cases = [
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
'inplace': True
}),
(F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
'kernel_size': 3,
'stride': 2,
'padding': 1,
'dilation': 2
}),
(torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True),
torch.rand(2, 3, 224, 224, requires_grad=True)), {}),
]
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}'
if __name__ == '__main__':
test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224))
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})

View File

@ -0,0 +1,38 @@
import pytest
import torch
import torch.distributed as dist
import torchvision.models as tm
try:
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
except:
pass
from .zoo import tm_models, tmm_models
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
assert tensor.stride() == meta_tensor.stride(
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
def run_and_compare(model):
x = torch.rand(2, 3, 224, 224, requires_grad=True)
x_out = model(x)
with MetaTensorMode():
meta_x = torch.rand(2, 3, 224, 224, requires_grad=True)
meta_out = model(meta_x)
compare_all(x_out, meta_out)
x_out.sum().backward()
meta_out.sum().backward()
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
def test_meta_mode_shape(m):
run_and_compare(m())
if __name__ == '__main__':
test_meta_mode_shape(tm.resnet18)

View File

@ -0,0 +1,53 @@
import timm.models as tmm
import torchvision.models as tm
# input shape: (batch_size, 3, 224, 224)
tm_models = [
tm.alexnet,
tm.convnext_base,
tm.densenet121,
# tm.efficientnet_v2_s,
# tm.googlenet, # output bad case
# tm.inception_v3, # bad case
tm.mobilenet_v2,
tm.mobilenet_v3_small,
tm.mnasnet0_5,
tm.resnet18,
tm.regnet_x_16gf,
tm.resnext50_32x4d,
tm.shufflenet_v2_x0_5,
tm.squeezenet1_0,
# tm.swin_s, # fx bad case
tm.vgg11,
tm.vit_b_16,
tm.wide_resnet50_2,
]
tmm_models = [
tmm.beit_base_patch16_224,
tmm.beitv2_base_patch16_224,
tmm.cait_s24_224,
tmm.coat_lite_mini,
tmm.convit_base,
tmm.deit3_base_patch16_224,
tmm.dm_nfnet_f0,
tmm.eca_nfnet_l0,
tmm.efficientformer_l1,
tmm.ese_vovnet19b_dw,
tmm.gmixer_12_224,
tmm.gmlp_b16_224,
tmm.hardcorenas_a,
tmm.hrnet_w18_small,
tmm.inception_v3,
tmm.mixer_b16_224,
tmm.nf_ecaresnet101,
tmm.nf_regnet_b0,
# tmm.pit_b_224, # pretrained only
tmm.regnetv_040,
tmm.skresnet18,
# tmm.swin_base_patch4_window7_224, # fx bad case
# tmm.tnt_b_patch16_224, # bad case
tmm.vgg11,
tmm.vit_base_patch16_18x2_224,
tmm.wide_resnet50_2,
]