From 98cdbf49c60d5869098afbea9c1baeb7217a86b8 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 7 Jun 2022 11:54:56 +0800 Subject: [PATCH] [hotfix] fix chunk comm src rank (#1072) --- colossalai/tensor/chunk.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index d0c8315dc..12cdb7957 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -47,6 +47,7 @@ class Chunk: self.utilized_size = 0 self.src_rank = src_rank self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank + self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank] self.dtype = dtype self.device = init_device or get_current_device() self.data = torch.empty(chunk_size, dtype=dtype, device=self.device) @@ -87,7 +88,7 @@ class Chunk: if not self.is_src_rank: self.data.storage().resize_(self.size) self.data.data = self.data.to(get_current_device()) - dist.broadcast(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA)) + dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) self._update_tensors_ptr() if not self.is_src_rank: self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) @@ -101,7 +102,7 @@ class Chunk: if is_all_reduce: dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) else: - dist.reduce(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA)) + dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) self._update_tensors_ptr() self._update_tensors_state(TensorState.HOLD)