[autoparallel] fixed broken node handler tests (#1708)

pull/1711/head
Frank Lee 2022-10-14 18:25:59 +08:00 committed by GitHub
parent 1468e4bcfc
commit 22a115406b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 53 additions and 49 deletions

View File

@ -22,10 +22,6 @@ class BatchNormStrategyGenerator(StrategyGenerator):
In this generator, both methods will be considered.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.

View File

@ -17,10 +17,6 @@ class ConvStrategyGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.

View File

@ -19,10 +19,6 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def validate(self) -> bool:
return super().validate()

View File

@ -18,10 +18,6 @@ class LayerNormGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def validate(self) -> bool:
return super().validate()

View File

@ -14,10 +14,6 @@ class MatMulStrategyGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
@ -512,11 +508,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {"bias": bias_comm_spec}
communication_action_mapping = {}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -538,11 +536,14 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {"bias": bias_comm_spec}
communication_action_mapping = {}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mappingp['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -566,15 +567,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
communication_action_mapping['other'] = other_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -600,15 +606,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec}
communication_action_mapping['input'] = input_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -633,15 +644,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec}
communication_action_mapping['output'] = output_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)

View File

@ -24,6 +24,13 @@ class StrategyGenerator(ABC):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
@property
def has_bias(self):
"""
A utility method to check for the existence of bias operand for convenience.
"""
return 'bias' in self.op_data
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.PARAM

View File

@ -22,7 +22,6 @@ class BMMTorchFunctionModule(nn.Module):
return torch.bmm(x1, x2)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module):
@ -93,7 +92,6 @@ def test_2d_device_mesh(module):
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
model = module()

View File

@ -11,7 +11,6 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()
@ -50,7 +49,7 @@ def test_norm_pool_handler():
assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
assert len(strategy_name_list) == 9