[autoparallel] update CommSpec to CommActions (#1768)

* [autoparallel] update CommSpec to CommActions

* polish code
pull/1769/head^2
YuliangLiu0306 2022-10-28 09:57:43 +08:00 committed by GitHub
parent 16b0abf94f
commit b0f7c8bde8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 267 additions and 122 deletions

View File

@ -202,16 +202,17 @@ class LinearFunctionHandler(NodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.node.args[2] is not None:
if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.args[2]),
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type,
data=self.node.args[2]._meta_data)
data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):

View File

@ -3,7 +3,12 @@ import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -204,12 +209,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -238,12 +244,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -282,12 +289,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0])
logical_process_axis=[mesh_dim_0],
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec}
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,7 +1,12 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
@ -83,11 +88,13 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
if gather_input:
input_communication_spec = self.get_communication_spec(
input_communication_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=logical_process_axis)
communication_action_mapping["input"] = input_communication_spec
logical_process_axis=logical_process_axis,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_communication_action
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'

View File

@ -3,9 +3,16 @@ import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -107,18 +114,20 @@ class LayerNormGenerator(StrategyGenerator):
total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["other"] = other_comm_spec
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["bias"] = bias_comm_spec
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,8 +1,14 @@
import operator
from ast import arg
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
@ -77,11 +83,12 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {"output": output_comm_spec}
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -124,15 +131,35 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -227,24 +254,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# set communication action
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
other_comm_spec = self.get_communication_spec(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_spec
communication_action_mapping['other'] = other_comm_spec
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['input'] = input_comm_action
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -273,24 +321,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action mapping
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['input'] = input_comm_spec
communication_action_mapping['output'] = output_comm_spec
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
communication_action_mapping['output'] = output_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['bias'] = bias_comm_spec
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -320,16 +389,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
input_comm_spec = self.get_communication_spec(
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping["input"] = input_comm_spec
communication_action_mapping['output'] = output_comm_spec
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_comm_action
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -354,12 +426,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_spec
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -386,12 +459,14 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_spec
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -414,18 +489,36 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['other'] = other_comm_spec
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -449,11 +542,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['output'] = output_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -480,11 +574,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['input'] = input_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -516,8 +612,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
[b, i, k] x [b, k, j] -> [b, i, j]
The bias term is considered to have a 2D logical shape.
Note: This class will be used to generate strategies for torch.bmm
and torch.addbmm. However, the result of torch.addbmm is not correct,
some extra runtime apply actions are required to keep numerical correctness.
"""
# TODO: torch.addbmm correctness issue need to be fixed.
def __init__(self, *args, **kwargs):
self.squeeze_batch_dim = False
super().__init__(*args, **kwargs)
@ -566,16 +667,16 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
print(sharding_spec_mapping)
# get communication actions
communication_action_mapping = {}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ -602,11 +703,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -637,18 +740,24 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['other'] = other_comm_spec
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
communication_action_mapping['other'].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -679,18 +788,23 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['input'] = input_comm_spec
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
communication_action_mapping['input'].arg_index += 1
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -719,18 +833,21 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping['output'] = output_comm_spec
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -771,6 +888,5 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list

View File

@ -41,7 +41,7 @@ def _split(tensor, comm_spec):
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length)
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec):
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
return tensor

View File

@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
@ -109,6 +110,7 @@ def test_linear_module_handler(bias):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('bias', [True, False])
def test_linear_function_handler(bias):
model = nn.Linear(16, 32, bias=bias).to('meta')