[autoparallel] added bias comm spec to matmul strategy (#1664)

pull/1669/head
Frank Lee 2022-09-29 11:08:05 +08:00 committed by GitHub
parent 746f8f979d
commit 247a9dbca9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 68 additions and 47 deletions

View File

@ -1,3 +1,4 @@
from audioop import bias
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
@ -121,7 +122,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
@ -129,7 +130,11 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'other': other_comm_spec}
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -236,8 +241,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec}
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -272,8 +281,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec}
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -390,8 +403,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communcation_action_mapping = {"other": other_comm_spec}
communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communcation_action_mapping)
@ -486,40 +503,22 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape)
def split_one_batch_dim(self):
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 1:
mesh_dim = 0
elif 1 in self.device_mesh.mesh_shape:
mesh_dim = self.device_mesh.mesh_shape.index(1)
else:
device_mesh_is_1d = False
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
if device_mesh_is_1d:
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get sharding_spec
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"bias": {},
"output": {
0: [mesh_dim]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
else:
return None
# get communication actions
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {"bias": bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
@ -538,7 +537,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {"bias": bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -566,7 +569,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {'other': other_comm_spec}
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -596,7 +603,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {'input': input_comm_spec}
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -625,21 +636,31 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {'output': output_comm_spec}
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
device_mesh_is_1d = False
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
strategy = self.split_one_batch_dim()
if strategy:
if device_mesh_is_1d:
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
strategy_list.append(strategy)
if len(self.device_mesh.mesh_shape) == 1:
mesh_dim = 0
else:
mesh_dim = self.device_mesh.mesh_shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh
# split batch dim of two inputs and the i dim of the first tensor