From f3f2511e74d7fe0271fbf450ae85579acdde05b8 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 26 Sep 2023 15:46:47 +0800 Subject: [PATCH] feat(solver/optimizer): add new file fsdp_optimizer.py --- internlm/solver/optimizer/__init__.py | 9 +- internlm/solver/optimizer/base_optimizer.py | 46 ++++ internlm/solver/optimizer/fsdp_optimizer.py | 220 ++++++++++++++++ .../solver/optimizer/hybrid_zero_optim.py | 244 +----------------- internlm/train/training_internlm.py | 3 + train.py | 3 - 6 files changed, 273 insertions(+), 252 deletions(-) create mode 100644 internlm/solver/optimizer/base_optimizer.py create mode 100644 internlm/solver/optimizer/fsdp_optimizer.py diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py index c4a1eb7..7c6a1c6 100644 --- a/internlm/solver/optimizer/__init__.py +++ b/internlm/solver/optimizer/__init__.py @@ -1,10 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from .hybrid_zero_optim import ( - FSDPadaptOptimizer, - HybridZeroOptimizer, - reload_zero_fp32_buff, -) +from .fsdp_optimizer import FSDPadaptOptimizer +from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff -__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"] +__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "reload_zero_fp32_buff"] diff --git a/internlm/solver/optimizer/base_optimizer.py b/internlm/solver/optimizer/base_optimizer.py new file mode 100644 index 0000000..61d26ca --- /dev/null +++ b/internlm/solver/optimizer/base_optimizer.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch.optim import Optimizer + + +class BaseOptimizer(Optimizer): + """ + Base Optimizer. + """ + + def __init__(self, optim: Optimizer): # pylint: disable=W0231 + self.optim = optim + + @property + def param_groups(self): + return self.optim.param_groups + + @property + def defaults(self): + return self.optim.defaults + + def add_param_group(self, *args, **kwargs): + return self.optim.add_param_group(*args, **kwargs) + + def step(self, *args, **kwargs): + return self.optim.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + self.optim.zero_grad(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + self.optim.load_state_dict(*args, **kwargs) + + def state_dict(self): + return self.optim.state_dict() + + def backward(self, loss): + loss.backward() + + def backward_by_grad(self, tensor, grad): + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + + def clip_grad_norm(self): + pass diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py new file mode 100644 index 0000000..c08b584 --- /dev/null +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from torch.optim import Optimizer + +from internlm.core.context import Config, ParallelMode +from internlm.core.context import global_context as gpc +from internlm.solver.optimizer.utils import ( + DynamicGradScaler, + reduce_tensor, + release_param_grad, +) +from internlm.utils.logger import get_logger + +from .base_optimizer import BaseOptimizer +from .utils import compute_norm + +logger = get_logger(__file__) + + +class FSDPadaptOptimizer(BaseOptimizer): + """ + optimizer for Pytorch FSDP if 'use_fsdp' is True in config file + reserve some necessary components of hybird-optim: + grad_scaler; + grad_clip and unscale; + state_dict and load_state_dict + """ + + def __init__( + self, + optimizer: Optimizer, + grad_scal_cfg: Config = None, + zero_cfg: Config = None, + ): + super().__init__(optim=optimizer) + + # gradient scaler + self.grad_scaler = DynamicGradScaler( + initial_scale=grad_scal_cfg.fp16.initial_scale, + min_scale=grad_scal_cfg.fp16.min_scale, + growth_factor=grad_scal_cfg.growth_factor, + backoff_factor=grad_scal_cfg.backoff_factor, + growth_interval=grad_scal_cfg.fp16.growth_interval, + hysteresis=grad_scal_cfg.hysteresis, + max_scale=grad_scal_cfg.max_scale, + ) + + # clip gradient + self._clip_grad_norm = zero_cfg.clip_grad_norm + self.use_fsdp = gpc.config.parallel.use_fsdp + + # mark whether a module is part of TP or not + # TODO: is_tensor_parallel_dict = dict() + + # fp16 and fp32 params + # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group + self._fp16_param_groups = dict() + self._fp32_param_tensor_groups = dict() + + # init fp16 and fp32 params + for group_idx, param_group in enumerate(self.optim.param_groups): + group_params = param_group["params"] + + # fp16 FlatParam storage + self._fp16_param_groups[group_idx] = group_params + + # create copy of fp32 weight + fp32_tensor_param = [param.data.float().requires_grad_(True) for param in group_params] + self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param + + # replace + param_group["params"] = fp32_tensor_param + + @property + def loss_scale(self): + return self.grad_scaler.scale + + def backward(self, loss, retain_graph=False): + loss = self.loss_scale * loss + loss.backward(retain_graph=retain_graph) + + def _compute_norm_with_fsdp_flatten(self, group_id): + params = self._fp16_param_groups[group_id] + gradients = [p.grad for p in params] + norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True) + + return norm_group + + def zero_grad(self): + for _, param_group in self._fp16_param_groups.items(): + for param in param_group: + param.grad = None + + def step(self): + # in case that fsdp-zero3 size is not equal to dp size + # FSDP module will only reduce gradient within FSDP process group + # so manually reduce grad is essential between two parallel FSDP process group + for group_idx in range(len(self.param_groups)): + params = self._fp16_param_groups[group_idx] + for param in params: + if param.requires_grad: + reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP) + + # compute norm + found_inf = False + norm_groups = {} + for group_idx in range(len(self.param_groups)): + group_name = self.param_groups[group_idx]["name"] if "name" in self.param_groups[group_idx] else "default" + group_name = f"{group_idx}_{group_name}" + norm_group = self._compute_norm_with_fsdp_flatten(group_idx) + if norm_group == -1: + found_inf = True + norm_groups[group_name] = norm_group + + loss_scale = float(self.loss_scale.item()) # backup + self.grad_scaler.update(found_inf) + if found_inf: + if gpc.is_rank_for_log(): + logger.warning("Overflow occurs, please check it.") + self.zero_grad() + return False, norm_groups + + # get the global norm + global_norm_groups = {} + if self._clip_grad_norm > 0: + for group_name, norm in norm_groups.items(): + global_norm_groups[group_name] = norm**0.5 + + # create gradient for fp32 params + for group_idx in range(len(self.param_groups)): + dtype = self._fp32_param_tensor_groups[group_idx][0].dtype + fp16_params = self._fp16_param_groups[group_idx] + grad_fp32 = [p.grad.to(dtype) for p in fp16_params] + + device = self._fp32_param_tensor_groups[group_idx][0].device + for p, g in zip(self._fp32_param_tensor_groups[group_idx], grad_fp32): + p.grad = g.to(device) + + # unscale + self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale) + + self.optim.step() + self.zero_grad() + + for group_idx in range(len(self._fp16_param_groups)): + fp16_params = self._fp16_param_groups[group_idx] + fp32_tensor_params = self._fp32_param_tensor_groups[group_idx] + # release fp32 grad + release_param_grad(fp32_tensor_params) + # update fp16 param + for p, q in zip(fp16_params, fp32_tensor_params): + p.data.copy_(q) + + for group_name, global_norm in global_norm_groups.items(): + global_norm_groups[group_name] = global_norm / loss_scale + return True, global_norm_groups + + def clip_grad_norm(self, model, max_norm): + # will conduct in the step() + pass + + ######################### + # utils from hybirdzero # + ######################### + + def _unscale_and_clip_grads(self, total_norm_groups, loss_scale): + # compute combined scale factor for this group + combined_scale_groups = [] + + if self._clip_grad_norm > 0.0: + # norm is in fact norm*scale + for group_id, total_norm in enumerate(total_norm_groups): + combined_scale_groups.append(loss_scale) + clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm + if clip > 1.0: + combined_scale_groups[group_id] = clip * loss_scale + + for group_id, grads in self._fp32_param_tensor_groups.items(): + for g in grads: + g.grad.data.mul_(1.0 / combined_scale_groups[group_id]) + + def state_dict(self): + states = {} + grad_scaler = self.grad_scaler.state_dict() + states["grad_scaler"] = grad_scaler + optim_states = self.optim.state_dict() + states["base_optim_states"] = optim_states + + flat_fp32_weights = {} + for group_idx, param in self._fp32_param_tensor_groups.items(): + flat_fp32_weights[group_idx] = param + states["flat_fp32_weights"] = flat_fp32_weights + + return states + + def load_state_dict(self, states): + assert "grad_scaler" in states, "Not found grad_scaler state!" + grad_scaler = states["grad_scaler"] + self.grad_scaler.load_state_dict(grad_scaler) + optim_states = states["base_optim_states"] + self.optim.load_state_dict(optim_states) + + # load fp32 optimizer weight + flat_fp32_weights = states["flat_fp32_weights"] + assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups) + for group_idx, param in flat_fp32_weights.items(): + self_param = self._fp32_param_tensor_groups[group_idx] + assert len(self_param) == len( + param + ), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}" + for p, q in zip(self_param, param): + p.data.copy_(q.data) + + # load fp16 model weight + for group_idx, param in flat_fp32_weights.items(): + fp16_param = self._fp16_param_groups[group_idx] + fp32_param = self._fp32_param_tensor_groups[group_idx] + for p, q in zip(fp16_param, fp32_param): + p.data.copy_(q.data) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 3b384f3..9408fc3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -34,255 +34,13 @@ from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.timeout import llm_timeout +from .base_optimizer import BaseOptimizer from .utils import compute_norm inf = math.inf logger = get_logger(__file__) -class BaseOptimizer(Optimizer): - """ - Base Optimizer. - """ - - def __init__(self, optim: Optimizer): # pylint: disable=W0231 - self.optim = optim - - @property - def param_groups(self): - return self.optim.param_groups - - @property - def defaults(self): - return self.optim.defaults - - def add_param_group(self, *args, **kwargs): - return self.optim.add_param_group(*args, **kwargs) - - def step(self, *args, **kwargs): - return self.optim.step(*args, **kwargs) - - def zero_grad(self, *args, **kwargs): - self.optim.zero_grad(*args, **kwargs) - - def load_state_dict(self, *args, **kwargs): - self.optim.load_state_dict(*args, **kwargs) - - def state_dict(self): - return self.optim.state_dict() - - def backward(self, loss): - loss.backward() - - def backward_by_grad(self, tensor, grad): - torch.autograd.backward(tensors=tensor, grad_tensors=grad) - - def clip_grad_norm(self): - pass - - -class FSDPadaptOptimizer(BaseOptimizer): - """ - optimizer for Pytorch FSDP if 'use_fsdp' is True in config file - reserve some necessary components of hybird-optim: - grad_scaler; - grad_clip and unscale; - state_dict and load_state_dict - """ - - def __init__( - self, - optimizer: Optimizer, - grad_scal_cfg: Config = None, - zero_cfg: Config = None, - ): - super().__init__(optim=optimizer) - - # gradient scaler - self.grad_scaler = DynamicGradScaler( - initial_scale=grad_scal_cfg.fp16.initial_scale, - min_scale=grad_scal_cfg.fp16.min_scale, - growth_factor=grad_scal_cfg.growth_factor, - backoff_factor=grad_scal_cfg.backoff_factor, - growth_interval=grad_scal_cfg.fp16.growth_interval, - hysteresis=grad_scal_cfg.hysteresis, - max_scale=grad_scal_cfg.max_scale, - ) - - # clip gradient - self._clip_grad_norm = zero_cfg.clip_grad_norm - self.use_fsdp = gpc.config.parallel.use_fsdp - - # mark whether a module is part of TP or not - # TODO: is_tensor_parallel_dict = dict() - - # fp16 and fp32 params - # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group - self._fp16_param_groups = dict() - self._fp32_param_tensor_groups = dict() - - # init fp16 and fp32 params - for group_idx, param_group in enumerate(self.optim.param_groups): - group_params = param_group["params"] - - # fp16 FlatParam storage - self._fp16_param_groups[group_idx] = group_params - - # create copy of fp32 weight - fp32_tensor_param = [param.data.float().requires_grad_(True) for param in group_params] - self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param - - # replace - param_group["params"] = fp32_tensor_param - - @property - def loss_scale(self): - return self.grad_scaler.scale - - def backward(self, loss, retain_graph=False): - loss = self.loss_scale * loss - loss.backward(retain_graph=retain_graph) - - def _compute_norm_with_fsdp_flatten(self, group_id): - params = self._fp16_param_groups[group_id] - gradients = [p.grad for p in params] - norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True) - - return norm_group - - def zero_grad(self): - for _, param_group in self._fp16_param_groups.items(): - for param in param_group: - param.grad = None - - def step(self): - # in case that fsdp-zero3 size is not equal to dp size - # FSDP module will only reduce gradient within FSDP process group - # so manually reduce grad is essential between two parallel FSDP process group - for group_idx in range(len(self.param_groups)): - params = self._fp16_param_groups[group_idx] - for param in params: - if param.requires_grad: - reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP) - - # compute norm - found_inf = False - norm_groups = {} - for group_idx in range(len(self.param_groups)): - group_name = self.param_groups[group_idx]["name"] if "name" in self.param_groups[group_idx] else "default" - group_name = f"{group_idx}_{group_name}" - norm_group = self._compute_norm_with_fsdp_flatten(group_idx) - if norm_group == -1: - found_inf = True - norm_groups[group_name] = norm_group - - loss_scale = float(self.loss_scale.item()) # backup - self.grad_scaler.update(found_inf) - if found_inf: - if gpc.is_rank_for_log(): - logger.warning("Overflow occurs, please check it.") - self.zero_grad() - return False, norm_groups - - # get the global norm - global_norm_groups = {} - if self._clip_grad_norm > 0: - for group_name, norm in norm_groups.items(): - global_norm_groups[group_name] = norm**0.5 - - # create gradient for fp32 params - for group_idx in range(len(self.param_groups)): - dtype = self._fp32_param_tensor_groups[group_idx][0].dtype - fp16_params = self._fp16_param_groups[group_idx] - grad_fp32 = [p.grad.to(dtype) for p in fp16_params] - - device = self._fp32_param_tensor_groups[group_idx][0].device - for p, g in zip(self._fp32_param_tensor_groups[group_idx], grad_fp32): - p.grad = g.to(device) - - # unscale - self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale) - - self.optim.step() - self.zero_grad() - - for group_idx in range(len(self._fp16_param_groups)): - fp16_params = self._fp16_param_groups[group_idx] - fp32_tensor_params = self._fp32_param_tensor_groups[group_idx] - # release fp32 grad - release_param_grad(fp32_tensor_params) - # update fp16 param - for p, q in zip(fp16_params, fp32_tensor_params): - p.data.copy_(q) - - for group_name, global_norm in global_norm_groups.items(): - global_norm_groups[group_name] = global_norm / loss_scale - return True, global_norm_groups - - def clip_grad_norm(self, model, max_norm): - # will conduct in the step() - pass - - ######################### - # utils from hybirdzero # - ######################### - - def _unscale_and_clip_grads(self, total_norm_groups, loss_scale): - # compute combined scale factor for this group - combined_scale_groups = [] - - if self._clip_grad_norm > 0.0: - # norm is in fact norm*scale - for group_id, total_norm in enumerate(total_norm_groups): - combined_scale_groups.append(loss_scale) - clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm - if clip > 1.0: - combined_scale_groups[group_id] = clip * loss_scale - - for group_id, grads in self._fp32_param_tensor_groups.items(): - for g in grads: - g.grad.data.mul_(1.0 / combined_scale_groups[group_id]) - - def state_dict(self): - states = {} - grad_scaler = self.grad_scaler.state_dict() - states["grad_scaler"] = grad_scaler - optim_states = self.optim.state_dict() - states["base_optim_states"] = optim_states - - flat_fp32_weights = {} - for group_idx, param in self._fp32_param_tensor_groups.items(): - flat_fp32_weights[group_idx] = param - states["flat_fp32_weights"] = flat_fp32_weights - - return states - - def load_state_dict(self, states): - assert "grad_scaler" in states, "Not found grad_scaler state!" - grad_scaler = states["grad_scaler"] - self.grad_scaler.load_state_dict(grad_scaler) - optim_states = states["base_optim_states"] - self.optim.load_state_dict(optim_states) - - # load fp32 optimizer weight - flat_fp32_weights = states["flat_fp32_weights"] - assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups) - for group_idx, param in flat_fp32_weights.items(): - self_param = self._fp32_param_tensor_groups[group_idx] - assert len(self_param) == len( - param - ), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}" - for p, q in zip(self_param, param): - p.data.copy_(q.data) - - # load fp16 model weight - for group_idx, param in flat_fp32_weights.items(): - fp16_param = self._fp16_param_groups[group_idx] - fp32_param = self._fp32_param_tensor_groups[group_idx] - for p, q in zip(fp16_param, fp32_param): - p.data.copy_(q.data) - - class HybridZeroOptimizer(BaseOptimizer): """ Hybrid Zero Optimizer. diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index b82aef2..ba08abf 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -100,6 +100,9 @@ def initialize_model(): # state in the same dp group are all the same. set_mode(ParallelMode.DATA) + # if fsdp enabled, wrap the model + model = wrap_FSDP_model(model) + return model diff --git a/train.py b/train.py index 7043fa0..9c61acf 100644 --- a/train.py +++ b/train.py @@ -111,9 +111,6 @@ def main(args): # initialize and resume train state train_state = TrainState(gpc.config, train_dl.batch_sampler) - # if fsdp enabled, wrap the model - model = wrap_FSDP_model(model) - optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) ckpt_manager = CheckpointManager(