mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added bias comm spec to matmul strategy (#1664)
parent
746f8f979d
commit
247a9dbca9
|
@ -1,3 +1,4 @@
|
||||||
|
from audioop import bias
|
||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
||||||
|
@ -121,7 +122,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
||||||
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||||
|
|
||||||
# get sharding spec
|
# get sharding spec
|
||||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
|
||||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
# get communication action
|
# get communication action
|
||||||
|
@ -129,7 +130,11 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping['other'],
|
sharding_spec=sharding_spec_mapping['other'],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim)
|
logical_process_axis=mesh_dim)
|
||||||
communication_action_mapping = {'other': other_comm_spec}
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['bias'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim)
|
||||||
|
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
@ -236,8 +241,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec_mapping["output"],
|
sharding_spec_mapping["output"],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_0)
|
logical_process_axis=mesh_dim_0)
|
||||||
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec_mapping["bias"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0)
|
||||||
|
|
||||||
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec}
|
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec}
|
||||||
|
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
@ -272,8 +281,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping["output"],
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=mesh_dim_1)
|
logical_process_axis=mesh_dim_1)
|
||||||
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping["bias"],
|
||||||
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
|
logical_process_axis=mesh_dim_1)
|
||||||
|
|
||||||
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec}
|
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec}
|
||||||
|
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
@ -390,8 +403,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping['other'],
|
sharding_spec=sharding_spec_mapping['other'],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['bias'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
|
|
||||||
communcation_action_mapping = {"other": other_comm_spec}
|
communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec}
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communcation_action_mapping)
|
communication_action_mapping=communcation_action_mapping)
|
||||||
|
@ -486,40 +503,22 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||||
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape)
|
return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape)
|
||||||
|
|
||||||
def split_one_batch_dim(self):
|
def split_one_batch_dim(self, mesh_dim):
|
||||||
device_mesh_is_1d = True
|
|
||||||
if len(self.device_mesh.mesh_shape) == 1:
|
|
||||||
mesh_dim = 0
|
|
||||||
elif 1 in self.device_mesh.mesh_shape:
|
|
||||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
|
||||||
else:
|
|
||||||
device_mesh_is_1d = False
|
|
||||||
|
|
||||||
if device_mesh_is_1d:
|
|
||||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||||
|
|
||||||
# get sharding_spec
|
# get sharding_spec
|
||||||
dim_partition_dict = {
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
|
||||||
"input": {
|
|
||||||
0: [mesh_dim]
|
|
||||||
},
|
|
||||||
"other": {
|
|
||||||
0: [mesh_dim]
|
|
||||||
},
|
|
||||||
"bias": {},
|
|
||||||
"output": {
|
|
||||||
0: [mesh_dim]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
# get communication actions
|
# get communication actions
|
||||||
communication_action_mapping = {}
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['bias'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim)
|
||||||
|
communication_action_mapping = {"bias": bias_comm_spec}
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
@ -538,7 +537,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
# get communication actions
|
# get communication actions
|
||||||
communication_action_mapping = {}
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['bias'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
|
communication_action_mapping = {"bias": bias_comm_spec}
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
@ -566,7 +569,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping['other'],
|
sharding_spec=sharding_spec_mapping['other'],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_1)
|
logical_process_axis=mesh_dim_1)
|
||||||
communication_action_mapping = {'other': other_comm_spec}
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['bias'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
|
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
@ -596,7 +603,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping['input'],
|
sharding_spec=sharding_spec_mapping['input'],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_1)
|
logical_process_axis=mesh_dim_1)
|
||||||
communication_action_mapping = {'input': input_comm_spec}
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['bias'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0)
|
||||||
|
communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec}
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
@ -625,21 +636,31 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping['output'],
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=mesh_dim_1)
|
logical_process_axis=mesh_dim_1)
|
||||||
communication_action_mapping = {'output': output_comm_spec}
|
bias_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['bias'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0)
|
||||||
|
communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec}
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
def generate(self) -> List[ShardingStrategy_V2]:
|
def generate(self) -> List[ShardingStrategy_V2]:
|
||||||
strategy_list = []
|
strategy_list = []
|
||||||
|
device_mesh_is_1d = True
|
||||||
|
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
|
||||||
|
device_mesh_is_1d = False
|
||||||
|
|
||||||
|
if device_mesh_is_1d:
|
||||||
# split only the batch dimension
|
# split only the batch dimension
|
||||||
# Sb = Sb x Sb
|
# Sb = Sb x Sb
|
||||||
# can be None as it is only for 1D device mesh
|
# can be None as it is only for 1D device mesh
|
||||||
strategy = self.split_one_batch_dim()
|
|
||||||
if strategy:
|
|
||||||
# only for 1D device mesh
|
# only for 1D device mesh
|
||||||
strategy_list.append(strategy)
|
if len(self.device_mesh.mesh_shape) == 1:
|
||||||
|
mesh_dim = 0
|
||||||
|
else:
|
||||||
|
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||||
|
strategy_list.append(self.split_one_batch_dim(mesh_dim))
|
||||||
else:
|
else:
|
||||||
# for 2D device mesh
|
# for 2D device mesh
|
||||||
# split batch dim of two inputs and the i dim of the first tensor
|
# split batch dim of two inputs and the i dim of the first tensor
|
||||||
|
|
Loading…
Reference in New Issue