[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
pull/1724/head
Super Daniel 2022-10-18 10:44:23 +08:00 committed by GitHub
parent e8d8eda5e7
commit 393f594051
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 351 additions and 310 deletions

View File

@ -1,10 +1,3 @@
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)

View File

@ -1,3 +1,4 @@
from .tracer import ColoTracer, meta_trace
from ._compatibility import compatibility, is_compatible_with_meta
from .graph_module import ColoGraphModule
from .passes import MetaInfoProp
from .tracer import ColoTracer, meta_trace

View File

@ -0,0 +1,46 @@
from typing import Callable
import torch
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
META_COMPATIBILITY = False
def compatibility(is_backward_compatible: bool = False) -> Callable:
"""A decorator to make a function compatible with different versions of PyTorch.
Args:
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.
Returns:
Callable: The decorated function
"""
def decorator(func):
if META_COMPATIBILITY:
return func
else:
if is_backward_compatible:
return func
else:
def wrapper(*args, **kwargs):
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
return wrapper
return decorator
def is_compatible_with_meta() -> bool:
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
experimental counterparts.
Returns:
bool: The meta compatibility
"""
return META_COMPATIBILITY

View File

@ -1,7 +1,10 @@
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
# should be activated for PyTorch version 1.12.0 and below
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import List, Optional, Tuple, Union
import torch
from torch.utils._pytree import tree_map
@ -31,6 +34,7 @@ def register_meta(op, register_dispatcher=True):
return wrapper
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
def meta_conv(
@ -165,6 +169,18 @@ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: t
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return grad_input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@register_meta(aten.relu.default)
def meta_relu(input: torch.Tensor):
return torch.empty_like(input)
@ -192,11 +208,8 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
return grad_in
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return input
# ============================== Normalization =====================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm.default)
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
@ -207,6 +220,7 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask):
@ -241,6 +255,7 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.
return dX, dgamma, dbeta
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs = input.size(0)
@ -252,6 +267,7 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
@ -261,13 +277,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
return dX, dgamma, dbeta
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return torch.empty_like(input)
# ================================== Misc ==========================================
#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
@register_meta(aten.index.Tensor)
@ -360,6 +381,8 @@ def meta_index_Tensor(self, indices):
return self.new_empty(before_shape + replacement_shape + after_shape)
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
@ -369,13 +392,7 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
layout=grad_output.layout)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):

View File

@ -1,16 +1,20 @@
from typing import List, Tuple
import copy
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch
from typing import List, Tuple
import torch
from colossalai.fx import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import \
_find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
from colossalai import META_COMPATIBILITY
from colossalai.fx.profiler import parameter_size
from torch.fx import GraphModule, Node
from .linearize import linearize
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
Sequence)
INF = float("inf")
@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule,
mem_limit -= parameter_size(gm)
# prepare data
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)

View File

@ -1,13 +1,12 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_flatten
from typing import Any, List, Tuple, NamedTuple, Dict, Optional
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
from torch.fx.graph_module import GraphModule
@compatibility(is_backward_compatible=True)

View File

@ -1,11 +1,13 @@
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from typing import Any, List, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
@compatibility(is_backward_compatible=True)

View File

@ -1,11 +1,12 @@
from ... import META_COMPATIBILITY
if META_COMPATIBILITY:
from .._compatibility import is_compatible_with_meta
if is_compatible_with_meta():
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .tensor import MetaTensor
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo
from .memory import parameter_size, activation_size, is_inplace
from .memory import activation_size, is_inplace, parameter_size

View File

@ -1,79 +0,0 @@
import torch
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY
__all__ = []
if META_COMPATIBILITY:
aten = torch.ops.aten
ALIAS_ATEN = [
# inplace reshaping
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']

View File

@ -0,0 +1,32 @@
import torch
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
aten = torch.ops.aten
ALIAS_ATEN = [
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]
INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]
CLONE_ATEN = [
aten.clone.default,
]

View File

@ -2,7 +2,10 @@ from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import Dict, List
from torch.fx import Graph, Node
from .._compatibility import compatibility
from .memory import activation_size, is_inplace
@ -12,6 +15,7 @@ class Phase(Enum):
PLACEHOLDER = 2
@compatibility(is_backward_compatible=True)
@dataclass
class GraphInfo:
"""
@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
return n.meta['phase'] == phase
@compatibility(is_backward_compatible=False)
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`.

View File

@ -1,5 +1,5 @@
from .registry import meta_profiler_function, meta_profiler_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .profiler import profile_function, profile_method, profile_module
from .profiler_function import *
from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module
from .registry import meta_profiler_function, meta_profiler_module

View File

@ -0,0 +1,44 @@
from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
import torch
__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]

View File

@ -1,11 +1,15 @@
# for PyTorch 1.11 compatibility uses
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple
from torch.fx import GraphModule, Node
from ..._compatibility import compatibility
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in`
@ -18,6 +22,7 @@ def calculate_fwd_in(n: Node) -> bool:
return n.meta['save_fwd_in']
@compatibility(is_backward_compatible=True)
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
@ -30,6 +35,7 @@ def calculate_fwd_tmp(n: Node) -> int:
return n.meta["fwd_mem_tmp"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`

View File

@ -1,15 +1,19 @@
from dataclasses import dataclass
from typing import Callable, Any, Dict, Tuple
from typing import Any, Callable, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..._compatibility import compatibility
from ..memory import activation_size
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
__all__ = ['profile_function', 'profile_module', 'profile_method']
# this is for compatibility use
@compatibility(is_backward_compatible=True)
@dataclass
class GraphInfo:
"""
@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
"""
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable:
return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable:
return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to

View File

@ -2,7 +2,6 @@ import operator
from typing import Any, Tuple
import torch
from ..registry import meta_profiler_function
from colossalai.fx.proxy import ColoProxy
@meta_profiler_function.register(operator.getitem)

View File

@ -1,13 +1,16 @@
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple
from . import META_COMPATIBILITY
from torch.fx import GraphModule, Node
from .._compatibility import compatibility, is_compatible_with_meta
__all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
@ -29,6 +32,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
@ -111,8 +115,8 @@ def is_inplace(n: Node):
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if META_COMPATIBILITY:
from .constant import ALIAS_ATEN
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":

View File

@ -1,10 +1,11 @@
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw
from functools import partial, reduce
import operator
from typing import Callable, List, Any
from functools import partial, reduce
from numbers import Number
from typing import Any, Callable, List
import torch
aten = torch.ops.aten

View File

@ -1,16 +1,19 @@
import time
from functools import partial
from typing import Callable, Any, Dict, Tuple
from typing import Any, Callable, Dict, Tuple
import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.nn.parameter import Parameter
from torch.utils._pytree import tree_map
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size
from .constant import ALIAS_ATEN
from .tensor import MetaTensor
from .opcount import flop_mapping
import time
from .tensor import MetaTensor
__all__ = ['profile_function', 'profile_module', 'profile_method']
@ -41,6 +44,7 @@ def detach_variables(x):
return x
@compatibility(is_backward_compatible=True)
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
To profile the actual forward memory, we first run target in the context torch.no_grad() to get
@ -140,6 +144,7 @@ def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...
return tree_map(detach_variables, out), graphinfo
@compatibility(is_backward_compatible=False)
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""
Profile a Callable function with args and kwargs on meta devices.
@ -277,6 +282,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
return tree_map(unwrap, out), graph_info
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
@ -335,6 +341,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_method` node
@ -353,6 +360,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to

View File

@ -1,10 +1,13 @@
import uuid
from copy import deepcopy
from typing import Optional
import torch
from torch.utils._pytree import tree_map, tree_flatten
from torch.types import _bool, _dtype, _device
import uuid
from .constant import ALIAS_ATEN
from torch.types import _bool, _device, _dtype
from torch.utils._pytree import tree_flatten, tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
__all__ = ['MetaTensor']
@ -15,6 +18,7 @@ def set_uuid(x):
setattr(x, 'uuid', uuid.uuid4())
@compatibility(is_backward_compatible=False)
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.

View File

@ -1,23 +1,19 @@
import threading
from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from functools import partial
from abc import ABC, abstractmethod
import math
import inspect
import math
import threading
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Tuple
import torch
from torch import nn
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail)
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
class Phase(Enum):
@ -195,7 +191,6 @@ class WorkerBase(ABC):
if isinstance(output, Future):
output = output.wait()
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released
@ -250,9 +245,6 @@ class WorkerBase(ABC):
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
if use_color_debug:
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# just for last pp_rank
@ -273,9 +265,6 @@ class WorkerBase(ABC):
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False)
if use_color_debug:
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@ -297,23 +286,14 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
'data dispatch', 'magenta')
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_producer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int):
@ -328,10 +308,6 @@ class WorkerBase(ABC):
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta')
for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
@ -342,17 +318,11 @@ class WorkerBase(ABC):
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_consumer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
def _get_producer_consumer(self) -> None:
@ -406,11 +376,6 @@ class WorkerBase(ABC):
is_first_stage = self.is_first_stage()
is_last_stage = self.is_last_stage()
# if self.pp_rank == 0:
# print(
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
# )
if phase == Phase.FORWARD:
# remind its consumer to get data before forward
if not is_last_stage:
@ -470,8 +435,6 @@ class WorkerBase(ABC):
else:
consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
@ -539,10 +502,6 @@ class WorkerBase(ABC):
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
# for input_node in stage_input_args:
# if isinstance(input_node, torch.Tensor):
# consume_result.append(input_node.grad)
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -593,11 +552,6 @@ class WorkerBase(ABC):
with self.work_list_condition_lock:
work_item = self.work_list.pop(work_item_key)
if use_color_debug:
color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green')
with self.output_list_condition_lock:
# assert work_item_key not in self.output_list
self.output_list[work_item_key] = work_item
@ -605,11 +559,6 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item)
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green')
work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step

View File

@ -1,13 +1,12 @@
from typing import List, Callable, Dict
import threading
from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage

View File

@ -1,25 +1,15 @@
from typing import List, Any, Tuple, Dict, Callable, Type, Union
import argparse
import os
import warnings
import argparse
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch
import torch.multiprocessing as mp
from torch.futures import Future
import torch.distributed.rpc as rpc
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colorama import Back, Style
import torch.multiprocessing as mp
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
# config for debug and test
use_color_debug = False
def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.futures import Future
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:

View File

@ -1,18 +1,20 @@
import copy
import colossalai
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm
import torch.fx
import colossalai
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.core import global_context as gpc
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.algorithms import solver_rotor
from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.core import global_context as gpc
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:
@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0):
graph = tracer.trace(model, meta_args={"x": data})
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
if META_COMPATIBILITY:
if is_compatible_with_meta():
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data_meta)

View File

@ -1,20 +1,22 @@
from typing import Callable
import copy
import re
from typing import Callable
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from torch.fx import GraphModule
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:
@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule):
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss()
data = torch.rand(2, 3, 32, 32)
label = torch.rand(2, 5)
m.cuda()
data = torch.rand(2, 3, 32, 32).cuda()
label = torch.rand(2, 5).cuda()
loss = criterion(m(data), label)
loss.backward()
loss = criterion(gm(data), label)
@ -77,7 +80,7 @@ def _run_ckpt_solver(rank):
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
if solver == solver_rotor:

View File

@ -1,13 +1,14 @@
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
import pytest
import torch
import torchvision.models as tm
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import solver_rotor, linearize
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.passes.algorithms import linearize, solver_rotor
from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:

View File

@ -1,13 +1,17 @@
import torch
import torch.nn as nn
import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
from colossalai.fx.passes.utils import get_comm_size
from colossalai import META_COMPATIBILITY
import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.utils import get_comm_size
from torch.fx import symbolic_trace
is_compatible = is_compatible_with_meta()
if is_compatible:
from colossalai.fx.profiler import MetaTensor
MODEL_DIM = 16
BATCH_SIZE = 8
@ -31,12 +35,12 @@ class MLP(torch.nn.Module):
return x
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_comm_size_compute():
from colossalai.fx.profiler import MetaTensor
model = MLP(MODEL_DIM)
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
gm = symbolic_trace(model)
if is_compatible:
input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)

View File

@ -1,12 +1,11 @@
from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai import META_COMPATIBILITY
import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
aten = torch.ops.aten
@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v:

View File

@ -1,10 +1,10 @@
import torchvision.models as tm
import pytest
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
import pytest
import torchvision.models as tm
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
tm_models = [
@ -27,7 +27,7 @@ tmm_models = [
]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models():
for m in tm_models:
model = m()
@ -35,7 +35,7 @@ def test_torchvision_models():
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models():
for m in tmm_models:
model = m()

View File

@ -1,10 +1,10 @@
import torchvision.models as tm
import pytest
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
import pytest
import torchvision.models as tm
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx import meta_trace
tm_models = [
@ -27,7 +27,7 @@ tmm_models = [
]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models_trace():
for m in tm_models:
model = m()
@ -35,7 +35,7 @@ def test_torchvision_models_trace():
graph = meta_trace(model, torch.device('cpu'), data)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models_trace():
for m in tmm_models:
model = m()

View File

@ -1,7 +1,10 @@
import torch
from torch.fx import symbolic_trace
from colossalai import META_COMPATIBILITY
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from torch.fx import symbolic_trace
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
BATCH_SIZE = 2
DIM_IN = 4
@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
if is_compatible_with_meta():
input_sample = MetaTensor(input_sample, fake_device='cpu')
orig_output = model(input_sample)
gm = symbolic_trace(model)

View File

@ -1,19 +1,17 @@
import os
import argparse
import os
import warnings
import torch
from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch.optim import SGD, Adam, RMSprop, Optimizer
from torch._C._distributed_rpc import _is_current_rpc_agent_set
import torch.distributed as dist
from colorama import Back, Style
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.logging import disable_existing_loggers
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
from torch import nn
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop
rpc_is_initialized = _is_current_rpc_agent_set