mirror of https://github.com/InternLM/InternLM
refactor(solver/optimizer): improve optimizer memory (#193)
* refactor(solver/optimizer): improve optimizer memory * feat(data): remove useless dataset type ids mappull/188/head^2
parent
5f3133fac8
commit
4e8bd39d8f
|
@ -5,7 +5,7 @@ import torch
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
|
||||||
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5}
|
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1}
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_type_id(path):
|
def get_dataset_type_id(path):
|
||||||
|
|
|
@ -39,7 +39,7 @@ def get_default_parser():
|
||||||
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
||||||
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
||||||
parser.add_argument("--seed", type=int, default=1024)
|
parser.add_argument("--seed", type=int, default=1024)
|
||||||
parser.add_argument("--profiling", default=True, action="store_true", help="enable/diable profiling.")
|
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch.optim import Optimizer
|
||||||
|
|
||||||
from internlm.core.context import Config, ParallelMode
|
from internlm.core.context import Config, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer.store import (
|
from internlm.solver.optimizer.store import (
|
||||||
BucketStore,
|
BucketStore,
|
||||||
GradientStore,
|
GradientStore,
|
||||||
|
@ -28,7 +29,6 @@ from internlm.solver.optimizer.utils import (
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.monitor import send_alert_message
|
|
||||||
|
|
||||||
from .utils import compute_norm
|
from .utils import compute_norm
|
||||||
|
|
||||||
|
@ -556,14 +556,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# The following operations are performed only on the rank to which parameters are assigned.
|
# The following operations are performed only on the rank to which parameters are assigned.
|
||||||
if not self.param_group_has_params[group_id]:
|
if not self.param_group_has_params[group_id]:
|
||||||
continue
|
continue
|
||||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
|
||||||
|
|
||||||
# create flat gradient for the flat fp32 params
|
# create flat gradient for the flat fp32 params
|
||||||
fp16_avg_grads = gradients
|
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||||
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
flat_fp16_avg_grads = flatten(gradients)
|
||||||
|
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||||
|
del gradients # release cuda memory
|
||||||
|
|
||||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||||
|
del flat_fp16_avg_grads # release cuda memory
|
||||||
|
|
||||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||||
assert (
|
assert (
|
||||||
|
@ -573,8 +575,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
|
||||||
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
# get the global norm
|
# get the global norm
|
||||||
|
|
Loading…
Reference in New Issue