mirror of https://github.com/hpcaitech/ColossalAI
[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)
* [fx] move meta registration * [fx] fix tests. * [fx] fix test. * [fx] fix. * [meta] refactor meta registration.py. * [fx] add compatibility descriptions. * [fx] polish import. * [fx] add a decorator. * [fx] fix tests. * [fx] remove print. * [fx] edit raise error. * [fx] edit raise error. * [fx] add type hint. * [fx] fix import in experimental. * [rpc] remove color debug. * [meta] fix naming.pull/1724/head
parent
e8d8eda5e7
commit
393f594051
|
@ -1,10 +1,3 @@
|
|||
try:
|
||||
from . import _meta_registrations
|
||||
META_COMPATIBILITY = True
|
||||
except:
|
||||
import torch
|
||||
META_COMPATIBILITY = False
|
||||
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
|
||||
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
|
||||
get_default_parser)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .tracer import ColoTracer, meta_trace
|
||||
from ._compatibility import compatibility, is_compatible_with_meta
|
||||
from .graph_module import ColoGraphModule
|
||||
from .passes import MetaInfoProp
|
||||
from .tracer import ColoTracer, meta_trace
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from . import _meta_registrations
|
||||
META_COMPATIBILITY = True
|
||||
except:
|
||||
META_COMPATIBILITY = False
|
||||
|
||||
|
||||
def compatibility(is_backward_compatible: bool = False) -> Callable:
|
||||
"""A decorator to make a function compatible with different versions of PyTorch.
|
||||
|
||||
Args:
|
||||
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Callable: The decorated function
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
if META_COMPATIBILITY:
|
||||
return func
|
||||
else:
|
||||
if is_backward_compatible:
|
||||
return func
|
||||
else:
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def is_compatible_with_meta() -> bool:
|
||||
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
|
||||
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
|
||||
experimental counterparts.
|
||||
|
||||
Returns:
|
||||
bool: The meta compatibility
|
||||
"""
|
||||
return META_COMPATIBILITY
|
|
@ -1,7 +1,10 @@
|
|||
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
|
||||
# should be activated for PyTorch version 1.12.0 and below
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
# for more meta_registrations
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
@ -31,6 +34,7 @@ def register_meta(op, register_dispatcher=True):
|
|||
return wrapper
|
||||
|
||||
|
||||
# ============================== Convolutions ======================================
|
||||
# https://github.com/pytorch/pytorch/pull/79834
|
||||
@register_meta(aten.convolution.default)
|
||||
def meta_conv(
|
||||
|
@ -165,6 +169,18 @@ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: t
|
|||
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta_adaptive_avg_pool2d_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
):
|
||||
grad_input = torch.empty_like(input)
|
||||
return grad_input
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||
# ============================== Activations =======================================
|
||||
@register_meta(aten.relu.default)
|
||||
def meta_relu(input: torch.Tensor):
|
||||
return torch.empty_like(input)
|
||||
|
@ -192,11 +208,8 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
|
|||
return grad_in
|
||||
|
||||
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
return input
|
||||
|
||||
|
||||
# ============================== Normalization =====================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm.default)
|
||||
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
|
@ -207,6 +220,7 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
|
|||
return output, running_mean, running_var
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
||||
save_invstd, train, eps, output_mask):
|
||||
|
@ -241,6 +255,7 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.
|
|||
return dX, dgamma, dbeta
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
bs = input.size(0)
|
||||
|
@ -252,6 +267,7 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
|||
return output, running_mean, running_var
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||
grad_input_mask):
|
||||
|
@ -261,13 +277,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
|
|||
return dX, dgamma, dbeta
|
||||
|
||||
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta_adaptive_avg_pool2d_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
):
|
||||
grad_input = torch.empty_like(input)
|
||||
return torch.empty_like(input)
|
||||
# ================================== Misc ==========================================
|
||||
#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
return input
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||
@register_meta(aten.where.self)
|
||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||
result_type = torch.result_type(self, other)
|
||||
return torch.empty_like(self, dtype=result_type)
|
||||
|
||||
|
||||
@register_meta(aten.index.Tensor)
|
||||
|
@ -360,6 +381,8 @@ def meta_index_Tensor(self, indices):
|
|||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
|
||||
|
||||
# ============================== Embedding =========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
|
@ -369,13 +392,7 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
|
|||
layout=grad_output.layout)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||
@register_meta(aten.where.self)
|
||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||
result_type = torch.result_type(self, other)
|
||||
return torch.empty_like(self, dtype=result_type)
|
||||
|
||||
|
||||
# ============================== Dropout ===========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
|
@ -1,16 +1,20 @@
|
|||
from typing import List, Tuple
|
||||
import copy
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from colossalai.fx import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import \
|
||||
_find_nested_ckpt_regions
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
|
||||
from colossalai import META_COMPATIBILITY
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from .linearize import linearize
|
||||
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
|
||||
Sequence)
|
||||
|
||||
INF = float("inf")
|
||||
|
||||
|
@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule,
|
|||
mem_limit -= parameter_size(gm)
|
||||
|
||||
# prepare data
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
from dataclasses import asdict
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, Argument, Target
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from typing import Any, List, Tuple, NamedTuple, Dict, Optional
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, Argument, Target
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
|
||||
profile_function, profile_method, profile_module)
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Any, List, Tuple, NamedTuple, Dict
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from ... import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from .._compatibility import is_compatible_with_meta
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
||||
from .opcount import flop_mapping
|
||||
from .tensor import MetaTensor
|
||||
from .profiler import profile_function, profile_method, profile_module
|
||||
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||
from .tensor import MetaTensor
|
||||
else:
|
||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||
|
||||
from .dataflow import GraphInfo
|
||||
from .memory import parameter_size, activation_size, is_inplace
|
||||
from .memory import activation_size, is_inplace, parameter_size
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
import torch
|
||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||
from . import META_COMPATIBILITY
|
||||
|
||||
__all__ = []
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
aten = torch.ops.aten
|
||||
|
||||
ALIAS_ATEN = [
|
||||
# inplace reshaping
|
||||
aten.detach.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
|
||||
INPLACE_NEW = [
|
||||
aten.empty_like.default,
|
||||
aten.new_empty_strided.default,
|
||||
]
|
||||
|
||||
INPLACE_MATH_ATEN = [
|
||||
aten.add_.Tensor,
|
||||
aten.sub_.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.bernoulli_.float,
|
||||
]
|
||||
|
||||
CLONE_ATEN = [
|
||||
aten.clone.default,
|
||||
]
|
||||
|
||||
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
|
||||
|
||||
else:
|
||||
# TODO fill out the inplace ops
|
||||
INPLACE_OPS = [
|
||||
add,
|
||||
sub,
|
||||
mul,
|
||||
floordiv,
|
||||
neg,
|
||||
pos,
|
||||
getitem,
|
||||
setitem,
|
||||
getattr,
|
||||
torch.Tensor.cpu,
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are inplace here
|
||||
INPLACE_METHOD = [
|
||||
'transpose',
|
||||
'permute',
|
||||
# TODO: reshape may return a copy of the data if the data is not contiguous
|
||||
'reshape',
|
||||
'dim',
|
||||
'flatten',
|
||||
'size',
|
||||
'view',
|
||||
'unsqueeze',
|
||||
'to',
|
||||
'type',
|
||||
'flatten',
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are not inplace here
|
||||
NON_INPLACE_METHOD = [
|
||||
'chunk',
|
||||
'contiguous',
|
||||
'expand',
|
||||
'mean',
|
||||
'split',
|
||||
]
|
||||
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
|
|
@ -0,0 +1,32 @@
|
|||
import torch
|
||||
|
||||
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
ALIAS_ATEN = [
|
||||
aten.detach.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
|
||||
INPLACE_NEW = [
|
||||
aten.empty_like.default,
|
||||
aten.new_empty_strided.default,
|
||||
]
|
||||
|
||||
INPLACE_MATH_ATEN = [
|
||||
aten.add_.Tensor,
|
||||
aten.sub_.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.bernoulli_.float,
|
||||
]
|
||||
|
||||
CLONE_ATEN = [
|
||||
aten.clone.default,
|
||||
]
|
|
@ -2,7 +2,10 @@ from dataclasses import dataclass, field
|
|||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from .._compatibility import compatibility
|
||||
from .memory import activation_size, is_inplace
|
||||
|
||||
|
||||
|
@ -12,6 +15,7 @@ class Phase(Enum):
|
|||
PLACEHOLDER = 2
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@dataclass
|
||||
class GraphInfo:
|
||||
"""
|
||||
|
@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
|||
return n.meta['phase'] == phase
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||
Basically the input graph should have all nodes marked for keyword `phase`.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
||||
from .profiler import profile_function, profile_method, profile_module
|
||||
from .profiler_function import *
|
||||
from .profiler_module import *
|
||||
from .profiler import profile_function, profile_method, profile_module
|
||||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
|
||||
|
||||
# TODO fill out the inplace ops
|
||||
INPLACE_OPS = [
|
||||
add,
|
||||
sub,
|
||||
mul,
|
||||
floordiv,
|
||||
neg,
|
||||
pos,
|
||||
getitem,
|
||||
setitem,
|
||||
getattr,
|
||||
torch.Tensor.cpu,
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are inplace here
|
||||
INPLACE_METHOD = [
|
||||
'transpose',
|
||||
'permute',
|
||||
# TODO: reshape may return a copy of the data if the data is not contiguous
|
||||
'reshape',
|
||||
'dim',
|
||||
'flatten',
|
||||
'size',
|
||||
'view',
|
||||
'unsqueeze',
|
||||
'to',
|
||||
'type',
|
||||
'flatten',
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are not inplace here
|
||||
NON_INPLACE_METHOD = [
|
||||
'chunk',
|
||||
'contiguous',
|
||||
'expand',
|
||||
'mean',
|
||||
'split',
|
||||
]
|
|
@ -1,11 +1,15 @@
|
|||
# for PyTorch 1.11 compatibility uses
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Node, GraphModule
|
||||
from typing import Union, Dict, List, Tuple
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ..._compatibility import compatibility
|
||||
|
||||
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def calculate_fwd_in(n: Node) -> bool:
|
||||
"""A helper function to calculate `fwd_in`
|
||||
|
||||
|
@ -18,6 +22,7 @@ def calculate_fwd_in(n: Node) -> bool:
|
|||
return n.meta['save_fwd_in']
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def calculate_fwd_tmp(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_tmp`
|
||||
|
||||
|
@ -30,6 +35,7 @@ def calculate_fwd_tmp(n: Node) -> int:
|
|||
return n.meta["fwd_mem_tmp"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def calculate_fwd_out(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_out`
|
||||
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Argument, Target
|
||||
from . import meta_profiler_function, meta_profiler_module
|
||||
|
||||
from ..._compatibility import compatibility
|
||||
from ..memory import activation_size
|
||||
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
|
||||
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
|
||||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
|
||||
|
||||
# this is for compatibility use
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@dataclass
|
||||
class GraphInfo:
|
||||
"""
|
||||
|
@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
|
|||
"""
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_function(target: 'Target') -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
|
@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable:
|
|||
return f
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_method(target: 'Target') -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
|
@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable:
|
|||
return f
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_module(module: torch.nn.Module) -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
|
|
|
@ -2,7 +2,6 @@ import operator
|
|||
from typing import Any, Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
|
||||
|
||||
@meta_profiler_function.register(operator.getitem)
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Node, GraphModule
|
||||
from typing import Union, Dict, List, Tuple
|
||||
from . import META_COMPATIBILITY
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from .._compatibility import compatibility, is_compatible_with_meta
|
||||
|
||||
__all__ = [
|
||||
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
||||
]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
"""Calculate activation size of a node.
|
||||
|
||||
|
@ -29,6 +32,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
|||
return act_size
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def parameter_size(mod: torch.nn.Module) -> int:
|
||||
"""Calculate parameter size of a node.
|
||||
|
||||
|
@ -111,8 +115,8 @@ def is_inplace(n: Node):
|
|||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
if META_COMPATIBILITY:
|
||||
from .constant import ALIAS_ATEN
|
||||
if is_compatible_with_meta():
|
||||
from .constants import ALIAS_ATEN
|
||||
if n.target in ALIAS_ATEN:
|
||||
inplace = True
|
||||
elif n.op == "call_module":
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
|
||||
# ideas from https://pastebin.com/AkvAyJBw
|
||||
|
||||
from functools import partial, reduce
|
||||
import operator
|
||||
from typing import Callable, List, Any
|
||||
from functools import partial, reduce
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
import time
|
||||
from functools import partial
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.utils._pytree import tree_map
|
||||
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
|
||||
|
||||
from .._compatibility import compatibility
|
||||
from .constants import ALIAS_ATEN
|
||||
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
||||
from .memory import activation_size, parameter_size
|
||||
from .constant import ALIAS_ATEN
|
||||
from .tensor import MetaTensor
|
||||
from .opcount import flop_mapping
|
||||
import time
|
||||
from .tensor import MetaTensor
|
||||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
|
||||
|
@ -41,6 +44,7 @@ def detach_variables(x):
|
|||
return x
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
|
||||
To profile the actual forward memory, we first run target in the context torch.no_grad() to get
|
||||
|
@ -140,6 +144,7 @@ def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...
|
|||
return tree_map(detach_variables, out), graphinfo
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||
"""
|
||||
Profile a Callable function with args and kwargs on meta devices.
|
||||
|
@ -277,6 +282,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
|||
return tree_map(unwrap, out), graph_info
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
|
@ -335,6 +341,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
|||
return f
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
|
@ -353,6 +360,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
|||
return f
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map, tree_flatten
|
||||
from torch.types import _bool, _dtype, _device
|
||||
import uuid
|
||||
from .constant import ALIAS_ATEN
|
||||
from torch.types import _bool, _device, _dtype
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
from .._compatibility import compatibility
|
||||
from .constants import ALIAS_ATEN
|
||||
|
||||
__all__ = ['MetaTensor']
|
||||
|
||||
|
@ -15,6 +18,7 @@ def set_uuid(x):
|
|||
setattr(x, 'uuid', uuid.uuid4())
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class MetaTensor(torch.Tensor):
|
||||
"""
|
||||
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||||
|
|
|
@ -1,23 +1,19 @@
|
|||
import threading
|
||||
from enum import Enum
|
||||
from typing import List, Any, Tuple, Dict, Callable
|
||||
from functools import partial
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
import inspect
|
||||
import math
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
|
||||
from torch import autograd
|
||||
from torch import optim
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
|
||||
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
|
||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
|
||||
split_batch, tensor_shape_list, type_detail)
|
||||
from torch import autograd, nn, optim
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
|
@ -195,7 +191,6 @@ class WorkerBase(ABC):
|
|||
if isinstance(output, Future):
|
||||
output = output.wait()
|
||||
|
||||
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
|
||||
output_work_item.refcount += 1
|
||||
|
||||
# all consumers have been satisfied, the work_item can be released
|
||||
|
@ -250,9 +245,6 @@ class WorkerBase(ABC):
|
|||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# just for last pp_rank
|
||||
|
@ -273,9 +265,6 @@ class WorkerBase(ABC):
|
|||
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, False)
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
|
||||
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
|
@ -297,23 +286,14 @@ class WorkerBase(ABC):
|
|||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
|
||||
'data dispatch', 'magenta')
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
|
||||
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
# add work_item to work_list
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_producer
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def subscribe_consumer(self, microbatch_id: int):
|
||||
|
@ -328,10 +308,6 @@ class WorkerBase(ABC):
|
|||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||
output = self._get_future_by_device()
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
|
||||
'data dispatch', 'magenta')
|
||||
|
||||
for i in range(consumer_num):
|
||||
consumer_stage_id = self.consumer_stage_ids[i]
|
||||
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
|
@ -342,17 +318,11 @@ class WorkerBase(ABC):
|
|||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, False)
|
||||
|
||||
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
|
||||
# add work_item to work_list
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_consumer
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def _get_producer_consumer(self) -> None:
|
||||
|
@ -406,11 +376,6 @@ class WorkerBase(ABC):
|
|||
is_first_stage = self.is_first_stage()
|
||||
is_last_stage = self.is_last_stage()
|
||||
|
||||
# if self.pp_rank == 0:
|
||||
# print(
|
||||
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
|
||||
# )
|
||||
|
||||
if phase == Phase.FORWARD:
|
||||
# remind its consumer to get data before forward
|
||||
if not is_last_stage:
|
||||
|
@ -470,8 +435,6 @@ class WorkerBase(ABC):
|
|||
else:
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
|
||||
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
||||
|
||||
if is_last_stage and self.criterion:
|
||||
with self.label_lock:
|
||||
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
|
||||
|
@ -539,10 +502,6 @@ class WorkerBase(ABC):
|
|||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
|
||||
# for input_node in stage_input_args:
|
||||
# if isinstance(input_node, torch.Tensor):
|
||||
# consume_result.append(input_node.grad)
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
|
||||
|
@ -593,11 +552,6 @@ class WorkerBase(ABC):
|
|||
with self.work_list_condition_lock:
|
||||
work_item = self.work_list.pop(work_item_key)
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
|
||||
'work loop', 'green')
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
# assert work_item_key not in self.output_list
|
||||
self.output_list[work_item_key] = work_item
|
||||
|
@ -605,11 +559,6 @@ class WorkerBase(ABC):
|
|||
|
||||
consume_result = self._consume_work_item_by_phase(work_item)
|
||||
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
|
||||
'work loop', 'green')
|
||||
|
||||
work_item.output.set_result(consume_result)
|
||||
|
||||
# if is last step in one batch reset context and do step
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
from typing import List, Callable, Dict
|
||||
import threading
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
|
||||
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
# Implementation of different Pipeline schedule
|
||||
# <strategy>Worker defines the worker for each stage
|
||||
|
|
|
@ -1,25 +1,15 @@
|
|||
from typing import List, Any, Tuple, Dict, Callable, Type, Union
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
import argparse
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.futures import Future
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from colorama import Back, Style
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
|
||||
# config for debug and test
|
||||
use_color_debug = False
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from torch.futures import Future
|
||||
|
||||
|
||||
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
import copy
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
import torch.fx
|
||||
import colossalai
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.passes.algorithms import solver_rotor
|
||||
from colossalai.fx.passes.algorithms.operation import Sequence
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
|
@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0):
|
|||
graph = tracer.trace(model, meta_args={"x": data})
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data_meta)
|
||||
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
from typing import Callable
|
||||
import copy
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from torch.fx import GraphModule
|
||||
import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from torch.fx import GraphModule
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
|
@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule):
|
|||
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
||||
model_cls: Callable[[], torch.nn.Module]):
|
||||
criterion = torch.nn.MSELoss()
|
||||
data = torch.rand(2, 3, 32, 32)
|
||||
label = torch.rand(2, 5)
|
||||
m.cuda()
|
||||
data = torch.rand(2, 3, 32, 32).cuda()
|
||||
label = torch.rand(2, 5).cuda()
|
||||
loss = criterion(m(data), label)
|
||||
loss.backward()
|
||||
loss = criterion(gm(data), label)
|
||||
|
@ -77,7 +80,7 @@ def _run_ckpt_solver(rank):
|
|||
m = model_cls(num_classes=5)
|
||||
graph = tracer.trace(root=m)
|
||||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
|
||||
MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
gm.graph.set_codegen(codegen)
|
||||
if solver == solver_rotor:
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
import pytest
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
||||
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.passes.algorithms import linearize, solver_rotor
|
||||
from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
from torch.fx import symbolic_trace
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
|
||||
from colossalai.fx.passes.utils import get_comm_size
|
||||
from colossalai import META_COMPATIBILITY
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.utils import get_comm_size
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
is_compatible = is_compatible_with_meta()
|
||||
if is_compatible:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
MODEL_DIM = 16
|
||||
BATCH_SIZE = 8
|
||||
|
@ -31,12 +35,12 @@ class MLP(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_comm_size_compute():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
model = MLP(MODEL_DIM)
|
||||
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
|
||||
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
|
||||
gm = symbolic_trace(model)
|
||||
if is_compatible:
|
||||
input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
from typing import Any, Callable, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai import META_COMPATIBILITY
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
|
|||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_meta_aten():
|
||||
for (aten_op, requires_backward), v in registered_meta.items():
|
||||
for f, x in v:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import torchvision.models as tm
|
||||
import pytest
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
from colossalai import META_COMPATIBILITY
|
||||
import pytest
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
tm_models = [
|
||||
|
@ -27,7 +27,7 @@ tmm_models = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_torchvision_models():
|
||||
for m in tm_models:
|
||||
model = m()
|
||||
|
@ -35,7 +35,7 @@ def test_torchvision_models():
|
|||
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_timm_models():
|
||||
for m in tmm_models:
|
||||
model = m()
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import torchvision.models as tm
|
||||
import pytest
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
from colossalai import META_COMPATIBILITY
|
||||
import pytest
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx import meta_trace
|
||||
|
||||
tm_models = [
|
||||
|
@ -27,7 +27,7 @@ tmm_models = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_torchvision_models_trace():
|
||||
for m in tm_models:
|
||||
model = m()
|
||||
|
@ -35,7 +35,7 @@ def test_torchvision_models_trace():
|
|||
graph = meta_trace(model, torch.device('cpu'), data)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_timm_models_trace():
|
||||
for m in tmm_models:
|
||||
model = m()
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from colossalai import META_COMPATIBILITY
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
BATCH_SIZE = 2
|
||||
DIM_IN = 4
|
||||
|
@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
|||
def test_meta_info_prop():
|
||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
if is_compatible_with_meta():
|
||||
input_sample = MetaTensor(input_sample, fake_device='cpu')
|
||||
orig_output = model(input_sample)
|
||||
gm = symbolic_trace(model)
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
import os
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
import torch.distributed as dist
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from torch import nn
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from torch.optim import SGD, Adam, Optimizer, RMSprop
|
||||
|
||||
rpc_is_initialized = _is_current_rpc_agent_set
|
||||
|
||||
|
|
Loading…
Reference in New Issue