diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 04f3ee2..5134245 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -164,7 +164,7 @@ class NaiveAMPModel(nn.Module): assert isinstance(outputs, (Tensor, tuple)) if isinstance(outputs, tuple): for output_data_ in outputs: - if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype: + if isinstance(output_data_, Tensor): outputs_.append(output_data_.to(self.dtype)) else: outputs_.append(output_data_) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 416c8f2..883129d 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -31,7 +31,8 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler -from internlm.utils.common import DummyProfile, create_param_groups +from internlm.train.utils import create_param_groups +from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( diff --git a/internlm/train/utils.py b/internlm/train/utils.py new file mode 100644 index 0000000..2021395 --- /dev/null +++ b/internlm/train/utils.py @@ -0,0 +1,58 @@ +from typing import Dict, Tuple + +import torch + + +def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]: + """Split parameters into different groups for optimizer + Compatiable with muiltiple param groups, each should have a name + + Args: + param_groups (Tuple[Dict]): + The list of parameter groups to split + + Returns: + Tuple[Dict]: + list of fp16/fp32 groups for optimizer + """ + if isinstance(param_groups, tuple): + param_groups = list(param_groups) # Tuple cannot be modified + elif isinstance(param_groups, dict): + param_groups = [param_groups] + elif not isinstance(param_groups, list): + raise ValueError(f"Unknown param group type of {type(param_groups)}") + + # Create fp32 and moe groups and copy origin attribute + for group_param in param_groups: + fp32_group = {} + + # copy attribute for fp32 group + for ori_key in group_param.keys(): + if ori_key == "name": + fp32_group["name"] = ori_key + "_fp32" + else: + if ori_key == "params": + fp32_group[ori_key] = [] + else: + fp32_group[ori_key] = group_param[ori_key] + + # Assign param + new_params = [] + for param in group_param["params"]: + if param.dtype == torch.float32: + fp32_group["params"].append(param) + else: + new_params.append(param) + + # origin group without fp32 + group_param["params"] = new_params + # append to origin group + param_groups.append(fp32_group) + + return tuple(param_groups) + + +def create_param_groups(model, weight_decay): + parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay} + + return split_params_into_different_groups_for_optimizer(parameters) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 4b9b047..f3b58c0 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -7,7 +7,7 @@ import os import random from contextlib import contextmanager from datetime import datetime -from typing import Dict, Tuple, Union +from typing import Union import numpy as np import torch @@ -236,58 +236,3 @@ class DummyProfile: def step(self): pass - - -def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]: - """Split parameters into different MoE groups for optimizer - Compatiable with muiltiple param groups, each should have a name - - Args: - param_groups (Tuple[Dict]): - The list of parameter groups to split - - Returns: - Tuple[Dict]: - list of MoE/non-MoE groups for optimizer - """ - if isinstance(param_groups, tuple): - param_groups = list(param_groups) # Tuple cannot be modified - elif isinstance(param_groups, dict): - param_groups = [param_groups] - elif not isinstance(param_groups, list): - raise ValueError(f"Unknown param group type of {type(param_groups)}") - - fp32_group = {} - # Create fp32 and moe groups and copy origin attribute - for param_group in param_groups: - # copy attribute for fp32 group - fp32_group["name"] = "fp32" - fp32_group["gate"] = True - for ori_key in param_group.keys(): - if ori_key != "name": - if ori_key == "params": - fp32_group[ori_key] = [] - else: - fp32_group[ori_key] = param_group[ori_key] - - # Assign param - for param_group in param_groups: - new_params = [] - for param in param_group["params"]: - if param.dtype == torch.float32: - fp32_group["params"].append(param) - else: - new_params.append(param) - # origin group without fp32 or moe parameter - param_group["params"] = new_params - - # append to origin group - param_groups.append(fp32_group) - - return tuple(param_groups) - - -def create_param_groups(model, weight_decay): - parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay} - - return split_params_into_different_groups_for_optimizer(parameters)