diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py new file mode 100644 index 000000000..743a8582f --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .node_handler import ModuleHandler, NodeHandler +from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData +from typing import List, Dict +from .registry import operator_registry + +__all__ = ['LinearModuleHandler'] + + +class DotProductStrategyGenerator(StrategyGenerator_V2): + """TODO: to be implemented""" + pass + + +class MatVecStrategyGenerator(StrategyGenerator_V2): + """TODO: to be implemented""" + pass + + +class LinearProjectionStrategyGenerator(StrategyGenerator_V2): + + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + """TODO: to be implemented""" + pass + + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + """TODO: to be implemented""" + pass + + def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]: + """TODO: to be implemented""" + pass + + def validate(self, *args, **kwargs) -> bool: + """TODO: to be implemented""" + pass + + +class BatchedMatMulStrategyGenerator(StrategyGenerator_V2): + """TODO: to be implemented""" + pass + + +@operator_registry.register(torch.nn.Linear) +class LinearModuleHandler(ModuleHandler): + """ + A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. + """ + + def register_strategy_generator(self) -> List[StrategyGenerator_V2]: + generators = [] + generators.append(LinearProjectionStrategyGenerator(self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=self.named_parameters['weight'].shape[::-1]) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if self.named_parameters['bias'] is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy_V2): + """ + Convert the sharding spec of the weight parameter back to its original shape. + """ + for op_data, sharding_spec in strategy.input_sharding_specs.items(): + if op_data.name == "weight": + assert op_data.logical_shape != op_data.data.shape + dim_partition_dict = sharding_spec.dim_partition_dict + # switch first and last dim of the linear module weight + dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0] + + # re-init the sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) + return strategy + + +@operator_registry.register(F.linear) +class LinearFunctionHandler(NodeHandler): + """ + A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. + """ + + def register_strategy_generator(self) -> List[StrategyGenerator_V2]: + generators = [] + generators.append(LinearProjectionStrategyGenerator(self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.node.args[1]._meta_data, + logical_shape=self.node.args[1]._meta_data.shape[::-1]) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if self.node.args[2] is not None: + physical_bias_operand = OperationData(name=str(self.node.args[2]), + type=OperationDataType.ARG, + data=self.node.args[2]._meta_data) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy_V2): + """ + Convert the sharding spec of the weight parameter back to its original shape. + """ + for op_data, sharding_spec in strategy.input_sharding_specs.items(): + if op_data.name == str(self.node.args[1]): + assert op_data.logical_shape != op_data.data.shape + dim_partition_dict = sharding_spec.dim_partition_dict + # switch first and last dim of the linear module weight + dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0] + + # re-init the sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) + return strategy diff --git a/colossalai/auto_parallel/solver/op_handler/node_handler.py b/colossalai/auto_parallel/solver/op_handler/node_handler.py index 8deafeb55..1b49f2028 100644 --- a/colossalai/auto_parallel/solver/op_handler/node_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/node_handler.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from torch.fx.node import Node from colossalai.device.device_mesh import DeviceMesh from typing import Dict, List -from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2 class NodeHandler(ABC): @@ -36,8 +36,15 @@ class NodeHandler(ABC): for generator in self.strategy_generator: strategies = generator.generate(operand_mapping) self.strategies_vector.extend(strategies) + + self.strategies_vector = map(self.post_process, self.strategies_vector) return self.strategies_vector + def post_process(self, strategy: ShardingStrategy_V2): + # tranform the strategy generated + # e.g. to process the sharding strategy for the transposed weights + return strategy + @abstractmethod def register_strategy_generator(self) -> List[StrategyGenerator_V2]: """ @@ -46,21 +53,40 @@ class NodeHandler(ABC): pass @abstractmethod - def get_operand_mapping(self) -> Dict[str, Operand]: + def get_operation_data_mapping(self) -> Dict[str, OperationData]: """ - Returns the mapping between the logical operand name to its physical operands. - A logical operand is defined by the strategy generator, for example, a matrix multiplication - operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is - the module input and the physical operand for "other" is the module weight. + Returns the mapping between the logical operation data to its physical data. + A logical operation data is a data associated with an operation, which can be input and output. It is + defined by the strategy generator, for example, a matrix multiplication operation has two operands "input" + and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is + the module input, the physical operand for "other" is the module weight, and the physical result for "output" + is the module output. Note that the operand name is specified by the StrategyGenerator object. For example: # for a linear layer mapping = { - "input": Operand(name=str(self.node.args[0]), type=OperandType.ARG), - "other": Operand(name="weight", type=OperandType.PARAM), - "bias": Operand(name="bias", type=OperandType.PARAM) + "input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data), + "other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']), + "bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']), + "output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data), } """ pass + + +class ModuleHandler(NodeHandler): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # set attributes to access module parameters for convenience + assert self.node.graph.owning_module is not None, \ + f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' + module = self.node.graph.owning_module.get_submodule(self.node.target) + named_parameters = list(module.named_parameters(recurse=False)) + # convert named parameters from list to dict + named_parameters = {k: v for k, v in named_parameters} + self.module = module + self.named_parameters = named_parameters diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index e7429d386..a3094d496 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -1,6 +1,10 @@ from dataclasses import dataclass from abc import ABC, abstractmethod from enum import Enum +import operator +import torch +from functools import reduce + from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec from typing import Dict, List, Union, Tuple, Any @@ -40,18 +44,35 @@ class ShardingStrategy: input_shardings: List[ShardingSpec] = None -class OperandType(Enum): +class OperationDataType(Enum): """ - An operand can come from the argument list of an operator or the parameter list of a module. + An operation can come from the argument list of an operator or the parameter list of a module. """ ARG = 0 PARAM = 1 + OUTPUT = 2 @dataclass -class Operand: +class OperationData: + """ + OperationData is the data related to an operator, the data can be the operand or the output. + + Args: + name (str): the name of the operation-related data + type (OperationDataType): the type of the operation data + data (torch.Tensor): the value for this data, usually it is a meta tensor. + logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory. + """ name: str - type: OperandType + type: OperationDataType + data: torch.Tensor + logical_shape: Tuple[int] = None + + def __post_init__(self): + # if no logical shape is specified, use the data shape as the logical shape + if self.logical_shape is None: + self.logical_shape = self.data.shape @dataclass @@ -69,6 +90,20 @@ class TrainCycleItem: total: Any +class CommunicationType(Enum): + FWD_ALL_REDUCE = 0 + BWD_ALL_REDUCE = 1 + + +@dataclass +class CommunicationAction: + """ + The actions + """ + type: CommunicationType + mesh_dim: int + + @dataclass class ShardingStrategy_V2: """ @@ -86,12 +121,35 @@ class ShardingStrategy_V2: strategy.(default to None) """ name: str - output_sharding_spec: ShardingSpec + sharding_specs: Dict[OperationData, ShardingSpec] = None compute_cost: TrainCycleItem = None communication_cost: TrainCycleItem = None memory_cost: TrainCycleItem = None - input_sharding_specs: Dict[Operand, ShardingSpec] = None - input_resharding_costs: Dict[Operand, List[float]] = None + input_resharding_costs: Dict[OperationData, List[float]] = None + communication_actions: Dict[OperationData, List[CommunicationAction]] = None + + @property + def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + specs = {} + specs.update(self._get_sharding_spec(OperationDataType.ARG)) + specs.update(self._get_sharding_spec(OperationDataType.PARAM)) + return specs + + @property + def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.ARG) + + @property + def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.PARAM) + + @property + def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: + return self._get_sharding_spec(OperationDataType.OUTPUT) + + def _get_sharding_spec(self, operation_data_type: OperationDataType): + specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type} + return specs class StrategyGenerator_V2(ABC): @@ -104,9 +162,57 @@ class StrategyGenerator_V2(ABC): def __init__(self, device_mesh: DeviceMesh): self.device_mesh = device_mesh - @abstractmethod - def generate(self, operand_mapping: Dict[str, Operand]) -> List[ShardingStrategy_V2]: + def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: """ + Compute the communication cost involved in the forward and backward iteration. + """ + + comm_cost = TrainCycleItem(fwd=0, bwd=0) + + def _compute_and_add(data: OperationData, action: CommunicationAction): + sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device() + dtype = operand.data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape) + cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim) + + # compute the fwd + if action.type == CommunicationType.FWD_ALL_REDUCE: + comm_cost.fwd += cost + elif action.type == CommunicationType.BWD_ALL_REDUCE: + comm_cost.fwd += cost + else: + raise ValueError(f"Found unknown CommunicationType {action.type}") + + # check if communication action exists + # if so, loop over each action and compute the cost of each action + if strategy.communication_actions is not None: + for operand, actions in strategy.communication_actions: + for action in actions: + _compute_and_add(operand, action) + + # update the communication cost attribute in-place + strategy.communication_cost = comm_cost + return strategy + + @abstractmethod + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + """ + Customize this method to compute the computation flops. + """ + pass + + @abstractmethod + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + """ + Customize this method to compute the memory cost in bytes. + """ + pass + + @abstractmethod + def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]: + """ + Generate all possible sharding strategies for this operation. """ pass diff --git a/tests/test_auto_parallel/test_linear_handler_v2.py b/tests/test_auto_parallel/test_linear_handler_v2.py new file mode 100644 index 000000000..22fced5af --- /dev/null +++ b/tests/test_auto_parallel/test_linear_handler_v2.py @@ -0,0 +1,104 @@ +from colossalai.fx.tracer.meta_patch.patched_module import linear +import torch +import torch.nn as nn +from colossalai.fx import ColoTracer, ColoGraphModule +from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +def test_linear_module_handler(): + model = nn.Sequential(nn.Linear(10, 20).to('meta')) + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + print(graph) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + linear_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 10]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 10]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([20, 10]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([10, 20]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([20]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([10, 20]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 20]) + assert mapping['output'].type == OperationDataType.OUTPUT + + +def test_linear_function_handler(): + model = nn.Linear(10, 20).to('meta') + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + print(graph) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + linear_func_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_func_node) + + # build handler + handler = LinearFunctionHandler(node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 10]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 10]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([20, 10]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([10, 20]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([20]) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([10, 20]) + + assert mapping['output'].name == "linear" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 20]) + assert mapping['output'].type == OperationDataType.OUTPUT + + +if __name__ == '__main__': + # test_linear_module_handler() + test_linear_function_handler()