From 45b08f08cb8581986e513ef9162d93a8c07fd250 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Tue, 18 Jul 2023 14:44:13 +0800 Subject: [PATCH] [zero] optimize the optimizer step time (#4221) * optimize the optimizer step time * fix corner case * polish * replace all-reduce with all-gather * set comm device to cuda --- colossalai/zero/low_level/low_level_optim.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 023db122f..2b3f50ed4 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -412,15 +412,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank + dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]['params'] - for idx, splited_param in enumerate(master_working_param): - full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)] - dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg) working_param = real_working_params[group_id][idx] - full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param) - working_param.data.copy_(full_master_param) + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) + working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]