|
|
|
@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
working_param = real_working_params[group_id][idx] |
|
|
|
|
param_to_gather = master_param.to(device).to(self._dtype) |
|
|
|
|
pg = self.param_to_pg[working_param] |
|
|
|
|
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: |
|
|
|
|
buffer_tensor = torch.empty_like( |
|
|
|
|
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) |
|
|
|
|
) |
|
|
|
|
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) |
|
|
|
|
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) |
|
|
|
|
continue |
|
|
|
|
try: |
|
|
|
|
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) |
|
|
|
|
except RuntimeError: |
|
|
|
|