mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] update CommSpec to CommActions (#1768)
* [autoparallel] update CommSpec to CommActions * polish codepull/1769/head^2
parent
16b0abf94f
commit
b0f7c8bde8
|
@ -202,16 +202,17 @@ class LinearFunctionHandler(NodeHandler):
|
|||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
if self.node.args[2] is not None:
|
||||
if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
|
||||
# check if the other operand is a parameter
|
||||
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
|
||||
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
physical_bias_operand = OperationData(name=str(self.node.args[2]),
|
||||
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
|
||||
type=data_type,
|
||||
data=self.node.args[2]._meta_data)
|
||||
data=self.node.kwargs["bias"]._meta_data)
|
||||
mapping['bias'] = physical_bias_operand
|
||||
|
||||
return mapping
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy):
|
||||
|
|
|
@ -3,7 +3,12 @@ import operator
|
|||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommType,
|
||||
MemoryCost,
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
@ -204,12 +209,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
# For SyncBN case, we don't need to do communication for weight and bias.
|
||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
||||
# to SyncBN operation instead of inserting a communication node.
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.AFTER)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_spec}
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -238,12 +244,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
|
||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
||||
# to SyncBN operation instead of inserting a communication node.
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.AFTER)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_spec}
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -282,12 +289,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
|
||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
||||
# to SyncBN operation instead of inserting a communication node.
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0])
|
||||
logical_process_axis=[mesh_dim_0],
|
||||
comm_type=CommType.AFTER)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_spec}
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommType,
|
||||
MemoryCost,
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import FollowingStrategyGenerator
|
||||
|
@ -83,11 +88,13 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
|
|||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
if gather_input:
|
||||
input_communication_spec = self.get_communication_spec(
|
||||
input_communication_action = self.get_communication_action(
|
||||
sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
logical_process_axis=logical_process_axis)
|
||||
communication_action_mapping["input"] = input_communication_spec
|
||||
logical_process_axis=logical_process_axis,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping["input"] = input_communication_action
|
||||
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
|
||||
|
||||
|
|
|
@ -3,9 +3,16 @@ import operator
|
|||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommType,
|
||||
MemoryCost,
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
@ -107,18 +114,20 @@ class LayerNormGenerator(StrategyGenerator):
|
|||
total_mesh_dim_list = total_mesh_dim_list[0]
|
||||
communication_action_mapping = {}
|
||||
|
||||
other_comm_spec = self.get_communication_spec(
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=total_mesh_dim_list)
|
||||
communication_action_mapping["other"] = other_comm_spec
|
||||
logical_process_axis=total_mesh_dim_list,
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=total_mesh_dim_list)
|
||||
communication_action_mapping["bias"] = bias_comm_spec
|
||||
logical_process_axis=total_mesh_dim_list,
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
import operator
|
||||
from ast import arg
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommType,
|
||||
MemoryCost,
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
|
@ -77,11 +83,12 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication action
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping = {"output": output_comm_spec}
|
||||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.AFTER)
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -124,15 +131,35 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
# get communication action
|
||||
other_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
|
||||
if self.is_param('other'):
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
if self.has_bias:
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=2)
|
||||
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -227,24 +254,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# set communication action
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
other_comm_spec = self.get_communication_spec(
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
communication_action_mapping['other'] = other_comm_spec
|
||||
if self.is_param('other'):
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
|
||||
communication_action_mapping['input'] = input_comm_action
|
||||
communication_action_mapping['other'] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -273,24 +321,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication action mapping
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.AFTER)
|
||||
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
if self.is_param('other'):
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
|
||||
communication_action_mapping['other'] = other_comm_action
|
||||
communication_action_mapping['output'] = output_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -320,16 +389,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.AFTER)
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping["input"] = input_comm_spec
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
communication_action_mapping['output'] = output_comm_action
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -354,12 +426,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.AFTER)
|
||||
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
communication_action_mapping['output'] = output_comm_action
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -386,12 +459,14 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
communication_action_mapping['input'] = input_comm_action
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -414,18 +489,36 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
other_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['other'] = other_comm_spec
|
||||
if self.is_param('other'):
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
communication_action_mapping['other'] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -449,11 +542,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.AFTER)
|
||||
communication_action_mapping['output'] = output_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -480,11 +574,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping['input'] = input_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -516,8 +612,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||
|
||||
The bias term is considered to have a 2D logical shape.
|
||||
|
||||
Note: This class will be used to generate strategies for torch.bmm
|
||||
and torch.addbmm. However, the result of torch.addbmm is not correct,
|
||||
some extra runtime apply actions are required to keep numerical correctness.
|
||||
"""
|
||||
|
||||
# TODO: torch.addbmm correctness issue need to be fixed.
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.squeeze_batch_dim = False
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -566,16 +667,16 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||
|
||||
print(sharding_spec_mapping)
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
logical_process_axis=mesh_dim,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
@ -602,11 +703,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -637,18 +740,24 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
other_comm_spec = self.get_communication_spec(
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['other'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping['other'] = other_comm_spec
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
communication_action_mapping['other'] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
# for addbmm case, other is the third argument instead of second.
|
||||
communication_action_mapping['other'].arg_index += 1
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -679,18 +788,23 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
input_comm_spec = self.get_communication_spec(
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['input'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping['input'] = input_comm_spec
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping['input'] = input_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE)
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
# for addbmm case, other is the second argument instead of first.
|
||||
communication_action_mapping['input'].arg_index += 1
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -719,18 +833,21 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
output_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['output'],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1)
|
||||
communication_action_mapping['output'] = output_comm_spec
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.AFTER)
|
||||
communication_action_mapping['output'] = output_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
bias_comm_spec = self.get_communication_spec(
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping['bias'] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -771,6 +888,5 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
|
||||
# split two batch dim
|
||||
strategy_list.append(self.split_two_batch_dim(0, 1))
|
||||
strategy_list.append(self.split_two_batch_dim(1, 0))
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -41,7 +41,7 @@ def _split(tensor, comm_spec):
|
|||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
|
@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec):
|
|||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
|
||||
return tensor
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
|
||||
|
@ -109,6 +110,7 @@ def test_linear_module_handler(bias):
|
|||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('bias', [True, False])
|
||||
def test_linear_function_handler(bias):
|
||||
model = nn.Linear(16, 32, bias=bias).to('meta')
|
||||
|
|
Loading…
Reference in New Issue