mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix the bug that large tensor exceed the maximum capacity of TensorBucket (#5879)
parent
7c2f79fa98
commit
ea94c07b95
|
@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
working_param = real_working_params[group_id][idx]
|
||||
param_to_gather = master_param.to(device).to(self._dtype)
|
||||
pg = self.param_to_pg[working_param]
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
buffer_tensor = torch.empty_like(
|
||||
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
|
||||
)
|
||||
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
|
||||
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
|
|
Loading…
Reference in New Issue