|
|
@ -164,8 +164,6 @@ class Chunk:
|
|
|
|
self.l2_norm = None
|
|
|
|
self.l2_norm = None
|
|
|
|
|
|
|
|
|
|
|
|
self.grad_chunk = None
|
|
|
|
self.grad_chunk = None
|
|
|
|
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
|
|
|
|
|
|
|
|
self.grad_reduce_work = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def memory_usage(self) -> Dict[str, int]:
|
|
|
|
def memory_usage(self) -> Dict[str, int]:
|
|
|
@ -376,49 +374,34 @@ class Chunk:
|
|
|
|
if self.is_gathered:
|
|
|
|
if self.is_gathered:
|
|
|
|
self.__scatter()
|
|
|
|
self.__scatter()
|
|
|
|
|
|
|
|
|
|
|
|
def reduce(self, async_op: bool = False):
|
|
|
|
def reduce(self):
|
|
|
|
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
|
|
|
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
|
|
|
# sanity check
|
|
|
|
# sanity check
|
|
|
|
assert self.is_gathered
|
|
|
|
assert self.is_gathered
|
|
|
|
assert self.grad_reduce_work is None
|
|
|
|
|
|
|
|
if self.pg_size == 1:
|
|
|
|
if self.pg_size == 1:
|
|
|
|
# tricky code here
|
|
|
|
# tricky code here
|
|
|
|
# just move cuda_global_chunk to cuda_shard
|
|
|
|
# just move cuda_global_chunk to cuda_shard
|
|
|
|
# the communication is not necessary
|
|
|
|
# the communication is not necessary
|
|
|
|
self.__scatter()
|
|
|
|
self.__scatter()
|
|
|
|
if self.extra_dp_group is not None:
|
|
|
|
if self.extra_dp_group is not None:
|
|
|
|
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
|
|
|
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
|
|
|
elif self.keep_gathered:
|
|
|
|
elif self.keep_gathered:
|
|
|
|
# we use all-reduce here
|
|
|
|
# we use all-reduce here
|
|
|
|
self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op)
|
|
|
|
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
|
|
|
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
|
|
|
|
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
|
|
|
|
self.wait_async_reduce()
|
|
|
|
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
|
|
|
self.grad_reduce_work = dist.all_reduce(
|
|
|
|
|
|
|
|
self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.cuda_shard = torch.empty(
|
|
|
|
self.cuda_shard = torch.empty(
|
|
|
|
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
|
|
|
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
|
|
|
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
|
|
|
self.grad_reduce_work = dist.reduce_scatter(
|
|
|
|
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
|
|
|
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.extra_dp_group is not None:
|
|
|
|
if self.extra_dp_group is not None:
|
|
|
|
self.wait_async_reduce()
|
|
|
|
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
|
|
|
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
free_storage(self.cuda_global_chunk)
|
|
|
|
free_storage(self.cuda_global_chunk)
|
|
|
|
self.is_gathered = False
|
|
|
|
self.is_gathered = False
|
|
|
|
self.__update_tensors_state(TensorState.HOLD)
|
|
|
|
self.__update_tensors_state(TensorState.HOLD)
|
|
|
|
|
|
|
|
|
|
|
|
def wait_async_reduce(self) -> None:
|
|
|
|
|
|
|
|
if self.grad_reduce_work is not None:
|
|
|
|
|
|
|
|
self.grad_reduce_work.wait()
|
|
|
|
|
|
|
|
self.grad_reduce_work = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
|
|
|
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Make a transition of the tensor into the next state.
|
|
|
|
Make a transition of the tensor into the next state.
|
|
|
|