mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix chunk comm src rank (#1072)
parent
bfdc5ccb7b
commit
98cdbf49c6
|
@ -47,6 +47,7 @@ class Chunk:
|
|||
self.utilized_size = 0
|
||||
self.src_rank = src_rank
|
||||
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
||||
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
|
||||
self.dtype = dtype
|
||||
self.device = init_device or get_current_device()
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
|
||||
|
@ -87,7 +88,7 @@ class Chunk:
|
|||
if not self.is_src_rank:
|
||||
self.data.storage().resize_(self.size)
|
||||
self.data.data = self.data.to(get_current_device())
|
||||
dist.broadcast(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||
self._update_tensors_ptr()
|
||||
if not self.is_src_rank:
|
||||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
||||
|
@ -101,7 +102,7 @@ class Chunk:
|
|||
if is_all_reduce:
|
||||
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
||||
else:
|
||||
dist.reduce(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||
self._update_tensors_ptr()
|
||||
self._update_tensors_state(TensorState.HOLD)
|
||||
|
||||
|
|
Loading…
Reference in New Issue