From 420e883d760b11326bfa4e80de4f420f95a11ee7 Mon Sep 17 00:00:00 2001 From: zaglc Date: Thu, 7 Sep 2023 17:11:01 +0800 Subject: [PATCH] fix(fsdp): add mix-precision training --- .../solver/optimizer/hybrid_zero_optim.py | 90 +++++++++++++++++-- 1 file changed, 81 insertions(+), 9 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 6cb4f3d..cc1deba 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -112,6 +112,28 @@ class FSDPadaptOptimizer(BaseOptimizer): self._clip_grad_norm = zero_cfg.clip_grad_norm self.use_fsdp = gpc.config.parallel.use_fsdp + # mark whether a module is part of TP or not + is_tensor_parallel_dict = dict() + + # fp16 and fp32 params + # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group + self._fp16_param_groups = dict() + self._fp32_param_tensor_groups = dict() + + # init fp16 and fp32 params + for group_idx, param_group in enumerate(self.optim.param_groups): + group_params = param_group["params"] + + # fp16 FlatParam storage + self._fp16_param_groups[group_idx] = group_params + + # create copy of fp32 weight + fp32_tensor_param = [param.data.float().requires_grad_(True) for param in group_params] + self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param + + # replace + param_group["params"] = fp32_tensor_param + @property def loss_scale(self): return self.grad_scaler.scale @@ -120,12 +142,28 @@ class FSDPadaptOptimizer(BaseOptimizer): loss = self.loss_scale * loss loss.backward(retain_graph=retain_graph) + def _compute_norm_with_fsdp_flatten(self, group_id): + params = self._fp16_param_groups[group_id] + gradients = [p.grad for p in params] + norm_group = compute_norm( + gradients=gradients, + parameters=params, + last_stage=True + ) + + return norm_group + + def zero_grad(self): + for _, param_group in self._fp16_param_groups.items(): + for param in param_group: + param.grad = None + def step(self): # in case that fsdp-zero3 size is not equal to dp size # FSDP module will only reduce gradient within FSDP process group # so manually reduce grad is essential between two parallel FSDP process group for group_idx in range(len(self.param_groups)): - params = self.param_groups[group_idx]["params"] + params = self._fp16_param_groups[group_idx] for param in params: if param.requires_grad: reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP) @@ -134,13 +172,7 @@ class FSDPadaptOptimizer(BaseOptimizer): found_inf = False norm_groups = [] for group_idx in range(len(self.param_groups)): - params = self.param_groups[group_idx]["params"] - gradients = [p.grad for p in params] - norm_group = compute_norm( - gradients=gradients, - parameters=params, - last_stage=True - ) + norm_group = self._compute_norm_with_fsdp_flatten(group_idx) if norm_group == -1: found_inf = True break @@ -157,15 +189,32 @@ class FSDPadaptOptimizer(BaseOptimizer): if self._clip_grad_norm > 0: global_norm = sum(norm_groups) ** 0.5 + # create gradient for fp32 params + for group_idx in range(len(self.param_groups)): + dtype = self._fp32_param_tensor_groups[group_idx][0].dtype + fp16_params = self._fp16_param_groups[group_idx] + grad_fp32 = [p.grad.to(dtype) for p in fp16_params] + + device = self._fp32_param_tensor_groups[group_idx][0].device + for p, g in zip(self._fp32_param_tensor_groups[group_idx], grad_fp32): + p.grad = g.to(device) + # unscale for group_idx in range(len(self.param_groups)): - params = self.param_groups[group_idx]["params"] + params = self._fp32_param_tensor_groups[group_idx] for p in params: self._unscale_and_clip_grads(p.grad, global_norm, loss_scale) self.optim.step() self.zero_grad() + # update fp16 param + for group_idx in range(len(self._fp16_param_groups)): + fp16_params = self._fp16_param_groups[group_idx] + fp32_tensor_params = self._fp32_param_tensor_groups[group_idx] + for p, q in zip(fp16_params, fp32_tensor_params): + p.data.copy_(q) + return True, [global_norm / loss_scale] def clip_grad_norm(self, model, max_norm): @@ -196,6 +245,11 @@ class FSDPadaptOptimizer(BaseOptimizer): optim_states = self.optim.state_dict() states["base_optim_states"] = optim_states + flat_fp32_weights = {} + for group_idx, param in self._fp32_param_tensor_groups.items(): + flat_fp32_weights[group_idx] = param + states["flat_fp32_weights"] = flat_fp32_weights + return states def load_state_dict(self, states): @@ -205,6 +259,24 @@ class FSDPadaptOptimizer(BaseOptimizer): optim_states = states["base_optim_states"] self.optim.load_state_dict(optim_states) + # load fp32 optimizer weight + flat_fp32_weights = states["flat_fp32_weights"] + assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups) + for group_idx, param in flat_fp32_weights.items(): + self_param = self._fp32_param_tensor_groups[group_idx] + assert ( + len(self_param) == len(param) + ), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}" + for p, q in zip(self_param, param): + p.data.copy_(q.data) + + # load fp16 model weight + for group_idx, param in flat_fp32_weights.items(): + fp16_param = self._fp16_param_groups[group_idx] + fp32_param = self._fp32_param_tensor_groups[group_idx] + for p, q in zip(fp16_param, fp32_param): + p.data.copy_(q.data) + class HybridZeroOptimizer(BaseOptimizer): """