mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polishpull/3445/head
parent
573af84184
commit
ffcdbf0f65
|
@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
|||
# Inputs contains the shapes of two matrices.
|
||||
input_shapes = [v.shape for v in inputs]
|
||||
assert len(input_shapes) == 2, input_shapes
|
||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||
|
||||
# There are three cases: 1) gemm, 2) gemv, 3) dot
|
||||
if all(len(shape) == 2 for shape in input_shapes):
|
||||
# gemm
|
||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||
elif all(len(shape) == 1 for shape in input_shapes):
|
||||
# dot
|
||||
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
|
||||
|
||||
# expand shape
|
||||
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
|
||||
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
|
||||
else:
|
||||
# gemv
|
||||
if len(input_shapes[0]) == 1:
|
||||
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
|
||||
input_shapes.reverse()
|
||||
else:
|
||||
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
|
||||
|
||||
# expand the shape of the vector to [batch size, 1]
|
||||
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
|
||||
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
||||
return flops
|
||||
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.fx.graph import CodeGen
|
||||
except:
|
||||
pass
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_format_target,
|
||||
|
@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
|
|||
"""
|
||||
Check if the node could end the ckpt region at `ckpt_level`
|
||||
"""
|
||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
||||
return node.meta['info'].to_recompute[ckpt_level] is not None
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
|
||||
return True
|
||||
|
||||
|
||||
|
@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
|||
current_region = None
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
|
@ -152,12 +156,12 @@ def emit_ckpt_func(body,
|
|||
|
||||
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
|
||||
# if there is more level to fetch
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
node_idx = 0
|
||||
|
|
|
@ -112,7 +112,7 @@ class MetaInfo:
|
|||
|
||||
# should keep the same whenever manipulated
|
||||
# ============================= Invariant ==================================
|
||||
to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
to_offload: Optional[bool] = False
|
||||
sharding_spec: str = 'RR'
|
||||
|
||||
|
|
|
@ -237,7 +237,14 @@ class ShapeProp(torch.fx.Interpreter):
|
|||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
||||
|
||||
# wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
||||
def wrap_fn(elem, device=device):
|
||||
if isinstance(elem, torch.Tensor):
|
||||
return MetaTensor(elem, device=device)
|
||||
else:
|
||||
return elem
|
||||
|
||||
with self._mode:
|
||||
return super().run(*tree_map(wrap_fn, args))
|
||||
|
||||
|
|
|
@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):
|
|||
|
||||
|
||||
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
|
||||
def conv1d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv1d(input, weight, **kwargs)
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
|
||||
def conv2d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv2d(input, weight, **kwargs)
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
|
||||
def conv3d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv3d(input, weight, **kwargs)
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
|
||||
def conv_transpose1d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose1d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_single(1),
|
||||
padding=_single(0),
|
||||
output_padding=_single(0),
|
||||
groups=1,
|
||||
dilation=_single(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose1d(input, weight, **kwargs)
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
|
||||
def conv_transpose2d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose2d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_pair(1),
|
||||
padding=_pair(0),
|
||||
output_padding=_pair(0),
|
||||
groups=1,
|
||||
dilation=_pair(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose2d(input, weight, **kwargs)
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
|
||||
def conv_transpose3d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose3d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_triple(1),
|
||||
padding=_triple(0),
|
||||
output_padding=_triple(0),
|
||||
groups=1,
|
||||
dilation=_triple(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose3d(input, weight, **kwargs)
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
|
||||
|
|
|
@ -155,7 +155,7 @@ class ColoTracer(Tracer):
|
|||
|
||||
def create_node(self, *args, **kwargs) -> Node:
|
||||
node = super().create_node(*args, **kwargs)
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
|
||||
return node
|
||||
|
||||
def trace(self,
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .meta_registry import *
|
||||
from .metainfo import *
|
||||
from .registry import meta_register
|
||||
from .shard_metainfo import *
|
||||
|
|
|
@ -2,9 +2,9 @@ from typing import Callable, List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import elementwise_flop_counter
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
|
|
@ -2,9 +2,9 @@ from typing import List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
|
||||
from ..registry import meta_register
|
||||
|
@ -17,7 +17,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
|||
"""Meta information generator for binary elementwise operations
|
||||
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
|
||||
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
|
||||
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
|
||||
they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify
|
||||
this behavior, it is critical for better memory estimation.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
|
@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
|
@ -110,18 +110,18 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
# calculate memory cost
|
||||
# TODO: use profiler to check conv temp memory
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias else activation_size([input_tensor, weight_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
|
|
|
@ -2,9 +2,9 @@ from typing import List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
@ -34,11 +34,11 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
|||
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
|
||||
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
|
||||
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=0,
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0)
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
|
||||
|
||||
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
|
@ -11,8 +13,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
|
@ -112,14 +112,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
|
@ -148,14 +148,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size(weight_tensor),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]),
|
||||
parameter=activation_size(weight_tensor),
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
|
@ -210,48 +210,48 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
# Check dimension
|
||||
if all(len(tensor.shape) == 1 for tensor in input_tensors):
|
||||
# Dot
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
|
||||
# gemv case 1: matrix-vector multiplication
|
||||
# &
|
||||
# batched gemv case 1: batched matrix-vector multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[1]],
|
||||
output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.mv.default](
|
||||
flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
|
||||
# gemv case 2: vector-matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemv case 2: vector-batched matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
||||
[output_tensors[0].reshape(-1)])
|
||||
|
||||
|
@ -260,15 +260,15 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
[output_tensors[0].reshape(-1), input_tensors[0]],
|
||||
output_tensors
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mv.default](
|
||||
flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
||||
|
@ -287,8 +287,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemm case 2: matrix-batched matrix multiplication
|
||||
|
@ -306,11 +306,12 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
|
||||
temp=activation_size(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
|
||||
compute_size_in_bytes(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
|
||||
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
|
||||
|
||||
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
||||
# Batched matrix-batched matrix multiplication
|
||||
|
@ -351,8 +352,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
|
||||
|
||||
else:
|
||||
# Case 2: batch dimensions are different
|
||||
|
@ -381,10 +382,10 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
|
||||
activation_size([extended_input_0, extended_input_1]),
|
||||
temp=activation_size([extended_input_0, extended_input_1]))
|
||||
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
|
||||
compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
|
||||
|
||||
# compute cost
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
|
|
@ -4,8 +4,6 @@ from typing import List, Tuple
|
|||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
|
@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
|
@ -77,17 +77,18 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
|||
# calculate memory cost
|
||||
# the fwd activation cost is output plus saved mean and saved inv std
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
|
||||
[input_tensor, output_tensor, mean_tensor, var_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=activation_size([mean_tensor, var_tensor]))
|
||||
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
|
||||
|
||||
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
|
||||
# and saved inv std during backward phase
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
temp=activation_size([mean_tensor, var_tensor]),
|
||||
buffer=activation_size([mean_tensor, var_tensor]))
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
|
||||
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
|
@ -131,15 +132,16 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
|||
|
||||
# memory cost
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
|
||||
[input_tensor, output_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=activation_size([running_mean, running_var]))
|
||||
buffer=compute_size_in_bytes([running_mean, running_var]))
|
||||
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
temp=activation_size([running_mean, running_var]),
|
||||
buffer=activation_size([running_mean, running_var]))
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=compute_size_in_bytes([running_mean, running_var]),
|
||||
buffer=compute_size_in_bytes([running_mean, running_var]))
|
||||
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
|
|
|
@ -2,9 +2,9 @@ from typing import List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
@ -52,8 +52,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
|||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor))
|
||||
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor))
|
||||
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(output_tensor))
|
||||
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(input_tensor))
|
||||
|
||||
# total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
|
||||
|
@ -114,11 +114,11 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
|||
# calculate memory cost
|
||||
# NOTE: the index matrix will be discarded in backward phase
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix]))
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
|
||||
|
||||
# temp memory for backward is the index matrix to be discarded
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix),
|
||||
temp=activation_size(index_matrix))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
|
||||
temp=compute_size_in_bytes(index_matrix))
|
||||
|
||||
# total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
|
||||
|
|
|
@ -2,9 +2,9 @@ from typing import Callable, List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
@ -35,11 +35,11 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
|
|||
|
||||
# memory costs
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0)
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
|
||||
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor,
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
|
||||
parameter=0,
|
||||
temp=activation_size(outputs) * bwd_mem_tmp_factor,
|
||||
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
|
||||
buffer=0)
|
||||
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
|
|
|
@ -2,9 +2,9 @@ from typing import List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
|
|
@ -15,11 +15,11 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|||
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['MetaInfo']
|
||||
__all__ = ['ShardMetaInfo']
|
||||
|
||||
|
||||
class MetaInfo:
|
||||
"""MetaInfo class
|
||||
class ShardMetaInfo:
|
||||
"""ShardMetaInfo class
|
||||
This class is used to store meta info based on sharding strategy and the given
|
||||
target function.
|
||||
"""
|
||||
|
@ -46,9 +46,9 @@ class MetaInfo:
|
|||
# target function
|
||||
self._target = target
|
||||
|
||||
# compute metainfo if possible
|
||||
# compute shard_metainfo if possible
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
@property
|
||||
def strategy(self) -> ShardingStrategy:
|
||||
|
@ -62,13 +62,13 @@ class MetaInfo:
|
|||
def strategy(self, strategy: ShardingStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
@target.setter
|
||||
def target(self, target: Callable) -> None:
|
||||
self._target = target
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
|
||||
"""
|
||||
|
@ -93,7 +93,7 @@ class MetaInfo:
|
|||
|
||||
return op_data
|
||||
|
||||
def compute_metainfo(self):
|
||||
def compute_shard_metainfo(self):
|
||||
"""
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
|
@ -14,15 +14,15 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
target_sharding_spec: ShardingSpec) -> MetaInfo:
|
||||
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
|
||||
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
meta_info = MetaInfo()
|
||||
meta_info = ShardMetaInfo()
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
# get mem cost for ShardMetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
# extract user that has _meta_data and extract element length
|
||||
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||
|
@ -36,12 +36,12 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
|||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
# get computation cost for ShardMetaInfo
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
# get tensor shape for ShardMetaInfo
|
||||
origin_sharding_spec: ShardingSpec
|
||||
target_sharding_spec: ShardingSpec
|
||||
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||
|
@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
|||
return meta_info
|
||||
|
||||
|
||||
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> ShardMetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
"""
|
||||
|
@ -65,17 +65,17 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
|
|||
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||
user_node_index]
|
||||
|
||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
|
||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> ShardMetaInfo:
|
||||
# extract node_index and op_data_name
|
||||
node_index, op_data_name = node.args[2], node.args[3]
|
||||
|
||||
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||
if isinstance(comm_action.comm_spec, CommSpec):
|
||||
# this case is for all_reduce, there will be no memory cost
|
||||
meta_info = MetaInfo()
|
||||
meta_info = ShardMetaInfo()
|
||||
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||
element_length = output_node._meta_data.element_size()
|
||||
|
@ -93,7 +93,7 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
|
|||
# this case will be handled by shape consistency manager
|
||||
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||
'tgt_spec']
|
||||
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
return meta_info
|
||||
|
||||
|
@ -105,9 +105,9 @@ def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_di
|
|||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == runtime_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
elif node.target == runtime_comm_spec_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
else:
|
||||
pass
|
||||
return gm
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.fx
|
|||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
|
@ -96,12 +96,12 @@ class MetaInfoProp:
|
|||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_metainfo
|
||||
meta_info: MetaInfo
|
||||
meta_info = node.best_strategy_info
|
||||
meta_info: ShardMetaInfo
|
||||
|
||||
# set data_ptr for input_tensor in MetaInfo class
|
||||
# set data_ptr for input_tensor in ShardMetaInfo class
|
||||
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
||||
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
||||
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
|
@ -128,9 +128,10 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
|
||||
|
||||
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(shape_consistency_node,
|
||||
mod_dir=user_node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
|
@ -210,9 +211,10 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
|||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(comm_spec_apply_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
|
||||
return gm
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
|||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
|
@ -74,9 +75,9 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
|||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||
str(node))
|
||||
|
||||
# attach the corresponding metainfo if node has the attribute `metainfo_vector`
|
||||
if hasattr(node, 'metainfo_vector'):
|
||||
setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index])
|
||||
# attach the corresponding metainfo if node has the attribute `strategies_info`
|
||||
if hasattr(node, 'strategies_info'):
|
||||
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
|
@ -172,8 +173,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
|||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(size_processing_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
|
|
|
@ -6,6 +6,10 @@ import torch.nn as nn
|
|||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
||||
|
@ -13,8 +17,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
|||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
|
@ -126,6 +128,7 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
|
|||
|
||||
|
||||
def transform_to_sharded_model(gm: ColoGraphModule,
|
||||
meta_args: Dict,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
|
@ -142,6 +145,7 @@ def transform_to_sharded_model(gm: ColoGraphModule,
|
|||
strategies_constructor,
|
||||
overlap=overlap)
|
||||
gm = runtime_apply_pass(gm)
|
||||
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
gm.recompile()
|
||||
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
||||
|
@ -243,10 +247,13 @@ def initialize_model(model: nn.Module,
|
|||
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
||||
return a series of integers, but return the best strategies.
|
||||
'''
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
gm.recompile()
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph,
|
||||
|
@ -261,7 +268,9 @@ def initialize_model(model: nn.Module,
|
|||
if save_solver_solution:
|
||||
torch.save(solution, solution_path)
|
||||
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
|
||||
overlap)
|
||||
|
||||
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
|
||||
if return_solution:
|
||||
|
|
|
@ -2,8 +2,6 @@ from typing import Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||
from .registry import operator_registry
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
|
@ -258,7 +258,7 @@ class MetaInfoNodeHandler(NodeHandler):
|
|||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
|
@ -266,15 +266,15 @@ class MetaInfoNodeHandler(NodeHandler):
|
|||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
strategies_info = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
metainfo = ShardMetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
strategies_info.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
setattr(self, "strategies_info", strategies_info)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
|
@ -313,7 +313,7 @@ class MetaInfoModuleHandler(ModuleHandler):
|
|||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
|
@ -321,15 +321,15 @@ class MetaInfoModuleHandler(ModuleHandler):
|
|||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
strategies_info = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
metainfo = ShardMetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
strategies_info.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
setattr(self, "strategies_info", strategies_info)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
|
|
|
@ -137,9 +137,9 @@ class StrategiesConstructor:
|
|||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# call_function node
|
||||
elif node.op == 'call_function':
|
||||
|
@ -150,9 +150,9 @@ class StrategiesConstructor:
|
|||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# call_method node
|
||||
elif node.op == 'call_method':
|
||||
|
@ -163,9 +163,9 @@ class StrategiesConstructor:
|
|||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# output node
|
||||
elif node.op == 'output':
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
|
@ -33,6 +35,7 @@ def recover_narrow(gm, narrow_node):
|
|||
return gm
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
def test_size_value_converting_pass():
|
||||
model = TestModule()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
@ -40,14 +43,14 @@ def test_size_value_converting_pass():
|
|||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
meta_args = {'x': torch.rand(4, 8).to('meta')}
|
||||
input = torch.rand(4, 8)
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
|
||||
x_node = list(graph.nodes)[0]
|
||||
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
|
||||
setattr(x_node, 'sharding_spec', x_sharding_spec)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm = insert_narrow(gm, x_node)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
gm.recompile()
|
||||
size = gm(input)
|
||||
assert size == torch.Size([2, 8])
|
||||
|
|
|
@ -4,7 +4,12 @@ import pytest
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
NO_CODEGEN = False
|
||||
except:
|
||||
NO_CODEGEN = True
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
@ -77,6 +82,7 @@ def check_conv_module(rank, world_size, port):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bias_addition_module():
|
||||
|
|
|
@ -8,13 +8,15 @@ import torch.nn as nn
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
NO_CODEGEN = False
|
||||
except:
|
||||
NO_CODEGEN = True
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
@ -43,6 +45,7 @@ def check_act_ckpt(rank, world_size, port):
|
|||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
|
||||
input = torch.rand(1, 64, HIDDEN_SIZE)
|
||||
input_sample = {
|
||||
'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'),
|
||||
}
|
||||
|
@ -54,10 +57,11 @@ def check_act_ckpt(rank, world_size, port):
|
|||
gm = initialize_model(model, input_sample, device_mesh)
|
||||
code = gm.module.graph.python_code('self').src
|
||||
assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code
|
||||
assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code
|
||||
assert "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" in code
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mlp_layer():
|
||||
|
|
|
@ -6,7 +6,12 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
NO_CODEGEN = False
|
||||
except:
|
||||
NO_CODEGEN = True
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
@ -93,6 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_compatibility_with_ddp():
|
||||
|
|
|
@ -6,7 +6,12 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
NO_CODEGEN = False
|
||||
except:
|
||||
NO_CODEGEN = True
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
@ -101,6 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_auto_parallel_with_gemini():
|
||||
|
|
|
@ -5,8 +5,11 @@ import torch.nn as nn
|
|||
from torch.fx import GraphModule
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
@ -83,11 +86,12 @@ def test_repeat_blocks(model_cls):
|
|||
|
||||
model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS)
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')}
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
shape_prop_pass(gm, *input_sample.values())
|
||||
gm.recompile()
|
||||
|
||||
node_list = list(graph.nodes)
|
||||
|
|
|
@ -10,15 +10,23 @@ import torch.multiprocessing as mp
|
|||
import transformers
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import (
|
||||
ModuleWrapper,
|
||||
build_strategy_constructor,
|
||||
solve_solution,
|
||||
transform_to_sharded_model,
|
||||
)
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import (
|
||||
ModuleWrapper,
|
||||
build_strategy_constructor,
|
||||
solve_solution,
|
||||
transform_to_sharded_model,
|
||||
)
|
||||
NO_CODEGEN = False
|
||||
except:
|
||||
NO_CODEGEN = True
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import to_global
|
||||
|
@ -52,9 +60,8 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor
|
|||
param_sharding_spec = best_sharding_spec_dict[new_name]
|
||||
grad_to_compare = copy.deepcopy(param_grad)
|
||||
param_grad_global = to_global(grad_to_compare, param_sharding_spec)
|
||||
|
||||
try:
|
||||
assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03)
|
||||
assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-05)
|
||||
except:
|
||||
difference = param_grad_global - origin_param_grad
|
||||
avg_diff = difference.abs().sum() / difference.numel()
|
||||
|
@ -66,7 +73,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
|||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM)
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
|
||||
|
||||
if model_cls == GPT2MLP:
|
||||
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
|
||||
|
@ -111,15 +118,17 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
|||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
shape_prop_pass(gm, *meta_input_sample.values())
|
||||
gm.recompile()
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
|
||||
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh,
|
||||
strategies_constructor)
|
||||
gm = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
|
@ -176,6 +185,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.skipif(NO_CODEGEN, reason="no codegen module")
|
||||
@pytest.mark.dist
|
||||
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -3,11 +3,12 @@ import torch.nn as nn
|
|||
import transformers
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
@ -21,7 +22,7 @@ HIDDEN_DIM = 384
|
|||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||
def test_self_attention_block(model_cls):
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
|
||||
if model_cls == GPT2MLP:
|
||||
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
|
||||
else:
|
||||
|
@ -33,7 +34,7 @@ def test_self_attention_block(model_cls):
|
|||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
if model_cls == GPT2MLP:
|
||||
input_sample = {
|
||||
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
||||
|
@ -52,6 +53,7 @@ def test_self_attention_block(model_cls):
|
|||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
shape_prop_pass(gm, *input_sample.values())
|
||||
print(gm.graph)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions()
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
@ -22,15 +25,14 @@ class LinearModel(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@pytest.mark.skip('meta tensor has some bugs in 1.11')
|
||||
def test_liveness_analysis():
|
||||
model = LinearModel()
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
'x1': torch.rand(4, 4, device='meta'),
|
||||
'x2': torch.rand(4, 4, device='meta')
|
||||
})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {'x1': torch.rand(4, 4, device='meta'), 'x2': torch.rand(4, 4, device='meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
|
|
|
@ -24,7 +24,7 @@ from colossalai.utils import free_port
|
|||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
|
|
|
@ -17,7 +17,7 @@ from colossalai.utils import free_port
|
|||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
|
|
|
@ -24,7 +24,7 @@ from colossalai.utils import free_port
|
|||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
|
|
|
@ -23,7 +23,7 @@ from colossalai.utils import free_port
|
|||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
|
||||
|
||||
def _batchnorm_module_mem_test(rank, world_size, port):
|
||||
|
|
|
@ -24,7 +24,7 @@ from colossalai.utils import free_port
|
|||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
|
||||
|
||||
class SplitModule(nn.Module):
|
||||
|
|
|
@ -22,7 +22,7 @@ from colossalai.utils import free_port
|
|||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
|
|
|
@ -5,16 +5,19 @@ from typing import Dict, List
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||
|
||||
|
||||
def mem_test_for_node_strategy(rank: int,
|
||||
|
@ -30,14 +33,16 @@ def mem_test_for_node_strategy(rank: int,
|
|||
model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy(
|
||||
input_kwargs)
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
input_sample = {}
|
||||
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
|
||||
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
|
||||
for meta_kwarg_name, input_kwarg in input_kwargs.items():
|
||||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||
gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||
shape_prop_pass(gm, *input_sample.values())
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
@ -108,10 +113,10 @@ def mem_test_for_node_strategy(rank: int,
|
|||
|
||||
# estimated memory
|
||||
if target_node.op == "call_module":
|
||||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
|
||||
target_node.graph.owning_module.get_submodule(target_node.target))
|
||||
metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index],
|
||||
target_node.graph.owning_module.get_submodule(target_node.target))
|
||||
else:
|
||||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target)
|
||||
metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], target_node.target)
|
||||
|
||||
print("estimated memory:")
|
||||
print(
|
||||
|
|
|
@ -1,126 +0,0 @@
|
|||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
def _param_resharding_cost_assertion(node):
|
||||
for strategy in node.strategies_vector:
|
||||
for prev_node, resharding_cost in strategy.resharding_costs.items():
|
||||
if strategy.get_op_data_by_name(str(prev_node)).type == OperationDataType.PARAM:
|
||||
for cost in resharding_cost:
|
||||
assert cost.fwd == 0
|
||||
assert cost.bwd == 0
|
||||
assert cost.total == 0
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = x * 2
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x * 2
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_linear_module():
|
||||
model = LinearModel(4, 8)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
|
||||
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
|
||||
# return mul
|
||||
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')})
|
||||
# def forward(self, x : torch.Tensor):
|
||||
# linear_weight = self.linear.weight
|
||||
# linear_bias = self.linear.bias
|
||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
||||
# add = linear + linear_bias; linear = linear_bias = None
|
||||
# mul = add * 2; add = None
|
||||
# return mul
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
linear_node = node_list[3]
|
||||
_param_resharding_cost_assertion(linear_node)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_module():
|
||||
model = ConvModel(3, 6, 2)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
|
||||
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
|
||||
# return mul
|
||||
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
|
||||
# def forward(self, x : torch.Tensor):
|
||||
# conv_weight = self.conv.weight
|
||||
# conv_bias = self.conv.bias
|
||||
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
|
||||
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
|
||||
# add = conv2d + view; conv2d = view = None
|
||||
# mul = add * 2; add = None
|
||||
# return mul
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
conv_node = node_list[3]
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
_param_resharding_cost_assertion(conv_node)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_module()
|
||||
test_conv_module()
|
|
@ -1,86 +0,0 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = torch.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
def check_apply(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
input = torch.rand(4, 4, 4, 4).cuda()
|
||||
test_input = copy.deepcopy(input)
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||
# return conv
|
||||
model = ConvModel(4, 4).cuda()
|
||||
test_model = copy.deepcopy(model)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
meta_args = {'x': torch.rand(4, 4, 4, 4).to('meta')}
|
||||
gm = initialize_model(model, meta_args, device_mesh)
|
||||
|
||||
output = gm(input)
|
||||
origin_output = test_model(test_input)
|
||||
assert output.equal(origin_output)
|
||||
origin_loss = origin_output.sum()
|
||||
loss = output.sum()
|
||||
|
||||
origin_loss.backward()
|
||||
loss.backward()
|
||||
|
||||
grad_0 = test_model.conv.weight.grad.narrow(0, 0, 1)
|
||||
grad_1 = test_model.conv.weight.grad.narrow(0, 1, 1)
|
||||
grad_2 = test_model.conv.weight.grad.narrow(0, 2, 1)
|
||||
grad_3 = test_model.conv.weight.grad.narrow(0, 3, 1)
|
||||
|
||||
if rank == 0:
|
||||
assert_close(gm.module.conv.weight.grad.data, grad_0.data)
|
||||
elif rank == 1:
|
||||
assert_close(gm.module.conv.weight.grad.data, grad_1.data)
|
||||
elif rank == 2:
|
||||
assert_close(gm.module.conv.weight.grad.data, grad_2.data)
|
||||
elif rank == 3:
|
||||
assert_close(gm.module.conv.weight.grad.data, grad_3.data)
|
||||
else:
|
||||
raise ValueError(f'rank {rank} does not exist.')
|
||||
|
||||
|
||||
# skip this test due to pulp not installed in CI environment
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_apply():
|
||||
world_size = 4
|
||||
run_func = partial(check_apply, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_apply()
|
|
@ -2,11 +2,13 @@ import torch
|
|||
from torch.fx import GraphModule
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
@ -20,7 +22,7 @@ def test_cost_graph():
|
|||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
model = resnet50(num_classes=100000)
|
||||
input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')}
|
||||
|
||||
|
@ -50,6 +52,7 @@ def test_cost_graph():
|
|||
# %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
|
||||
# return fc
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
shape_prop_pass(gm, *input_sample.values())
|
||||
gm.recompile()
|
||||
|
||||
solver_options = SolverOptions()
|
||||
|
|
Loading…
Reference in New Issue