mirror of https://github.com/hpcaitech/ColossalAI
[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)
* [fx] move meta registration * [fx] fix tests. * [fx] fix test. * [fx] fix. * [meta] refactor meta registration.py. * [fx] add compatibility descriptions. * [fx] polish import. * [fx] add a decorator. * [fx] fix tests. * [fx] remove print. * [fx] edit raise error. * [fx] edit raise error. * [fx] add type hint. * [fx] fix import in experimental. * [rpc] remove color debug. * [meta] fix naming.pull/1724/head
parent
e8d8eda5e7
commit
393f594051
|
@ -1,10 +1,3 @@
|
||||||
try:
|
|
||||||
from . import _meta_registrations
|
|
||||||
META_COMPATIBILITY = True
|
|
||||||
except:
|
|
||||||
import torch
|
|
||||||
META_COMPATIBILITY = False
|
|
||||||
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
|
|
||||||
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
|
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
|
||||||
get_default_parser)
|
get_default_parser)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
from .tracer import ColoTracer, meta_trace
|
from ._compatibility import compatibility, is_compatible_with_meta
|
||||||
from .graph_module import ColoGraphModule
|
from .graph_module import ColoGraphModule
|
||||||
from .passes import MetaInfoProp
|
from .passes import MetaInfoProp
|
||||||
|
from .tracer import ColoTracer, meta_trace
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import _meta_registrations
|
||||||
|
META_COMPATIBILITY = True
|
||||||
|
except:
|
||||||
|
META_COMPATIBILITY = False
|
||||||
|
|
||||||
|
|
||||||
|
def compatibility(is_backward_compatible: bool = False) -> Callable:
|
||||||
|
"""A decorator to make a function compatible with different versions of PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: The decorated function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
if META_COMPATIBILITY:
|
||||||
|
return func
|
||||||
|
else:
|
||||||
|
if is_backward_compatible:
|
||||||
|
return func
|
||||||
|
else:
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def is_compatible_with_meta() -> bool:
|
||||||
|
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
|
||||||
|
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
|
||||||
|
experimental counterparts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: The meta compatibility
|
||||||
|
"""
|
||||||
|
return META_COMPATIBILITY
|
|
@ -1,7 +1,10 @@
|
||||||
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
|
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
|
||||||
# should be activated for PyTorch version 1.12.0 and below
|
# should be activated for PyTorch version 1.12.0 and below
|
||||||
|
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
|
# for more meta_registrations
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
@ -31,6 +34,7 @@ def register_meta(op, register_dispatcher=True):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# ============================== Convolutions ======================================
|
||||||
# https://github.com/pytorch/pytorch/pull/79834
|
# https://github.com/pytorch/pytorch/pull/79834
|
||||||
@register_meta(aten.convolution.default)
|
@register_meta(aten.convolution.default)
|
||||||
def meta_conv(
|
def meta_conv(
|
||||||
|
@ -165,6 +169,18 @@ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: t
|
||||||
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
|
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||||
|
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||||
|
def meta_adaptive_avg_pool2d_backward(
|
||||||
|
grad_output: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
):
|
||||||
|
grad_input = torch.empty_like(input)
|
||||||
|
return grad_input
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||||
|
# ============================== Activations =======================================
|
||||||
@register_meta(aten.relu.default)
|
@register_meta(aten.relu.default)
|
||||||
def meta_relu(input: torch.Tensor):
|
def meta_relu(input: torch.Tensor):
|
||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
|
@ -192,11 +208,8 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
|
||||||
return grad_in
|
return grad_in
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.roll.default)
|
# ============================== Normalization =====================================
|
||||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
return input
|
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.native_batch_norm.default)
|
@register_meta(aten.native_batch_norm.default)
|
||||||
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||||
n_input = input.size(1)
|
n_input = input.size(1)
|
||||||
|
@ -207,6 +220,7 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
|
||||||
return output, running_mean, running_var
|
return output, running_mean, running_var
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
@register_meta(aten.native_batch_norm_backward.default)
|
@register_meta(aten.native_batch_norm_backward.default)
|
||||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
||||||
save_invstd, train, eps, output_mask):
|
save_invstd, train, eps, output_mask):
|
||||||
|
@ -241,6 +255,7 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.
|
||||||
return dX, dgamma, dbeta
|
return dX, dgamma, dbeta
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||||
@register_meta(aten.native_layer_norm.default)
|
@register_meta(aten.native_layer_norm.default)
|
||||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||||
bs = input.size(0)
|
bs = input.size(0)
|
||||||
|
@ -252,6 +267,7 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||||
return output, running_mean, running_var
|
return output, running_mean, running_var
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||||
@register_meta(aten.native_layer_norm_backward.default)
|
@register_meta(aten.native_layer_norm_backward.default)
|
||||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||||
grad_input_mask):
|
grad_input_mask):
|
||||||
|
@ -261,13 +277,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
|
||||||
return dX, dgamma, dbeta
|
return dX, dgamma, dbeta
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
# ================================== Misc ==========================================
|
||||||
def meta_adaptive_avg_pool2d_backward(
|
#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
grad_output: torch.Tensor,
|
@register_meta(aten.roll.default)
|
||||||
input: torch.Tensor,
|
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||||
):
|
return input
|
||||||
grad_input = torch.empty_like(input)
|
|
||||||
return torch.empty_like(input)
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||||
|
@register_meta(aten.where.self)
|
||||||
|
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||||
|
result_type = torch.result_type(self, other)
|
||||||
|
return torch.empty_like(self, dtype=result_type)
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.index.Tensor)
|
@register_meta(aten.index.Tensor)
|
||||||
|
@ -360,6 +381,8 @@ def meta_index_Tensor(self, indices):
|
||||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================== Embedding =========================================
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||||
@register_meta(aten.embedding_dense_backward.default)
|
@register_meta(aten.embedding_dense_backward.default)
|
||||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||||
scale_grad_by_freq):
|
scale_grad_by_freq):
|
||||||
|
@ -369,13 +392,7 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
|
||||||
layout=grad_output.layout)
|
layout=grad_output.layout)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
# ============================== Dropout ===========================================
|
||||||
@register_meta(aten.where.self)
|
|
||||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
|
||||||
result_type = torch.result_type(self, other)
|
|
||||||
return torch.empty_like(self, dtype=result_type)
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||||
@register_meta(aten.native_dropout.default)
|
@register_meta(aten.native_dropout.default)
|
||||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
|
@ -1,16 +1,20 @@
|
||||||
from typing import List, Tuple
|
|
||||||
import copy
|
import copy
|
||||||
import torch
|
|
||||||
from torch.fx import GraphModule, Node
|
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.profiler import parameter_size
|
|
||||||
import math
|
import math
|
||||||
from .linearize import linearize
|
from typing import List, Tuple
|
||||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch
|
|
||||||
|
import torch
|
||||||
|
from colossalai.fx import is_compatible_with_meta
|
||||||
|
from colossalai.fx.codegen.activation_checkpoint_codegen import \
|
||||||
|
_find_nested_ckpt_regions
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
from colossalai.fx.profiler import parameter_size
|
||||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
|
from torch.fx import GraphModule, Node
|
||||||
from colossalai import META_COMPATIBILITY
|
|
||||||
|
from .linearize import linearize
|
||||||
|
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
|
||||||
|
Sequence)
|
||||||
|
|
||||||
INF = float("inf")
|
INF = float("inf")
|
||||||
|
|
||||||
|
@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule,
|
||||||
mem_limit -= parameter_size(gm)
|
mem_limit -= parameter_size(gm)
|
||||||
|
|
||||||
# prepare data
|
# prepare data
|
||||||
if META_COMPATIBILITY:
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from colossalai.fx.profiler import GraphInfo
|
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.node import Node, Argument, Target
|
from colossalai.fx._compatibility import compatibility
|
||||||
|
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
|
||||||
|
from torch.fx.node import Argument, Node, Target
|
||||||
from torch.utils._pytree import tree_flatten
|
from torch.utils._pytree import tree_flatten
|
||||||
from typing import Any, List, Tuple, NamedTuple, Dict, Optional
|
|
||||||
from torch.fx._compatibility import compatibility
|
|
||||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
|
||||||
from torch.fx.graph_module import GraphModule
|
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
from typing import Any, Dict, List, NamedTuple, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.node import Node, Argument, Target
|
from colossalai.fx._compatibility import compatibility
|
||||||
|
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
|
||||||
|
profile_function, profile_method, profile_module)
|
||||||
|
from torch.fx.node import Argument, Node, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from typing import Any, List, Tuple, NamedTuple, Dict
|
|
||||||
from torch.fx._compatibility import compatibility
|
|
||||||
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in
|
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from ... import META_COMPATIBILITY
|
from .._compatibility import is_compatible_with_meta
|
||||||
if META_COMPATIBILITY:
|
|
||||||
|
if is_compatible_with_meta():
|
||||||
|
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
from .tensor import MetaTensor
|
|
||||||
from .profiler import profile_function, profile_method, profile_module
|
from .profiler import profile_function, profile_method, profile_module
|
||||||
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
from .tensor import MetaTensor
|
||||||
else:
|
else:
|
||||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||||
|
|
||||||
from .dataflow import GraphInfo
|
from .dataflow import GraphInfo
|
||||||
from .memory import parameter_size, activation_size, is_inplace
|
from .memory import activation_size, is_inplace, parameter_size
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
import torch
|
|
||||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
|
||||||
from . import META_COMPATIBILITY
|
|
||||||
|
|
||||||
__all__ = []
|
|
||||||
|
|
||||||
if META_COMPATIBILITY:
|
|
||||||
aten = torch.ops.aten
|
|
||||||
|
|
||||||
ALIAS_ATEN = [
|
|
||||||
# inplace reshaping
|
|
||||||
aten.detach.default,
|
|
||||||
aten.t.default,
|
|
||||||
aten.transpose.int,
|
|
||||||
aten.view.default,
|
|
||||||
aten._unsafe_view.default,
|
|
||||||
aten._reshape_alias.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
INPLACE_NEW = [
|
|
||||||
aten.empty_like.default,
|
|
||||||
aten.new_empty_strided.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
INPLACE_MATH_ATEN = [
|
|
||||||
aten.add_.Tensor,
|
|
||||||
aten.sub_.Tensor,
|
|
||||||
aten.div_.Tensor,
|
|
||||||
aten.div_.Scalar,
|
|
||||||
aten.mul_.Tensor,
|
|
||||||
aten.bernoulli_.float,
|
|
||||||
]
|
|
||||||
|
|
||||||
CLONE_ATEN = [
|
|
||||||
aten.clone.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
|
|
||||||
|
|
||||||
else:
|
|
||||||
# TODO fill out the inplace ops
|
|
||||||
INPLACE_OPS = [
|
|
||||||
add,
|
|
||||||
sub,
|
|
||||||
mul,
|
|
||||||
floordiv,
|
|
||||||
neg,
|
|
||||||
pos,
|
|
||||||
getitem,
|
|
||||||
setitem,
|
|
||||||
getattr,
|
|
||||||
torch.Tensor.cpu,
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO: list all call_methods that are inplace here
|
|
||||||
INPLACE_METHOD = [
|
|
||||||
'transpose',
|
|
||||||
'permute',
|
|
||||||
# TODO: reshape may return a copy of the data if the data is not contiguous
|
|
||||||
'reshape',
|
|
||||||
'dim',
|
|
||||||
'flatten',
|
|
||||||
'size',
|
|
||||||
'view',
|
|
||||||
'unsqueeze',
|
|
||||||
'to',
|
|
||||||
'type',
|
|
||||||
'flatten',
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO: list all call_methods that are not inplace here
|
|
||||||
NON_INPLACE_METHOD = [
|
|
||||||
'chunk',
|
|
||||||
'contiguous',
|
|
||||||
'expand',
|
|
||||||
'mean',
|
|
||||||
'split',
|
|
||||||
]
|
|
||||||
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
|
||||||
|
|
||||||
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
ALIAS_ATEN = [
|
||||||
|
aten.detach.default,
|
||||||
|
aten.t.default,
|
||||||
|
aten.transpose.int,
|
||||||
|
aten.view.default,
|
||||||
|
aten._unsafe_view.default,
|
||||||
|
aten._reshape_alias.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
INPLACE_NEW = [
|
||||||
|
aten.empty_like.default,
|
||||||
|
aten.new_empty_strided.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
INPLACE_MATH_ATEN = [
|
||||||
|
aten.add_.Tensor,
|
||||||
|
aten.sub_.Tensor,
|
||||||
|
aten.div_.Tensor,
|
||||||
|
aten.div_.Scalar,
|
||||||
|
aten.mul_.Tensor,
|
||||||
|
aten.bernoulli_.float,
|
||||||
|
]
|
||||||
|
|
||||||
|
CLONE_ATEN = [
|
||||||
|
aten.clone.default,
|
||||||
|
]
|
|
@ -2,7 +2,10 @@ from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
|
||||||
|
from .._compatibility import compatibility
|
||||||
from .memory import activation_size, is_inplace
|
from .memory import activation_size, is_inplace
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +15,7 @@ class Phase(Enum):
|
||||||
PLACEHOLDER = 2
|
PLACEHOLDER = 2
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphInfo:
|
class GraphInfo:
|
||||||
"""
|
"""
|
||||||
|
@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
||||||
return n.meta['phase'] == phase
|
return n.meta['phase'] == phase
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||||
Basically the input graph should have all nodes marked for keyword `phase`.
|
Basically the input graph should have all nodes marked for keyword `phase`.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .registry import meta_profiler_function, meta_profiler_module
|
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
||||||
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
from .profiler import profile_function, profile_method, profile_module
|
||||||
from .profiler_function import *
|
from .profiler_function import *
|
||||||
from .profiler_module import *
|
from .profiler_module import *
|
||||||
from .profiler import profile_function, profile_method, profile_module
|
from .registry import meta_profiler_function, meta_profiler_module
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
|
||||||
|
|
||||||
|
# TODO fill out the inplace ops
|
||||||
|
INPLACE_OPS = [
|
||||||
|
add,
|
||||||
|
sub,
|
||||||
|
mul,
|
||||||
|
floordiv,
|
||||||
|
neg,
|
||||||
|
pos,
|
||||||
|
getitem,
|
||||||
|
setitem,
|
||||||
|
getattr,
|
||||||
|
torch.Tensor.cpu,
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: list all call_methods that are inplace here
|
||||||
|
INPLACE_METHOD = [
|
||||||
|
'transpose',
|
||||||
|
'permute',
|
||||||
|
# TODO: reshape may return a copy of the data if the data is not contiguous
|
||||||
|
'reshape',
|
||||||
|
'dim',
|
||||||
|
'flatten',
|
||||||
|
'size',
|
||||||
|
'view',
|
||||||
|
'unsqueeze',
|
||||||
|
'to',
|
||||||
|
'type',
|
||||||
|
'flatten',
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: list all call_methods that are not inplace here
|
||||||
|
NON_INPLACE_METHOD = [
|
||||||
|
'chunk',
|
||||||
|
'contiguous',
|
||||||
|
'expand',
|
||||||
|
'mean',
|
||||||
|
'split',
|
||||||
|
]
|
|
@ -1,11 +1,15 @@
|
||||||
# for PyTorch 1.11 compatibility uses
|
# for PyTorch 1.11 compatibility uses
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Node, GraphModule
|
from torch.fx import GraphModule, Node
|
||||||
from typing import Union, Dict, List, Tuple
|
|
||||||
|
from ..._compatibility import compatibility
|
||||||
|
|
||||||
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def calculate_fwd_in(n: Node) -> bool:
|
def calculate_fwd_in(n: Node) -> bool:
|
||||||
"""A helper function to calculate `fwd_in`
|
"""A helper function to calculate `fwd_in`
|
||||||
|
|
||||||
|
@ -18,6 +22,7 @@ def calculate_fwd_in(n: Node) -> bool:
|
||||||
return n.meta['save_fwd_in']
|
return n.meta['save_fwd_in']
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def calculate_fwd_tmp(n: Node) -> int:
|
def calculate_fwd_tmp(n: Node) -> int:
|
||||||
"""A helper function to calculate `fwd_tmp`
|
"""A helper function to calculate `fwd_tmp`
|
||||||
|
|
||||||
|
@ -30,6 +35,7 @@ def calculate_fwd_tmp(n: Node) -> int:
|
||||||
return n.meta["fwd_mem_tmp"]
|
return n.meta["fwd_mem_tmp"]
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def calculate_fwd_out(n: Node) -> int:
|
def calculate_fwd_out(n: Node) -> int:
|
||||||
"""A helper function to calculate `fwd_out`
|
"""A helper function to calculate `fwd_out`
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,19 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Any, Dict, Tuple
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
from . import meta_profiler_function, meta_profiler_module
|
|
||||||
|
from ..._compatibility import compatibility
|
||||||
from ..memory import activation_size
|
from ..memory import activation_size
|
||||||
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
|
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
|
||||||
|
from .registry import meta_profiler_function, meta_profiler_module
|
||||||
|
|
||||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||||
|
|
||||||
|
|
||||||
# this is for compatibility use
|
# this is for compatibility use
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphInfo:
|
class GraphInfo:
|
||||||
"""
|
"""
|
||||||
|
@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def profile_function(target: 'Target') -> Callable:
|
def profile_function(target: 'Target') -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||||
|
@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def profile_method(target: 'Target') -> Callable:
|
def profile_method(target: 'Target') -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_method` node
|
Wrap a `call_method` node
|
||||||
|
@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def profile_module(module: torch.nn.Module) -> Callable:
|
def profile_module(module: torch.nn.Module) -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_module` node or `torch.nn` in order to
|
Wrap a `call_module` node or `torch.nn` in order to
|
||||||
|
|
|
@ -2,7 +2,6 @@ import operator
|
||||||
from typing import Any, Tuple
|
from typing import Any, Tuple
|
||||||
import torch
|
import torch
|
||||||
from ..registry import meta_profiler_function
|
from ..registry import meta_profiler_function
|
||||||
from colossalai.fx.proxy import ColoProxy
|
|
||||||
|
|
||||||
|
|
||||||
@meta_profiler_function.register(operator.getitem)
|
@meta_profiler_function.register(operator.getitem)
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Node, GraphModule
|
from torch.fx import GraphModule, Node
|
||||||
from typing import Union, Dict, List, Tuple
|
|
||||||
from . import META_COMPATIBILITY
|
from .._compatibility import compatibility, is_compatible_with_meta
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
"""Calculate activation size of a node.
|
"""Calculate activation size of a node.
|
||||||
|
|
||||||
|
@ -29,6 +32,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
return act_size
|
return act_size
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def parameter_size(mod: torch.nn.Module) -> int:
|
def parameter_size(mod: torch.nn.Module) -> int:
|
||||||
"""Calculate parameter size of a node.
|
"""Calculate parameter size of a node.
|
||||||
|
|
||||||
|
@ -111,8 +115,8 @@ def is_inplace(n: Node):
|
||||||
inplace = False
|
inplace = False
|
||||||
if n.op == "call_function":
|
if n.op == "call_function":
|
||||||
inplace = n.kwargs.get("inplace", False)
|
inplace = n.kwargs.get("inplace", False)
|
||||||
if META_COMPATIBILITY:
|
if is_compatible_with_meta():
|
||||||
from .constant import ALIAS_ATEN
|
from .constants import ALIAS_ATEN
|
||||||
if n.target in ALIAS_ATEN:
|
if n.target in ALIAS_ATEN:
|
||||||
inplace = True
|
inplace = True
|
||||||
elif n.op == "call_module":
|
elif n.op == "call_module":
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
|
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
|
||||||
# ideas from https://pastebin.com/AkvAyJBw
|
# ideas from https://pastebin.com/AkvAyJBw
|
||||||
|
|
||||||
from functools import partial, reduce
|
|
||||||
import operator
|
import operator
|
||||||
from typing import Callable, List, Any
|
from functools import partial, reduce
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
|
from typing import Any, Callable, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Any, Dict, Tuple
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo
|
|
||||||
|
from .._compatibility import compatibility
|
||||||
|
from .constants import ALIAS_ATEN
|
||||||
|
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
||||||
from .memory import activation_size, parameter_size
|
from .memory import activation_size, parameter_size
|
||||||
from .constant import ALIAS_ATEN
|
|
||||||
from .tensor import MetaTensor
|
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
import time
|
from .tensor import MetaTensor
|
||||||
|
|
||||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||||
|
|
||||||
|
@ -41,6 +44,7 @@ def detach_variables(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||||
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
|
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
|
||||||
To profile the actual forward memory, we first run target in the context torch.no_grad() to get
|
To profile the actual forward memory, we first run target in the context torch.no_grad() to get
|
||||||
|
@ -140,6 +144,7 @@ def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...
|
||||||
return tree_map(detach_variables, out), graphinfo
|
return tree_map(detach_variables, out), graphinfo
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
|
||||||
"""
|
"""
|
||||||
Profile a Callable function with args and kwargs on meta devices.
|
Profile a Callable function with args and kwargs on meta devices.
|
||||||
|
@ -277,6 +282,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||||
return tree_map(unwrap, out), graph_info
|
return tree_map(unwrap, out), graph_info
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||||
|
@ -335,6 +341,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_method` node
|
Wrap a `call_method` node
|
||||||
|
@ -353,6 +360,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||||
"""
|
"""
|
||||||
Wrap a `call_module` node or `torch.nn` in order to
|
Wrap a `call_module` node or `torch.nn` in order to
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
|
import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils._pytree import tree_map, tree_flatten
|
from torch.types import _bool, _device, _dtype
|
||||||
from torch.types import _bool, _dtype, _device
|
from torch.utils._pytree import tree_flatten, tree_map
|
||||||
import uuid
|
|
||||||
from .constant import ALIAS_ATEN
|
from .._compatibility import compatibility
|
||||||
|
from .constants import ALIAS_ATEN
|
||||||
|
|
||||||
__all__ = ['MetaTensor']
|
__all__ = ['MetaTensor']
|
||||||
|
|
||||||
|
@ -15,6 +18,7 @@ def set_uuid(x):
|
||||||
setattr(x, 'uuid', uuid.uuid4())
|
setattr(x, 'uuid', uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
class MetaTensor(torch.Tensor):
|
class MetaTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||||||
|
|
|
@ -1,23 +1,19 @@
|
||||||
import threading
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Any, Tuple, Dict, Callable
|
|
||||||
from functools import partial
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import math
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Callable, Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torch.distributed.rpc as rpc
|
import torch.distributed.rpc as rpc
|
||||||
from torch.futures import Future
|
|
||||||
from torch._C._distributed_rpc import PyRRef
|
|
||||||
|
|
||||||
from torch import autograd
|
|
||||||
from torch import optim
|
|
||||||
|
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
|
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
|
||||||
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
|
split_batch, tensor_shape_list, type_detail)
|
||||||
|
from torch import autograd, nn, optim
|
||||||
|
from torch._C._distributed_rpc import PyRRef
|
||||||
|
from torch.futures import Future
|
||||||
|
|
||||||
|
|
||||||
class Phase(Enum):
|
class Phase(Enum):
|
||||||
|
@ -195,7 +191,6 @@ class WorkerBase(ABC):
|
||||||
if isinstance(output, Future):
|
if isinstance(output, Future):
|
||||||
output = output.wait()
|
output = output.wait()
|
||||||
|
|
||||||
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
|
|
||||||
output_work_item.refcount += 1
|
output_work_item.refcount += 1
|
||||||
|
|
||||||
# all consumers have been satisfied, the work_item can be released
|
# all consumers have been satisfied, the work_item can be released
|
||||||
|
@ -250,9 +245,6 @@ class WorkerBase(ABC):
|
||||||
self.num_microbatches, forward_only)
|
self.num_microbatches, forward_only)
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
self.work_list[key] = work_item
|
self.work_list[key] = work_item
|
||||||
if use_color_debug:
|
|
||||||
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
|
|
||||||
'data dispatch', 'magenta')
|
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
# just for last pp_rank
|
# just for last pp_rank
|
||||||
|
@ -273,9 +265,6 @@ class WorkerBase(ABC):
|
||||||
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||||
self.num_microbatches, False)
|
self.num_microbatches, False)
|
||||||
|
|
||||||
if use_color_debug:
|
|
||||||
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
|
|
||||||
|
|
||||||
self.work_list[key] = work_item
|
self.work_list[key] = work_item
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
|
@ -297,23 +286,14 @@ class WorkerBase(ABC):
|
||||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||||
|
|
||||||
if use_color_debug:
|
|
||||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
|
|
||||||
'data dispatch', 'magenta')
|
|
||||||
|
|
||||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||||
microbatch_id, None, self.num_microbatches, forward_only)
|
microbatch_id, None, self.num_microbatches, forward_only)
|
||||||
|
|
||||||
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
|
||||||
# add work_item to work_list
|
# add work_item to work_list
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
assert key not in self.work_list
|
assert key not in self.work_list
|
||||||
self.work_list[key] = work_item_from_producer
|
self.work_list[key] = work_item_from_producer
|
||||||
if use_color_debug:
|
|
||||||
color_debug(
|
|
||||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
|
|
||||||
'data dispatch', 'magenta')
|
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
def subscribe_consumer(self, microbatch_id: int):
|
def subscribe_consumer(self, microbatch_id: int):
|
||||||
|
@ -328,10 +308,6 @@ class WorkerBase(ABC):
|
||||||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||||
output = self._get_future_by_device()
|
output = self._get_future_by_device()
|
||||||
|
|
||||||
if use_color_debug:
|
|
||||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
|
|
||||||
'data dispatch', 'magenta')
|
|
||||||
|
|
||||||
for i in range(consumer_num):
|
for i in range(consumer_num):
|
||||||
consumer_stage_id = self.consumer_stage_ids[i]
|
consumer_stage_id = self.consumer_stage_ids[i]
|
||||||
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||||
|
@ -342,17 +318,11 @@ class WorkerBase(ABC):
|
||||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||||
microbatch_id, None, self.num_microbatches, False)
|
microbatch_id, None, self.num_microbatches, False)
|
||||||
|
|
||||||
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
|
||||||
|
|
||||||
# add work_item to work_list
|
# add work_item to work_list
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||||
assert key not in self.work_list
|
assert key not in self.work_list
|
||||||
self.work_list[key] = work_item_from_consumer
|
self.work_list[key] = work_item_from_consumer
|
||||||
if use_color_debug:
|
|
||||||
color_debug(
|
|
||||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
|
|
||||||
'data dispatch', 'magenta')
|
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
def _get_producer_consumer(self) -> None:
|
def _get_producer_consumer(self) -> None:
|
||||||
|
@ -406,11 +376,6 @@ class WorkerBase(ABC):
|
||||||
is_first_stage = self.is_first_stage()
|
is_first_stage = self.is_first_stage()
|
||||||
is_last_stage = self.is_last_stage()
|
is_last_stage = self.is_last_stage()
|
||||||
|
|
||||||
# if self.pp_rank == 0:
|
|
||||||
# print(
|
|
||||||
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
|
|
||||||
# )
|
|
||||||
|
|
||||||
if phase == Phase.FORWARD:
|
if phase == Phase.FORWARD:
|
||||||
# remind its consumer to get data before forward
|
# remind its consumer to get data before forward
|
||||||
if not is_last_stage:
|
if not is_last_stage:
|
||||||
|
@ -470,8 +435,6 @@ class WorkerBase(ABC):
|
||||||
else:
|
else:
|
||||||
consume_result = self.module_partition(*args, **kwargs)
|
consume_result = self.module_partition(*args, **kwargs)
|
||||||
|
|
||||||
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
|
||||||
|
|
||||||
if is_last_stage and self.criterion:
|
if is_last_stage and self.criterion:
|
||||||
with self.label_lock:
|
with self.label_lock:
|
||||||
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
|
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
|
||||||
|
@ -539,10 +502,6 @@ class WorkerBase(ABC):
|
||||||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||||
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||||
|
|
||||||
# for input_node in stage_input_args:
|
|
||||||
# if isinstance(input_node, torch.Tensor):
|
|
||||||
# consume_result.append(input_node.grad)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||||
|
|
||||||
|
@ -593,11 +552,6 @@ class WorkerBase(ABC):
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
work_item = self.work_list.pop(work_item_key)
|
work_item = self.work_list.pop(work_item_key)
|
||||||
|
|
||||||
if use_color_debug:
|
|
||||||
color_debug(
|
|
||||||
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
|
|
||||||
'work loop', 'green')
|
|
||||||
|
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
# assert work_item_key not in self.output_list
|
# assert work_item_key not in self.output_list
|
||||||
self.output_list[work_item_key] = work_item
|
self.output_list[work_item_key] = work_item
|
||||||
|
@ -605,11 +559,6 @@ class WorkerBase(ABC):
|
||||||
|
|
||||||
consume_result = self._consume_work_item_by_phase(work_item)
|
consume_result = self._consume_work_item_by_phase(work_item)
|
||||||
|
|
||||||
if use_color_debug:
|
|
||||||
color_debug(
|
|
||||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
|
|
||||||
'work loop', 'green')
|
|
||||||
|
|
||||||
work_item.output.set_result(consume_result)
|
work_item.output.set_result(consume_result)
|
||||||
|
|
||||||
# if is last step in one batch reset context and do step
|
# if is last step in one batch reset context and do step
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
from typing import List, Callable, Dict
|
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.futures import Future
|
|
||||||
from torch._C._distributed_rpc import PyRRef
|
|
||||||
|
|
||||||
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
|
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
|
||||||
|
from torch._C._distributed_rpc import PyRRef
|
||||||
|
from torch.futures import Future
|
||||||
|
|
||||||
# Implementation of different Pipeline schedule
|
# Implementation of different Pipeline schedule
|
||||||
# <strategy>Worker defines the worker for each stage
|
# <strategy>Worker defines the worker for each stage
|
||||||
|
|
|
@ -1,25 +1,15 @@
|
||||||
from typing import List, Any, Tuple, Dict, Callable, Type, Union
|
import argparse
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
import argparse
|
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from torch.futures import Future
|
|
||||||
import torch.distributed.rpc as rpc
|
import torch.distributed.rpc as rpc
|
||||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
import torch.multiprocessing as mp
|
||||||
from colorama import Back, Style
|
|
||||||
|
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||||
# config for debug and test
|
from torch.futures import Future
|
||||||
use_color_debug = False
|
|
||||||
|
|
||||||
|
|
||||||
def color_debug(text, prefix=' ', color='blue'):
|
|
||||||
color = color.upper()
|
|
||||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
|
||||||
|
|
||||||
|
|
||||||
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||||
|
|
|
@ -1,18 +1,20 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.fx
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
import torch.fx
|
from colossalai.core import global_context as gpc
|
||||||
import colossalai
|
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
from colossalai.fx.passes.algorithms import solver_rotor
|
from colossalai.fx.passes.algorithms import solver_rotor
|
||||||
from colossalai.fx.passes.algorithms.operation import Sequence
|
from colossalai.fx.passes.algorithms.operation import Sequence
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
import pytest
|
|
||||||
from colossalai import META_COMPATIBILITY
|
if is_compatible_with_meta():
|
||||||
if META_COMPATIBILITY:
|
|
||||||
from colossalai.fx.profiler.tensor import MetaTensor
|
from colossalai.fx.profiler.tensor import MetaTensor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0):
|
||||||
graph = tracer.trace(model, meta_args={"x": data})
|
graph = tracer.trace(model, meta_args={"x": data})
|
||||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||||
if META_COMPATIBILITY:
|
if is_compatible_with_meta():
|
||||||
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||||
MetaInfoProp(gm).run(data_meta)
|
MetaInfoProp(gm).run(data_meta)
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,22 @@
|
||||||
from typing import Callable
|
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from torch.fx import GraphModule
|
|
||||||
import colossalai
|
|
||||||
from colossalai.fx import ColoTracer
|
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
|
||||||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
import pytest
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai import META_COMPATIBILITY
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
if META_COMPATIBILITY:
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx.profiler.tensor import MetaTensor
|
from colossalai.fx.profiler.tensor import MetaTensor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule):
|
||||||
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
||||||
model_cls: Callable[[], torch.nn.Module]):
|
model_cls: Callable[[], torch.nn.Module]):
|
||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
data = torch.rand(2, 3, 32, 32)
|
m.cuda()
|
||||||
label = torch.rand(2, 5)
|
data = torch.rand(2, 3, 32, 32).cuda()
|
||||||
|
label = torch.rand(2, 5).cuda()
|
||||||
loss = criterion(m(data), label)
|
loss = criterion(m(data), label)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
loss = criterion(gm(data), label)
|
loss = criterion(gm(data), label)
|
||||||
|
@ -77,7 +80,7 @@ def _run_ckpt_solver(rank):
|
||||||
m = model_cls(num_classes=5)
|
m = model_cls(num_classes=5)
|
||||||
graph = tracer.trace(root=m)
|
graph = tracer.trace(root=m)
|
||||||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||||
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
|
MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
|
||||||
codegen = ActivationCheckpointCodeGen()
|
codegen = ActivationCheckpointCodeGen()
|
||||||
gm.graph.set_codegen(codegen)
|
gm.graph.set_codegen(codegen)
|
||||||
if solver == solver_rotor:
|
if solver == solver_rotor:
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
from colossalai.fx.passes.algorithms import linearize, solver_rotor
|
||||||
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
|
||||||
import pytest
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai import META_COMPATIBILITY
|
|
||||||
if META_COMPATIBILITY:
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx.profiler.tensor import MetaTensor
|
from colossalai.fx.profiler.tensor import MetaTensor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import colossalai
|
import colossalai
|
||||||
import colossalai.nn as col_nn
|
import colossalai.nn as col_nn
|
||||||
from torch.fx import symbolic_trace
|
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
|
||||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
|
|
||||||
from colossalai.fx.passes.utils import get_comm_size
|
|
||||||
from colossalai import META_COMPATIBILITY
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
|
from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass)
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.passes.utils import get_comm_size
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
|
||||||
|
is_compatible = is_compatible_with_meta()
|
||||||
|
if is_compatible:
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
MODEL_DIM = 16
|
MODEL_DIM = 16
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
|
@ -31,12 +35,12 @@ class MLP(torch.nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
|
||||||
def test_comm_size_compute():
|
def test_comm_size_compute():
|
||||||
from colossalai.fx.profiler import MetaTensor
|
|
||||||
model = MLP(MODEL_DIM)
|
model = MLP(MODEL_DIM)
|
||||||
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
|
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
|
||||||
gm = symbolic_trace(model)
|
gm = symbolic_trace(model)
|
||||||
|
if is_compatible:
|
||||||
|
input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
|
||||||
MetaInfoProp(gm).run(input_sample)
|
MetaInfoProp(gm).run(input_sample)
|
||||||
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||||
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
from typing import Any, Callable, Union
|
from typing import Any, Callable, Union
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from colossalai import META_COMPATIBILITY
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
|
|
||||||
if META_COMPATIBILITY:
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
|
||||||
compare_all(x.grad, meta_x.grad)
|
compare_all(x.grad, meta_x.grad)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||||
def test_meta_aten():
|
def test_meta_aten():
|
||||||
for (aten_op, requires_backward), v in registered_meta.items():
|
for (aten_op, requires_backward), v in registered_meta.items():
|
||||||
for f, x in v:
|
for f, x in v:
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import torchvision.models as tm
|
import pytest
|
||||||
import timm.models as tmm
|
import timm.models as tmm
|
||||||
import torch
|
import torch
|
||||||
from colossalai import META_COMPATIBILITY
|
import torchvision.models as tm
|
||||||
import pytest
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
|
|
||||||
if META_COMPATIBILITY:
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
tm_models = [
|
tm_models = [
|
||||||
|
@ -27,7 +27,7 @@ tmm_models = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||||
def test_torchvision_models():
|
def test_torchvision_models():
|
||||||
for m in tm_models:
|
for m in tm_models:
|
||||||
model = m()
|
model = m()
|
||||||
|
@ -35,7 +35,7 @@ def test_torchvision_models():
|
||||||
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||||
def test_timm_models():
|
def test_timm_models():
|
||||||
for m in tmm_models:
|
for m in tmm_models:
|
||||||
model = m()
|
model = m()
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import torchvision.models as tm
|
import pytest
|
||||||
import timm.models as tmm
|
import timm.models as tmm
|
||||||
import torch
|
import torch
|
||||||
from colossalai import META_COMPATIBILITY
|
import torchvision.models as tm
|
||||||
import pytest
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
|
|
||||||
if META_COMPATIBILITY:
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx import meta_trace
|
from colossalai.fx import meta_trace
|
||||||
|
|
||||||
tm_models = [
|
tm_models = [
|
||||||
|
@ -27,7 +27,7 @@ tmm_models = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||||
def test_torchvision_models_trace():
|
def test_torchvision_models_trace():
|
||||||
for m in tm_models:
|
for m in tm_models:
|
||||||
model = m()
|
model = m()
|
||||||
|
@ -35,7 +35,7 @@ def test_torchvision_models_trace():
|
||||||
graph = meta_trace(model, torch.device('cpu'), data)
|
graph = meta_trace(model, torch.device('cpu'), data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||||
def test_timm_models_trace():
|
def test_timm_models_trace():
|
||||||
for m in tmm_models:
|
for m in tmm_models:
|
||||||
model = m()
|
model = m()
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import symbolic_trace
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
from colossalai import META_COMPATIBILITY
|
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
|
||||||
|
if is_compatible_with_meta():
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
DIM_IN = 4
|
DIM_IN = 4
|
||||||
|
@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
||||||
def test_meta_info_prop():
|
def test_meta_info_prop():
|
||||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
||||||
if META_COMPATIBILITY:
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx.profiler import MetaTensor
|
|
||||||
input_sample = MetaTensor(input_sample, fake_device='cpu')
|
input_sample = MetaTensor(input_sample, fake_device='cpu')
|
||||||
orig_output = model(input_sample)
|
orig_output = model(input_sample)
|
||||||
gm = symbolic_trace(model)
|
gm = symbolic_trace(model)
|
||||||
|
|
|
@ -1,19 +1,17 @@
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.distributed.rpc as rpc
|
|
||||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
|
||||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colorama import Back, Style
|
import torch.distributed.rpc as rpc
|
||||||
|
import torch.multiprocessing as mp
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
|
||||||
from colossalai.logging import disable_existing_loggers
|
|
||||||
from colossalai import launch
|
from colossalai import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
from torch import nn
|
||||||
|
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||||
|
from torch.optim import SGD, Adam, Optimizer, RMSprop
|
||||||
|
|
||||||
rpc_is_initialized = _is_current_rpc_agent_set
|
rpc_is_initialized = _is_current_rpc_agent_set
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue