[zero] optimize the optimizer step time (#4221)

* optimize the optimizer step time

* fix corner case

* polish

* replace all-reduce with all-gather

* set comm device to cuda
pull/4359/head
LuGY 2023-07-18 14:44:13 +08:00 committed by Hongxin Liu
parent 1a49a5ea00
commit 45b08f08cb
1 changed files with 6 additions and 5 deletions

View File

@ -412,15 +412,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
dtype = real_working_params[0][0].dtype
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]['params']
for idx, splited_param in enumerate(master_working_param):
full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg)
working_param = real_working_params[group_id][idx]
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
working_param.data.copy_(full_master_param)
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]