mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix the bug that large tensor exceed the maximum capacity of TensorBucket (#5879)
parent
7c2f79fa98
commit
ea94c07b95
|
@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
working_param = real_working_params[group_id][idx]
|
working_param = real_working_params[group_id][idx]
|
||||||
param_to_gather = master_param.to(device).to(self._dtype)
|
param_to_gather = master_param.to(device).to(self._dtype)
|
||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
|
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||||
|
buffer_tensor = torch.empty_like(
|
||||||
|
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
|
||||||
|
)
|
||||||
|
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
|
||||||
|
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
|
Loading…
Reference in New Issue