|
|
|
@ -2,9 +2,8 @@ import operator
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
|
|
|
|
from colossalai.auto_parallel.tensor_shard.utils import \
|
|
|
|
|
ignore_sharding_exception
|
|
|
|
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
|
|
|
|
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
|
|
|
|
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
|
|
|
|
|
|
|
|
|
from .strategy_generator import StrategyGenerator
|
|
|
|
@ -227,6 +226,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
|
|
|
|
|
|
# set communication action
|
|
|
|
|
communication_action_mapping = {}
|
|
|
|
|
input_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping["input"],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
@ -235,12 +235,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping["output"],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim_0)
|
|
|
|
|
bias_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec_mapping["bias"],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim_0)
|
|
|
|
|
|
|
|
|
|
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec}
|
|
|
|
|
communication_action_mapping['input'] = input_comm_spec
|
|
|
|
|
communication_action_mapping['other'] = other_comm_spec
|
|
|
|
|
|
|
|
|
|
if self.has_bias:
|
|
|
|
|
bias_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec_mapping["bias"],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim_0)
|
|
|
|
|
communication_action_mapping['bias'] = bias_comm_spec
|
|
|
|
|
|
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
|
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
|
|
@ -268,6 +272,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
|
|
|
|
|
|
# get communication action mapping
|
|
|
|
|
communication_action_mapping = {}
|
|
|
|
|
input_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping["input"],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
@ -276,12 +281,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec=sharding_spec_mapping["output"],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim_1)
|
|
|
|
|
bias_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping["bias"],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim_1)
|
|
|
|
|
|
|
|
|
|
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec}
|
|
|
|
|
communication_action_mapping['input'] = input_comm_spec
|
|
|
|
|
communication_action_mapping['output'] = output_comm_spec
|
|
|
|
|
|
|
|
|
|
if self.has_bias:
|
|
|
|
|
bias_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping["bias"],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim_1)
|
|
|
|
|
communication_action_mapping['bias'] = bias_comm_spec
|
|
|
|
|
|
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
|
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
|
|
@ -310,6 +319,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
|
|
|
|
|
|
# get communication actions
|
|
|
|
|
communication_action_mapping = {}
|
|
|
|
|
output_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping['output'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
|
|
@ -318,7 +328,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec=sharding_spec_mapping['input'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim_1)
|
|
|
|
|
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
|
|
|
|
|
communication_action_mapping["input"] = input_comm_spec
|
|
|
|
|
communication_action_mapping['output'] = output_comm_spec
|
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
|
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
|
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
@ -342,11 +353,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
|
|
|
|
|
|
# get communication action
|
|
|
|
|
communication_action_mapping = {}
|
|
|
|
|
output_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping['output'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim)
|
|
|
|
|
communication_action_mapping = {'output': output_comm_spec}
|
|
|
|
|
|
|
|
|
|
communication_action_mapping['output'] = output_comm_spec
|
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
|
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
|
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
@ -372,11 +385,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
|
|
|
|
|
|
# get communication actions
|
|
|
|
|
communication_action_mapping = {}
|
|
|
|
|
input_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping['input'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=mesh_dim)
|
|
|
|
|
communication_action_mapping = {'input': input_comm_spec}
|
|
|
|
|
|
|
|
|
|
communication_action_mapping['input'] = input_comm_spec
|
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
|
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
|
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
@ -398,19 +413,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
|
|
|
|
|
|
# get communication action
|
|
|
|
|
communication_action_mapping = {}
|
|
|
|
|
other_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping['other'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
|
|
|
bias_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping['bias'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
|
|
|
communication_action_mapping['other'] = other_comm_spec
|
|
|
|
|
|
|
|
|
|
communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec}
|
|
|
|
|
if self.has_bias:
|
|
|
|
|
bias_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping['bias'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
|
|
|
communication_action_mapping['bias'] = bias_comm_spec
|
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
|
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
|
|
|
communication_action_mapping=communcation_action_mapping)
|
|
|
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
|
|
|
|
|
|
@ignore_sharding_exception
|
|
|
|
|
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
|
|
|
@ -430,11 +448,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
|
|
|
|
|
|
# get communication action
|
|
|
|
|
communication_action_mapping = {}
|
|
|
|
|
output_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping['output'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
|
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
|
|
|
communication_action_mapping = {'output': output_comm_spec}
|
|
|
|
|
communication_action_mapping['output'] = output_comm_spec
|
|
|
|
|
|
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
|
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
|
|
@ -460,11 +479,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
|
|
|
|
|
|
# get communication action
|
|
|
|
|
communication_action_mapping = {}
|
|
|
|
|
input_comm_spec = self.get_communication_spec(
|
|
|
|
|
sharding_spec=sharding_spec_mapping['input'],
|
|
|
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
|
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
|
|
|
communication_action_mapping = {'input': input_comm_spec}
|
|
|
|
|
communication_action_mapping['input'] = input_comm_spec
|
|
|
|
|
|
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
|
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
|
|
@ -492,7 +512,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
"""
|
|
|
|
|
Generate sharding strategies for the batched matrix multiplication.
|
|
|
|
|
|
|
|
|
|
A batched matrix multiplication can be viewed as
|
|
|
|
|
A batched matrix multiplication can be viewed as
|
|
|
|
|
[b, i, k] x [b, k, j] -> [b, i, j]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -642,7 +662,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|
|
|
|
"bias": {},
|
|
|
|
|
"output": {
|
|
|
|
|
0: [mesh_dim_0],
|
|
|
|
|
-2: [mesh_dim_1]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
|
|
|
|