diff --git a/colossalai/auto_parallel/solver/_utils.py b/colossalai/auto_parallel/solver/_utils.py index dd2a09053..77faa706a 100644 --- a/colossalai/auto_parallel/solver/_utils.py +++ b/colossalai/auto_parallel/solver/_utils.py @@ -24,7 +24,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic """ if isinstance(input_, Node): - assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data' + assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' meta_tensor = input_._meta_data assert meta_tensor is not None, "The given node's _meta_data attribute is None" shape = meta_tensor.shape @@ -47,7 +47,8 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic def generate_resharding_costs(nodes: List[Node], sharding_specs: List[ShardingSpec], count_backward: Optional[bool] = True, - dtype: Optional[torch.dtype] = None): + dtype: Optional[torch.dtype] = None, + index=None): ''' Compute the resharding costs with this specific strategy. @@ -68,6 +69,9 @@ def generate_resharding_costs(nodes: List[Node], resharding_costs[input_node] = [] for strategy in input_node.strategies_vector: input_sharding_spec = strategy.output_sharding_spec + if not isinstance(input_sharding_spec, ShardingSpec): + assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' + input_sharding_spec = input_sharding_spec[index] assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' try: # compute the resharding cost diff --git a/colossalai/auto_parallel/solver/constants.py b/colossalai/auto_parallel/solver/constants.py index ecaa74ca7..d9f06bf70 100644 --- a/colossalai/auto_parallel/solver/constants.py +++ b/colossalai/auto_parallel/solver/constants.py @@ -9,12 +9,22 @@ __all__ = [ ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_FUNC_OP = [ - torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu, - torch.nn.functional.dropout, torch.flatten + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + torch.flatten, + # softmax should not be here + torch.nn.functional.softmax ] ELEMENTWISE_METHOD_OP = [ torch.Tensor.to, torch.Tensor.type, + # TODO: contiguous maybe need some extra processes. + torch.Tensor.contiguous ] RESHAPE_FUNC_OP = [torch.flatten, torch.reshape] RESHAPE_METHOD_OP = [ @@ -26,7 +36,7 @@ RESHAPE_METHOD_OP = [ ] BCAST_FUNC_OP = [ torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, - operator.mul, operator.floordiv, operator.truediv, torch.matmul + operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh ] CONV_MODULE_OP = [ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, @@ -41,6 +51,34 @@ LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] LAYERNORM_MODULE_OP = [torch.nn.LayerNorm] POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] -NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP +NON_PARAM_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + torch.flatten, + torch.where, + operator.pow, + torch.pow, + torch.tanh, + torch.add, + torch.sub, + torch.mul, + torch.div, + torch.floor_divide, + torch.true_divide, + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + # softmax should not be here + torch.nn.functional.softmax +] INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler.py b/colossalai/auto_parallel/solver/op_handler/dot_handler.py index 29beb116c..b6be639f4 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler.py @@ -431,7 +431,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute computation cost total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] @@ -473,7 +473,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] @@ -510,7 +510,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] @@ -548,7 +548,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy total_sharding_size = self.device_mesh.shape[mesh_dim] @@ -583,7 +583,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy total_sharding_size = self.device_mesh.shape[mesh_dim] @@ -619,7 +619,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] @@ -655,7 +655,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] @@ -692,7 +692,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] diff --git a/colossalai/auto_parallel/solver/op_handler/operator_handler.py b/colossalai/auto_parallel/solver/op_handler/operator_handler.py index 1e3234a56..c0d5e9143 100644 --- a/colossalai/auto_parallel/solver/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/operator_handler.py @@ -121,7 +121,13 @@ class OperatorHandler(ABC): def _generate_resharding_costs(self, sharding_specs): # The resharding_cost of weight is counted due to sharing weight cases. - dtype = self.node._meta_data.dtype + if hasattr(self.node._meta_data, 'dtype'): + dtype = self.node._meta_data.dtype + else: + assert isinstance(self.node._meta_data, + tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected' + dtype = self.node._meta_data[0].dtype + nodes = self.predecessor_node return generate_resharding_costs(nodes=nodes, sharding_specs=sharding_specs, diff --git a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py b/colossalai/auto_parallel/solver/op_handler/reshape_handler.py index 19b99ad77..b57b1e83d 100644 --- a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/reshape_handler.py @@ -1,9 +1,14 @@ +import colorsys from .operator_handler import OperatorHandler from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from copy import deepcopy import math +from colossalai.auto_parallel.solver._utils import exception_handler +import warnings +import torch +from ..constants import INFINITY_COST class ReshapeHandler(OperatorHandler): @@ -19,6 +24,7 @@ class ReshapeHandler(OperatorHandler): def _generate_compute_cost(self, *args, **kwargs): return super()._generate_compute_cost(*args, **kwargs) + @exception_handler def register_strategy(self): # TODO: add strategies with more output sharding specs other than only fully replicated. input_node = self.strategies_vector.predecessor_nodes[0] @@ -37,11 +43,23 @@ class ReshapeHandler(OperatorHandler): continue sharding_spec_checklist.append(input_sharding_spec) dim_partition_dict_for_output = {} - output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + if isinstance(self.output_data, tuple): + dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))] + try: + if isinstance(self.output_data, tuple): + output_sharding_spec = [] + for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output): + output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict)) + else: + output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + except AssertionError as e: + warnings.warn(f'{e}') + continue name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED' # TODO: use meta_info_prop to profile memory cost and compute cost compute_cost = 0 - memory_cost = self.node._meta_data.numel() + # consider node._meta_data is in type of tuple + memory_cost = 0 # compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating. dim_partition_dict_for_replicate_input = {} @@ -56,7 +74,7 @@ class ReshapeHandler(OperatorHandler): resharding_costs = self._generate_resharding_costs([input_sharding_spec]) # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[input_node] = [0 if cost == 0 else math.inf for cost in resharding_costs[input_node]] + resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]] sharding_strategy = ShardingStrategy(name, output_sharding_spec, compute_cost=compute_cost, diff --git a/tests/test_auto_parallel/test_solver_with_gpt.py b/tests/test_auto_parallel/test_solver_with_gpt.py new file mode 100644 index 000000000..9001d2ce3 --- /dev/null +++ b/tests/test_auto_parallel/test_solver_with_gpt.py @@ -0,0 +1,80 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.solver.cost_graph import CostGraph +from copy import deepcopy +from colossalai.auto_parallel.solver import Solver +import transformers +from colossalai.auto_parallel.solver.constants import * +from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.solver.options import SolverOptions + +BATCH_SIZE = 8 +SEQ_LENGHT = 8 + + +@pytest.mark.skip("for higher testing speed") +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 8) + mesh_shape = (2, 4) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12) + model = transformers.GPT2LMHeadModel(config=config) + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + + graph = tracer.trace(root=model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + print(graph) + strategies_constructor.build_strategies_and_cost() + for check_node, strategies_vector in strategies_constructor.strategy_map.items(): + print(check_node, len(strategies_vector)) + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + + ret = solver.call_solver_serialized_args() + print(ret) + strategies_list = list(ret[0]) + print(strategies_list) + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_cost_graph()