Browse Source

[hotfix] fix the bug that large tensor exceed the maximum capacity of TensorBucket (#5879)

pull/5881/head
Haze188 5 months ago committed by GitHub
parent
commit
ea94c07b95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      colossalai/zero/low_level/low_level_optim.py

7
colossalai/zero/low_level/low_level_optim.py

@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype) param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try: try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError: except RuntimeError:

Loading…
Cancel
Save