From ea94c07b959e8895b713d6dd68b168ea37db6b7b Mon Sep 17 00:00:00 2001 From: Haze188 Date: Tue, 2 Jul 2024 12:42:02 +0800 Subject: [PATCH] [hotfix] fix the bug that large tensor exceed the maximum capacity of TensorBucket (#5879) --- colossalai/zero/low_level/low_level_optim.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e06cf0581..bdc91b51f 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = real_working_params[group_id][idx] param_to_gather = master_param.to(device).to(self._dtype) pg = self.param_to_pg[working_param] + if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: + buffer_tensor = torch.empty_like( + torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) + ) + dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) + working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) + continue try: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) except RuntimeError: