diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 29b6a6db6..c762bdca7 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationDataType, ShardingStrategy, ) +from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.comm_spec import _all_reduce from colossalai.tensor.shape_consistency import ShapeConsistencyManager @@ -19,13 +20,23 @@ from colossalai.tensor.sharding_spec import ShardingSpec shape_consistency_manager = ShapeConsistencyManager() -def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]): +def _solution_annotatation(gm: torch.fx.GraphModule, + solution: List[int], + strategies_constructor: StrategiesConstructor = None): """ This method is used to stick the solution strategy to the nodes and add the information required in runtime into graph as placeholder nodes. """ mod_graph = gm.graph - nodes = tuple(mod_graph.nodes) + # TODO: In future PR, strategies_constructor should be a required argument, + # instead of optional argument. This is because we don't need to consider nodes with + # no strategy in runtime preparation pass. + if strategies_constructor is not None: + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + no_strategy_nodes = strategies_constructor.no_strategy_nodes + else: + nodes = tuple(mod_graph.nodes) + no_strategy_nodes = [] # the dict to get origin sharding spec of node origin_node_sharding_spec_dict = {} @@ -44,7 +55,10 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]): for index, node in enumerate(nodes): target_sharding_specs = [] for user_node in node.strategies_vector.successor_nodes: - target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) + if user_node in no_strategy_nodes: + target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name)) + else: + target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs setattr(node, 'target_sharding_specs', target_sharding_specs) @@ -136,13 +150,17 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): new_args.append(arg) for dim, shard_dims in output_dim_partition_dict.items(): - # we will skip the dim with -1 value - if new_args[dim + 1] == -1: - continue total_shard_size = 1 for shard_dim in shard_dims: total_shard_size *= device_mesh.shape[shard_dim] - new_args[dim + 1] //= total_shard_size + # There are two ways to use torch.view: + # 1. torch.view(input, *shape) + # 2. torch.view(input, shape) + if isinstance(new_args[1], int): + new_args[dim + 1] //= total_shard_size + else: + new_args[1] = list(new_args[1]) + new_args[1][dim] //= total_shard_size node.args = tuple(new_args) elif node.op == 'call_function': @@ -193,12 +211,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) setattr(param, 'sharding_spec', origin_sharding_spec) # TODO: build a ColoParamter class to manager the distributed parameters - param_sharded = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) - else: - param_sharded = param - setattr(target_module, name, param_sharded) + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + param.data = shape_consistency_manager.apply_for_autoparallel_runtime( + param.data, param.sharding_spec, target_sharding_spec).detach().clone() + + setattr(target_module, name, param) comm_actions = node.best_strategy.communication_actions for operation_data, comm_action in comm_actions.items(): comm_spec_to_use = comm_action.comm_spec @@ -212,7 +230,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): param.register_hook(hook_fn) - wrapper(param_sharded, comm_spec_to_use) + wrapper(param, comm_spec_to_use) sharded_buffer_dict = {} # apply the sharding spec of buffers @@ -242,12 +260,13 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {}) setattr(target, 'sharding_spec', origin_sharding_spec) # TODO: build a ColoParamter class to manager the distributed parameters - target_sharded = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec, - target_sharding_spec).detach().clone()) - else: - target_sharded = target - setattr(target_module, atoms[-1], target_sharded) + # we could use .data here, because all the operations just happen before the real training + # loop, so we don't need to track these operations in the autograd graph. + target.data = shape_consistency_manager.apply_for_autoparallel_runtime( + target.data, target.sharding_spec, target_sharding_spec).detach().clone() + + assert hasattr(target_module, atoms[-1]) + setattr(target_module, atoms[-1], target) comm_actions = node.best_strategy.communication_actions for operation_data, comm_action in comm_actions.items(): @@ -262,7 +281,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): param.register_hook(hook_fn) - wrapper(target_sharded, comm_spec_to_use) + wrapper(target, comm_spec_to_use) return gm @@ -273,9 +292,12 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): pass -def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh): +def runtime_preparation_pass(gm: torch.fx.GraphModule, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor = None): gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( - gm, solution) + gm, solution, strategies_constructor) gm = _node_args_converting(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # gm = implicit_comm_action_apply(gm) diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index adfd03d7d..9d1ff7fd1 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -41,6 +41,7 @@ class StrategiesConstructor: self.leaf_strategies = [] self.strategy_map = {} self.solver_options = solver_options + self.no_strategy_nodes = [] def remove_duplicated_strategy(self, strategies_vector): ''' @@ -78,12 +79,11 @@ class StrategiesConstructor: return _check_no_strategy_for_data(node._meta_data) - no_strategy_node = [] for node in self.nodes: strategies_vector = StrategiesVector(node) if _check_no_strategy_for_node(node): - no_strategy_node.append(node) + self.no_strategy_nodes.append(node) pass # placeholder node diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py new file mode 100644 index 000000000..d573c6590 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py @@ -0,0 +1,214 @@ +import copy +import random +from functools import partial +from typing import Optional, Tuple, Union + +import numpy as np +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import transformers +from torch.fx import GraphModule +from transformers.activations import ACT2FN +from transformers.models.gpt2.modeling_gpt2 import GPT2MLP +from transformers.pytorch_utils import Conv1D + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global +from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 768 + +seed = 128 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) +np.random.seed(seed) +random.seed(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +class GPT2MLP(nn.Module): + + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + # We temporarily banned the Dropout layer because the rng state need + # to process to get the correct result. + # self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + # TODO: the rng state need to be fixed for distributed runtime + # hidden_states = self.dropout(hidden_states) + return hidden_states + + +def check_mlp_layer(rank, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + input = torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('cuda') + test_model = copy.deepcopy(model) + test_input = copy.deepcopy(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + } + + graph = tracer.trace(root=model, meta_args=input_sample) + print(graph) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + print(gm) + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) + ret = solver.call_solver_serialized_args() + + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor) + gm = runtime_apply_pass(gm) + gm.recompile() + cuda_rng_state = torch.cuda.get_rng_state() + cpu_rng_state = torch.get_rng_state() + origin_output = test_model(test_input) + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(cpu_rng_state) + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, origin_output, rtol=1e-03, atol=1e-04) + + #*******************backward starting******************* + cuda_rng_state = torch.cuda.get_rng_state() + output.sum().backward() + torch.cuda.set_rng_state(cuda_rng_state) + origin_output.sum().backward() + origin_param_dict = dict(test_model.named_parameters()) + if rank == 0: + print("*******************backward starting*******************") + for name, param in model.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + origin_param_size = origin_param_grad.shape[-1] + print(name, param_grad, origin_param_grad) + if name == 'c_fc.bias': + assert_close_loose(param_grad, + origin_param_grad.narrow(0, 0, origin_param_size // 2), + rtol=1e-03, + atol=1e-03) + else: + assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + print("*******************backward finished*******************") + if rank == 1: + for name, param in model.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + origin_param_size = origin_param_grad.shape[-1] + if name == 'c_fc.bias': + assert_close_loose(param_grad, + origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2), + rtol=1e-03, + atol=1e-03) + else: + assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + if rank == 2: + for name, param in model.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + origin_param_size = origin_param_grad.shape[-1] + if name == 'c_fc.bias': + assert_close_loose(param_grad, + origin_param_grad.narrow(0, 0, origin_param_size // 2), + rtol=1e-03, + atol=1e-03) + else: + assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + if rank == 3: + for name, param in model.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + origin_param_size = origin_param_grad.shape[-1] + if name == 'c_fc.bias': + assert_close_loose(param_grad, + origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2), + rtol=1e-03, + atol=1e-03) + else: + assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + + #*******************backward finished******************* + + #*******************strategy selected******************* + if rank == 0: + print("*******************strategy selected*******************") + strategies_list = solver.last_s_val + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('model_cls', [GPT2MLP]) +@rerun_if_address_is_in_use() +def test_mlp_layer(model_cls): + world_size = 4 + run_func = partial(check_mlp_layer, model_cls=model_cls, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mlp_layer()