[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
pull/1724/head
Super Daniel 2022-10-18 10:44:23 +08:00 committed by GitHub
parent e8d8eda5e7
commit 393f594051
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 351 additions and 310 deletions

View File

@ -1,10 +1,3 @@
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch, from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser) get_default_parser)

View File

@ -1,3 +1,4 @@
from .tracer import ColoTracer, meta_trace from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule from .graph_module import ColoGraphModule
from .passes import MetaInfoProp from .passes import MetaInfoProp
from .tracer import ColoTracer, meta_trace

View File

@ -0,0 +1,46 @@
from typing import Callable
import torch
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
META_COMPATIBILITY = False
def compatibility(is_backward_compatible: bool = False) -> Callable:
"""A decorator to make a function compatible with different versions of PyTorch.
Args:
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.
Returns:
Callable: The decorated function
"""
def decorator(func):
if META_COMPATIBILITY:
return func
else:
if is_backward_compatible:
return func
else:
def wrapper(*args, **kwargs):
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
return wrapper
return decorator
def is_compatible_with_meta() -> bool:
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
experimental counterparts.
Returns:
bool: The meta compatibility
"""
return META_COMPATIBILITY

View File

@ -1,7 +1,10 @@
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py # meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
# should be activated for PyTorch version 1.12.0 and below # should be activated for PyTorch version 1.12.0 and below
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -31,6 +34,7 @@ def register_meta(op, register_dispatcher=True):
return wrapper return wrapper
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834 # https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default) @register_meta(aten.convolution.default)
def meta_conv( def meta_conv(
@ -165,6 +169,18 @@ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: t
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return grad_input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@register_meta(aten.relu.default) @register_meta(aten.relu.default)
def meta_relu(input: torch.Tensor): def meta_relu(input: torch.Tensor):
return torch.empty_like(input) return torch.empty_like(input)
@ -192,11 +208,8 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
return grad_in return grad_in
@register_meta(aten.roll.default) # ============================== Normalization =====================================
def meta_roll(input: torch.Tensor, shifts, dims): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
return input
@register_meta(aten.native_batch_norm.default) @register_meta(aten.native_batch_norm.default)
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1) n_input = input.size(1)
@ -207,6 +220,7 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
return output, running_mean, running_var return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default) @register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask): save_invstd, train, eps, output_mask):
@ -241,6 +255,7 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.
return dX, dgamma, dbeta return dX, dgamma, dbeta
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default) @register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs = input.size(0) bs = input.size(0)
@ -252,6 +267,7 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
return output, running_mean, running_var return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default) @register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask): grad_input_mask):
@ -261,13 +277,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
return dX, dgamma, dbeta return dX, dgamma, dbeta
@register_meta(aten._adaptive_avg_pool2d_backward.default) # ================================== Misc ==========================================
def meta_adaptive_avg_pool2d_backward( #https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
grad_output: torch.Tensor, @register_meta(aten.roll.default)
input: torch.Tensor, def meta_roll(input: torch.Tensor, shifts, dims):
): return input
grad_input = torch.empty_like(input)
return torch.empty_like(input)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
@register_meta(aten.index.Tensor) @register_meta(aten.index.Tensor)
@ -360,6 +381,8 @@ def meta_index_Tensor(self, indices):
return self.new_empty(before_shape + replacement_shape + after_shape) return self.new_empty(before_shape + replacement_shape + after_shape)
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default) @register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq): scale_grad_by_freq):
@ -369,13 +392,7 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
layout=grad_output.layout) layout=grad_output.layout)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp # ============================== Dropout ===========================================
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default) @register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):

View File

@ -1,16 +1,20 @@
from typing import List, Tuple
import copy import copy
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
import math import math
from .linearize import linearize from typing import List, Tuple
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch
import torch
from colossalai.fx import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import \
_find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.profiler import parameter_size
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec from torch.fx import GraphModule, Node
from colossalai import META_COMPATIBILITY
from .linearize import linearize
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
Sequence)
INF = float("inf") INF = float("inf")
@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule,
mem_limit -= parameter_size(gm) mem_limit -= parameter_size(gm)
# prepare data # prepare data
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device) data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)

View File

@ -1,13 +1,12 @@
from dataclasses import asdict from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch import torch
import torch.fx import torch.fx
from torch.fx.node import Node, Argument, Target from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from typing import Any, List, Tuple, NamedTuple, Dict, Optional
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
from torch.fx.graph_module import GraphModule
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)

View File

@ -1,11 +1,13 @@
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
import torch import torch
import torch.fx import torch.fx
from torch.fx.node import Node, Argument, Target from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from typing import Any, List, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)

View File

@ -1,11 +1,12 @@
from ... import META_COMPATIBILITY from .._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .opcount import flop_mapping from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module from .profiler import profile_function, profile_method, profile_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .tensor import MetaTensor
else: else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo from .dataflow import GraphInfo
from .memory import parameter_size, activation_size, is_inplace from .memory import activation_size, is_inplace, parameter_size

View File

@ -1,79 +0,0 @@
import torch
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY
__all__ = []
if META_COMPATIBILITY:
aten = torch.ops.aten
ALIAS_ATEN = [
# inplace reshaping
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']

View File

@ -0,0 +1,32 @@
import torch
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
aten = torch.ops.aten
ALIAS_ATEN = [
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]

View File

@ -2,7 +2,10 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Dict, List from typing import Dict, List
from torch.fx import Graph, Node from torch.fx import Graph, Node
from .._compatibility import compatibility
from .memory import activation_size, is_inplace from .memory import activation_size, is_inplace
@ -12,6 +15,7 @@ class Phase(Enum):
PLACEHOLDER = 2 PLACEHOLDER = 2
@compatibility(is_backward_compatible=True)
@dataclass @dataclass
class GraphInfo: class GraphInfo:
""" """
@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
return n.meta['phase'] == phase return n.meta['phase'] == phase
@compatibility(is_backward_compatible=False)
def autograd_graph_analysis(graph: Graph) -> GraphInfo: def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage. """Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`. Basically the input graph should have all nodes marked for keyword `phase`.

View File

@ -1,5 +1,5 @@
from .registry import meta_profiler_function, meta_profiler_module from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .profiler import profile_function, profile_method, profile_module
from .profiler_function import * from .profiler_function import *
from .profiler_module import * from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module from .registry import meta_profiler_function, meta_profiler_module

View File

@ -0,0 +1,44 @@
from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
import torch
__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]

View File

@ -1,11 +1,15 @@
# for PyTorch 1.11 compatibility uses # for PyTorch 1.11 compatibility uses
from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx import Node, GraphModule from torch.fx import GraphModule, Node
from typing import Union, Dict, List, Tuple
from ..._compatibility import compatibility
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] __all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_in(n: Node) -> bool: def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in` """A helper function to calculate `fwd_in`
@ -18,6 +22,7 @@ def calculate_fwd_in(n: Node) -> bool:
return n.meta['save_fwd_in'] return n.meta['save_fwd_in']
@compatibility(is_backward_compatible=True)
def calculate_fwd_tmp(n: Node) -> int: def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp` """A helper function to calculate `fwd_tmp`
@ -30,6 +35,7 @@ def calculate_fwd_tmp(n: Node) -> int:
return n.meta["fwd_mem_tmp"] return n.meta["fwd_mem_tmp"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_out(n: Node) -> int: def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out` """A helper function to calculate `fwd_out`

View File

@ -1,15 +1,19 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Any, Dict, Tuple from typing import Any, Callable, Dict, Tuple
import torch import torch
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..._compatibility import compatibility
from ..memory import activation_size from ..memory import activation_size
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ['profile_function', 'profile_module', 'profile_method']
# this is for compatibility use # this is for compatibility use
@compatibility(is_backward_compatible=True)
@dataclass @dataclass
class GraphInfo: class GraphInfo:
""" """
@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
""" """
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable: def profile_function(target: 'Target') -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable:
return f return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target') -> Callable: def profile_method(target: 'Target') -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable:
return f return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module) -> Callable: def profile_module(module: torch.nn.Module) -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to

View File

@ -2,7 +2,6 @@ import operator
from typing import Any, Tuple from typing import Any, Tuple
import torch import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function
from colossalai.fx.proxy import ColoProxy
@meta_profiler_function.register(operator.getitem) @meta_profiler_function.register(operator.getitem)

View File

@ -1,13 +1,16 @@
from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx import Node, GraphModule from torch.fx import GraphModule, Node
from typing import Union, Dict, List, Tuple
from . import META_COMPATIBILITY from .._compatibility import compatibility, is_compatible_with_meta
__all__ = [ __all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out" 'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
] ]
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node. """Calculate activation size of a node.
@ -29,6 +32,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
return act_size return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int: def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node. """Calculate parameter size of a node.
@ -111,8 +115,8 @@ def is_inplace(n: Node):
inplace = False inplace = False
if n.op == "call_function": if n.op == "call_function":
inplace = n.kwargs.get("inplace", False) inplace = n.kwargs.get("inplace", False)
if META_COMPATIBILITY: if is_compatible_with_meta():
from .constant import ALIAS_ATEN from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN: if n.target in ALIAS_ATEN:
inplace = True inplace = True
elif n.op == "call_module": elif n.op == "call_module":

View File

@ -1,10 +1,11 @@
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py # adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw # ideas from https://pastebin.com/AkvAyJBw
from functools import partial, reduce
import operator import operator
from typing import Callable, List, Any from functools import partial, reduce
from numbers import Number from numbers import Number
from typing import Any, Callable, List
import torch import torch
aten = torch.ops.aten aten = torch.ops.aten

View File

@ -1,16 +1,19 @@
import time
from functools import partial from functools import partial
from typing import Callable, Any, Dict, Tuple from typing import Any, Callable, Dict, Tuple
import torch import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.nn.parameter import Parameter
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size from .memory import activation_size, parameter_size
from .constant import ALIAS_ATEN
from .tensor import MetaTensor
from .opcount import flop_mapping from .opcount import flop_mapping
import time from .tensor import MetaTensor
__all__ = ['profile_function', 'profile_module', 'profile_method'] __all__ = ['profile_function', 'profile_module', 'profile_method']
@ -41,6 +44,7 @@ def detach_variables(x):
return x return x
@compatibility(is_backward_compatible=True)
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30 """Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
To profile the actual forward memory, we first run target in the context torch.no_grad() to get To profile the actual forward memory, we first run target in the context torch.no_grad() to get
@ -140,6 +144,7 @@ def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...
return tree_map(detach_variables, out), graphinfo return tree_map(detach_variables, out), graphinfo
@compatibility(is_backward_compatible=False)
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
""" """
Profile a Callable function with args and kwargs on meta devices. Profile a Callable function with args and kwargs on meta devices.
@ -277,6 +282,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
return tree_map(unwrap, out), graph_info return tree_map(unwrap, out), graph_info
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable: def profile_function(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_function` node or `torch.nn.functional` in order to Wrap a `call_function` node or `torch.nn.functional` in order to
@ -335,6 +341,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
return f return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target', device: str = 'meta') -> Callable: def profile_method(target: 'Target', device: str = 'meta') -> Callable:
""" """
Wrap a `call_method` node Wrap a `call_method` node
@ -353,6 +360,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
return f return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
""" """
Wrap a `call_module` node or `torch.nn` in order to Wrap a `call_module` node or `torch.nn` in order to

View File

@ -1,10 +1,13 @@
import uuid
from copy import deepcopy from copy import deepcopy
from typing import Optional from typing import Optional
import torch import torch
from torch.utils._pytree import tree_map, tree_flatten from torch.types import _bool, _device, _dtype
from torch.types import _bool, _dtype, _device from torch.utils._pytree import tree_flatten, tree_map
import uuid
from .constant import ALIAS_ATEN from .._compatibility import compatibility
from .constants import ALIAS_ATEN
__all__ = ['MetaTensor'] __all__ = ['MetaTensor']
@ -15,6 +18,7 @@ def set_uuid(x):
setattr(x, 'uuid', uuid.uuid4()) setattr(x, 'uuid', uuid.uuid4())
@compatibility(is_backward_compatible=False)
class MetaTensor(torch.Tensor): class MetaTensor(torch.Tensor):
""" """
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.

View File

@ -1,23 +1,19 @@
import threading
from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from functools import partial
from abc import ABC, abstractmethod
import math
import inspect import inspect
import math
import threading
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Tuple
import torch import torch
from torch import nn
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail, from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug) split_batch, tensor_shape_list, type_detail)
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
class Phase(Enum): class Phase(Enum):
@ -195,7 +191,6 @@ class WorkerBase(ABC):
if isinstance(output, Future): if isinstance(output, Future):
output = output.wait() output = output.wait()
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1 output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released # all consumers have been satisfied, the work_item can be released
@ -250,9 +245,6 @@ class WorkerBase(ABC):
self.num_microbatches, forward_only) self.num_microbatches, forward_only)
with self.work_list_condition_lock: with self.work_list_condition_lock:
self.work_list[key] = work_item self.work_list[key] = work_item
if use_color_debug:
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
# just for last pp_rank # just for last pp_rank
@ -273,9 +265,6 @@ class WorkerBase(ABC):
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False) self.num_microbatches, False)
if use_color_debug:
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item self.work_list[key] = work_item
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
@ -297,23 +286,14 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
'data dispatch', 'magenta')
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only) microbatch_id, None, self.num_microbatches, forward_only)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list # add work_item to work_list
with self.work_list_condition_lock: with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list assert key not in self.work_list
self.work_list[key] = work_item_from_producer self.work_list[key] = work_item_from_producer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int): def subscribe_consumer(self, microbatch_id: int):
@ -328,10 +308,6 @@ class WorkerBase(ABC):
subscribe_backward_futures: List[Future] = [None] * consumer_num subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device() output = self._get_future_by_device()
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta')
for i in range(consumer_num): for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i] consumer_stage_id = self.consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
@ -342,17 +318,11 @@ class WorkerBase(ABC):
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False) microbatch_id, None, self.num_microbatches, False)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list # add work_item to work_list
with self.work_list_condition_lock: with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.BACKWARD) key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list assert key not in self.work_list
self.work_list[key] = work_item_from_consumer self.work_list[key] = work_item_from_consumer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
def _get_producer_consumer(self) -> None: def _get_producer_consumer(self) -> None:
@ -406,11 +376,6 @@ class WorkerBase(ABC):
is_first_stage = self.is_first_stage() is_first_stage = self.is_first_stage()
is_last_stage = self.is_last_stage() is_last_stage = self.is_last_stage()
# if self.pp_rank == 0:
# print(
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
# )
if phase == Phase.FORWARD: if phase == Phase.FORWARD:
# remind its consumer to get data before forward # remind its consumer to get data before forward
if not is_last_stage: if not is_last_stage:
@ -470,8 +435,6 @@ class WorkerBase(ABC):
else: else:
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion: if is_last_stage and self.criterion:
with self.label_lock: with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
@ -539,10 +502,6 @@ class WorkerBase(ABC):
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
# for input_node in stage_input_args:
# if isinstance(input_node, torch.Tensor):
# consume_result.append(input_node.grad)
else: else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -593,11 +552,6 @@ class WorkerBase(ABC):
with self.work_list_condition_lock: with self.work_list_condition_lock:
work_item = self.work_list.pop(work_item_key) work_item = self.work_list.pop(work_item_key)
if use_color_debug:
color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green')
with self.output_list_condition_lock: with self.output_list_condition_lock:
# assert work_item_key not in self.output_list # assert work_item_key not in self.output_list
self.output_list[work_item_key] = work_item self.output_list[work_item_key] = work_item
@ -605,11 +559,6 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item) consume_result = self._consume_work_item_by_phase(work_item)
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green')
work_item.output.set_result(consume_result) work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step # if is last step in one batch reset context and do step

View File

@ -1,13 +1,12 @@
from typing import List, Callable, Dict
import threading import threading
from typing import Callable, Dict, List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
# Implementation of different Pipeline schedule # Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage # <strategy>Worker defines the worker for each stage

View File

@ -1,25 +1,15 @@
from typing import List, Any, Tuple, Dict, Callable, Type, Union import argparse
import os import os
import warnings import warnings
import argparse from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch import torch
import torch.multiprocessing as mp
from torch.futures import Future
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from torch._C._distributed_rpc import _is_current_rpc_agent_set import torch.multiprocessing as mp
from colorama import Back, Style
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from torch._C._distributed_rpc import _is_current_rpc_agent_set
# config for debug and test from torch.futures import Future
use_color_debug = False
def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:

View File

@ -1,18 +1,20 @@
import copy import copy
import colossalai
import pytest
import torch import torch
import torch.fx
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
import torch.fx from colossalai.core import global_context as gpc
import colossalai
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.algorithms import solver_rotor from colossalai.fx.passes.algorithms import solver_rotor
from colossalai.fx.passes.algorithms.operation import Sequence from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.core import global_context as gpc from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.utils import free_port
import pytest
from colossalai import META_COMPATIBILITY if is_compatible_with_meta():
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
try: try:
@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0):
graph = tracer.trace(model, meta_args={"x": data}) graph = tracer.trace(model, meta_args={"x": data})
graph.set_codegen(ActivationCheckpointCodeGen()) graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__) gm = ColoGraphModule(model, graph, model.__class__.__name__)
if META_COMPATIBILITY: if is_compatible_with_meta():
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device) data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data_meta) MetaInfoProp(gm).run(data_meta)

View File

@ -1,20 +1,22 @@
from typing import Callable
import copy import copy
import re import re
from typing import Callable
import colossalai
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
import pytest from colossalai.fx import ColoTracer
from colossalai import META_COMPATIBILITY from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY: from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from torch.fx import GraphModule
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
try: try:
@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule):
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]): model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
data = torch.rand(2, 3, 32, 32) m.cuda()
label = torch.rand(2, 5) data = torch.rand(2, 3, 32, 32).cuda()
label = torch.rand(2, 5).cuda()
loss = criterion(m(data), label) loss = criterion(m(data), label)
loss.backward() loss.backward()
loss = criterion(gm(data), label) loss = criterion(gm(data), label)
@ -77,7 +80,7 @@ def _run_ckpt_solver(rank):
m = model_cls(num_classes=5) m = model_cls(num_classes=5)
graph = tracer.trace(root=m) graph = tracer.trace(root=m)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda')) MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
codegen = ActivationCheckpointCodeGen() codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen) gm.graph.set_codegen(codegen)
if solver == solver_rotor: if solver == solver_rotor:

View File

@ -1,13 +1,14 @@
from colossalai.fx.passes.meta_info_prop import MetaInfoProp import pytest
import torch import torch
import torchvision.models as tm import torchvision.models as tm
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import solver_rotor, linearize from colossalai.fx.passes.algorithms import linearize, solver_rotor
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
import pytest from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
try: try:

View File

@ -1,13 +1,17 @@
import torch
import torch.nn as nn
import colossalai import colossalai
import colossalai.nn as col_nn import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
from colossalai.fx.passes.utils import get_comm_size
from colossalai import META_COMPATIBILITY
import pytest import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.utils import get_comm_size
from torch.fx import symbolic_trace
is_compatible = is_compatible_with_meta()
if is_compatible:
from colossalai.fx.profiler import MetaTensor
MODEL_DIM = 16 MODEL_DIM = 16
BATCH_SIZE = 8 BATCH_SIZE = 8
@ -31,12 +35,12 @@ class MLP(torch.nn.Module):
return x return x
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_comm_size_compute(): def test_comm_size_compute():
from colossalai.fx.profiler import MetaTensor
model = MLP(MODEL_DIM) model = MLP(MODEL_DIM)
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu') input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
gm = symbolic_trace(model) gm = symbolic_trace(model)
if is_compatible:
input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(input_sample) MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model) split_model, split_submodules = split_with_split_nodes_pass(annotated_model)

View File

@ -1,12 +1,11 @@
from typing import Any, Callable, Union from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai import META_COMPATIBILITY
import pytest import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
aten = torch.ops.aten aten = torch.ops.aten
@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
compare_all(x.grad, meta_x.grad) compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_meta_aten(): def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items(): for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v: for f, x in v:

View File

@ -1,10 +1,10 @@
import torchvision.models as tm import pytest
import timm.models as tmm import timm.models as tmm
import torch import torch
from colossalai import META_COMPATIBILITY import torchvision.models as tm
import pytest from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
tm_models = [ tm_models = [
@ -27,7 +27,7 @@ tmm_models = [
] ]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models(): def test_torchvision_models():
for m in tm_models: for m in tm_models:
model = m() model = m()
@ -35,7 +35,7 @@ def test_torchvision_models():
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models(): def test_timm_models():
for m in tmm_models: for m in tmm_models:
model = m() model = m()

View File

@ -1,10 +1,10 @@
import torchvision.models as tm import pytest
import timm.models as tmm import timm.models as tmm
import torch import torch
from colossalai import META_COMPATIBILITY import torchvision.models as tm
import pytest from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx import meta_trace from colossalai.fx import meta_trace
tm_models = [ tm_models = [
@ -27,7 +27,7 @@ tmm_models = [
] ]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models_trace(): def test_torchvision_models_trace():
for m in tm_models: for m in tm_models:
model = m() model = m()
@ -35,7 +35,7 @@ def test_torchvision_models_trace():
graph = meta_trace(model, torch.device('cpu'), data) graph = meta_trace(model, torch.device('cpu'), data)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models_trace(): def test_timm_models_trace():
for m in tmm_models: for m in tmm_models:
model = m() model = m()

View File

@ -1,7 +1,10 @@
import torch import torch
from torch.fx import symbolic_trace from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai import META_COMPATIBILITY
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from torch.fx import symbolic_trace
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
BATCH_SIZE = 2 BATCH_SIZE = 2
DIM_IN = 4 DIM_IN = 4
@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
def test_meta_info_prop(): def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT) model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
input_sample = MetaTensor(input_sample, fake_device='cpu') input_sample = MetaTensor(input_sample, fake_device='cpu')
orig_output = model(input_sample) orig_output = model(input_sample)
gm = symbolic_trace(model) gm = symbolic_trace(model)

View File

@ -1,19 +1,17 @@
import os
import argparse import argparse
import os
import warnings import warnings
import torch import torch
from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch.optim import SGD, Adam, RMSprop, Optimizer
from torch._C._distributed_rpc import _is_current_rpc_agent_set
import torch.distributed as dist import torch.distributed as dist
from colorama import Back, Style import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.logging import disable_existing_loggers
from colossalai import launch from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
from torch import nn
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop
rpc_is_initialized = _is_current_rpc_agent_set rpc_is_initialized = _is_current_rpc_agent_set