refactor code

pull/319/head
Qu Wenwen 2023-09-19 10:57:20 +08:00
parent 98329da327
commit f76fd41325
4 changed files with 62 additions and 58 deletions

View File

@ -164,7 +164,7 @@ class NaiveAMPModel(nn.Module):
assert isinstance(outputs, (Tensor, tuple)) assert isinstance(outputs, (Tensor, tuple))
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
for output_data_ in outputs: for output_data_ in outputs:
if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype: if isinstance(output_data_, Tensor):
outputs_.append(output_data_.to(self.dtype)) outputs_.append(output_data_.to(self.dtype))
else: else:
outputs_.append(output_data_) outputs_.append(output_data_)

View File

@ -31,7 +31,8 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile, create_param_groups from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import ( from internlm.utils.parallel import (

58
internlm/train/utils.py Normal file
View File

@ -0,0 +1,58 @@
from typing import Dict, Tuple
import torch
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
"""Split parameters into different groups for optimizer
Compatiable with muiltiple param groups, each should have a name
Args:
param_groups (Tuple[Dict]):
The list of parameter groups to split
Returns:
Tuple[Dict]:
list of fp16/fp32 groups for optimizer
"""
if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified
elif isinstance(param_groups, dict):
param_groups = [param_groups]
elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}")
# Create fp32 and moe groups and copy origin attribute
for group_param in param_groups:
fp32_group = {}
# copy attribute for fp32 group
for ori_key in group_param.keys():
if ori_key == "name":
fp32_group["name"] = ori_key + "_fp32"
else:
if ori_key == "params":
fp32_group[ori_key] = []
else:
fp32_group[ori_key] = group_param[ori_key]
# Assign param
new_params = []
for param in group_param["params"]:
if param.dtype == torch.float32:
fp32_group["params"].append(param)
else:
new_params.append(param)
# origin group without fp32
group_param["params"] = new_params
# append to origin group
param_groups.append(fp32_group)
return tuple(param_groups)
def create_param_groups(model, weight_decay):
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
return split_params_into_different_groups_for_optimizer(parameters)

View File

@ -7,7 +7,7 @@ import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from typing import Dict, Tuple, Union from typing import Union
import numpy as np import numpy as np
import torch import torch
@ -236,58 +236,3 @@ class DummyProfile:
def step(self): def step(self):
pass pass
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Compatiable with muiltiple param groups, each should have a name
Args:
param_groups (Tuple[Dict]):
The list of parameter groups to split
Returns:
Tuple[Dict]:
list of MoE/non-MoE groups for optimizer
"""
if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified
elif isinstance(param_groups, dict):
param_groups = [param_groups]
elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}")
fp32_group = {}
# Create fp32 and moe groups and copy origin attribute
for param_group in param_groups:
# copy attribute for fp32 group
fp32_group["name"] = "fp32"
fp32_group["gate"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
fp32_group[ori_key] = []
else:
fp32_group[ori_key] = param_group[ori_key]
# Assign param
for param_group in param_groups:
new_params = []
for param in param_group["params"]:
if param.dtype == torch.float32:
fp32_group["params"].append(param)
else:
new_params.append(param)
# origin group without fp32 or moe parameter
param_group["params"] = new_params
# append to origin group
param_groups.append(fp32_group)
return tuple(param_groups)
def create_param_groups(model, weight_decay):
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
return split_params_into_different_groups_for_optimizer(parameters)