mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fixed wrong generated strategy for dot op (#1746)
* [autoparallel] fixed wrong generated strategy for dot op * polish codepull/1747/head
parent
993b8875b6
commit
8b8937d901
|
@ -1,7 +1,8 @@
|
|||
from .batch_norm_handler import BatchNormModuleHandler
|
||||
from .bmm_handler import BMMFunctionHandler
|
||||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||
from .dot_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .layer_norm_handler import LayerNormModuleHandler
|
||||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .output_handler import OuputHandler
|
||||
from .placeholder_handler import PlacehodlerHandler
|
||||
|
@ -11,7 +12,7 @@ from .unary_elementwise_handler import UnaryElementwiseHandler
|
|||
from .where_handler import WhereHandler
|
||||
|
||||
__all__ = [
|
||||
'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler',
|
||||
'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler',
|
||||
'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry'
|
||||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'LayerNormModuleHandler',
|
||||
'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler',
|
||||
'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
||||
@operator_registry.register(torch.bmm)
|
||||
@operator_registry.register(torch.Tensor.bmm)
|
||||
class BMMFunctionHandler(NodeHandler):
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[0]._meta_data)
|
||||
|
||||
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[1]._meta_data)
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
return mapping
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
generators = []
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
|
@ -1,19 +1,18 @@
|
|||
from copy import deepcopy
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import tranpose_partition_dim, update_partition_dim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
||||
|
||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyGenerator, StrategyGenerator)
|
||||
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
|
||||
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
|
||||
|
||||
|
||||
def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
|
||||
|
@ -31,7 +30,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
|
|||
op_data = strategy.get_op_data_by_name(weight_name)
|
||||
assert op_data.logical_shape != op_data.data.shape, \
|
||||
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
|
||||
switch_partition_dim(sharding_spec, 0, -1)
|
||||
tranpose_partition_dim(sharding_spec, 0, -1)
|
||||
return strategy
|
||||
|
||||
|
||||
|
@ -104,8 +103,6 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
dim_mapping={},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
print(input_op_data.data.shape)
|
||||
print(output_op_data.data.shape)
|
||||
sharding_strategies.append(strategy_copy)
|
||||
return sharding_strategies
|
||||
|
||||
|
@ -144,7 +141,7 @@ class LinearModuleHandler(ModuleHandler):
|
|||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
if self.named_parameters['bias'] is not None:
|
||||
if 'bias' in self.named_parameters is not None:
|
||||
physical_bias_operand = OperationData(name="bias",
|
||||
type=OperationDataType.PARAM,
|
||||
data=self.named_parameters['bias'])
|
||||
|
@ -229,30 +226,3 @@ class LinearFunctionHandler(NodeHandler):
|
|||
input_name=str(self.node.args[0]),
|
||||
output_name=str(self.node))
|
||||
return strategies
|
||||
|
||||
|
||||
@operator_registry.register(torch.bmm)
|
||||
@operator_registry.register(torch.Tensor.bmm)
|
||||
class BMMFunctionHandler(NodeHandler):
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[0]._meta_data)
|
||||
|
||||
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[1]._meta_data)
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
return mapping
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
generators = []
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
|
@ -2,9 +2,8 @@ import operator
|
|||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
@ -227,6 +226,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# set communication action
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
|
@ -235,12 +235,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
|
||||
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec}
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
communication_action_mapping['other'] = other_comm_spec
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -268,6 +272,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# get communication action mapping
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
|
@ -276,12 +281,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
|
||||
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec}
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -310,6 +319,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
|
@ -318,7 +328,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
|
||||
communication_action_mapping["input"] = input_comm_spec
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -342,11 +353,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping = {'output': output_comm_spec}
|
||||
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -372,11 +385,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping = {'input': input_comm_spec}
|
||||
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -398,19 +413,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
other_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['other'] = other_comm_spec
|
||||
|
||||
communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec}
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communcation_action_mapping)
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
|
@ -430,11 +448,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping = {'output': output_comm_spec}
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -460,11 +479,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping = {'input': input_comm_spec}
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -492,7 +512,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
"""
|
||||
Generate sharding strategies for the batched matrix multiplication.
|
||||
|
||||
A batched matrix multiplication can be viewed as
|
||||
A batched matrix multiplication can be viewed as
|
||||
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||
"""
|
||||
|
||||
|
@ -642,7 +662,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
|
|
@ -5,13 +5,13 @@ from .sharding import (
|
|||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size,
|
||||
switch_partition_dim,
|
||||
tranpose_partition_dim,
|
||||
update_partition_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
|
|
@ -36,9 +36,10 @@ def ignore_sharding_exception(func):
|
|||
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
||||
"""
|
||||
This function checks whether the ShardingSpec is valid for the physical tensor.
|
||||
This check includes 2 items:
|
||||
This check includes 3 items:
|
||||
1. the sharding spec covers all dimensions of the physical tensor
|
||||
2. the sharding spec for each dimension is divisible by the number of devices.
|
||||
3. the sharding spec's entire shape must match the tensor shape
|
||||
#
|
||||
"""
|
||||
# make sure all dims are covered in sharding spec
|
||||
|
@ -65,3 +66,6 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
|||
|
||||
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
||||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
||||
|
||||
# make sure the entire shape matches the physical tensor shape
|
||||
assert sharding_spec.entire_shape == tensor.shape
|
||||
|
|
|
@ -8,12 +8,12 @@ import torch
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
||||
|
||||
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
"""
|
||||
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
||||
|
||||
|
@ -22,19 +22,26 @@ def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> S
|
|||
dim1 (int): the tensor dimension to switch
|
||||
dim2 (int): the tensor dimension to switch
|
||||
"""
|
||||
assert len(sharding_spec.entire_shape) == 2
|
||||
assert len(sharding_spec.entire_shape) >= 2, \
|
||||
'The entire_shape of the sharding spec must have at least 2 dimensions'
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
|
||||
# transpose the dim partition
|
||||
dim1_partition = dim_partition_dict.pop(dim1, None)
|
||||
dim2_partition = dim_partition_dict.pop(dim2, None)
|
||||
|
||||
if dim1_partition:
|
||||
dim_partition_dict[dim2] = dim1_partition
|
||||
|
||||
if dim2_partition:
|
||||
dim_partition_dict[dim1] = dim2_partition
|
||||
|
||||
# get the transposed shape
|
||||
new_shape = list(sharding_spec.entire_shape[:])
|
||||
new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2]
|
||||
new_shape = torch.Size(new_shape)
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
|
|
|
@ -2,12 +2,10 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \
|
||||
BMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class BMMTensorMethodModule(nn.Module):
|
||||
|
@ -91,6 +89,16 @@ def test_2d_device_mesh(module):
|
|||
assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list
|
||||
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
def test_1d_device_mesh(module):
|
||||
|
@ -145,6 +153,16 @@ def test_1d_device_mesh(module):
|
|||
# one batch dim
|
||||
assert 'Sb0 = Sb0 x Sb0' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_1d_device_mesh(BMMTensorMethodModule)
|
||||
|
|
|
@ -5,8 +5,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
|
|||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_module_handler():
|
||||
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
|
@ -108,6 +110,7 @@ class ConvModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_function_handler():
|
||||
model = ConvModel()
|
||||
tracer = ColoTracer()
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \
|
||||
GetItemHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class GetItemModel(nn.Module):
|
||||
|
@ -22,6 +21,7 @@ class GetItemModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_getitem_function_handler():
|
||||
model = GetItemModel()
|
||||
tracer = ColoTracer()
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -10,10 +11,12 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
|
||||
def test_linear_module_handler():
|
||||
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
|
||||
@parameterize('bias', [True, False])
|
||||
def test_linear_module_handler(bias):
|
||||
model = nn.Sequential(nn.Linear(16, 32, bias=bias).to('meta'))
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
|
@ -50,11 +53,12 @@ def test_linear_module_handler():
|
|||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([32])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([32])
|
||||
if bias:
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([32])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([32])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
|
@ -91,18 +95,23 @@ def test_linear_module_handler():
|
|||
strategy: ShardingStrategy
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
||||
|
||||
if bias:
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
if bias:
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
def test_linear_function_handler():
|
||||
model = nn.Linear(16, 32).to('meta')
|
||||
@parameterize('bias', [True, False])
|
||||
def test_linear_function_handler(bias):
|
||||
model = nn.Linear(16, 32, bias=bias).to('meta')
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
@ -111,7 +120,11 @@ def test_linear_function_handler():
|
|||
print(graph)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
linear_func_node = list(graph.nodes)[3]
|
||||
|
||||
if bias:
|
||||
linear_func_node = list(graph.nodes)[3]
|
||||
else:
|
||||
linear_func_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(linear_func_node)
|
||||
|
||||
# build handler
|
||||
|
@ -120,8 +133,6 @@ def test_linear_function_handler():
|
|||
# # check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
print(mapping['input'].logical_shape)
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
||||
|
@ -134,11 +145,12 @@ def test_linear_function_handler():
|
|||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([32])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
if bias:
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([32])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.is_meta
|
||||
|
@ -172,17 +184,20 @@ def test_linear_function_handler():
|
|||
|
||||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy
|
||||
print(strategy)
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
||||
|
||||
if bias:
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
if bias:
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \
|
||||
ReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ReshapeModel(nn.Module):
|
||||
|
@ -21,6 +20,7 @@ class ReshapeModel(nn.Module):
|
|||
return reshape_node
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_reshape_handler():
|
||||
model = ReshapeModel()
|
||||
tracer = ColoTracer()
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
|
||||
UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ReLuModel(nn.Module):
|
||||
|
@ -22,6 +22,7 @@ class ReLuModel(nn.Module):
|
|||
return relu_node
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_elementwise_handler():
|
||||
model = ReLuModel()
|
||||
tracer = ColoTracer()
|
||||
|
|
Loading…
Reference in New Issue