From 4e8bd39d8facc3b2c61ab0e53cc5320da79ec41e Mon Sep 17 00:00:00 2001 From: cx <759046501@qq.com> Date: Fri, 11 Aug 2023 17:46:07 +0800 Subject: [PATCH] refactor(solver/optimizer): improve optimizer memory (#193) * refactor(solver/optimizer): improve optimizer memory * feat(data): remove useless dataset type ids map --- internlm/data/utils.py | 2 +- internlm/initialize/launch.py | 2 +- internlm/solver/optimizer/hybrid_zero_optim.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internlm/data/utils.py b/internlm/data/utils.py index a86984a..3eee9d9 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -5,7 +5,7 @@ import torch from internlm.core.context import global_context as gpc -DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5} +DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1} def get_dataset_type_id(path): diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 33b5d15..d3ea708 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -39,7 +39,7 @@ def get_default_parser(): parser.add_argument("--local_rank", type=int, help="local rank on the node") parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") parser.add_argument("--seed", type=int, default=1024) - parser.add_argument("--profiling", default=True, action="store_true", help="enable/diable profiling.") + parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") return parser diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 618b772..9d42a98 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -9,6 +9,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc +from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, GradientStore, @@ -28,7 +29,6 @@ from internlm.solver.optimizer.utils import ( from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.monitor import send_alert_message from .utils import compute_norm @@ -556,14 +556,16 @@ class HybridZeroOptimizer(BaseOptimizer): # The following operations are performed only on the rank to which parameters are assigned. if not self.param_group_has_params[group_id]: continue - gradients = self._grad_store.get_averaged_gradients_by_group(group_id) # create flat gradient for the flat fp32 params - fp16_avg_grads = gradients - flat_fp16_avg_grads = flatten(fp16_avg_grads) + gradients = self._grad_store.get_averaged_gradients_by_group(group_id) + flat_fp16_avg_grads = flatten(gradients) + self._grad_store.reset_average_gradients_by_group(group_id) + del gradients # release cuda memory dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) + del flat_fp16_avg_grads # release cuda memory param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape assert ( @@ -573,8 +575,6 @@ class HybridZeroOptimizer(BaseOptimizer): single_grad_partition_groups.append(flat_fp32_avg_grads) device = self._fp32_flat_param_groups_of_current_rank[group_id].device self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) - self._grad_store._averaged_gradients[group_id] = [] - self._grad_store._averaged_gradients[group_id] = [] # unscale and clip grads # get the global norm