[autoparallel] fixed wrong generated strategy for dot op (#1746)

* [autoparallel] fixed wrong generated strategy for dot op

* polish code
pull/1747/head
Frank Lee 2 years ago committed by GitHub
parent 993b8875b6
commit 8b8937d901
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,8 @@
from .batch_norm_handler import BatchNormModuleHandler
from .bmm_handler import BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .dot_handler import LinearFunctionHandler, LinearModuleHandler
from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler
@ -11,7 +12,7 @@ from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler',
'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler',
'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry'
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'LayerNormModuleHandler',
'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler',
'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry'
]

@ -0,0 +1,33 @@
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
@operator_registry.register(torch.bmm)
@operator_registry.register(torch.Tensor.bmm)
class BMMFunctionHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = []
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
return generators

@ -1,19 +1,18 @@
from copy import deepcopy
from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
from colossalai.auto_parallel.tensor_shard.utils import tranpose_partition_dim, update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyGenerator, StrategyGenerator)
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
@ -31,7 +30,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
switch_partition_dim(sharding_spec, 0, -1)
tranpose_partition_dim(sharding_spec, 0, -1)
return strategy
@ -104,8 +103,6 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
dim_mapping={},
physical_shape=output_op_data.data.shape,
inplace=True)
print(input_op_data.data.shape)
print(output_op_data.data.shape)
sharding_strategies.append(strategy_copy)
return sharding_strategies
@ -144,7 +141,7 @@ class LinearModuleHandler(ModuleHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.named_parameters['bias'] is not None:
if 'bias' in self.named_parameters is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
@ -229,30 +226,3 @@ class LinearFunctionHandler(NodeHandler):
input_name=str(self.node.args[0]),
output_name=str(self.node))
return strategies
@operator_registry.register(torch.bmm)
@operator_registry.register(torch.Tensor.bmm)
class BMMFunctionHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = []
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
return generators

@ -2,9 +2,8 @@ import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -227,6 +226,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
@ -235,12 +235,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec}
communication_action_mapping['input'] = input_comm_spec
communication_action_mapping['other'] = other_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -268,6 +272,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action mapping
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
@ -276,12 +281,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec}
communication_action_mapping['input'] = input_comm_spec
communication_action_mapping['output'] = output_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -310,6 +319,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
@ -318,7 +328,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
communication_action_mapping["input"] = input_comm_spec
communication_action_mapping['output'] = output_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -342,11 +353,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'output': output_comm_spec}
communication_action_mapping['output'] = output_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -372,11 +385,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'input': input_comm_spec}
communication_action_mapping['input'] = input_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -398,19 +413,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['other'] = other_comm_spec
communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communcation_action_mapping)
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
@ -430,11 +448,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'output': output_comm_spec}
communication_action_mapping['output'] = output_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -460,11 +479,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'input': input_comm_spec}
communication_action_mapping['input'] = input_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -492,7 +512,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
"""
@ -642,7 +662,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
"bias": {},
"output": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)

@ -5,13 +5,13 @@ from .sharding import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
switch_partition_dim,
tranpose_partition_dim,
update_partition_dim,
)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]

@ -36,9 +36,10 @@ def ignore_sharding_exception(func):
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
This check includes 3 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
3. the sharding spec's entire shape must match the tensor shape
#
"""
# make sure all dims are covered in sharding spec
@ -65,3 +66,6 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape

@ -8,12 +8,12 @@ import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
"""
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
@ -22,19 +22,26 @@ def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> S
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
assert len(sharding_spec.entire_shape) == 2
assert len(sharding_spec.entire_shape) >= 2, \
'The entire_shape of the sharding spec must have at least 2 dimensions'
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
dim1_partition = dim_partition_dict.pop(dim1, None)
dim2_partition = dim_partition_dict.pop(dim2, None)
if dim1_partition:
dim_partition_dict[dim2] = dim1_partition
if dim2_partition:
dim_partition_dict[dim1] = dim2_partition
# get the transposed shape
new_shape = list(sharding_spec.entire_shape[:])
new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2]
new_shape = torch.Size(new_shape)
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict)
return sharding_spec

@ -2,12 +2,10 @@ import pytest
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \
BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class BMMTensorMethodModule(nn.Module):
@ -91,6 +89,16 @@ def test_2d_device_mesh(module):
assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
@ -145,6 +153,16 @@ def test_1d_device_mesh(module):
# one batch dim
assert 'Sb0 = Sb0 x Sb0' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if __name__ == '__main__':
test_1d_device_mesh(BMMTensorMethodModule)

@ -5,8 +5,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_module_handler():
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta'))
tracer = ColoTracer()
@ -108,6 +110,7 @@ class ConvModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_function_handler():
model = ConvModel()
tracer = ColoTracer()

@ -1,14 +1,13 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \
GetItemHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class GetItemModel(nn.Module):
@ -22,6 +21,7 @@ class GetItemModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_getitem_function_handler():
model = GetItemModel()
tracer = ColoTracer()

@ -1,7 +1,8 @@
import torch
import torch.nn as nn
from typing_extensions import Self
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@ -10,10 +11,12 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.utils import parameterize
def test_linear_module_handler():
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
@parameterize('bias', [True, False])
def test_linear_module_handler(bias):
model = nn.Sequential(nn.Linear(16, 32, bias=bias).to('meta'))
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
@ -50,11 +53,12 @@ def test_linear_module_handler():
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([32])
if bias:
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([32])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
@ -91,18 +95,23 @@ def test_linear_module_handler():
strategy: ShardingStrategy
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
if bias:
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if bias:
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
def test_linear_function_handler():
model = nn.Linear(16, 32).to('meta')
@parameterize('bias', [True, False])
def test_linear_function_handler(bias):
model = nn.Linear(16, 32, bias=bias).to('meta')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
@ -111,7 +120,11 @@ def test_linear_function_handler():
print(graph)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
linear_func_node = list(graph.nodes)[3]
if bias:
linear_func_node = list(graph.nodes)[3]
else:
linear_func_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_func_node)
# build handler
@ -120,8 +133,6 @@ def test_linear_function_handler():
# # check operation data mapping
mapping = handler.get_operation_data_mapping()
print(mapping['input'].logical_shape)
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
@ -134,11 +145,12 @@ def test_linear_function_handler():
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([16, 32])
if bias:
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta
@ -172,17 +184,20 @@ def test_linear_function_handler():
for strategy in strategies_vector:
strategy: ShardingStrategy
print(strategy)
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
if bias:
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if bias:
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if __name__ == '__main__':

@ -1,13 +1,12 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \
ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ReshapeModel(nn.Module):
@ -21,6 +20,7 @@ class ReshapeModel(nn.Module):
return reshape_node
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_reshape_handler():
model = ReshapeModel()
tracer = ColoTracer()

@ -1,13 +1,13 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ReLuModel(nn.Module):
@ -22,6 +22,7 @@ class ReLuModel(nn.Module):
return relu_node
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_elementwise_handler():
model = ReLuModel()
tracer = ColoTracer()

Loading…
Cancel
Save