mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fixed broken node handler tests (#1708)
parent
1468e4bcfc
commit
22a115406b
|
@ -22,10 +22,6 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
In this generator, both methods will be considered.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_bias(self):
|
||||
return 'bias' in self.op_data
|
||||
|
||||
def validate(self) -> bool:
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
|
|
|
@ -17,10 +17,6 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
The operation data is defined as `output = input x other + bias`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_bias(self):
|
||||
return 'bias' in self.op_data
|
||||
|
||||
def validate(self) -> bool:
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
|
|
|
@ -19,10 +19,6 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
|
|||
3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_bias(self):
|
||||
return 'bias' in self.op_data
|
||||
|
||||
def validate(self) -> bool:
|
||||
return super().validate()
|
||||
|
||||
|
|
|
@ -18,10 +18,6 @@ class LayerNormGenerator(StrategyGenerator):
|
|||
The operation data is defined as `output = input x other + bias`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_bias(self):
|
||||
return 'bias' in self.op_data
|
||||
|
||||
def validate(self) -> bool:
|
||||
return super().validate()
|
||||
|
||||
|
|
|
@ -14,10 +14,6 @@ class MatMulStrategyGenerator(StrategyGenerator):
|
|||
The operation data is defined as `output = input x other + bias`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_bias(self):
|
||||
return 'bias' in self.op_data
|
||||
|
||||
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
size_mapping = {
|
||||
'input': self._compute_size_in_bytes(strategy, "input"),
|
||||
|
@ -512,11 +508,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication actions
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping = {"bias": bias_comm_spec}
|
||||
communication_action_mapping = {}
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -538,11 +536,14 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication actions
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping = {"bias": bias_comm_spec}
|
||||
communication_action_mapping = {}
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mappingp['bias'] = bias_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -566,15 +567,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
other_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
|
||||
communication_action_mapping['other'] = other_comm_spec
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -600,15 +606,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec}
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -633,15 +644,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec}
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
|
|
@ -24,6 +24,13 @@ class StrategyGenerator(ABC):
|
|||
self.op_data = operation_data_mapping
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
@property
|
||||
def has_bias(self):
|
||||
"""
|
||||
A utility method to check for the existence of bias operand for convenience.
|
||||
"""
|
||||
return 'bias' in self.op_data
|
||||
|
||||
def is_param(self, op_data_name):
|
||||
other_data = self.op_data[op_data_name]
|
||||
return other_data.type == OperationDataType.PARAM
|
||||
|
|
|
@ -22,7 +22,6 @@ class BMMTorchFunctionModule(nn.Module):
|
|||
return torch.bmm(x1, x2)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
def test_2d_device_mesh(module):
|
||||
|
||||
|
@ -93,7 +92,6 @@ def test_2d_device_mesh(module):
|
|||
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
def test_1d_device_mesh(module):
|
||||
model = module()
|
||||
|
|
|
@ -11,7 +11,6 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
|||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_norm_pool_handler():
|
||||
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
|
@ -50,7 +49,7 @@ def test_norm_pool_handler():
|
|||
assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy()
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
assert len(strategy_name_list) == 9
|
||||
|
||||
|
|
Loading…
Reference in New Issue