[hotfix] fix chunk comm src rank (#1072)

pull/1076/head
ver217 3 years ago committed by GitHub
parent bfdc5ccb7b
commit 98cdbf49c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,6 +47,7 @@ class Chunk:
self.utilized_size = 0
self.src_rank = src_rank
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
self.dtype = dtype
self.device = init_device or get_current_device()
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
@ -87,7 +88,7 @@ class Chunk:
if not self.is_src_rank:
self.data.storage().resize_(self.size)
self.data.data = self.data.to(get_current_device())
dist.broadcast(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
self._update_tensors_ptr()
if not self.is_src_rank:
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
@ -101,7 +102,7 @@ class Chunk:
if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
else:
dist.reduce(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
self._update_tensors_ptr()
self._update_tensors_state(TensorState.HOLD)

Loading…
Cancel
Save