mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix chunk comm src rank (#1072)
parent
bfdc5ccb7b
commit
98cdbf49c6
|
@ -47,6 +47,7 @@ class Chunk:
|
||||||
self.utilized_size = 0
|
self.utilized_size = 0
|
||||||
self.src_rank = src_rank
|
self.src_rank = src_rank
|
||||||
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
||||||
|
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = init_device or get_current_device()
|
self.device = init_device or get_current_device()
|
||||||
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
|
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
|
||||||
|
@ -87,7 +88,7 @@ class Chunk:
|
||||||
if not self.is_src_rank:
|
if not self.is_src_rank:
|
||||||
self.data.storage().resize_(self.size)
|
self.data.storage().resize_(self.size)
|
||||||
self.data.data = self.data.to(get_current_device())
|
self.data.data = self.data.to(get_current_device())
|
||||||
dist.broadcast(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
|
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||||
self._update_tensors_ptr()
|
self._update_tensors_ptr()
|
||||||
if not self.is_src_rank:
|
if not self.is_src_rank:
|
||||||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
||||||
|
@ -101,7 +102,7 @@ class Chunk:
|
||||||
if is_all_reduce:
|
if is_all_reduce:
|
||||||
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
||||||
else:
|
else:
|
||||||
dist.reduce(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
|
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||||
self._update_tensors_ptr()
|
self._update_tensors_ptr()
|
||||||
self._update_tensors_state(TensorState.HOLD)
|
self._update_tensors_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue