[refactor] remove legacy async reduce scatter code

pull/5760/head
hxwang 6 months ago
parent fee35678e5
commit 58ad76d466

@ -164,8 +164,6 @@ class Chunk:
self.l2_norm = None self.l2_norm = None
self.grad_chunk = None self.grad_chunk = None
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
self.grad_reduce_work = None
@property @property
def memory_usage(self) -> Dict[str, int]: def memory_usage(self) -> Dict[str, int]:
@ -376,49 +374,34 @@ class Chunk:
if self.is_gathered: if self.is_gathered:
self.__scatter() self.__scatter()
def reduce(self, async_op: bool = False): def reduce(self):
"""Reduce scatter all the gradients. It's an operation done in CUDA.""" """Reduce scatter all the gradients. It's an operation done in CUDA."""
# sanity check # sanity check
assert self.is_gathered assert self.is_gathered
assert self.grad_reduce_work is None
if self.pg_size == 1: if self.pg_size == 1:
# tricky code here # tricky code here
# just move cuda_global_chunk to cuda_shard # just move cuda_global_chunk to cuda_shard
# the communication is not necessary # the communication is not necessary
self.__scatter() self.__scatter()
if self.extra_dp_group is not None: if self.extra_dp_group is not None:
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op) dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
elif self.keep_gathered: elif self.keep_gathered:
# we use all-reduce here # we use all-reduce here
self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op) dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
self.wait_async_reduce() dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
self.grad_reduce_work = dist.all_reduce(
self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op
)
else: else:
self.cuda_shard = torch.empty( self.cuda_shard = torch.empty(
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
) )
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
self.grad_reduce_work = dist.reduce_scatter( dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
)
if self.extra_dp_group is not None: if self.extra_dp_group is not None:
self.wait_async_reduce() dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
free_storage(self.cuda_global_chunk) free_storage(self.cuda_global_chunk)
self.is_gathered = False self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD) self.__update_tensors_state(TensorState.HOLD)
def wait_async_reduce(self) -> None:
if self.grad_reduce_work is not None:
self.grad_reduce_work.wait()
self.grad_reduce_work = None
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
""" """
Make a transition of the tensor into the next state. Make a transition of the tensor into the next state.

@ -143,12 +143,12 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state) chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool: def reduce_chunk(self, chunk: Chunk) -> bool:
"""Reduce or all reduce the chunk.""" """Reduce or all reduce the chunk."""
if not chunk.can_reduce: if not chunk.can_reduce:
return False return False
self.__sub_memory_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
chunk.reduce(async_op=async_op) chunk.reduce()
self.__sub_accessed_chunk(chunk) self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
return True return True

@ -34,8 +34,7 @@ def check_equal(param, param_cp):
@parameterize("init_device", [None, torch.device("cpu")]) @parameterize("init_device", [None, torch.device("cpu")])
@parameterize("keep_gathered", [True, False]) @parameterize("keep_gathered", [True, False])
@parameterize("pin_memory", [True, False]) @parameterize("pin_memory", [True, False])
@parameterize("async_op", [True, False]) def exam_chunk_basic(init_device, keep_gathered, pin_memory):
def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = _get_default_group() pg = _get_default_group()
my_chunk = Chunk( my_chunk = Chunk(
@ -95,12 +94,9 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op):
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
assert my_chunk.can_reduce assert my_chunk.can_reduce
my_chunk.reduce(async_op) my_chunk.reduce()
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
if async_op:
my_chunk.wait_async_reduce()
if keep_gathered is False: if keep_gathered is False:
assert my_chunk.cuda_shard.size(0) == 1024 // world_size assert my_chunk.cuda_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == "cuda" assert my_chunk.device_type == "cuda"

Loading…
Cancel
Save