From aa0f6686f90f3f5aad3b6c30efd0b5f97be42443 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Sun, 29 Jan 2023 11:13:15 +0800 Subject: [PATCH] [autoparallel] accelerate gpt2 training (#2495) --- .../passes/runtime_preparation_pass.py | 14 ++++++++------ .../tensor_shard/node_handler/matmul_handler.py | 2 ++ .../strategy/matmul_strategy_generator.py | 14 +++++++------- colossalai/device/device_mesh.py | 2 +- colossalai/tensor/comm_spec.py | 6 +++--- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 1c25e4c94..988970957 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # register hook to the parameters if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: - def wrapper(param, comm_spec): + def wrapper(param, comm_spec, stream): def hook_fn(grad): - _all_reduce(grad, comm_spec, async_op=False) + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) param.register_hook(hook_fn) - wrapper(param, comm_spec_to_use) + wrapper(param, comm_spec_to_use, reduction_stream) sharded_buffer_dict = {} # apply the sharding spec of buffers @@ -440,14 +441,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # register hook to the parameters if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: - def wrapper(param, comm_spec): + def wrapper(param, comm_spec, stream): def hook_fn(grad): - _all_reduce(grad, comm_spec, async_op=False) + with torch.cuda.stream(stream): + _all_reduce(grad, comm_spec, async_op=True) param.register_hook(hook_fn) - wrapper(target, comm_spec_to_use) + wrapper(target, comm_spec_to_use, reduction_stream) return gm diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index d3f9fd01d..131c35156 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -483,4 +483,6 @@ class MatMulHandler(NodeHandler): raise TypeError( f"Found unexpected output type {type(output)} from the recover method of BmmTransform") strategies = recovered_stragies + for index, strategies in enumerate(strategies): + strategies.name = f"{strategies.name}_{index}" return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index fa2246f95..9aa95b43a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): strategies.append(self.split_rhs_space_both_contract(1, 0)) # RR= RS x SR - strategies.append(self.recompute_split_both_contract(0)) - strategies.append(self.recompute_split_both_contract(1)) + # strategies.append(self.recompute_split_both_contract(0)) + # strategies.append(self.recompute_split_both_contract(1)) - # RS = RR x RS - strategies.append(self.split_rhs_space_only(0)) - strategies.append(self.split_rhs_space_only(1)) + # # RS = RR x RS + # strategies.append(self.split_rhs_space_only(0)) + # strategies.append(self.split_rhs_space_only(1)) # S01R = S01R x RR strategies.append(self.split_lhs_1st_dim_1d(0, 1)) @@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # RS01 = RR x RS01 strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) - # RR = RR x RR - strategies.append(self.non_split()) + # # RR = RR x RR + # strategies.append(self.non_split()) return strategies diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index b5a97eded..22a01dddb 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -98,7 +98,7 @@ class DeviceMesh: return DeviceMesh(self.physical_mesh_id, tuple(flatten_mesh_shape), mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), init_process_group=self.init_process_group, need_flatten=False) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 3c9e0fd56..b31c06994 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -463,7 +463,7 @@ class CommSpec: if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) # give a tiny cost to shard - backward_communication_cost = 10 + backward_communication_cost = 100 if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) @@ -481,13 +481,13 @@ class CommSpec: if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: # give a tiny cost to shard - forward_communication_cost = 10 + forward_communication_cost = 100 backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: # no need for axis because all devices are used in mix_gather forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size) - backward_communication_cost = 10 + backward_communication_cost = 100 if self.forward_only: cost_dict["forward"] = forward_communication_cost