|
|
|
@ -47,6 +47,7 @@ class Chunk:
|
|
|
|
|
self.utilized_size = 0
|
|
|
|
|
self.src_rank = src_rank
|
|
|
|
|
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
|
|
|
|
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
self.device = init_device or get_current_device()
|
|
|
|
|
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
|
|
|
|
@ -87,7 +88,7 @@ class Chunk:
|
|
|
|
|
if not self.is_src_rank:
|
|
|
|
|
self.data.storage().resize_(self.size)
|
|
|
|
|
self.data.data = self.data.to(get_current_device())
|
|
|
|
|
dist.broadcast(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
|
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
|
self._update_tensors_ptr()
|
|
|
|
|
if not self.is_src_rank:
|
|
|
|
|
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
|
|
|
@ -101,7 +102,7 @@ class Chunk:
|
|
|
|
|
if is_all_reduce:
|
|
|
|
|
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
|
else:
|
|
|
|
|
dist.reduce(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
|
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
|
self._update_tensors_ptr()
|
|
|
|
|
self._update_tensors_state(TensorState.HOLD)
|
|
|
|
|
|
|
|
|
|