diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py index 743a8582f..c5b858fe2 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py @@ -1,46 +1,12 @@ import torch -import torch.nn as nn import torch.nn.functional as F from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData +from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData +from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2 from typing import List, Dict from .registry import operator_registry -__all__ = ['LinearModuleHandler'] - - -class DotProductStrategyGenerator(StrategyGenerator_V2): - """TODO: to be implemented""" - pass - - -class MatVecStrategyGenerator(StrategyGenerator_V2): - """TODO: to be implemented""" - pass - - -class LinearProjectionStrategyGenerator(StrategyGenerator_V2): - - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: - """TODO: to be implemented""" - pass - - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: - """TODO: to be implemented""" - pass - - def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]: - """TODO: to be implemented""" - pass - - def validate(self, *args, **kwargs) -> bool: - """TODO: to be implemented""" - pass - - -class BatchedMatMulStrategyGenerator(StrategyGenerator_V2): - """TODO: to be implemented""" - pass +__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] @operator_registry.register(torch.nn.Linear) @@ -49,9 +15,10 @@ class LinearModuleHandler(ModuleHandler): A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. """ - def register_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(LinearProjectionStrategyGenerator(self.device_mesh)) + generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -97,9 +64,10 @@ class LinearFunctionHandler(NodeHandler): A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. """ - def register_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(LinearProjectionStrategyGenerator(self.device_mesh)) + generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -108,8 +76,15 @@ class LinearFunctionHandler(NodeHandler): physical_input_operand = OperationData(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + physical_other_operand = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, + type=data_type, data=self.node.args[1]._meta_data, logical_shape=self.node.args[1]._meta_data.shape[::-1]) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) @@ -117,8 +92,13 @@ class LinearFunctionHandler(NodeHandler): mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} if self.node.args[2] is not None: + # check if the other operand is a parameter + if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG physical_bias_operand = OperationData(name=str(self.node.args[2]), - type=OperationDataType.ARG, + type=data_type, data=self.node.args[2]._meta_data) mapping['bias'] = physical_bias_operand return mapping diff --git a/colossalai/auto_parallel/solver/op_handler/node_handler.py b/colossalai/auto_parallel/solver/op_handler/node_handler.py index 1b49f2028..a509664fc 100644 --- a/colossalai/auto_parallel/solver/op_handler/node_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/node_handler.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod from torch.fx.node import Node from colossalai.device.device_mesh import DeviceMesh from typing import Dict, List -from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData +from ..strategy import StrategyGenerator_V2 class NodeHandler(ABC): @@ -26,14 +27,14 @@ class NodeHandler(ABC): self.successor_node = list(node.users.keys()) self.device_mesh = device_mesh self.strategies_vector = strategies_vector - self.strategy_generator = self.register_strategy_generator() def register_strategy(self) -> StrategiesVector: """ Register different sharding strategies for the current node. """ - operand_mapping = self.get_operand_mapping() - for generator in self.strategy_generator: + strategy_generators = self.get_strategy_generator() + operand_mapping = self.get_operation_data_mapping() + for generator in strategy_generators: strategies = generator.generate(operand_mapping) self.strategies_vector.extend(strategies) @@ -46,7 +47,7 @@ class NodeHandler(ABC): return strategy @abstractmethod - def register_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: """ Define which generators should be used by this NodeHandler object. """ @@ -81,6 +82,8 @@ class ModuleHandler(NodeHandler): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + print("created") + # set attributes to access module parameters for convenience assert self.node.graph.owning_module is not None, \ f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index a3094d496..c63aae863 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -7,6 +7,7 @@ from functools import reduce from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from typing import Dict, List, Union, Tuple, Any from torch.fx.node import Node from .constants import * @@ -90,18 +91,12 @@ class TrainCycleItem: total: Any -class CommunicationType(Enum): - FWD_ALL_REDUCE = 0 - BWD_ALL_REDUCE = 1 - - @dataclass -class CommunicationAction: +class MemoryCost: """ - The actions """ - type: CommunicationType - mesh_dim: int + activation: int = 0 + parameter: int = 0 @dataclass @@ -126,7 +121,7 @@ class ShardingStrategy_V2: communication_cost: TrainCycleItem = None memory_cost: TrainCycleItem = None input_resharding_costs: Dict[OperationData, List[float]] = None - communication_actions: Dict[OperationData, List[CommunicationAction]] = None + communication_actions: Dict[OperationData, CommSpec] = None @property def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: @@ -152,79 +147,6 @@ class ShardingStrategy_V2: return specs -class StrategyGenerator_V2(ABC): - """ - StrategyGenerator is used to generate the same group of sharding strategies. - - TODO: remove the original strategy_generator.py after refactoring - """ - - def __init__(self, device_mesh: DeviceMesh): - self.device_mesh = device_mesh - - def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: - """ - Compute the communication cost involved in the forward and backward iteration. - """ - - comm_cost = TrainCycleItem(fwd=0, bwd=0) - - def _compute_and_add(data: OperationData, action: CommunicationAction): - sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device() - dtype = operand.data.dtype - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape) - cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim) - - # compute the fwd - if action.type == CommunicationType.FWD_ALL_REDUCE: - comm_cost.fwd += cost - elif action.type == CommunicationType.BWD_ALL_REDUCE: - comm_cost.fwd += cost - else: - raise ValueError(f"Found unknown CommunicationType {action.type}") - - # check if communication action exists - # if so, loop over each action and compute the cost of each action - if strategy.communication_actions is not None: - for operand, actions in strategy.communication_actions: - for action in actions: - _compute_and_add(operand, action) - - # update the communication cost attribute in-place - strategy.communication_cost = comm_cost - return strategy - - @abstractmethod - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: - """ - Customize this method to compute the computation flops. - """ - pass - - @abstractmethod - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: - """ - Customize this method to compute the memory cost in bytes. - """ - pass - - @abstractmethod - def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]: - """ - Generate all possible sharding strategies for this operation. - """ - pass - - @abstractmethod - def validate(self, *args, **kwargs) -> bool: - """ - Validate if the operands are of desired shape. - If True, means this generator can be used for the current operation. - """ - pass - - class StrategiesVector(list): ''' Each node in fx graph will have a corresponding StrategiesVector, to store all the possible diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/solver/strategy/__init__.py new file mode 100644 index 000000000..634b3e5af --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/__init__.py @@ -0,0 +1,7 @@ +from .strategy_generator import StrategyGenerator_V2 +from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator + +__all__ = [ + 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', + 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator' +] diff --git a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py new file mode 100644 index 000000000..d1b561cb5 --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py @@ -0,0 +1,364 @@ +from cmath import log +from distutils.log import Log +import operator +import torch +from functools import reduce +from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from .strategy_generator import StrategyGenerator_V2 +from typing import List + + +class DotProductStrategyGenerator(StrategyGenerator_V2): + """TODO: to be implemented""" + pass + + +class MatVecStrategyGenerator(StrategyGenerator_V2): + """TODO: to be implemented""" + pass + + +class LinearProjectionStrategyGenerator(StrategyGenerator_V2): + + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + # C = AB + # C: [M, N], A: [M, P], B: [P, N] + # fwd cost = MNP (only count mul) + # bwd: 2 x fwd_cost + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + dim_m_val = reduce(operator.mul, sharded_input_shape[:-1]) + dim_n_val = sharded_other_shape[-1] + dim_p_val = sharded_other_shape[0] + + fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=bwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + input_size = self._compute_size_in_bytes(strategy, "input") + other_size = self._compute_size_in_bytes(strategy, "input") + + if "bias" in self.op_data: + bias_size = self._compute_size_in_bytes(strategy, "bias") + else: + bias_size = 0 + output_size = self._compute_size_in_bytes(strategy, "output") + + fwd_mem_cost = MemoryCost(activation=output_size, parameter=other_size + bias_size) + bwd_mem_cost = MemoryCost(activation=input_size + other_size + bias_size, parameter=other_size) + total_mem_cost = MemoryCost(activation=input_size + 2 * output_size + bias_size, + parameter=other_size + bias_size) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def generate(self) -> List[ShardingStrategy_V2]: + strategies = [] + + # SS = SR x RS + strategies.append(self.split_lhs_space_rhs_space(0, 1)) + strategies.append(self.split_lhs_space_rhs_space(1, 0)) + + # SR = SS x SR + strategies.append(self.split_lhs_space_both_contract(0, 1)) + strategies.append(self.split_lhs_space_both_contract(1, 0)) + + # RS = RS x SS + strategies.append(self.split_rhs_space_both_contract(0, 1)) + strategies.append(self.split_rhs_space_both_contract(1, 0)) + + # RR= RS x SR + strategies.append(self.recompute_split_both_contract(0)) + strategies.append(self.recompute_split_both_contract(1)) + + # RS = RR x RS + strategies.append(self.split_rhs_space_only(0)) + strategies.append(self.split_rhs_space_only(1)) + + # S01R = S01R x RR + strategies.append(self.split_lhs_1st_dim_1d(0, 1)) + + # RR = RS01 x S01R + strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) + + # RS01 = RR x RS01 + strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) + + # update mete info on cost + for strategy in strategies: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategies + + def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): + # handle case SS = SR x RS + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + self.dim_q: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1) + other_comm_spec = self.get_communication_spec( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + + communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + # handle the case SR = SS x SR + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + # get sharding spec mapping + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "other": { + self.dim_p: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0] + }, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action mapping + input_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1) + + communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + # get sharding specs + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim_0] + }, + "other": { + self.dim_p: [mesh_dim_0], + self.dim_q: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + -1: [mesh_dim_1] + }, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication actions + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0) + input_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1) + communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def recompute_split_both_contract(self, mesh_dim): + name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + + # get sharding spec + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim] + }, + "other": { + self.dim_p: [mesh_dim] + }, + "bias": {}, + "output": {}, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim) + communication_action_mapping = {'output': output_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_rhs_space_only(self, mesh_dim): + name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": { + self.dim_q: [mesh_dim] + }, + "bias": { + -1: [mesh_dim] + }, + "output": { + -1: [mesh_dim] + }, + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication actions + input_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim) + communication_action_mapping = {'input': input_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + # get sharding spec + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + }, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + other_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + + communcation_action_mapping = {"other": other_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communcation_action_mapping) + + def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + + # get sharding spec + dim_partition_dict_mapping = { + "input": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "other": { + self.dim_p: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": {}, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping = {'output': output_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": { + self.dim_q: [mesh_dim_0, mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_0, mesh_dim_1] + }, + "output": { + -1: [mesh_dim_0, mesh_dim_1] + }, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + input_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping = {'input': input_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def validate(self) -> bool: + assert "input" in self.op_data + assert "other" in self.op_data + + # make sure the other has 2 dim + input_data = self.op_data['input'] + other_data = self.op_data['other'] + assert input_data.data.dim() > 0 and other_data.data.dim() == 2 + assert other_data.logical_shape[0] == input_data.logical_shape[-1] + + # check if bias has the same a valid dim + has_bias = "bias" in self.op_data + + if has_bias: + bias_data = self.op_data['bias'] + assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] + + +class BatchedMatMulStrategyGenerator(StrategyGenerator_V2): + """TODO: to be implemented""" + pass diff --git a/colossalai/auto_parallel/solver/strategy/strategy_generator.py b/colossalai/auto_parallel/solver/strategy/strategy_generator.py new file mode 100644 index 000000000..6b73ba0ce --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/strategy_generator.py @@ -0,0 +1,154 @@ +import operator +import torch +from colossalai.tensor.sharding_spec import ShardingSpec +from functools import reduce +from abc import ABC, abstractmethod +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.device.device_mesh import DeviceMesh +from typing import Dict, List, Union, Any +from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem + + +class StrategyGenerator_V2(ABC): + """ + StrategyGenerator is used to generate the same group of sharding strategies. + + TODO: remove the original strategy_generator.py after refactoring + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh): + self.op_data = operation_data_mapping + self.device_mesh = device_mesh + + def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], + communication_action_mapping: Dict[str, CommSpec]): + """ + A factory method to produce a ShardingStrategy object. + + Args: + sharding_spec_mapping (Dict[str, ShardingSpec]): the mapping between the operation data name and the ShardingSpec object. + communication_action_mapping (Dict[str, CommSpec]): the mapping between the operation data name and the CommSpec object. + """ + sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping) + communication_actions = self.replace_op_name_with_op_data(communication_action_mapping) + return ShardingStrategy_V2(name=name, + sharding_specs=sharding_specs, + communication_actions=communication_actions) + + def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]): + """ + A utility method to convert the the dim partition dict to a ShardingSpec object. + + Args: + mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary. + """ + results = {} + for op_data_name, dim_partition_dict in mapping.items(): + op_data = self.op_data[op_data_name] + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=op_data.logical_shape, + dim_partition_dict=dim_partition_dict) + results[op_data_name] = sharding_spec + return results + + def replace_op_name_with_op_data(self, mapping: Dict[str, Any]): + """ + Convert the key of the dictionary from the operation data name to an OperationData object. + """ + results = {} + for k, v in mapping.items(): + op_data = self.op_data[k] + results[op_data] = v + return results + + def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]]): + """ + A factory method to produce a CommSpec object. + """ + # use flatten device mesh the same action is applied to two axes + if isinstance(logical_process_axis, list) and len(logical_process_axis) == 2: + sharding_spec.device_mesh = sharding_spec.device_mesh.flatten() + logical_process_axis = 0 + return CommSpec(comm_pattern=communication_pattern, + sharding_spec=sharding_spec, + logical_process_axis=logical_process_axis) + + def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + """ + Compute the communication cost involved in the forward and backward iteration. + """ + + comm_cost = TrainCycleItem(fwd=0, bwd=0) + + def _compute_and_add(data: OperationData, comm_spec: CommSpec): + num_ele_in_comm = comm_spec.get_comm_cost() + dtype = operand.data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + cost = size_per_elem_bytes * num_ele_in_comm + + # compute the fwd + # TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost. + # it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used, + # so total cost is either for fwd or bwd. + if comm_spec.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: + comm_cost.fwd += cost + elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + comm_cost.fwd += cost + else: + raise ValueError(f"Found unknown CommunicationType {comm_spec.comm_pattern}") + + # check if communication action exists + # if so, loop over each action and compute the cost of each action + if strategy.communication_actions is not None: + for operand, comm_spec in strategy.communication_actions: + _compute_and_add(operand, comm_spec) + + # update the communication cost attribute in-place + strategy.communication_cost = comm_cost + return strategy + + @abstractmethod + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + """ + Customize this method to compute the computation flops. + """ + pass + + @abstractmethod + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + """ + Customize this method to compute the memory cost in bytes. + """ + pass + + def _compute_size_in_bytes(self, strategy: ShardingStrategy_V2, key: str): + """ + Compute the size of a tensor in bytes. + + Args: + strategy (ShardingStrategy): the ShardingStrategy generated. + key (str): the name of the operation data defined by the generator. + + """ + op_data = self.op_data[key] + sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device() + dtype = self.op_data[key].data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + return reduce(operator.mul, sharded_shape) * size_per_elem_bytes + + @abstractmethod + def generate(self) -> List[ShardingStrategy_V2]: + """ + Generate all possible sharding strategies for this operation. + """ + pass + + @abstractmethod + def validate(self, *args, **kwargs) -> bool: + """ + Validate if the operands are of desired shape. + If True, means this generator can be used for the current operation. + """ + pass diff --git a/tests/test_auto_parallel/test_linear_handler_v2.py b/tests/test_auto_parallel/test_linear_handler_v2.py index 22fced5af..3b020cba3 100644 --- a/tests/test_auto_parallel/test_linear_handler_v2.py +++ b/tests/test_auto_parallel/test_linear_handler_v2.py @@ -84,13 +84,13 @@ def test_linear_function_handler(): assert mapping['other'].name == "weight" assert mapping['other'].data.is_meta assert mapping['other'].data.shape == torch.Size([20, 10]) - assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['bias'].name == "bias" assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([20]) - assert mapping['bias'].type == OperationDataType.ARG + assert mapping['bias'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['output'].name == "linear" @@ -100,5 +100,5 @@ def test_linear_function_handler(): if __name__ == '__main__': - # test_linear_module_handler() + test_linear_module_handler() test_linear_function_handler()