mirror of https://github.com/hpcaitech/ColossalAI
[zero] optimize the optimizer step time (#4221)
* optimize the optimizer step time * fix corner case * polish * replace all-reduce with all-gather * set comm device to cudapull/4359/head
parent
1a49a5ea00
commit
45b08f08cb
|
@ -412,15 +412,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
# update working partition updated by the current rank
|
||||
dtype = real_working_params[0][0].dtype
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]['params']
|
||||
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
|
||||
dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg)
|
||||
working_param = real_working_params[group_id][idx]
|
||||
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
|
||||
working_param.data.copy_(full_master_param)
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
|
||||
working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param))
|
||||
|
||||
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
|
|
Loading…
Reference in New Issue