mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added bias comm spec to matmul strategy (#1664)
parent
746f8f979d
commit
247a9dbca9
|
@ -1,3 +1,4 @@
|
|||
from audioop import bias
|
||||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
||||
|
@ -121,7 +122,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
|||
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||
|
||||
# get sharding spec
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication action
|
||||
|
@ -129,7 +130,11 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping = {'other': other_comm_spec}
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -236,8 +241,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
|
||||
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec}
|
||||
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -272,8 +281,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
|
||||
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec}
|
||||
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -390,8 +403,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
|
||||
communcation_action_mapping = {"other": other_comm_spec}
|
||||
communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communcation_action_mapping)
|
||||
|
@ -486,40 +503,22 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||
return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape)
|
||||
|
||||
def split_one_batch_dim(self):
|
||||
device_mesh_is_1d = True
|
||||
if len(self.device_mesh.mesh_shape) == 1:
|
||||
mesh_dim = 0
|
||||
elif 1 in self.device_mesh.mesh_shape:
|
||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||
else:
|
||||
device_mesh_is_1d = False
|
||||
def split_one_batch_dim(self, mesh_dim):
|
||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||
|
||||
if device_mesh_is_1d:
|
||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||
# get sharding_spec
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get sharding_spec
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim]
|
||||
}
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
else:
|
||||
return None
|
||||
# get communication actions
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping = {"bias": bias_comm_spec}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
||||
|
@ -538,7 +537,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping = {"bias": bias_comm_spec}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -566,7 +569,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping = {'other': other_comm_spec}
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -596,7 +603,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping = {'input': input_comm_spec}
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -625,21 +636,31 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping = {'output': output_comm_spec}
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def generate(self) -> List[ShardingStrategy_V2]:
|
||||
strategy_list = []
|
||||
device_mesh_is_1d = True
|
||||
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
|
||||
device_mesh_is_1d = False
|
||||
|
||||
# split only the batch dimension
|
||||
# Sb = Sb x Sb
|
||||
# can be None as it is only for 1D device mesh
|
||||
strategy = self.split_one_batch_dim()
|
||||
if strategy:
|
||||
if device_mesh_is_1d:
|
||||
# split only the batch dimension
|
||||
# Sb = Sb x Sb
|
||||
# can be None as it is only for 1D device mesh
|
||||
# only for 1D device mesh
|
||||
strategy_list.append(strategy)
|
||||
if len(self.device_mesh.mesh_shape) == 1:
|
||||
mesh_dim = 0
|
||||
else:
|
||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||
strategy_list.append(self.split_one_batch_dim(mesh_dim))
|
||||
else:
|
||||
# for 2D device mesh
|
||||
# split batch dim of two inputs and the i dim of the first tensor
|
||||
|
|
Loading…
Reference in New Issue