diff --git a/colossalai/__init__.py b/colossalai/__init__.py index 79ae1ba16..e00d783d5 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,10 +1,3 @@ -try: - from . import _meta_registrations - META_COMPATIBILITY = True -except: - import torch - META_COMPATIBILITY = False - print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.') from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch, get_default_parser) diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index b1850798c..5693f3eac 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1,3 +1,4 @@ -from .tracer import ColoTracer, meta_trace +from ._compatibility import compatibility, is_compatible_with_meta from .graph_module import ColoGraphModule from .passes import MetaInfoProp +from .tracer import ColoTracer, meta_trace diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py new file mode 100644 index 000000000..126403270 --- /dev/null +++ b/colossalai/fx/_compatibility.py @@ -0,0 +1,46 @@ +from typing import Callable + +import torch + +try: + from . import _meta_registrations + META_COMPATIBILITY = True +except: + META_COMPATIBILITY = False + + +def compatibility(is_backward_compatible: bool = False) -> Callable: + """A decorator to make a function compatible with different versions of PyTorch. + + Args: + is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False. + + Returns: + Callable: The decorated function + """ + + def decorator(func): + if META_COMPATIBILITY: + return func + else: + if is_backward_compatible: + return func + else: + + def wrapper(*args, **kwargs): + raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}') + + return wrapper + + return decorator + + +def is_compatible_with_meta() -> bool: + """Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx` + modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its + experimental counterparts. + + Returns: + bool: The meta compatibility + """ + return META_COMPATIBILITY diff --git a/colossalai/_meta_registrations.py b/colossalai/fx/_meta_registrations.py similarity index 91% rename from colossalai/_meta_registrations.py rename to colossalai/fx/_meta_registrations.py index 4e58c61c4..94387fbe0 100644 --- a/colossalai/_meta_registrations.py +++ b/colossalai/fx/_meta_registrations.py @@ -1,7 +1,10 @@ # meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py # should be activated for PyTorch version 1.12.0 and below +# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +# for more meta_registrations from typing import List, Optional, Tuple, Union + import torch from torch.utils._pytree import tree_map @@ -31,6 +34,7 @@ def register_meta(op, register_dispatcher=True): return wrapper +# ============================== Convolutions ====================================== # https://github.com/pytorch/pytorch/pull/79834 @register_meta(aten.convolution.default) def meta_conv( @@ -165,6 +169,18 @@ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: t return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta_adaptive_avg_pool2d_backward( + grad_output: torch.Tensor, + input: torch.Tensor, +): + grad_input = torch.empty_like(input) + return grad_input + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp +# ============================== Activations ======================================= @register_meta(aten.relu.default) def meta_relu(input: torch.Tensor): return torch.empty_like(input) @@ -192,11 +208,8 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: return grad_in -@register_meta(aten.roll.default) -def meta_roll(input: torch.Tensor, shifts, dims): - return input - - +# ============================== Normalization ===================================== +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm.default) def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) @@ -207,6 +220,7 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini return output, running_mean, running_var +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm_backward.default) def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask): @@ -241,6 +255,7 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch. return dX, dgamma, dbeta +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm.default) def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): bs = input.size(0) @@ -252,6 +267,7 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): return output, running_mean, running_var +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm_backward.default) def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask): @@ -261,13 +277,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me return dX, dgamma, dbeta -@register_meta(aten._adaptive_avg_pool2d_backward.default) -def meta_adaptive_avg_pool2d_backward( - grad_output: torch.Tensor, - input: torch.Tensor, -): - grad_input = torch.empty_like(input) - return torch.empty_like(input) +# ================================== Misc ========================================== +#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +@register_meta(aten.roll.default) +def meta_roll(input: torch.Tensor, shifts, dims): + return input + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp +@register_meta(aten.where.self) +def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): + result_type = torch.result_type(self, other) + return torch.empty_like(self, dtype=result_type) @register_meta(aten.index.Tensor) @@ -360,6 +381,8 @@ def meta_index_Tensor(self, indices): return self.new_empty(before_shape + replacement_shape + after_shape) +# ============================== Embedding ========================================= +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp @register_meta(aten.embedding_dense_backward.default) def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq): @@ -369,13 +392,7 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens layout=grad_output.layout) -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp -@register_meta(aten.where.self) -def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): - result_type = torch.result_type(self, other) - return torch.empty_like(self, dtype=result_type) - - +# ============================== Dropout =========================================== # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp @register_meta(aten.native_dropout.default) def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py index 841dd19a1..69e4e9f2c 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py @@ -1,16 +1,20 @@ -from typing import List, Tuple import copy -import torch -from torch.fx import GraphModule, Node -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.profiler import parameter_size import math -from .linearize import linearize -from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch +from typing import List, Tuple + +import torch +from colossalai.fx import is_compatible_with_meta +from colossalai.fx.codegen.activation_checkpoint_codegen import \ + _find_nested_ckpt_regions +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec) from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions -from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec -from colossalai import META_COMPATIBILITY +from colossalai.fx.profiler import parameter_size +from torch.fx import GraphModule, Node + +from .linearize import linearize +from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch, + Sequence) INF = float("inf") @@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule, mem_limit -= parameter_size(gm) # prepare data - if META_COMPATIBILITY: + if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor data = MetaTensor(data, fake_device=next(gm.parameters()).device) MetaInfoProp(gm).run(data) diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py index 44dea6fc4..191d8d67d 100644 --- a/colossalai/fx/passes/concrete_info_prop.py +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -1,13 +1,12 @@ from dataclasses import asdict -from colossalai.fx.profiler import GraphInfo +from typing import Any, Dict, List, NamedTuple, Optional, Tuple + import torch import torch.fx -from torch.fx.node import Node, Argument, Target +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module) +from torch.fx.node import Argument, Node, Target from torch.utils._pytree import tree_flatten -from typing import Any, List, Tuple, NamedTuple, Dict, Optional -from torch.fx._compatibility import compatibility -from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size -from torch.fx.graph_module import GraphModule @compatibility(is_backward_compatible=True) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index e7435fa4e..4fab5d041 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,11 +1,13 @@ from dataclasses import asdict +from typing import Any, Dict, List, NamedTuple, Tuple + import torch import torch.fx -from torch.fx.node import Node, Argument, Target +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp, + profile_function, profile_method, profile_module) +from torch.fx.node import Argument, Node, Target from torch.utils._pytree import tree_map -from typing import Any, List, Tuple, NamedTuple, Dict -from torch.fx._compatibility import compatibility -from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in @compatibility(is_backward_compatible=True) diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index fc02e0c46..b520ff124 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -1,11 +1,12 @@ -from ... import META_COMPATIBILITY -if META_COMPATIBILITY: +from .._compatibility import is_compatible_with_meta + +if is_compatible_with_meta(): + from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp from .opcount import flop_mapping - from .tensor import MetaTensor from .profiler import profile_function, profile_method, profile_module - from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out + from .tensor import MetaTensor else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .dataflow import GraphInfo -from .memory import parameter_size, activation_size, is_inplace +from .memory import activation_size, is_inplace, parameter_size diff --git a/colossalai/fx/profiler/constant.py b/colossalai/fx/profiler/constant.py deleted file mode 100644 index d923346fb..000000000 --- a/colossalai/fx/profiler/constant.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos -from . import META_COMPATIBILITY - -__all__ = [] - -if META_COMPATIBILITY: - aten = torch.ops.aten - - ALIAS_ATEN = [ - # inplace reshaping - aten.detach.default, - aten.t.default, - aten.transpose.int, - aten.view.default, - aten._unsafe_view.default, - aten._reshape_alias.default, - ] - - INPLACE_NEW = [ - aten.empty_like.default, - aten.new_empty_strided.default, - ] - - INPLACE_MATH_ATEN = [ - aten.add_.Tensor, - aten.sub_.Tensor, - aten.div_.Tensor, - aten.div_.Scalar, - aten.mul_.Tensor, - aten.bernoulli_.float, - ] - - CLONE_ATEN = [ - aten.clone.default, - ] - - __all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN'] - -else: - # TODO fill out the inplace ops - INPLACE_OPS = [ - add, - sub, - mul, - floordiv, - neg, - pos, - getitem, - setitem, - getattr, - torch.Tensor.cpu, - ] - - # TODO: list all call_methods that are inplace here - INPLACE_METHOD = [ - 'transpose', - 'permute', - # TODO: reshape may return a copy of the data if the data is not contiguous - 'reshape', - 'dim', - 'flatten', - 'size', - 'view', - 'unsqueeze', - 'to', - 'type', - 'flatten', - ] - - # TODO: list all call_methods that are not inplace here - NON_INPLACE_METHOD = [ - 'chunk', - 'contiguous', - 'expand', - 'mean', - 'split', - ] - __all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py new file mode 100644 index 000000000..38214e219 --- /dev/null +++ b/colossalai/fx/profiler/constants.py @@ -0,0 +1,32 @@ +import torch + +__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN'] + +aten = torch.ops.aten + +ALIAS_ATEN = [ + aten.detach.default, + aten.t.default, + aten.transpose.int, + aten.view.default, + aten._unsafe_view.default, + aten._reshape_alias.default, +] + +INPLACE_NEW = [ + aten.empty_like.default, + aten.new_empty_strided.default, +] + +INPLACE_MATH_ATEN = [ + aten.add_.Tensor, + aten.sub_.Tensor, + aten.div_.Tensor, + aten.div_.Scalar, + aten.mul_.Tensor, + aten.bernoulli_.float, +] + +CLONE_ATEN = [ + aten.clone.default, +] diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index fe870b673..f7009a84a 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -2,7 +2,10 @@ from dataclasses import dataclass, field from enum import Enum from functools import partial from typing import Dict, List + from torch.fx import Graph, Node + +from .._compatibility import compatibility from .memory import activation_size, is_inplace @@ -12,6 +15,7 @@ class Phase(Enum): PLACEHOLDER = 2 +@compatibility(is_backward_compatible=True) @dataclass class GraphInfo: """ @@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool: return n.meta['phase'] == phase +@compatibility(is_backward_compatible=False) def autograd_graph_analysis(graph: Graph) -> GraphInfo: """Analyze the autograd node dependencies and find out the memory usage. Basically the input graph should have all nodes marked for keyword `phase`. diff --git a/colossalai/fx/profiler/experimental/__init__.py b/colossalai/fx/profiler/experimental/__init__.py index 3dfdd2758..fbb6ff624 100644 --- a/colossalai/fx/profiler/experimental/__init__.py +++ b/colossalai/fx/profiler/experimental/__init__.py @@ -1,5 +1,5 @@ -from .registry import meta_profiler_function, meta_profiler_module -from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out +from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp +from .profiler import profile_function, profile_method, profile_module from .profiler_function import * from .profiler_module import * -from .profiler import profile_function, profile_method, profile_module +from .registry import meta_profiler_function, meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py new file mode 100644 index 000000000..57ff3fd91 --- /dev/null +++ b/colossalai/fx/profiler/experimental/constants.py @@ -0,0 +1,44 @@ +from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub + +import torch + +__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] + +# TODO fill out the inplace ops +INPLACE_OPS = [ + add, + sub, + mul, + floordiv, + neg, + pos, + getitem, + setitem, + getattr, + torch.Tensor.cpu, +] + +# TODO: list all call_methods that are inplace here +INPLACE_METHOD = [ + 'transpose', + 'permute', + # TODO: reshape may return a copy of the data if the data is not contiguous + 'reshape', + 'dim', + 'flatten', + 'size', + 'view', + 'unsqueeze', + 'to', + 'type', + 'flatten', +] + +# TODO: list all call_methods that are not inplace here +NON_INPLACE_METHOD = [ + 'chunk', + 'contiguous', + 'expand', + 'mean', + 'split', +] diff --git a/colossalai/fx/profiler/experimental/memory.py b/colossalai/fx/profiler/experimental/memory.py index 601c4cf36..1e53ed0bf 100644 --- a/colossalai/fx/profiler/experimental/memory.py +++ b/colossalai/fx/profiler/experimental/memory.py @@ -1,11 +1,15 @@ # for PyTorch 1.11 compatibility uses +from typing import Dict, List, Tuple, Union + import torch -from torch.fx import Node, GraphModule -from typing import Union, Dict, List, Tuple +from torch.fx import GraphModule, Node + +from ..._compatibility import compatibility __all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] +@compatibility(is_backward_compatible=True) def calculate_fwd_in(n: Node) -> bool: """A helper function to calculate `fwd_in` @@ -18,6 +22,7 @@ def calculate_fwd_in(n: Node) -> bool: return n.meta['save_fwd_in'] +@compatibility(is_backward_compatible=True) def calculate_fwd_tmp(n: Node) -> int: """A helper function to calculate `fwd_tmp` @@ -30,6 +35,7 @@ def calculate_fwd_tmp(n: Node) -> int: return n.meta["fwd_mem_tmp"] +@compatibility(is_backward_compatible=True) def calculate_fwd_out(n: Node) -> int: """A helper function to calculate `fwd_out` diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index c7c3f81dd..fbeea5128 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -1,15 +1,19 @@ from dataclasses import dataclass -from typing import Callable, Any, Dict, Tuple +from typing import Any, Callable, Dict, Tuple + import torch from torch.fx.node import Argument, Target -from . import meta_profiler_function, meta_profiler_module + +from ..._compatibility import compatibility from ..memory import activation_size -from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS +from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD +from .registry import meta_profiler_function, meta_profiler_module __all__ = ['profile_function', 'profile_module', 'profile_method'] # this is for compatibility use +@compatibility(is_backward_compatible=True) @dataclass class GraphInfo: """ @@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int """ +@compatibility(is_backward_compatible=True) def profile_function(target: 'Target') -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to @@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable: return f +@compatibility(is_backward_compatible=True) def profile_method(target: 'Target') -> Callable: """ Wrap a `call_method` node @@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable: return f +@compatibility(is_backward_compatible=True) def profile_module(module: torch.nn.Module) -> Callable: """ Wrap a `call_module` node or `torch.nn` in order to diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py index 15e8aa675..1e8561206 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py @@ -2,7 +2,6 @@ import operator from typing import Any, Tuple import torch from ..registry import meta_profiler_function -from colossalai.fx.proxy import ColoProxy @meta_profiler_function.register(operator.getitem) diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index 884de33e0..1a3f127f1 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -1,13 +1,16 @@ +from typing import Dict, List, Tuple, Union + import torch -from torch.fx import Node, GraphModule -from typing import Union, Dict, List, Tuple -from . import META_COMPATIBILITY +from torch.fx import GraphModule, Node + +from .._compatibility import compatibility, is_compatible_with_meta __all__ = [ 'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out" ] +@compatibility(is_backward_compatible=True) def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: """Calculate activation size of a node. @@ -29,6 +32,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: return act_size +@compatibility(is_backward_compatible=True) def parameter_size(mod: torch.nn.Module) -> int: """Calculate parameter size of a node. @@ -111,8 +115,8 @@ def is_inplace(n: Node): inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) - if META_COMPATIBILITY: - from .constant import ALIAS_ATEN + if is_compatible_with_meta(): + from .constants import ALIAS_ATEN if n.target in ALIAS_ATEN: inplace = True elif n.op == "call_module": diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 22298ef26..8bd972ff3 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -1,10 +1,11 @@ # adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py # ideas from https://pastebin.com/AkvAyJBw -from functools import partial, reduce import operator -from typing import Callable, List, Any +from functools import partial, reduce from numbers import Number +from typing import Any, Callable, List + import torch aten = torch.ops.aten diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 30284f64a..608cc9e4d 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,16 +1,19 @@ +import time from functools import partial -from typing import Callable, Any, Dict, Tuple +from typing import Any, Callable, Dict, Tuple + import torch -from torch.nn.parameter import Parameter from torch.fx import Graph, Node from torch.fx.node import Argument, Target +from torch.nn.parameter import Parameter from torch.utils._pytree import tree_map -from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo + +from .._compatibility import compatibility +from .constants import ALIAS_ATEN +from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase from .memory import activation_size, parameter_size -from .constant import ALIAS_ATEN -from .tensor import MetaTensor from .opcount import flop_mapping -import time +from .tensor import MetaTensor __all__ = ['profile_function', 'profile_module', 'profile_method'] @@ -41,6 +44,7 @@ def detach_variables(x): return x +@compatibility(is_backward_compatible=True) def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30 To profile the actual forward memory, we first run target in the context torch.no_grad() to get @@ -140,6 +144,7 @@ def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ... return tree_map(detach_variables, out), graphinfo +@compatibility(is_backward_compatible=False) def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: """ Profile a Callable function with args and kwargs on meta devices. @@ -277,6 +282,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G return tree_map(unwrap, out), graph_info +@compatibility(is_backward_compatible=True) def profile_function(target: 'Target', device: str = 'meta') -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to @@ -335,6 +341,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: return f +@compatibility(is_backward_compatible=True) def profile_method(target: 'Target', device: str = 'meta') -> Callable: """ Wrap a `call_method` node @@ -353,6 +360,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: return f +@compatibility(is_backward_compatible=True) def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: """ Wrap a `call_module` node or `torch.nn` in order to diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index b380512a6..d4a078c2a 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -1,10 +1,13 @@ +import uuid from copy import deepcopy from typing import Optional + import torch -from torch.utils._pytree import tree_map, tree_flatten -from torch.types import _bool, _dtype, _device -import uuid -from .constant import ALIAS_ATEN +from torch.types import _bool, _device, _dtype +from torch.utils._pytree import tree_flatten, tree_map + +from .._compatibility import compatibility +from .constants import ALIAS_ATEN __all__ = ['MetaTensor'] @@ -15,6 +18,7 @@ def set_uuid(x): setattr(x, 'uuid', uuid.uuid4()) +@compatibility(is_backward_compatible=False) class MetaTensor(torch.Tensor): """ A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index fd5b1b2d1..830e2bf2d 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -1,23 +1,19 @@ -import threading -from enum import Enum -from typing import List, Any, Tuple, Dict, Callable -from functools import partial -from abc import ABC, abstractmethod -import math import inspect +import math +import threading +from abc import ABC, abstractmethod +from enum import Enum +from functools import partial +from typing import Any, Callable, Dict, List, Tuple import torch -from torch import nn import torch.distributed.rpc as rpc -from torch.futures import Future -from torch._C._distributed_rpc import PyRRef - -from torch import autograd -from torch import optim - from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail, - pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug) +from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map, + split_batch, tensor_shape_list, type_detail) +from torch import autograd, nn, optim +from torch._C._distributed_rpc import PyRRef +from torch.futures import Future class Phase(Enum): @@ -195,7 +191,6 @@ class WorkerBase(ABC): if isinstance(output, Future): output = output.wait() - # color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red') output_work_item.refcount += 1 # all consumers have been satisfied, the work_item can be released @@ -250,9 +245,6 @@ class WorkerBase(ABC): self.num_microbatches, forward_only) with self.work_list_condition_lock: self.work_list[key] = work_item - if use_color_debug: - color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', - 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() # just for last pp_rank @@ -273,9 +265,6 @@ class WorkerBase(ABC): work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, self.num_microbatches, False) - if use_color_debug: - color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta') - self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -297,23 +286,14 @@ class WorkerBase(ABC): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) - if use_color_debug: - color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', - 'data dispatch', 'magenta') - work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, microbatch_id, None, self.num_microbatches, forward_only) - # color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') # add work_item to work_list with self.work_list_condition_lock: key = UniqueKey(microbatch_id, Phase.FORWARD) assert key not in self.work_list self.work_list[key] = work_item_from_producer - if use_color_debug: - color_debug( - f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}', - 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() def subscribe_consumer(self, microbatch_id: int): @@ -328,10 +308,6 @@ class WorkerBase(ABC): subscribe_backward_futures: List[Future] = [None] * consumer_num output = self._get_future_by_device() - if use_color_debug: - color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer', - 'data dispatch', 'magenta') - for i in range(consumer_num): consumer_stage_id = self.consumer_stage_ids[i] consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) @@ -342,17 +318,11 @@ class WorkerBase(ABC): work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, microbatch_id, None, self.num_microbatches, False) - # color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') - # add work_item to work_list with self.work_list_condition_lock: key = UniqueKey(microbatch_id, Phase.BACKWARD) assert key not in self.work_list self.work_list[key] = work_item_from_consumer - if use_color_debug: - color_debug( - f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}', - 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() def _get_producer_consumer(self) -> None: @@ -406,11 +376,6 @@ class WorkerBase(ABC): is_first_stage = self.is_first_stage() is_last_stage = self.is_last_stage() - # if self.pp_rank == 0: - # print( - # f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}' - # ) - if phase == Phase.FORWARD: # remind its consumer to get data before forward if not is_last_stage: @@ -470,8 +435,6 @@ class WorkerBase(ABC): else: consume_result = self.module_partition(*args, **kwargs) - # print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', ) - if is_last_stage and self.criterion: with self.label_lock: self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) @@ -539,10 +502,6 @@ class WorkerBase(ABC): pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) - # for input_node in stage_input_args: - # if isinstance(input_node, torch.Tensor): - # consume_result.append(input_node.grad) - else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -593,11 +552,6 @@ class WorkerBase(ABC): with self.work_list_condition_lock: work_item = self.work_list.pop(work_item_key) - if use_color_debug: - color_debug( - f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}', - 'work loop', 'green') - with self.output_list_condition_lock: # assert work_item_key not in self.output_list self.output_list[work_item_key] = work_item @@ -605,11 +559,6 @@ class WorkerBase(ABC): consume_result = self._consume_work_item_by_phase(work_item) - if use_color_debug: - color_debug( - f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}', - 'work loop', 'green') - work_item.output.set_result(consume_result) # if is last step in one batch reset context and do step diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index fb4feb26d..0ab3a3694 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -1,13 +1,12 @@ -from typing import List, Callable, Dict import threading +from typing import Callable, Dict, List import torch import torch.distributed as dist -from torch.futures import Future -from torch._C._distributed_rpc import PyRRef - -from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem) +from torch._C._distributed_rpc import PyRRef +from torch.futures import Future # Implementation of different Pipeline schedule # Worker defines the worker for each stage diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py index 887166467..361f6faf7 100644 --- a/colossalai/pipeline/rpc/utils.py +++ b/colossalai/pipeline/rpc/utils.py @@ -1,25 +1,15 @@ -from typing import List, Any, Tuple, Dict, Callable, Type, Union +import argparse import os import warnings -import argparse +from typing import Any, Callable, Dict, List, Tuple, Type, Union import torch -import torch.multiprocessing as mp -from torch.futures import Future import torch.distributed.rpc as rpc -from torch._C._distributed_rpc import _is_current_rpc_agent_set -from colorama import Back, Style - +import torch.multiprocessing as mp from colossalai.initialize import launch from colossalai.pipeline.pipeline_process_group import ppg - -# config for debug and test -use_color_debug = False - - -def color_debug(text, prefix=' ', color='blue'): - color = color.upper() - print(getattr(Back, color), prefix, Style.RESET_ALL, text) +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from torch.futures import Future def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: diff --git a/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py index 41ed6fd8c..773cf151d 100644 --- a/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py @@ -1,18 +1,20 @@ import copy + +import colossalai +import pytest import torch +import torch.fx import torch.multiprocessing as mp import torchvision.models as tm -import torch.fx -import colossalai -from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.core import global_context as gpc from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.passes.algorithms import solver_rotor from colossalai.fx.passes.algorithms.operation import Sequence -from colossalai.core import global_context as gpc +from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.utils import free_port -import pytest -from colossalai import META_COMPATIBILITY -if META_COMPATIBILITY: + +if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor try: @@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0): graph = tracer.trace(model, meta_args={"x": data}) graph.set_codegen(ActivationCheckpointCodeGen()) gm = ColoGraphModule(model, graph, model.__class__.__name__) - if META_COMPATIBILITY: + if is_compatible_with_meta(): data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device) MetaInfoProp(gm).run(data_meta) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index ff61e604c..3914d57be 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,20 +1,22 @@ -from typing import Callable import copy import re +from typing import Callable + +import colossalai +import pytest import torch import torch.multiprocessing as mp import torchvision.models as tm -from torch.fx import GraphModule -import colossalai -from colossalai.fx import ColoTracer -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor -from colossalai.utils import free_port from colossalai.core import global_context as gpc -import pytest -from colossalai import META_COMPATIBILITY -if META_COMPATIBILITY: +from colossalai.fx import ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port +from torch.fx import GraphModule + +if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor try: @@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule): def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], model_cls: Callable[[], torch.nn.Module]): criterion = torch.nn.MSELoss() - data = torch.rand(2, 3, 32, 32) - label = torch.rand(2, 5) + m.cuda() + data = torch.rand(2, 3, 32, 32).cuda() + label = torch.rand(2, 5).cuda() loss = criterion(m(data), label) loss.backward() loss = criterion(gm(data), label) @@ -77,7 +80,7 @@ def _run_ckpt_solver(rank): m = model_cls(num_classes=5) graph = tracer.trace(root=m) gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) - MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda')) + MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda()) codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) if solver == solver_rotor: diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py index ec30d0e76..a803f8c07 100644 --- a/tests/test_fx/test_ckpt_solvers/test_linearize.py +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -1,13 +1,14 @@ -from colossalai.fx.passes.meta_info_prop import MetaInfoProp +import pytest import torch import torchvision.models as tm from colossalai.fx import ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.algorithms import solver_rotor, linearize -from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd -import pytest -from colossalai import META_COMPATIBILITY -if META_COMPATIBILITY: +from colossalai.fx.passes.algorithms import linearize, solver_rotor +from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) +from colossalai.fx.passes.meta_info_prop import MetaInfoProp + +if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor try: diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index bc4348c97..8825bbb46 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -1,13 +1,17 @@ -import torch -import torch.nn as nn import colossalai import colossalai.nn as col_nn -from torch.fx import symbolic_trace -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass -from colossalai.fx.passes.utils import get_comm_size -from colossalai import META_COMPATIBILITY import pytest +import torch +import torch.nn as nn +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass) +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.passes.utils import get_comm_size +from torch.fx import symbolic_trace + +is_compatible = is_compatible_with_meta() +if is_compatible: + from colossalai.fx.profiler import MetaTensor MODEL_DIM = 16 BATCH_SIZE = 8 @@ -31,12 +35,12 @@ class MLP(torch.nn.Module): return x -@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') def test_comm_size_compute(): - from colossalai.fx.profiler import MetaTensor model = MLP(MODEL_DIM) - input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu') + input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') gm = symbolic_trace(model) + if is_compatible: + input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device) MetaInfoProp(gm).run(input_sample) annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) split_model, split_submodules = split_with_split_nodes_pass(annotated_model) diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index 61eda1d67..209ded89c 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -1,12 +1,11 @@ from typing import Any, Callable, Union -import torch -import torch.nn as nn -import torch.nn.functional as F -from colossalai import META_COMPATIBILITY import pytest +import torch +import torch.nn as nn +from colossalai.fx._compatibility import is_compatible_with_meta -if META_COMPATIBILITY: +if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor aten = torch.ops.aten @@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): for f, x in v: diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index 84ac56881..351c02c57 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -1,10 +1,10 @@ -import torchvision.models as tm +import pytest import timm.models as tmm import torch -from colossalai import META_COMPATIBILITY -import pytest +import torchvision.models as tm +from colossalai.fx._compatibility import is_compatible_with_meta -if META_COMPATIBILITY: +if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor tm_models = [ @@ -27,7 +27,7 @@ tmm_models = [ ] -@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') def test_torchvision_models(): for m in tm_models: model = m() @@ -35,7 +35,7 @@ def test_torchvision_models(): model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() -@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') def test_timm_models(): for m in tmm_models: model = m() diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py index 67b69f1da..404b6d27d 100644 --- a/tests/test_fx/test_meta/test_meta_trace.py +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -1,10 +1,10 @@ -import torchvision.models as tm +import pytest import timm.models as tmm import torch -from colossalai import META_COMPATIBILITY -import pytest +import torchvision.models as tm +from colossalai.fx._compatibility import is_compatible_with_meta -if META_COMPATIBILITY: +if is_compatible_with_meta(): from colossalai.fx import meta_trace tm_models = [ @@ -27,7 +27,7 @@ tmm_models = [ ] -@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') def test_torchvision_models_trace(): for m in tm_models: model = m() @@ -35,7 +35,7 @@ def test_torchvision_models_trace(): graph = meta_trace(model, torch.device('cpu'), data) -@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') def test_timm_models_trace(): for m in tmm_models: model = m() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 7f1051987..6fac180d8 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -1,7 +1,10 @@ import torch -from torch.fx import symbolic_trace -from colossalai import META_COMPATIBILITY +from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from torch.fx import symbolic_trace + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor BATCH_SIZE = 2 DIM_IN = 4 @@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') - if META_COMPATIBILITY: - from colossalai.fx.profiler import MetaTensor + if is_compatible_with_meta(): input_sample = MetaTensor(input_sample, fake_device='cpu') orig_output = model(input_sample) gm = symbolic_trace(model) diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index 5c332f270..fe0333bde 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -1,19 +1,17 @@ -import os import argparse +import os import warnings import torch -from torch import nn -import torch.multiprocessing as mp -import torch.distributed.rpc as rpc -from torch.optim import SGD, Adam, RMSprop, Optimizer -from torch._C._distributed_rpc import _is_current_rpc_agent_set import torch.distributed as dist -from colorama import Back, Style - -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.logging import disable_existing_loggers +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg +from torch import nn +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from torch.optim import SGD, Adam, Optimizer, RMSprop rpc_is_initialized = _is_current_rpc_agent_set