mirror of https://github.com/hpcaitech/ColossalAI
[refactor] remove legacy async reduce scatter code
parent
fee35678e5
commit
58ad76d466
|
@ -164,8 +164,6 @@ class Chunk:
|
||||||
self.l2_norm = None
|
self.l2_norm = None
|
||||||
|
|
||||||
self.grad_chunk = None
|
self.grad_chunk = None
|
||||||
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
|
|
||||||
self.grad_reduce_work = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def memory_usage(self) -> Dict[str, int]:
|
def memory_usage(self) -> Dict[str, int]:
|
||||||
|
@ -376,49 +374,34 @@ class Chunk:
|
||||||
if self.is_gathered:
|
if self.is_gathered:
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
|
|
||||||
def reduce(self, async_op: bool = False):
|
def reduce(self):
|
||||||
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
||||||
# sanity check
|
# sanity check
|
||||||
assert self.is_gathered
|
assert self.is_gathered
|
||||||
assert self.grad_reduce_work is None
|
|
||||||
if self.pg_size == 1:
|
if self.pg_size == 1:
|
||||||
# tricky code here
|
# tricky code here
|
||||||
# just move cuda_global_chunk to cuda_shard
|
# just move cuda_global_chunk to cuda_shard
|
||||||
# the communication is not necessary
|
# the communication is not necessary
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
if self.extra_dp_group is not None:
|
if self.extra_dp_group is not None:
|
||||||
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||||
elif self.keep_gathered:
|
elif self.keep_gathered:
|
||||||
# we use all-reduce here
|
# we use all-reduce here
|
||||||
self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op)
|
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
||||||
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
|
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
|
||||||
self.wait_async_reduce()
|
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||||
self.grad_reduce_work = dist.all_reduce(
|
|
||||||
self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.cuda_shard = torch.empty(
|
self.cuda_shard = torch.empty(
|
||||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||||
self.grad_reduce_work = dist.reduce_scatter(
|
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||||
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.extra_dp_group is not None:
|
if self.extra_dp_group is not None:
|
||||||
self.wait_async_reduce()
|
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||||
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
|
||||||
|
|
||||||
free_storage(self.cuda_global_chunk)
|
free_storage(self.cuda_global_chunk)
|
||||||
self.is_gathered = False
|
self.is_gathered = False
|
||||||
self.__update_tensors_state(TensorState.HOLD)
|
self.__update_tensors_state(TensorState.HOLD)
|
||||||
|
|
||||||
def wait_async_reduce(self) -> None:
|
|
||||||
if self.grad_reduce_work is not None:
|
|
||||||
self.grad_reduce_work.wait()
|
|
||||||
self.grad_reduce_work = None
|
|
||||||
|
|
||||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||||
"""
|
"""
|
||||||
Make a transition of the tensor into the next state.
|
Make a transition of the tensor into the next state.
|
||||||
|
|
|
@ -143,12 +143,12 @@ class ChunkManager:
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
chunk = self.tensor_chunk_map[tensor]
|
||||||
chunk.tensor_trans_state(tensor, state)
|
chunk.tensor_trans_state(tensor, state)
|
||||||
|
|
||||||
def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool:
|
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||||
"""Reduce or all reduce the chunk."""
|
"""Reduce or all reduce the chunk."""
|
||||||
if not chunk.can_reduce:
|
if not chunk.can_reduce:
|
||||||
return False
|
return False
|
||||||
self.__sub_memory_usage(chunk.memory_usage)
|
self.__sub_memory_usage(chunk.memory_usage)
|
||||||
chunk.reduce(async_op=async_op)
|
chunk.reduce()
|
||||||
self.__sub_accessed_chunk(chunk)
|
self.__sub_accessed_chunk(chunk)
|
||||||
self.__add_memory_usage(chunk.memory_usage)
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -34,8 +34,7 @@ def check_equal(param, param_cp):
|
||||||
@parameterize("init_device", [None, torch.device("cpu")])
|
@parameterize("init_device", [None, torch.device("cpu")])
|
||||||
@parameterize("keep_gathered", [True, False])
|
@parameterize("keep_gathered", [True, False])
|
||||||
@parameterize("pin_memory", [True, False])
|
@parameterize("pin_memory", [True, False])
|
||||||
@parameterize("async_op", [True, False])
|
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||||
def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op):
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = _get_default_group()
|
pg = _get_default_group()
|
||||||
my_chunk = Chunk(
|
my_chunk = Chunk(
|
||||||
|
@ -95,12 +94,9 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op):
|
||||||
|
|
||||||
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
|
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
|
||||||
assert my_chunk.can_reduce
|
assert my_chunk.can_reduce
|
||||||
my_chunk.reduce(async_op)
|
my_chunk.reduce()
|
||||||
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
||||||
|
|
||||||
if async_op:
|
|
||||||
my_chunk.wait_async_reduce()
|
|
||||||
|
|
||||||
if keep_gathered is False:
|
if keep_gathered is False:
|
||||||
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
||||||
assert my_chunk.device_type == "cuda"
|
assert my_chunk.device_type == "cuda"
|
||||||
|
|
Loading…
Reference in New Issue