[autoparallel] update CommSpec to CommActions (#1768)

* [autoparallel] update CommSpec to CommActions

* polish code
pull/1769/head^2
YuliangLiu0306 2022-10-28 09:57:43 +08:00 committed by GitHub
parent 16b0abf94f
commit b0f7c8bde8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 267 additions and 122 deletions

View File

@ -202,16 +202,17 @@ class LinearFunctionHandler(NodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.node.args[2] is not None: if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM data_type = OperationDataType.PARAM
else: else:
data_type = OperationDataType.ARG data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.args[2]), physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type, type=data_type,
data=self.node.args[2]._meta_data) data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand mapping['bias'] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy): def post_process(self, strategy: ShardingStrategy):

View File

@ -3,7 +3,12 @@ import operator
from functools import reduce from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
@ -204,12 +209,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for weight and bias. # For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec} communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -238,12 +244,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias. # For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec} communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -282,12 +289,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias. # For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0]) logical_process_axis=[mesh_dim_0],
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec} communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,7 +1,12 @@
import copy import copy
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator from .strategy_generator import FollowingStrategyGenerator
@ -83,11 +88,13 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
} }
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
if gather_input: if gather_input:
input_communication_spec = self.get_communication_spec( input_communication_action = self.get_communication_action(
sharding_spec_mapping["input"], sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis,
communication_action_mapping["input"] = input_communication_spec comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_communication_action
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'

View File

@ -3,9 +3,16 @@ import operator
from functools import reduce from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, CommType,
enumerate_all_possible_2d_sharding) MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
@ -107,18 +114,20 @@ class LayerNormGenerator(StrategyGenerator):
total_mesh_dim_list = total_mesh_dim_list[0] total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {} communication_action_mapping = {}
other_comm_spec = self.get_communication_spec( other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["other"], sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list) logical_process_axis=total_mesh_dim_list,
communication_action_mapping["other"] = other_comm_spec comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["bias"], sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list) logical_process_axis=total_mesh_dim_list,
communication_action_mapping["bias"] = bias_comm_spec comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
strategy = self.get_sharding_strategy(name=name, strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,8 +1,14 @@
import operator import operator
from ast import arg
from functools import reduce from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
@ -77,11 +83,12 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action # get communication action
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'], sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim) logical_process_axis=mesh_dim,
communication_action_mapping = {"output": output_comm_spec} comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ -124,15 +131,35 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action # get communication action
other_comm_spec = self.get_communication_spec( if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'], sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim) logical_process_axis=mesh_dim,
bias_comm_spec = self.get_communication_spec( comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=1)
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'], sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim) logical_process_axis=mesh_dim,
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec} comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=2)
communication_action_mapping = {'other': other_comm_action, 'bias': bias_comm_action}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ -227,24 +254,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# set communication action # set communication action
communication_action_mapping = {} communication_action_mapping = {}
input_comm_spec = self.get_communication_spec( input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"], sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1) logical_process_axis=mesh_dim_1,
other_comm_spec = self.get_communication_spec( comm_type=CommType.BEFORE,
arg_index=0)
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"], sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['input'] = input_comm_spec communication_action_mapping['input'] = input_comm_action
communication_action_mapping['other'] = other_comm_spec communication_action_mapping['other'] = other_comm_action
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"], sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0,
communication_action_mapping['bias'] = bias_comm_spec comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -273,24 +321,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action mapping # get communication action mapping
communication_action_mapping = {} communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"], output_comm_action = self.get_communication_action(
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1) logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER)
communication_action_mapping['input'] = input_comm_spec if self.is_param('other'):
communication_action_mapping['output'] = output_comm_spec other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
communication_action_mapping['output'] = output_comm_action
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( if self.is_param('bias'):
sharding_spec=sharding_spec_mapping["bias"], bias_comm_action = self.get_communication_action(
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec_mapping["bias"],
logical_process_axis=mesh_dim_1) communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
communication_action_mapping['bias'] = bias_comm_spec logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -320,16 +389,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions # get communication actions
communication_action_mapping = {} communication_action_mapping = {}
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'], sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0,
input_comm_spec = self.get_communication_spec( comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'], sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1) logical_process_axis=mesh_dim_1,
communication_action_mapping["input"] = input_comm_spec comm_type=CommType.BEFORE,
communication_action_mapping['output'] = output_comm_spec arg_index=0)
communication_action_mapping["input"] = input_comm_action
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ -354,12 +426,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action # get communication action
communication_action_mapping = {} communication_action_mapping = {}
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'], sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim) logical_process_axis=mesh_dim,
comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_spec communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ -386,12 +459,14 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions # get communication actions
communication_action_mapping = {} communication_action_mapping = {}
input_comm_spec = self.get_communication_spec( input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'], sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim) logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_spec communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ -414,18 +489,36 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action # get communication action
communication_action_mapping = {} communication_action_mapping = {}
other_comm_spec = self.get_communication_spec( if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'], sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1],
communication_action_mapping['other'] = other_comm_spec comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'], sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1],
communication_action_mapping['bias'] = bias_comm_spec comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ -449,11 +542,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action # get communication action
communication_action_mapping = {} communication_action_mapping = {}
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'], sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1],
communication_action_mapping['output'] = output_comm_spec comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -480,11 +574,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action # get communication action
communication_action_mapping = {} communication_action_mapping = {}
input_comm_spec = self.get_communication_spec( input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'], sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1],
communication_action_mapping['input'] = input_comm_spec comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -516,8 +612,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
[b, i, k] x [b, k, j] -> [b, i, j] [b, i, k] x [b, k, j] -> [b, i, j]
The bias term is considered to have a 2D logical shape. The bias term is considered to have a 2D logical shape.
Note: This class will be used to generate strategies for torch.bmm
and torch.addbmm. However, the result of torch.addbmm is not correct,
some extra runtime apply actions are required to keep numerical correctness.
""" """
# TODO: torch.addbmm correctness issue need to be fixed.
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.squeeze_batch_dim = False self.squeeze_batch_dim = False
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -566,16 +667,16 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
self._pop_batch_dim_sharding_for_output(dim_partition_dict) self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
print(sharding_spec_mapping)
# get communication actions # get communication actions
communication_action_mapping = {} communication_action_mapping = {}
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'], sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim) logical_process_axis=mesh_dim,
communication_action_mapping['bias'] = bias_comm_spec comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ -602,11 +703,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions # get communication actions
communication_action_mapping = {} communication_action_mapping = {}
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'], sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1],
communication_action_mapping['bias'] = bias_comm_spec comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -637,18 +740,24 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions # get communication actions
communication_action_mapping = {} communication_action_mapping = {}
other_comm_spec = self.get_communication_spec( other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['other'], sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1) logical_process_axis=mesh_dim_1,
communication_action_mapping['other'] = other_comm_spec comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'], sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1],
communication_action_mapping['bias'] = bias_comm_spec comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
communication_action_mapping['other'].arg_index += 1
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -679,18 +788,23 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions # get communication actions
communication_action_mapping = {} communication_action_mapping = {}
input_comm_spec = self.get_communication_spec( input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['input'], sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1) logical_process_axis=mesh_dim_1,
communication_action_mapping['input'] = input_comm_spec comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['input'] = input_comm_action
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'], sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0,
communication_action_mapping['bias'] = bias_comm_spec comm_type=CommType.BEFORE)
communication_action_mapping['bias'] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
communication_action_mapping['input'].arg_index += 1
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -719,18 +833,21 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions # get communication actions
communication_action_mapping = {} communication_action_mapping = {}
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['output'], sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1) logical_process_axis=mesh_dim_1,
communication_action_mapping['output'] = output_comm_spec comm_type=CommType.AFTER)
communication_action_mapping['output'] = output_comm_action
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'], sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0,
communication_action_mapping['bias'] = bias_comm_spec comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping['bias'] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
@ -771,6 +888,5 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# split two batch dim # split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1)) strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list return strategy_list

View File

@ -41,7 +41,7 @@ def _split(tensor, comm_spec):
dim = comm_spec.shard_dim dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank()) start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length) output = torch.narrow(tensor, dim, start, length).contiguous()
return output return output
@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec):
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list: if dist.get_rank() in rank_list:
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group) dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
return tensor return tensor

View File

@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize from colossalai.testing.utils import parameterize
@ -109,6 +110,7 @@ def test_linear_module_handler(bias):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('bias', [True, False]) @parameterize('bias', [True, False])
def test_linear_function_handler(bias): def test_linear_function_handler(bias):
model = nn.Linear(16, 32, bias=bias).to('meta') model = nn.Linear(16, 32, bias=bias).to('meta')