[autoparallel] adapt autoparallel with new analyzer (#3261)

* [autoparallel] adapt autoparallel with new analyzer

* fix all node handler tests

* polish

* polish
pull/3367/head
YuliangLiu0306 2023-03-30 17:47:24 +08:00 committed by GitHub
parent e78a1e949a
commit fee2af8610
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 481 additions and 386 deletions

View File

@ -446,10 +446,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp

View File

@ -51,7 +51,10 @@ def _normalize_tuple(x):
def _current_device(module):
return next(module.parameters()).device
try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')
@compatibility(is_backward_compatible=False)
@ -120,15 +123,18 @@ class ShapeProp(torch.fx.Interpreter):
return t.to('meta')
if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)
elif isinstance(elem, torch.Tensor):
if isinstance(elem, torch.nn.Parameter):
return torch.nn.Parameter(_convert_meta(elem))
return _convert_meta(elem)
else:
return elem
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
@ -149,7 +155,11 @@ class ShapeProp(torch.fx.Interpreter):
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v))
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
# align with SPMD
if isinstance(r, (tuple, list)):
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
else:
n._meta_data = unwrap_fn(r)
n_info.global_ctx = self.global_hook.ctx
n_info.curr_ctx = self.global_hook.ctx.copy()
@ -175,10 +185,48 @@ class ShapeProp(torch.fx.Interpreter):
Return
Any: The value returned by the function invocation
"""
convert_to_param = False
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
convert_to_param = True
if target in self._custom_dispatch_func:
return self._custom_dispatch_func[target](*args, **kwargs)
res = self._custom_dispatch_func[target](*args, **kwargs)
else:
return super().call_function(target, args, kwargs)
res = super().call_function(target, args, kwargs)
if convert_to_param:
return torch.nn.Parameter(res)
else:
return res
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the method invocation
"""
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
target_method = getattr(self_obj.__class__, target)
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
res = getattr(self_obj, target)(*args_tail, **kwargs)
if convert_to_parameter:
return torch.nn.Parameter(res)
else:
return res
def propagate(self, *args, device=None):
"""

View File

@ -21,111 +21,69 @@ def linear_impl(input, weight, bias=None):
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
def conv1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv1d(input, weight, **kwargs)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
def conv2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv2d(input, weight, **kwargs)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
def conv3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
return F.conv3d(input, weight, **kwargs)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
def conv_transpose1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose1d(input, weight, **kwargs)
else:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
def conv_transpose2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose2d(input, weight, **kwargs)
else:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
def conv_transpose3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose3d(input, weight, **kwargs)
else:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')

View File

@ -70,14 +70,28 @@ class MetaInfo:
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
"""
Compute sharded opdata based on the given data and sharding spec.
"""
return OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
if isinstance(sharding_spec, ShardingSpec):
op_data = OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same."
sharded_data = []
for d, s in zip(data, sharding_spec):
sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta"))
op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type)
else:
raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.")
return op_data
def compute_metainfo(self):
"""

View File

@ -387,12 +387,13 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param):
def _add_hook_for_grad_communication(node, param, name=None):
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
def _filter_param_to_hook(node, op_data, comm_action, name):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
return True
if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
@ -402,7 +403,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if _filter_param_to_hook(node, operation_data, comm_action):
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap):
@ -442,7 +443,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
param = _shard_param(param, target_sharding_spec)
setattr(target_module, name, param)
_add_hook_for_grad_communication(node, param)
_add_hook_for_grad_communication(node, param, name)
sharded_buffer_dict = {}
# apply the sharding spec of buffers

View File

@ -81,7 +81,10 @@ class AddBMMFunctionHandler(NodeHandler):
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
generator = BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)
# addbmm will shrink the first batch dim
generator.squeeze_batch_dim = True
generators.append(generator)
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:

View File

@ -776,10 +776,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
bias_op_data = self.op_data['bias']
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
if self.op_data['output'].data.dim() == 2:
# addbmm will shrink the first batch dim
self.squeeze_batch_dim = True
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape)

View File

@ -386,7 +386,7 @@ def meta_local_scalar_dense(self: torch.Tensor):
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)
return torch.empty_like(condition + self + other, dtype=result_type)
@register_meta(aten.index.Tensor)

View File

@ -1,22 +1,20 @@
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
@ -96,7 +94,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
meta_arg_names=meta_arg_names,
node_type='bias_module')
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %m1 : torch.Tensor [#users=1] = placeholder[target=m1]
@ -109,6 +107,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
# return add
graph = tracer.trace(model, meta_args=meta_args_for_tracer)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args_for_tracer.values())
# [input_1, m1, m2, addmm, output]
node_list = list(graph.nodes)
linear_node = node_list[4]

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -38,13 +40,15 @@ def check_bn_module_handler(rank, world_size, port):
strategy_number=strategy_number,
input_args=[input],
meta_arg_names=['input'])
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')})
meta_args = {"input": torch.rand(4, 16, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
bn_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(bn_mod_node)

View File

@ -1,14 +1,14 @@
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
@ -17,12 +17,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -66,7 +64,7 @@ def check_linear_module_handler(rank, world_size, port):
meta_arg_names=meta_arg_names,
node_type='bias_module')
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %weight : [#users=1] = get_attr[target=weight]
@ -74,8 +72,10 @@ def check_linear_module_handler(rank, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {})
# return add
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')})
meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)

View File

@ -1,13 +1,13 @@
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
@ -16,12 +16,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -62,9 +60,11 @@ def check_linear_module_handler(rank, bias, world_size, port):
meta_arg_names=meta_arg_names,
node_type='bias_module')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -52,10 +54,11 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
op_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(op_node)
@ -172,12 +175,11 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
strategy_number=strategy_number,
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
meta_args = {'x1': torch.rand(4, 4).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
print(graph)
# assert False
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if model_cls == BEOpModelWithNodeConst:
op_node = list(graph.nodes)[2]

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -52,13 +54,11 @@ def check_2d_device_mesh(rank, module, world_size, port):
strategy_number=strategy_number,
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_mod_node)
@ -147,13 +147,11 @@ def check_1d_device_mesh(rank, module, world_size, port):
strategy_number=strategy_number,
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_mod_node)
@ -205,6 +203,7 @@ def check_1d_device_mesh(rank, module, world_size, port):
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bmm_handler(module):

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -41,9 +43,11 @@ def check_conv_module_handler(rank, bias, world_size, port):
strategy_number=strategy_number,
input_args=[input],
meta_arg_names=['input'])
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
conv_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(conv_mod_node)
@ -178,7 +182,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
meta_arg_names=meta_arg_names,
input_kwargs=input_kwargs)
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others]
@ -189,6 +193,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
meta_args['bias'] = torch.rand(16).to('meta')
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if bias:
conv_mod_node = list(graph.nodes)[3]

View File

@ -1,11 +1,13 @@
import torch
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@ -23,19 +25,20 @@ class ReshapeModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_reshape_handler():
model = ReshapeModel()
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'),
})
meta_args = {
"input": torch.rand(4, 4, 64, 64).to('meta'),
"other": torch.rand(16, 4, 3, 3).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@ -67,13 +70,13 @@ def test_reshape_handler():
assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62])
assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62])
assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].name == "view"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([2, 30752])
assert mapping['output'].data.shape == torch.Size([2, 123008])
assert mapping['output'].type == OperationDataType.OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.

View File

@ -5,13 +5,15 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import (
EmbeddingFunctionHandler,
EmbeddingModuleHandler,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -60,9 +62,11 @@ def check_embedding_module_handler(rank, world_size, port):
input_args=[input],
meta_arg_names=['input'])
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 16).to('meta')})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
embedding_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(embedding_node)
@ -171,18 +175,19 @@ def check_embedding_function_handler(rank, world_size, port):
input_args=input_args,
meta_arg_names=meta_arg_names,
input_kwargs=input_kwargs)
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others]
# %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False})
# return embedding
meta_args = {
"input": torch.rand(4, 16, 16).to('meta'),
"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta'),
"others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta')
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
embedding_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(embedding_node)

View File

@ -1,10 +1,13 @@
import pytest
import torch
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class GetattrModel(nn.Module):
@ -18,15 +21,18 @@ class GetattrModel(nn.Module):
return weight
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_getattr_handler():
model = GetattrModel()
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=0] = placeholder[target=input]
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# return conv_weight
graph = tracer.trace(model, meta_args={'input': torch.rand(4, 4, 64, 64).to('meta')})
meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)

View File

@ -5,13 +5,15 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
@ -58,15 +60,15 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *list(meta_args.values()))
linear_mod_node = list(graph.nodes)[2]
getitem_mod_node = list(graph.nodes)[3]
getitem_strategies_vector = StrategiesVector(getitem_mod_node)
@ -129,10 +131,12 @@ def test_getitem_from_tuple_handler():
# %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
# return getitem
graph = tracer.trace(model, meta_args={
meta_args = {
"input": torch.rand(4, 4, 64, 64).to('meta'),
})
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
@ -40,13 +42,15 @@ def check_ln_module_handler(rank, world_size, port):
strategy_number=strategy_number,
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
meta_args = {"input": torch.rand(4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
ln_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(ln_mod_node)

View File

@ -5,6 +5,9 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
@ -13,7 +16,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -49,9 +51,11 @@ def check_linear_module_handler(rank, bias, input_shape, world_size, port):
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {"input": torch.rand(input_shape).cuda()}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(linear_mod_node)
@ -196,13 +200,12 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"input": torch.rand(input_shape).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {'input': torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if bias:
linear_func_node = list(graph.nodes)[3]
else:

View File

@ -2,6 +2,9 @@ import pytest
import torch
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
MatMulHandler,
MatMulType,
@ -15,7 +18,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.utils import parameterize
@ -57,9 +59,11 @@ def test_matmul_node_handler(tensor_shapes):
model = MatMulModule()
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {"x1": x1.to('meta'), 'x2': x2.to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
print(graph)
@ -124,7 +128,6 @@ def test_matmul_node_handler(tensor_shapes):
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
if matmul_type == MatMulType.DOT:
# dot product will produce a scaler
# results should fulfill:
@ -159,7 +162,10 @@ def test_matmul_node_handler(tensor_shapes):
if len(other_shape) > 1:
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if len(input_shape) > 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
if len(other_shape) == 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-1]
else:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
if len(other_shape) > 2:
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]

View File

@ -2,10 +2,12 @@ import pytest
import torch
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@ -13,14 +15,16 @@ from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

View File

@ -1,10 +1,13 @@
import pytest
import torch
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -18,19 +21,20 @@ class OutputModel(nn.Module):
return x, y
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
def test_output_handler(output_option):
model = OutputModel()
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %x : torch.Tensor [#users=2] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# return (x, mul)
graph = tracer.trace(model, meta_args={
"x": torch.rand(4, 4, 64, 64).to('meta'),
})
meta_args = {'x': torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

View File

@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -88,7 +90,7 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
input_args=[input, other],
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvReshapeModel':
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -96,11 +98,11 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {})
# return permute
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'),
})
meta_args = {
'input': torch.rand(8, 8, 66, 66).to('meta'),
'other': torch.rand(16, 8, 3, 3).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
if model_cls.__name__ == 'LinearReshapeModel':
# graph():
@ -109,13 +111,14 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return permute
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
meta_args = {
'input': torch.rand(8, 16, 64, 32).to('meta'),
'other': torch.rand(64, 32).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2]
reshape_node = list(graph.nodes)[3]

View File

@ -1,10 +1,13 @@
import pytest
import torch
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -17,18 +20,21 @@ class PlaceholderModel(nn.Module):
return input
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
def test_placeholder_handler(placeholder_option):
model = PlaceholderModel()
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# return input_1
graph = tracer.trace(model, meta_args={
meta_args = {
"input": torch.rand(4, 4, 64, 64).to('meta'),
})
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

View File

@ -1,17 +1,15 @@
from functools import partial
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
class LinearModel(nn.Module):
@ -30,13 +28,11 @@ def check_shard_option(shard_option):
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 4, 16).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
tracer = ColoTracer(bias_addition_split=True)
meta_args = {'input': torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_func_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_func_node)

View File

@ -6,11 +6,13 @@ import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -54,7 +56,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
input_args=[input, other],
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -62,13 +64,14 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
# return split
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
meta_args = {
'input': torch.rand(8, 16, 64, 32).to('meta'),
'other': torch.rand(64, 32).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2]
split_node = list(graph.nodes)[3]

View File

@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -76,7 +78,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
input_args=[input, other],
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvSplitModel':
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -84,11 +86,11 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {})
# return split
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'),
})
meta_args = {
'input': torch.rand(8, 8, 66, 66).to('meta'),
'other': torch.rand(16, 8, 3, 3).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
if model_cls.__name__ == 'LinearSplitModel':
# graph():
@ -97,13 +99,14 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
# return split
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
meta_args = {
'input': torch.rand(8, 16, 64, 32).to('meta'),
'other': torch.rand(64, 32).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2]
split_node = list(graph.nodes)[3]

View File

@ -5,12 +5,13 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -58,7 +59,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -66,12 +67,13 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {})
# return sum_1
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
meta_args = {
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2]
sum_node = list(graph.nodes)[3]
@ -116,107 +118,107 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
# check strategy name
if sum_dims == (0, 2) and keepdim == False:
assert '[R, R, R, S1] -> [R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [S0, S1]_1' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_2' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [S1, S0]_4' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [S0, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [S01, R]_1' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01]_4' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_5' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [S1, R]_10' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_9' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_10' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [S0, S1]_12' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_13' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [S1, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [S01, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [S0, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01]_22' in strategy_name_list
assert '[R, S1, R, R] -> [S1, R]_21' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list
if sum_dims == (0, 2) and keepdim == True:
assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_2' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_13' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list
if sum_dims == 1 and keepdim == False:
assert '[S0, R, R, S1] -> [S0, R, S1]_0' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, S0]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R]_6' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R]_0' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, S01]_4' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_5' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R]_11' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_8' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_9' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_10' in strategy_name_list
assert '[S0, R, R, S1] -> [S0, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R]_20' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, S1]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, S0]_14' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, S0]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, S01]_22' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list
if sum_dims == 1 and keepdim == True:
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list

View File

@ -1,10 +1,12 @@
import torch
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@ -22,7 +24,7 @@ class TensorConstructorModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_where_handler():
model = TensorConstructorModel()
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %x : torch.Tensor [#users=2] = placeholder[target=x]
# %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {})
@ -30,10 +32,10 @@ def test_where_handler():
# %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {})
# return add
graph = tracer.trace(model, meta_args={
"x": torch.rand(10).to('meta'),
})
meta_args = {'x': torch.rand(10).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

View File

@ -1,12 +1,13 @@
import torch
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@ -25,19 +26,20 @@ class ReLuModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_elementwise_handler():
model = ReLuModel()
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {})
# return act
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'),
})
meta_args = {
'input': torch.rand(4, 4, 64, 64).to('meta'),
'other': torch.rand(16, 4, 3, 3).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@ -69,13 +71,13 @@ def test_elementwise_handler():
assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62])
assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62])
assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].name == "act"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 4, 62, 62])
assert mapping['output'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].type == OperationDataType.OUTPUT
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.

View File

@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -74,7 +76,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
input_args=[input, other],
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvViewModel':
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -82,11 +84,8 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'),
})
meta_args = {'input': torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
if model_cls.__name__ == 'LinearViewModel':
# graph():
@ -95,13 +94,14 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return view
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
meta_args = {
'input': torch.rand(8, 16, 64, 32).to('meta'),
'other': torch.rand(64, 32).to('meta'),
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2]
view_node = list(graph.nodes)[3]

View File

@ -1,12 +1,13 @@
import pytest
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \
WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ConvModel(nn.Module):
@ -19,22 +20,24 @@ class ConvModel(nn.Module):
return output
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_where_handler():
model = ConvModel()
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
# graph():
# %condition : torch.Tensor [#users=1] = placeholder[target=condition]
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %y : torch.Tensor [#users=1] = placeholder[target=y]
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
# return where
graph = tracer.trace(model,
meta_args={
"condition": torch.rand(4, 4, 64, 64).to('meta'),
"x": torch.rand(4, 1, 64, 64).to('meta'),
"y": torch.rand(1, 4, 64, 64).to('meta')
})
meta_args = {
'condition': torch.rand(4, 4, 64, 64).to('meta'),
'x': torch.rand(4, 1, 64, 64).to('meta'),
'y': torch.rand(1, 4, 64, 64).to('meta')
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

View File

@ -4,6 +4,9 @@ from typing import Dict, List
import torch
from torch.fx import GraphModule
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
@ -11,7 +14,6 @@ from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import to_global
from colossalai.testing.comparison import assert_close
@ -79,14 +81,16 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs,
grad_to_shard_dict)
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
input_sample = {}
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to('meta')
for meta_kwarg_name, input_kwarg in input_kwargs.items():
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to('meta')
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
shape_prop_pass(gm, *input_sample.values())
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

View File

@ -1,11 +1,13 @@
import pytest
import torch
import transformers
from topo_utils import split_model_and_get_DAG, check_topo, MLP
from topo_utils import MLP, check_topo, split_model_and_get_DAG
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_opt():
MODEL_LIST = [
MLP,
@ -13,7 +15,10 @@ def test_opt():
]
CONFIGS = [
{'dim': 10, 'layers': 12},
{
'dim': 10,
'layers': 12
},
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
]
@ -21,15 +26,15 @@ def test_opt():
x = torch.zeros((16, 10))
kwargs = dict(x=x)
return kwargs
def data_gen_OPT():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
DATAGEN = [
data_gen_MLP,
data_gen_MLP,
data_gen_OPT,
]
@ -39,5 +44,6 @@ def test_opt():
# print(f'{top_mod=}\n----\n{topo=}')
check_topo(top_mod, topo)
if __name__ == '__main__':
test_opt()
test_opt()