From 58ad76d4665032bbe548d066116d1c572ce98979 Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 29 May 2024 02:22:04 +0000 Subject: [PATCH] [refactor] remove legacy async reduce scatter code --- colossalai/zero/gemini/chunk/chunk.py | 29 +++++---------------- colossalai/zero/gemini/chunk/manager.py | 4 +-- tests/test_zero/test_gemini/test_chunkv2.py | 8 ++---- 3 files changed, 10 insertions(+), 31 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index ed4566fe0..6e8533555 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -164,8 +164,6 @@ class Chunk: self.l2_norm = None self.grad_chunk = None - # the async all-reduce/reduce-scatter work of this grad chunk (None means sync) - self.grad_reduce_work = None @property def memory_usage(self) -> Dict[str, int]: @@ -376,49 +374,34 @@ class Chunk: if self.is_gathered: self.__scatter() - def reduce(self, async_op: bool = False): + def reduce(self): """Reduce scatter all the gradients. It's an operation done in CUDA.""" # sanity check assert self.is_gathered - assert self.grad_reduce_work is None if self.pg_size == 1: # tricky code here # just move cuda_global_chunk to cuda_shard # the communication is not necessary self.__scatter() if self.extra_dp_group is not None: - self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) + dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) elif self.keep_gathered: # we use all-reduce here - self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op) + dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce - self.wait_async_reduce() - self.grad_reduce_work = dist.all_reduce( - self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op - ) + dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) else: self.cuda_shard = torch.empty( self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() ) - input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - self.grad_reduce_work = dist.reduce_scatter( - self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op - ) - + dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) if self.extra_dp_group is not None: - self.wait_async_reduce() - self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) - + dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) free_storage(self.cuda_global_chunk) self.is_gathered = False self.__update_tensors_state(TensorState.HOLD) - def wait_async_reduce(self) -> None: - if self.grad_reduce_work is not None: - self.grad_reduce_work.wait() - self.grad_reduce_work = None - def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: """ Make a transition of the tensor into the next state. diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 6ec595914..5ad83d20f 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -143,12 +143,12 @@ class ChunkManager: chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) - def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool: + def reduce_chunk(self, chunk: Chunk) -> bool: """Reduce or all reduce the chunk.""" if not chunk.can_reduce: return False self.__sub_memory_usage(chunk.memory_usage) - chunk.reduce(async_op=async_op) + chunk.reduce() self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) return True diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 51b20c400..257311328 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -34,8 +34,7 @@ def check_equal(param, param_cp): @parameterize("init_device", [None, torch.device("cpu")]) @parameterize("keep_gathered", [True, False]) @parameterize("pin_memory", [True, False]) -@parameterize("async_op", [True, False]) -def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op): +def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = _get_default_group() my_chunk = Chunk( @@ -95,12 +94,9 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op): assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 assert my_chunk.can_reduce - my_chunk.reduce(async_op) + my_chunk.reduce() assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 - if async_op: - my_chunk.wait_async_reduce() - if keep_gathered is False: assert my_chunk.cuda_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cuda"