[autoparallel] adapt autoparallel with new analyzer (#3261)

* [autoparallel] adapt autoparallel with new analyzer

* fix all node handler tests

* polish

* polish
pull/3367/head
YuliangLiu0306 2023-03-30 17:47:24 +08:00 committed by GitHub
parent e78a1e949a
commit fee2af8610
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 481 additions and 386 deletions

View File

@ -446,10 +446,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
@register_meta(aten.embedding_dense_backward.default) @register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq): scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)), return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
# ============================== Dropout =========================================== # ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp

View File

@ -51,7 +51,10 @@ def _normalize_tuple(x):
def _current_device(module): def _current_device(module):
return next(module.parameters()).device try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
@ -120,15 +123,18 @@ class ShapeProp(torch.fx.Interpreter):
return t.to('meta') return t.to('meta')
if isinstance(elem, MetaTensor): if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor) return _convert_meta(elem._tensor)
elif isinstance(elem, torch.Tensor): elif isinstance(elem, torch.Tensor):
if isinstance(elem, torch.nn.Parameter):
return torch.nn.Parameter(_convert_meta(elem))
return _convert_meta(elem) return _convert_meta(elem)
else: else:
return elem return elem
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter) is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n) n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r) n_info.outputs = _normalize_tuple(r)
@ -149,7 +155,11 @@ class ShapeProp(torch.fx.Interpreter):
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \ n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v)) tuple(v for v in kwargs.values() if is_pure_tensor(v))
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD # align with SPMD
if isinstance(r, (tuple, list)):
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
else:
n._meta_data = unwrap_fn(r)
n_info.global_ctx = self.global_hook.ctx n_info.global_ctx = self.global_hook.ctx
n_info.curr_ctx = self.global_hook.ctx.copy() n_info.curr_ctx = self.global_hook.ctx.copy()
@ -175,10 +185,48 @@ class ShapeProp(torch.fx.Interpreter):
Return Return
Any: The value returned by the function invocation Any: The value returned by the function invocation
""" """
convert_to_param = False
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
convert_to_param = True
if target in self._custom_dispatch_func: if target in self._custom_dispatch_func:
return self._custom_dispatch_func[target](*args, **kwargs) res = self._custom_dispatch_func[target](*args, **kwargs)
else: else:
return super().call_function(target, args, kwargs) res = super().call_function(target, args, kwargs)
if convert_to_param:
return torch.nn.Parameter(res)
else:
return res
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the method invocation
"""
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
target_method = getattr(self_obj.__class__, target)
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
res = getattr(self_obj, target)(*args_tail, **kwargs)
if convert_to_parameter:
return torch.nn.Parameter(res)
else:
return res
def propagate(self, *args, device=None): def propagate(self, *args, device=None):
""" """

View File

@ -21,111 +21,69 @@ def linear_impl(input, weight, bias=None):
@register_tracer_impl(F.conv1d, name='_bias_addition_impl') @register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): def conv1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None: if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv1d(input, weight, **kwargs)
else: else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( new_kwargs = kwargs
(-1, 1)) new_kwargs['bias'] = None
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv2d, name='_bias_addition_impl') @register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): def conv2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None: if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv2d(input, weight, **kwargs)
else: else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( new_kwargs = kwargs
(-1, 1, 1)) new_kwargs['bias'] = None
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv3d, name='_bias_addition_impl') @register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): def conv3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None: if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv3d(input, weight, **kwargs)
else: else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( new_kwargs = kwargs
(-1, 1, 1, 1)) new_kwargs['bias'] = None
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl') @register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input, def conv_transpose1d_impl(input, weight, **kwargs):
weight, bias = getattr(kwargs, 'bias', None)
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
if bias is None: if bias is None:
return F.conv_transpose1d(input, return F.conv_transpose1d(input, weight, **kwargs)
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else: else:
return F.conv_transpose1d(input, new_kwargs = kwargs
weight, new_kwargs['bias'] = None
stride=stride, return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl') @register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input, def conv_transpose2d_impl(input, weight, **kwargs):
weight, bias = getattr(kwargs, 'bias', None)
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
if bias is None: if bias is None:
return F.conv_transpose2d(input, return F.conv_transpose2d(input, weight, **kwargs)
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else: else:
return F.conv_transpose2d(input, new_kwargs = kwargs
weight, new_kwargs['bias'] = None
stride=stride, return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl') @register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input, def conv_transpose3d_impl(input, weight, **kwargs):
weight, bias = getattr(kwargs, 'bias', None)
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
if bias is None: if bias is None:
return F.conv_transpose3d(input, return F.conv_transpose3d(input, weight, **kwargs)
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else: else:
return F.conv_transpose3d(input, new_kwargs = kwargs
weight, new_kwargs['bias'] = None
stride=stride, return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl') @register_tracer_impl(torch.addmm, name='_bias_addition_impl')

View File

@ -70,14 +70,28 @@ class MetaInfo:
if self._strategy is not None and self._target is not None: if self._strategy is not None and self._target is not None:
self.compute_metainfo() self.compute_metainfo()
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
""" """
Compute sharded opdata based on the given data and sharding spec. Compute sharded opdata based on the given data and sharding spec.
""" """
return OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), if isinstance(sharding_spec, ShardingSpec):
type=operation_data.type, op_data = OperationData(name=operation_data.name,
logical_shape=operation_data.logical_shape) data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same."
sharded_data = []
for d, s in zip(data, sharding_spec):
sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta"))
op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type)
else:
raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.")
return op_data
def compute_metainfo(self): def compute_metainfo(self):
""" """

View File

@ -387,12 +387,13 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# This stream is created for overlaping the communication and computation. # This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream() reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param): def _add_hook_for_grad_communication(node, param, name=None):
comm_actions = node.best_strategy.communication_actions comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action): def _filter_param_to_hook(node, op_data, comm_action, name):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
return True return True
if node.op == 'get_attr' and isinstance( if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
@ -402,7 +403,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for operation_data, comm_action in comm_actions.items(): for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters # register hook to the parameters
if _filter_param_to_hook(node, operation_data, comm_action): if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap): def wrapper(param, comm_spec, stream, overlap):
@ -442,7 +443,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
param = _shard_param(param, target_sharding_spec) param = _shard_param(param, target_sharding_spec)
setattr(target_module, name, param) setattr(target_module, name, param)
_add_hook_for_grad_communication(node, param) _add_hook_for_grad_communication(node, param, name)
sharded_buffer_dict = {} sharded_buffer_dict = {}
# apply the sharding spec of buffers # apply the sharding spec of buffers

View File

@ -81,7 +81,10 @@ class AddBMMFunctionHandler(NodeHandler):
def get_strategy_generator(self) -> List[StrategyGenerator]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) generator = BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)
# addbmm will shrink the first batch dim
generator.squeeze_batch_dim = True
generators.append(generator)
return generators return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:

View File

@ -776,10 +776,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
bias_op_data = self.op_data['bias'] bias_op_data = self.op_data['bias']
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2 assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
if self.op_data['output'].data.dim() == 2:
# addbmm will shrink the first batch dim
self.squeeze_batch_dim = True
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul, fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape) self.op_data['output'].data.shape)

View File

@ -386,7 +386,7 @@ def meta_local_scalar_dense(self: torch.Tensor):
@register_meta(aten.where.self) @register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other) result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type) return torch.empty_like(condition + self + other, dtype=result_type)
@register_meta(aten.index.Tensor) @register_meta(aten.index.Tensor)

View File

@ -1,22 +1,20 @@
from faulthandler import disable
from functools import partial from functools import partial
from xml.dom import WrongDocumentErr
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType, OperationDataType,
ShardingStrategy, ShardingStrategy,
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
@ -96,7 +94,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
node_type='bias_module') node_type='bias_module')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %m1 : torch.Tensor [#users=1] = placeholder[target=m1] # %m1 : torch.Tensor [#users=1] = placeholder[target=m1]
@ -109,6 +107,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
# return add # return add
graph = tracer.trace(model, meta_args=meta_args_for_tracer) graph = tracer.trace(model, meta_args=meta_args_for_tracer)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args_for_tracer.values())
# [input_1, m1, m2, addmm, output] # [input_1, m1, m2, addmm, output]
node_list = list(graph.nodes) node_list = list(graph.nodes)
linear_node = node_list[4] linear_node = node_list[4]

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -38,13 +40,15 @@ def check_bn_module_handler(rank, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=[input], input_args=[input],
meta_arg_names=['input']) meta_arg_names=['input'])
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0 # return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')}) meta_args = {"input": torch.rand(4, 16, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
bn_mod_node = list(graph.nodes)[1] bn_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(bn_mod_node) strategies_vector = StrategiesVector(bn_mod_node)

View File

@ -1,14 +1,14 @@
from faulthandler import disable
from functools import partial from functools import partial
from xml.dom import WrongDocumentErr
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
@ -17,12 +17,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -66,7 +64,7 @@ def check_linear_module_handler(rank, world_size, port):
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
node_type='bias_module') node_type='bias_module')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x] # %x : torch.Tensor [#users=1] = placeholder[target=x]
# %weight : [#users=1] = get_attr[target=weight] # %weight : [#users=1] = get_attr[target=weight]
@ -74,8 +72,10 @@ def check_linear_module_handler(rank, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {})
# return add # return add
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[3] linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)

View File

@ -1,13 +1,13 @@
from faulthandler import disable
from functools import partial from functools import partial
from xml.dom import WrongDocumentErr
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
@ -16,12 +16,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -62,9 +60,11 @@ def check_linear_module_handler(rank, bias, world_size, port):
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
node_type='bias_module') node_type='bias_module')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[3] linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -52,10 +54,11 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')}
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
op_node = list(graph.nodes)[2] op_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(op_node) strategies_vector = StrategiesVector(op_node)
@ -172,12 +175,11 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
meta_args = {'x1': torch.rand(4, 4).to('meta')} meta_args = {'x1': torch.rand(4, 4).to('meta')}
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
print(graph)
# assert False
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if model_cls == BEOpModelWithNodeConst: if model_cls == BEOpModelWithNodeConst:
op_node = list(graph.nodes)[2] op_node = list(graph.nodes)[2]

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -52,13 +54,11 @@ def check_2d_device_mesh(rank, module, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[2] linear_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)
@ -147,13 +147,11 @@ def check_1d_device_mesh(rank, module, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[2] linear_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)
@ -205,6 +203,7 @@ def check_1d_device_mesh(rank, module, world_size, port):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) @parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bmm_handler(module): def test_bmm_handler(module):

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -41,9 +43,11 @@ def check_conv_module_handler(rank, bias, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=[input], input_args=[input],
meta_arg_names=['input']) meta_arg_names=['input'])
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
conv_mod_node = list(graph.nodes)[1] conv_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(conv_mod_node) strategies_vector = StrategiesVector(conv_mod_node)
@ -178,7 +182,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
input_kwargs=input_kwargs) input_kwargs=input_kwargs)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others] # %others : torch.Tensor [#users=1] = placeholder[target=others]
@ -189,6 +193,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
meta_args['bias'] = torch.rand(16).to('meta') meta_args['bias'] = torch.rand(16).to('meta')
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if bias: if bias:
conv_mod_node = list(graph.nodes)[3] conv_mod_node = list(graph.nodes)[3]

View File

@ -1,11 +1,13 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
@ -23,19 +25,20 @@ class ReshapeModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
def test_reshape_handler(): def test_reshape_handler():
model = ReshapeModel() model = ReshapeModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other] # %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view # return view
graph = tracer.trace(model, meta_args = {
meta_args={ "input": torch.rand(4, 4, 64, 64).to('meta'),
"input": torch.rand(4, 4, 64, 64).to('meta'), "other": torch.rand(16, 4, 3, 3).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
@ -67,13 +70,13 @@ def test_reshape_handler():
assert mapping['input'].name == "conv2d" assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].name == "view" assert mapping['output'].name == "view"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([2, 30752]) assert mapping['output'].data.shape == torch.Size([2, 123008])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.

View File

@ -5,13 +5,15 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import ( from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import (
EmbeddingFunctionHandler, EmbeddingFunctionHandler,
EmbeddingModuleHandler, EmbeddingModuleHandler,
) )
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -60,9 +62,11 @@ def check_embedding_module_handler(rank, world_size, port):
input_args=[input], input_args=[input],
meta_arg_names=['input']) meta_arg_names=['input'])
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 16).to('meta')}) meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
embedding_node = list(graph.nodes)[1] embedding_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(embedding_node) strategies_vector = StrategiesVector(embedding_node)
@ -171,18 +175,19 @@ def check_embedding_function_handler(rank, world_size, port):
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
input_kwargs=input_kwargs) input_kwargs=input_kwargs)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others] # %others : torch.Tensor [#users=1] = placeholder[target=others]
# %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) # %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False})
# return embedding # return embedding
meta_args = { meta_args = {
"input": torch.rand(4, 16, 16).to('meta'), "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta'),
"others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta') "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta')
} }
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
embedding_node = list(graph.nodes)[2] embedding_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(embedding_node) strategies_vector = StrategiesVector(embedding_node)

View File

@ -1,10 +1,13 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class GetattrModel(nn.Module): class GetattrModel(nn.Module):
@ -18,15 +21,18 @@ class GetattrModel(nn.Module):
return weight return weight
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_getattr_handler(): def test_getattr_handler():
model = GetattrModel() model = GetattrModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=0] = placeholder[target=input] # %input_1 : torch.Tensor [#users=0] = placeholder[target=input]
# %conv_weight : [#users=1] = get_attr[target=conv.weight] # %conv_weight : [#users=1] = get_attr[target=conv.weight]
# return conv_weight # return conv_weight
graph = tracer.trace(model, meta_args={'input': torch.rand(4, 4, 64, 64).to('meta')}) meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)

View File

@ -5,13 +5,15 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
@ -58,15 +60,15 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
meta_args = {
graph = tracer.trace(model, "input": torch.rand(8, 16, 64, 32).to('meta'),
meta_args={ "other": torch.rand(64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), }
"other": torch.rand(64, 32).to('meta'), graph = tracer.trace(model, meta_args=meta_args)
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *list(meta_args.values()))
linear_mod_node = list(graph.nodes)[2] linear_mod_node = list(graph.nodes)[2]
getitem_mod_node = list(graph.nodes)[3] getitem_mod_node = list(graph.nodes)[3]
getitem_strategies_vector = StrategiesVector(getitem_mod_node) getitem_strategies_vector = StrategiesVector(getitem_mod_node)
@ -129,10 +131,12 @@ def test_getitem_from_tuple_handler():
# %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0}) # %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
# return getitem # return getitem
graph = tracer.trace(model, meta_args={ meta_args = {
"input": torch.rand(4, 4, 64, 64).to('meta'), "input": torch.rand(4, 4, 64, 64).to('meta'),
}) }
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)

View File

@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
@ -40,13 +42,15 @@ def check_ln_module_handler(rank, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0 # return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) meta_args = {"input": torch.rand(4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
ln_mod_node = list(graph.nodes)[1] ln_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(ln_mod_node) strategies_vector = StrategiesVector(ln_mod_node)

View File

@ -5,6 +5,9 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
@ -13,7 +16,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -49,9 +51,11 @@ def check_linear_module_handler(rank, bias, input_shape, world_size, port):
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')}) meta_args = {"input": torch.rand(input_shape).cuda()}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[1] linear_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)
@ -196,13 +200,12 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args = {'input': torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"input": torch.rand(input_shape).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if bias: if bias:
linear_func_node = list(graph.nodes)[3] linear_func_node = list(graph.nodes)[3]
else: else:

View File

@ -2,6 +2,9 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import ( from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
MatMulHandler, MatMulHandler,
MatMulType, MatMulType,
@ -15,7 +18,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.utils import parameterize from colossalai.testing.utils import parameterize
@ -57,9 +59,11 @@ def test_matmul_node_handler(tensor_shapes):
model = MatMulModule() model = MatMulModule()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')}) meta_args = {"x1": x1.to('meta'), 'x2': x2.to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
print(graph) print(graph)
@ -124,7 +128,6 @@ def test_matmul_node_handler(tensor_shapes):
input_sharding_spec = strategy.get_sharding_spec_by_name('x1') input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2') other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
if matmul_type == MatMulType.DOT: if matmul_type == MatMulType.DOT:
# dot product will produce a scaler # dot product will produce a scaler
# results should fulfill: # results should fulfill:
@ -159,7 +162,10 @@ def test_matmul_node_handler(tensor_shapes):
if len(other_shape) > 1: if len(other_shape) > 1:
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if len(input_shape) > 1: if len(input_shape) > 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2] if len(other_shape) == 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-1]
else:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
if len(other_shape) > 2: if len(other_shape) > 2:
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]

View File

@ -2,10 +2,12 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
@ -13,14 +15,16 @@ from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler(): def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0 # return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)

View File

@ -1,10 +1,13 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -18,19 +21,20 @@ class OutputModel(nn.Module):
return x, y return x, y
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('output_option', ['distributed', 'replicated']) @parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_output_handler(output_option): def test_output_handler(output_option):
model = OutputModel() model = OutputModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %x : torch.Tensor [#users=2] = placeholder[target=x] # %x : torch.Tensor [#users=2] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# return (x, mul) # return (x, mul)
graph = tracer.trace(model, meta_args={ meta_args = {'x': torch.rand(4, 4, 64, 64).to('meta')}
"x": torch.rand(4, 4, 64, 64).to('meta'), graph = tracer.trace(model, meta_args=meta_args)
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)

View File

@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -88,7 +90,7 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
input_args=[input, other], input_args=[input, other],
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvReshapeModel': if model_cls.__name__ == 'ConvReshapeModel':
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -96,11 +98,11 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {})
# return permute # return permute
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 8, 66, 66).to('meta'),
"input": torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
if model_cls.__name__ == 'LinearReshapeModel': if model_cls.__name__ == 'LinearReshapeModel':
# graph(): # graph():
@ -109,13 +111,14 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return permute # return permute
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), 'other': torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
reshape_node = list(graph.nodes)[3] reshape_node = list(graph.nodes)[3]

View File

@ -1,10 +1,13 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -17,18 +20,21 @@ class PlaceholderModel(nn.Module):
return input return input
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('placeholder_option', ['distributed', 'replicated']) @parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_placeholder_handler(placeholder_option): def test_placeholder_handler(placeholder_option):
model = PlaceholderModel() model = PlaceholderModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# return input_1 # return input_1
graph = tracer.trace(model, meta_args={ meta_args = {
"input": torch.rand(4, 4, 64, 64).to('meta'), "input": torch.rand(4, 4, 64, 64).to('meta'),
}) }
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)

View File

@ -1,17 +1,15 @@
from functools import partial
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.options import ShardOption from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
class LinearModel(nn.Module): class LinearModel(nn.Module):
@ -30,13 +28,11 @@ def check_shard_option(shard_option):
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args = {'input': torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"input": torch.rand(4, 4, 4, 16).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_func_node = list(graph.nodes)[2] linear_func_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_func_node) strategies_vector = StrategiesVector(linear_func_node)

View File

@ -6,11 +6,13 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -54,7 +56,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
input_args=[input, other], input_args=[input, other],
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -62,13 +64,14 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
# return split # return split
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), 'other': torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
split_node = list(graph.nodes)[3] split_node = list(graph.nodes)[3]

View File

@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -76,7 +78,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
input_args=[input, other], input_args=[input, other],
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvSplitModel': if model_cls.__name__ == 'ConvSplitModel':
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -84,11 +86,11 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {}) # %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {})
# return split # return split
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 8, 66, 66).to('meta'),
"input": torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
if model_cls.__name__ == 'LinearSplitModel': if model_cls.__name__ == 'LinearSplitModel':
# graph(): # graph():
@ -97,13 +99,14 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
# return split # return split
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), 'other': torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
split_node = list(graph.nodes)[3] split_node = list(graph.nodes)[3]

View File

@ -5,12 +5,13 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -58,7 +59,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -66,12 +67,13 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {})
# return sum_1 # return sum_1
graph = tracer.trace(model, meta_args = {
meta_args={ "input": torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), "other": torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
sum_node = list(graph.nodes)[3] sum_node = list(graph.nodes)[3]
@ -116,107 +118,107 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
# check strategy name # check strategy name
if sum_dims == (0, 2) and keepdim == False: if sum_dims == (0, 2) and keepdim == False:
assert '[R, R, R, S1] -> [R, S1]_0' in strategy_name_list assert '[R, R, R, R] -> [R, R]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [S0, S1]_1' in strategy_name_list assert '[R, S01, R, R] -> [S01, R]_1' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_2' in strategy_name_list assert '[R, R, R, R] -> [R, R]_2' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_3' in strategy_name_list assert '[R, R, R, R] -> [R, R]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [S1, S0]_4' in strategy_name_list assert '[R, R, R, S01] -> [R, S01]_4' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_5' in strategy_name_list assert '[R, R, R, S1] -> [R, S1]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_6' in strategy_name_list assert '[R, R, R, S0] -> [R, S0]_6' in strategy_name_list
assert '[R, S0, R, R] -> [S0, R]_7' in strategy_name_list assert '[R, R, R, R] -> [R, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_9' in strategy_name_list assert '[R, R, R, S0] -> [R, S0]_9' in strategy_name_list
assert '[R, S1, R, R] -> [S1, R]_10' in strategy_name_list assert '[R, R, R, S1] -> [R, S1]_10' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_11' in strategy_name_list assert '[R, R, R, S1] -> [R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_12' in strategy_name_list assert '[R, S0, R, S1] -> [S0, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_13' in strategy_name_list assert '[R, R, R, S1] -> [R, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_14' in strategy_name_list assert '[R, R, R, S0] -> [R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_15' in strategy_name_list assert '[R, S1, R, S0] -> [S1, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_17' in strategy_name_list assert '[R, R, R, R] -> [R, R]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_18' in strategy_name_list assert '[R, S0, R, R] -> [S0, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [S01, R]_19' in strategy_name_list assert '[R, R, R, R] -> [R, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_21' in strategy_name_list assert '[R, S1, R, R] -> [S1, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01]_22' in strategy_name_list assert '[R, R, R, R] -> [R, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list
if sum_dims == (0, 2) and keepdim == True: if sum_dims == (0, 2) and keepdim == True:
assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_2' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_2' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_5' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_11' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list
if sum_dims == 1 and keepdim == False: if sum_dims == 1 and keepdim == False:
assert '[S0, R, R, S1] -> [S0, R, S1]_0' in strategy_name_list assert '[S01, R, R, R] -> [S01, R, R]_0' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_1' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, S1]_2' in strategy_name_list assert '[R, R, S01, R] -> [R, S01, R]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, S0]_3' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_4' in strategy_name_list assert '[R, R, R, S01] -> [R, R, S01]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, S0]_5' in strategy_name_list assert '[R, R, R, S1] -> [R, R, S1]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R]_6' in strategy_name_list assert '[R, R, R, S0] -> [R, R, S0]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R]_8' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R]_9' in strategy_name_list assert '[R, R, R, S0] -> [R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_10' in strategy_name_list assert '[R, R, R, S1] -> [R, R, S1]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R]_11' in strategy_name_list assert '[S0, R, R, S1] -> [S0, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_13' in strategy_name_list assert '[R, R, S0, S1] -> [R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_14' in strategy_name_list assert '[S1, R, R, S0] -> [S1, R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_15' in strategy_name_list assert '[R, R, R, S0] -> [R, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_16' in strategy_name_list assert '[R, R, S1, S0] -> [R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_17' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R]_18' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_19' in strategy_name_list assert '[R, R, S0, R] -> [R, S0, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R]_20' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, S01]_22' in strategy_name_list assert '[R, R, S1, R] -> [R, S1, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list
if sum_dims == 1 and keepdim == True: if sum_dims == 1 and keepdim == True:
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list

View File

@ -1,10 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
@ -22,7 +24,7 @@ class TensorConstructorModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
def test_where_handler(): def test_where_handler():
model = TensorConstructorModel() model = TensorConstructorModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %x : torch.Tensor [#users=2] = placeholder[target=x] # %x : torch.Tensor [#users=2] = placeholder[target=x]
# %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {}) # %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {})
@ -30,10 +32,10 @@ def test_where_handler():
# %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {}) # %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {})
# return add # return add
graph = tracer.trace(model, meta_args={ meta_args = {'x': torch.rand(10).to('meta')}
"x": torch.rand(10).to('meta'), graph = tracer.trace(model, meta_args=meta_args)
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)

View File

@ -1,12 +1,13 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
@ -25,19 +26,20 @@ class ReLuModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
def test_elementwise_handler(): def test_elementwise_handler():
model = ReLuModel() model = ReLuModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other] # %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {}) # %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {})
# return act # return act
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(4, 4, 64, 64).to('meta'),
"input": torch.rand(4, 4, 64, 64).to('meta'), 'other': torch.rand(16, 4, 3, 3).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
@ -69,13 +71,13 @@ def test_elementwise_handler():
assert mapping['input'].name == "conv2d" assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].name == "act" assert mapping['output'].name == "act"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 4, 62, 62]) assert mapping['output'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.

View File

@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
@ -74,7 +76,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
input_args=[input, other], input_args=[input, other],
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvViewModel': if model_cls.__name__ == 'ConvViewModel':
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -82,11 +84,8 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view # return view
graph = tracer.trace(model, meta_args = {'input': torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'),
})
if model_cls.__name__ == 'LinearViewModel': if model_cls.__name__ == 'LinearViewModel':
# graph(): # graph():
@ -95,13 +94,14 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return view # return view
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), 'other': torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
view_node = list(graph.nodes)[3] view_node = list(graph.nodes)[3]

View File

@ -1,12 +1,13 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \ from colossalai._analyzer.fx.graph_module import ColoGraphModule
WhereHandler from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ConvModel(nn.Module): class ConvModel(nn.Module):
@ -19,22 +20,24 @@ class ConvModel(nn.Module):
return output return output
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_where_handler(): def test_where_handler():
model = ConvModel() model = ConvModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %condition : torch.Tensor [#users=1] = placeholder[target=condition] # %condition : torch.Tensor [#users=1] = placeholder[target=condition]
# %x : torch.Tensor [#users=1] = placeholder[target=x] # %x : torch.Tensor [#users=1] = placeholder[target=x]
# %y : torch.Tensor [#users=1] = placeholder[target=y] # %y : torch.Tensor [#users=1] = placeholder[target=y]
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
# return where # return where
graph = tracer.trace(model, meta_args = {
meta_args={ 'condition': torch.rand(4, 4, 64, 64).to('meta'),
"condition": torch.rand(4, 4, 64, 64).to('meta'), 'x': torch.rand(4, 1, 64, 64).to('meta'),
"x": torch.rand(4, 1, 64, 64).to('meta'), 'y': torch.rand(1, 4, 64, 64).to('meta')
"y": torch.rand(1, 4, 64, 64).to('meta') }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)

View File

@ -4,6 +4,9 @@ from typing import Dict, List
import torch import torch
from torch.fx import GraphModule from torch.fx import GraphModule
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.options import SolverOptions
@ -11,7 +14,6 @@ from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import to_global from colossalai.tensor.shape_consistency import to_global
from colossalai.testing.comparison import assert_close from colossalai.testing.comparison import assert_close
@ -79,14 +81,16 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs,
grad_to_shard_dict) grad_to_shard_dict)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
input_sample = {} input_sample = {}
for input_arg, meta_arg_name in zip(input_args, meta_arg_names): for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to('meta')
for meta_kwarg_name, input_kwarg in input_kwargs.items(): for meta_kwarg_name, input_kwarg in input_kwargs.items():
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to('meta')
graph = tracer.trace(root=model_to_shard, meta_args=input_sample) graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
shape_prop_pass(gm, *input_sample.values())
solver_options = SolverOptions() solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()

View File

@ -1,11 +1,13 @@
import pytest import pytest
import torch import torch
import transformers import transformers
from topo_utils import split_model_and_get_DAG, check_topo, MLP from topo_utils import MLP, check_topo, split_model_and_get_DAG
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_opt(): def test_opt():
MODEL_LIST = [ MODEL_LIST = [
MLP, MLP,
@ -13,7 +15,10 @@ def test_opt():
] ]
CONFIGS = [ CONFIGS = [
{'dim': 10, 'layers': 12}, {
'dim': 10,
'layers': 12
},
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
] ]
@ -21,15 +26,15 @@ def test_opt():
x = torch.zeros((16, 10)) x = torch.zeros((16, 10))
kwargs = dict(x=x) kwargs = dict(x=x)
return kwargs return kwargs
def data_gen_OPT(): def data_gen_OPT():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs return kwargs
DATAGEN = [ DATAGEN = [
data_gen_MLP, data_gen_MLP,
data_gen_OPT, data_gen_OPT,
] ]
@ -39,5 +44,6 @@ def test_opt():
# print(f'{top_mod=}\n----\n{topo=}') # print(f'{top_mod=}\n----\n{topo=}')
check_topo(top_mod, topo) check_topo(top_mod, topo)
if __name__ == '__main__': if __name__ == '__main__':
test_opt() test_opt()