mirror of https://github.com/InternLM/InternLM
feat(train): add fsdp training option (#293)
* feat(fsdp): add training option for fsdp * fix(fsdp): add mix-precision training * fix failure in lint-check * fix format problem * restore 7B_sft * fix load ckpt bug * fix load ckpt bug2 * feat(solver/optimizer): add new file fsdp_optimizer.py * fix(train.py): fix ci lint error * fix(fsdp_optimizer.py): wait grad async * fix bug for loading ckpts when zero1 < dp_size * fix(context/parallel_context.py): only log warning for fsdp * change ckpt name * fix(model/modeling_internlm.py): fix checkpoint=False runtime error * more wrap * add support for FSDP with tp * modify args_sanity_check for fsdp with pipeline and fsdp with moe * fix(internlm/utils/parallel.py): fix circular import * fix(internlm/train/training_internlm.py): remove set IS_TENSOR_PARALLEL attr * fix(internlm/train/training_internlm.py): update wrap class and fix lint error * fix(internlm/model): reset dropout_selective_checkpoint=True * feat(configs/7B_sft.py): move fsdp config to parallel zero1 * feat(configs/7B_sft.py): adapt to old version config --------- Co-authored-by: huangting4201 <1538303371@qq.com>pull/407/head
parent
582ee000bd
commit
a075153adf
|
@ -154,7 +154,7 @@ pipeline parallel (dict):
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=dict(size=8, fsdp=False),
|
||||||
tensor=1,
|
tensor=1,
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=False,
|
sequence_parallel=False,
|
||||||
|
|
|
@ -11,6 +11,7 @@ from .process_group_initializer import (
|
||||||
Initializer_Pipeline,
|
Initializer_Pipeline,
|
||||||
Initializer_Tensor,
|
Initializer_Tensor,
|
||||||
Initializer_Zero1,
|
Initializer_Zero1,
|
||||||
|
Initializer_Zero3_dp,
|
||||||
ParallelMode,
|
ParallelMode,
|
||||||
ProcessGroupInitializer,
|
ProcessGroupInitializer,
|
||||||
)
|
)
|
||||||
|
@ -36,6 +37,7 @@ __all__ = [
|
||||||
"Initializer_Data",
|
"Initializer_Data",
|
||||||
"Initializer_Zero1",
|
"Initializer_Zero1",
|
||||||
"Initializer_Nettest",
|
"Initializer_Nettest",
|
||||||
|
"Initializer_Zero3_dp",
|
||||||
"ProcessGroupInitializer",
|
"ProcessGroupInitializer",
|
||||||
"Initializer_Model",
|
"Initializer_Model",
|
||||||
"seed",
|
"seed",
|
||||||
|
|
|
@ -328,6 +328,9 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
return False
|
return False
|
||||||
return self.is_last_rank(ParallelMode.PIPELINE)
|
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
def is_no_pp_or_last_stage(self):
|
||||||
|
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
|
||||||
|
|
||||||
def get_world_size(self, parallel_mode: ParallelMode):
|
def get_world_size(self, parallel_mode: ParallelMode):
|
||||||
"""Returns the world size for `parallel_mode`.
|
"""Returns the world size for `parallel_mode`.
|
||||||
|
|
||||||
|
@ -429,6 +432,16 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
assert self.zero1_parallel_size > 0
|
assert self.zero1_parallel_size > 0
|
||||||
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
||||||
|
|
||||||
|
# check for fsdp:
|
||||||
|
# if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights
|
||||||
|
# because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank
|
||||||
|
# pytorch vision: 1.13.1+cu117
|
||||||
|
if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.zero1.get("fsdp", False):
|
||||||
|
logger.warning(
|
||||||
|
f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, "
|
||||||
|
"will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value"
|
||||||
|
)
|
||||||
|
|
||||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||||
if key in config:
|
if key in config:
|
||||||
ele = config[key]
|
ele = config[key]
|
||||||
|
@ -495,6 +508,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||||
|
if isinstance(self.config.parallel.zero1, dict) and self.config.parallel.zero1.get("fsdp", False):
|
||||||
|
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||||
|
|
|
@ -37,6 +37,11 @@ class ParallelMode(Enum):
|
||||||
# runntime network test
|
# runntime network test
|
||||||
NETTEST = "nettest"
|
NETTEST = "nettest"
|
||||||
|
|
||||||
|
# zero3-dp parallel
|
||||||
|
# if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size
|
||||||
|
# then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
|
||||||
|
ZERO3_DP = "zero3_dp"
|
||||||
|
|
||||||
# expert parallel
|
# expert parallel
|
||||||
EXPERT = "expert"
|
EXPERT = "expert"
|
||||||
|
|
||||||
|
@ -594,3 +599,62 @@ class Initializer_Expert_Data(ProcessGroupInitializer):
|
||||||
)
|
)
|
||||||
|
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
|
||||||
|
class Initializer_Zero3_dp(ProcessGroupInitializer):
|
||||||
|
"""A ProcessGroupInitializer for data parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rank (int): The rank of current process.
|
||||||
|
world_size (int): Size of whole communication world.
|
||||||
|
data_parallel_size (int): Size of data parallel.
|
||||||
|
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||||
|
tensor_parallel_size (int): Size of tensor parallel.
|
||||||
|
zero1_parallel_size (int): Size of zero1 parallel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
||||||
|
|
||||||
|
# the only difference between this initializer and DP_initializer
|
||||||
|
# when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding
|
||||||
|
# eg: when zero=4 and dp=8
|
||||||
|
# no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group
|
||||||
|
# fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually
|
||||||
|
|
||||||
|
self.data_parallel_size //= self.zero1_parallel_size
|
||||||
|
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
||||||
|
|
||||||
|
assert self.world_size % self.data_parallel_size == 0
|
||||||
|
|
||||||
|
def init_dist_group(self, use_cpu: bool = False):
|
||||||
|
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||||
|
A Data parallelism's information tuple.
|
||||||
|
"""
|
||||||
|
local_rank = None
|
||||||
|
ranks_in_group = None
|
||||||
|
process_group = None
|
||||||
|
cpu_group = None
|
||||||
|
group_world_size = None
|
||||||
|
mode = ParallelMode.ZERO3_DP
|
||||||
|
|
||||||
|
for i in range(self.rank_num_per_dp_group):
|
||||||
|
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
if use_cpu:
|
||||||
|
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||||
|
else:
|
||||||
|
group_cpu = None
|
||||||
|
|
||||||
|
if self.rank in ranks:
|
||||||
|
local_rank = ranks.index(self.rank)
|
||||||
|
group_world_size = len(ranks)
|
||||||
|
process_group = group
|
||||||
|
cpu_group = group_cpu
|
||||||
|
ranks_in_group = ranks
|
||||||
|
|
||||||
|
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||||
|
|
|
@ -16,7 +16,7 @@ from internlm.utils.common import get_master_node
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
# check pacakge
|
# check package
|
||||||
try:
|
try:
|
||||||
import numa
|
import numa
|
||||||
from numa import memory, schedule
|
from numa import memory, schedule
|
||||||
|
@ -65,14 +65,36 @@ def args_sanity_check():
|
||||||
|
|
||||||
# procssing the parallel config in gpc
|
# procssing the parallel config in gpc
|
||||||
if "zero1" not in gpc.config.parallel:
|
if "zero1" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("zero1", -1)
|
gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False))
|
||||||
|
|
||||||
|
if isinstance(gpc.config.parallel.zero1, int):
|
||||||
|
zero1_size = gpc.config.parallel.zero1
|
||||||
|
gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False))
|
||||||
|
|
||||||
if "pipeline" not in gpc.config.parallel:
|
if "pipeline" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("pipeline", 1)
|
gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False))
|
||||||
|
|
||||||
if "tensor" not in gpc.config.parallel:
|
if "tensor" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("tensor", 1)
|
gpc.config.parallel._add_item("tensor", 1)
|
||||||
|
|
||||||
|
if isinstance(gpc.config.parallel.pipeline, int):
|
||||||
|
pp = gpc.config.parallel.pipeline
|
||||||
|
else:
|
||||||
|
pp = gpc.config.parallel.pipeline.size
|
||||||
|
|
||||||
|
# check fsdp config
|
||||||
|
if "fsdp" not in gpc.config.parallel.zero1:
|
||||||
|
gpc.config.parallel.zero1._add_item("fsdp", False)
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
gpc.config.parallel.zero1.fsdp and pp > 1
|
||||||
|
), "FSDP is not supportted when pipeline size > 1, please set pipeline size to 1 or disabled FSDP"
|
||||||
|
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
|
assert (
|
||||||
|
torch.__version__ >= "2.0.1"
|
||||||
|
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
|
||||||
|
|
||||||
# processing the data config in gpc
|
# processing the data config in gpc
|
||||||
data = gpc.config.data
|
data = gpc.config.data
|
||||||
|
|
||||||
|
@ -271,6 +293,9 @@ def args_sanity_check():
|
||||||
model._add_item("moe_use_residual", False)
|
model._add_item("moe_use_residual", False)
|
||||||
if "moe_gate_k" not in model:
|
if "moe_gate_k" not in model:
|
||||||
model._add_item("moe_gate_k", 2)
|
model._add_item("moe_gate_k", 2)
|
||||||
|
assert not (
|
||||||
|
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
|
||||||
|
), "FSDP does not support num_experts > 1"
|
||||||
|
|
||||||
# process the parallel config
|
# process the parallel config
|
||||||
if "sequence_parallel" not in gpc.config.parallel:
|
if "sequence_parallel" not in gpc.config.parallel:
|
||||||
|
|
|
@ -6,7 +6,6 @@ from torch_scatter import scatter
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.parallel import is_no_pp_or_last_stage
|
|
||||||
|
|
||||||
|
|
||||||
class AccPerplex:
|
class AccPerplex:
|
||||||
|
@ -138,7 +137,7 @@ class AccPerplex:
|
||||||
self.total_log_probs += total_log_probs
|
self.total_log_probs += total_log_probs
|
||||||
|
|
||||||
def get_metric(self, reset=True):
|
def get_metric(self, reset=True):
|
||||||
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||||
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
@ -236,7 +235,7 @@ class LossWithTypeId:
|
||||||
self.ds_token_num += token_num_type
|
self.ds_token_num += token_num_type
|
||||||
|
|
||||||
def get_metric(self, reset=True):
|
def get_metric(self, reset=True):
|
||||||
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||||
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
if hasattr(self, "total_type_count"):
|
if hasattr(self, "total_type_count"):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from .fsdp_optimizer import FSDPadaptOptimizer
|
||||||
from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff
|
from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff
|
||||||
|
|
||||||
__all__ = ["HybridZeroOptimizer", "reload_zero_fp32_buff"]
|
__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "reload_zero_fp32_buff"]
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOptimizer(Optimizer):
|
||||||
|
"""
|
||||||
|
Base Optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optim: Optimizer): # pylint: disable=W0231
|
||||||
|
self.optim = optim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
return self.optim.param_groups
|
||||||
|
|
||||||
|
@property
|
||||||
|
def defaults(self):
|
||||||
|
return self.optim.defaults
|
||||||
|
|
||||||
|
def add_param_group(self, *args, **kwargs):
|
||||||
|
return self.optim.add_param_group(*args, **kwargs)
|
||||||
|
|
||||||
|
def step(self, *args, **kwargs):
|
||||||
|
return self.optim.step(*args, **kwargs)
|
||||||
|
|
||||||
|
def zero_grad(self, *args, **kwargs):
|
||||||
|
self.optim.zero_grad(*args, **kwargs)
|
||||||
|
|
||||||
|
def load_state_dict(self, *args, **kwargs):
|
||||||
|
self.optim.load_state_dict(*args, **kwargs)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.optim.state_dict()
|
||||||
|
|
||||||
|
def backward(self, loss):
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def backward_by_grad(self, tensor, grad):
|
||||||
|
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||||
|
|
||||||
|
def clip_grad_norm(self):
|
||||||
|
pass
|
|
@ -0,0 +1,221 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from internlm.core.context import Config, ParallelMode
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.solver.optimizer.utils import (
|
||||||
|
DynamicGradScaler,
|
||||||
|
reduce_tensor,
|
||||||
|
release_param_grad,
|
||||||
|
)
|
||||||
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
|
from .base_optimizer import BaseOptimizer
|
||||||
|
from .utils import compute_norm
|
||||||
|
|
||||||
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
|
"""
|
||||||
|
optimizer for Pytorch FSDP if 'parallel.zero1.fsdp' is True in config file
|
||||||
|
reserve some necessary components of hybird-optim:
|
||||||
|
grad_scaler;
|
||||||
|
grad_clip and unscale;
|
||||||
|
state_dict and load_state_dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
grad_scal_cfg: Config = None,
|
||||||
|
zero_cfg: Config = None,
|
||||||
|
):
|
||||||
|
super().__init__(optim=optimizer)
|
||||||
|
|
||||||
|
# gradient scaler
|
||||||
|
self.grad_scaler = DynamicGradScaler(
|
||||||
|
initial_scale=grad_scal_cfg.fp16.initial_scale,
|
||||||
|
min_scale=grad_scal_cfg.fp16.min_scale,
|
||||||
|
growth_factor=grad_scal_cfg.growth_factor,
|
||||||
|
backoff_factor=grad_scal_cfg.backoff_factor,
|
||||||
|
growth_interval=grad_scal_cfg.fp16.growth_interval,
|
||||||
|
hysteresis=grad_scal_cfg.hysteresis,
|
||||||
|
max_scale=grad_scal_cfg.max_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# clip gradient
|
||||||
|
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
||||||
|
|
||||||
|
# fp16 and fp32 params
|
||||||
|
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
||||||
|
self._fp16_param_groups = dict()
|
||||||
|
self._fp32_param_tensor_groups = dict()
|
||||||
|
|
||||||
|
# init fp16 and fp32 params
|
||||||
|
for group_idx, param_group in enumerate(self.optim.param_groups):
|
||||||
|
group_params = param_group["params"]
|
||||||
|
|
||||||
|
# fp16 FlatParam storage
|
||||||
|
self._fp16_param_groups[group_idx] = group_params
|
||||||
|
|
||||||
|
# create copy of fp32 weight
|
||||||
|
fp32_tensor_param = [param.data.float() for param in group_params]
|
||||||
|
self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param
|
||||||
|
|
||||||
|
# replace
|
||||||
|
param_group["params"] = fp32_tensor_param
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_scale(self):
|
||||||
|
return self.grad_scaler.scale
|
||||||
|
|
||||||
|
def backward(self, loss, retain_graph=False):
|
||||||
|
loss = self.loss_scale * loss
|
||||||
|
loss.backward(retain_graph=retain_graph)
|
||||||
|
|
||||||
|
def _compute_norm_with_fsdp_flatten(self, group_id):
|
||||||
|
params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0]
|
||||||
|
gradients = [p.grad for p in params if p.untyped_storage().size() != 0]
|
||||||
|
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
|
||||||
|
|
||||||
|
return norm_group
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
for _, param_group in self._fp16_param_groups.items():
|
||||||
|
for param in param_group:
|
||||||
|
param.grad = None
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
# in case that fsdp-zero3 size is not equal to dp size
|
||||||
|
# FSDP module will only reduce gradient within FSDP process group
|
||||||
|
# so manually reduce grad is essential between two parallel FSDP process group
|
||||||
|
for group_idx in range(len(self.param_groups)):
|
||||||
|
params = self._fp16_param_groups[group_idx]
|
||||||
|
for param in params:
|
||||||
|
if param.requires_grad and param.grad is not None:
|
||||||
|
handle = reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
# compute norm
|
||||||
|
found_inf = False
|
||||||
|
norm_groups = {}
|
||||||
|
for group_idx in range(len(self.param_groups)):
|
||||||
|
group_name = self.param_groups[group_idx]["name"] if "name" in self.param_groups[group_idx] else "default"
|
||||||
|
group_name = f"{group_idx}_{group_name}"
|
||||||
|
norm_group = self._compute_norm_with_fsdp_flatten(group_idx)
|
||||||
|
if norm_group == -1:
|
||||||
|
found_inf = True
|
||||||
|
norm_groups[group_name] = norm_group
|
||||||
|
|
||||||
|
loss_scale = float(self.loss_scale.item()) # backup
|
||||||
|
self.grad_scaler.update(found_inf)
|
||||||
|
if found_inf:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.warning("Overflow occurs, please check it.")
|
||||||
|
self.zero_grad()
|
||||||
|
return False, norm_groups
|
||||||
|
|
||||||
|
# get the global norm
|
||||||
|
global_norm_groups = {}
|
||||||
|
if self._clip_grad_norm > 0:
|
||||||
|
for group_name, norm in norm_groups.items():
|
||||||
|
global_norm_groups[group_name] = norm**0.5
|
||||||
|
|
||||||
|
# create gradient for fp32 params
|
||||||
|
for group_idx in range(len(self.param_groups)):
|
||||||
|
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
|
||||||
|
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||||
|
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
|
||||||
|
|
||||||
|
device = self._fp32_param_tensor_groups[group_idx][0].device
|
||||||
|
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||||
|
for p, g in zip(nonzero_fp32, grad_fp32):
|
||||||
|
p.grad = g.to(device)
|
||||||
|
|
||||||
|
# unscale
|
||||||
|
self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale)
|
||||||
|
|
||||||
|
self.optim.step()
|
||||||
|
self.zero_grad()
|
||||||
|
|
||||||
|
for group_idx in range(len(self._fp16_param_groups)):
|
||||||
|
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||||
|
fp32_tensor_params = [
|
||||||
|
p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0
|
||||||
|
]
|
||||||
|
# release fp32 grad
|
||||||
|
release_param_grad(fp32_tensor_params)
|
||||||
|
# update fp16 param
|
||||||
|
for p, q in zip(fp16_params, fp32_tensor_params):
|
||||||
|
p.data.copy_(q)
|
||||||
|
|
||||||
|
for group_name, global_norm in global_norm_groups.items():
|
||||||
|
global_norm_groups[group_name] = global_norm / loss_scale
|
||||||
|
return True, global_norm_groups
|
||||||
|
|
||||||
|
def clip_grad_norm(self, model, max_norm):
|
||||||
|
# will conduct in the step()
|
||||||
|
pass
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# utils from hybirdzero #
|
||||||
|
#########################
|
||||||
|
|
||||||
|
def _unscale_and_clip_grads(self, total_norm_groups, loss_scale):
|
||||||
|
# compute combined scale factor for this group
|
||||||
|
combined_scale_groups = []
|
||||||
|
|
||||||
|
if self._clip_grad_norm > 0.0:
|
||||||
|
# norm is in fact norm*scale
|
||||||
|
for group_id, total_norm in enumerate(total_norm_groups):
|
||||||
|
combined_scale_groups.append(loss_scale)
|
||||||
|
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||||
|
if clip > 1.0:
|
||||||
|
combined_scale_groups[group_id] = clip * loss_scale
|
||||||
|
|
||||||
|
for group_id, param in self._fp32_param_tensor_groups.items():
|
||||||
|
for p in param:
|
||||||
|
if p.untyped_storage().size() != 0:
|
||||||
|
p.grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
states = {}
|
||||||
|
grad_scaler = self.grad_scaler.state_dict()
|
||||||
|
states["grad_scaler"] = grad_scaler
|
||||||
|
optim_states = self.optim.state_dict()
|
||||||
|
states["base_optim_states"] = optim_states
|
||||||
|
|
||||||
|
flat_fp32_weights = {}
|
||||||
|
for group_idx, param in self._fp32_param_tensor_groups.items():
|
||||||
|
flat_fp32_weights[group_idx] = param
|
||||||
|
states["flat_fp32_weights"] = flat_fp32_weights
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
def load_state_dict(self, states):
|
||||||
|
assert "grad_scaler" in states, "Not found grad_scaler state!"
|
||||||
|
grad_scaler = states["grad_scaler"]
|
||||||
|
self.grad_scaler.load_state_dict(grad_scaler)
|
||||||
|
optim_states = states["base_optim_states"]
|
||||||
|
self.optim.load_state_dict(optim_states)
|
||||||
|
|
||||||
|
# load fp32 optimizer weight
|
||||||
|
flat_fp32_weights = states["flat_fp32_weights"]
|
||||||
|
assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups)
|
||||||
|
for group_idx, param in flat_fp32_weights.items():
|
||||||
|
self_param = self._fp32_param_tensor_groups[group_idx]
|
||||||
|
assert len(self_param) == len(
|
||||||
|
param
|
||||||
|
), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}"
|
||||||
|
for p, q in zip(self_param, param):
|
||||||
|
p.data.copy_(q.data)
|
||||||
|
|
||||||
|
# load fp16 model weight
|
||||||
|
for group_idx, param in flat_fp32_weights.items():
|
||||||
|
fp16_param = self._fp16_param_groups[group_idx]
|
||||||
|
fp32_param = self._fp32_param_tensor_groups[group_idx]
|
||||||
|
for p, q in zip(fp16_param, fp32_param):
|
||||||
|
p.data.copy_(q.data)
|
|
@ -33,53 +33,13 @@ from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
|
from .base_optimizer import BaseOptimizer
|
||||||
from .utils import compute_norm
|
from .utils import compute_norm
|
||||||
|
|
||||||
inf = math.inf
|
inf = math.inf
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
class BaseOptimizer(Optimizer):
|
|
||||||
"""
|
|
||||||
Base Optimizer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, optim: Optimizer): # pylint: disable=W0231
|
|
||||||
self.optim = optim
|
|
||||||
|
|
||||||
@property
|
|
||||||
def param_groups(self):
|
|
||||||
return self.optim.param_groups
|
|
||||||
|
|
||||||
@property
|
|
||||||
def defaults(self):
|
|
||||||
return self.optim.defaults
|
|
||||||
|
|
||||||
def add_param_group(self, *args, **kwargs):
|
|
||||||
return self.optim.add_param_group(*args, **kwargs)
|
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
|
||||||
return self.optim.step(*args, **kwargs)
|
|
||||||
|
|
||||||
def zero_grad(self, *args, **kwargs):
|
|
||||||
self.optim.zero_grad(*args, **kwargs)
|
|
||||||
|
|
||||||
def load_state_dict(self, *args, **kwargs):
|
|
||||||
self.optim.load_state_dict(*args, **kwargs)
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
return self.optim.state_dict()
|
|
||||||
|
|
||||||
def backward(self, loss):
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
|
||||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
|
||||||
|
|
||||||
def clip_grad_norm(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HybridZeroOptimizer(BaseOptimizer):
|
class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
Hybrid Zero Optimizer.
|
Hybrid Zero Optimizer.
|
||||||
|
|
|
@ -6,6 +6,7 @@ from .training_internlm import (
|
||||||
initialize_optimizer,
|
initialize_optimizer,
|
||||||
load_new_batch,
|
load_new_batch,
|
||||||
record_current_batch_training_metrics,
|
record_current_batch_training_metrics,
|
||||||
|
wrap_FSDP_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -16,4 +17,5 @@ __all__ = [
|
||||||
"initialize_optimizer",
|
"initialize_optimizer",
|
||||||
"load_new_batch",
|
"load_new_batch",
|
||||||
"record_current_batch_training_metrics",
|
"record_current_batch_training_metrics",
|
||||||
|
"wrap_FSDP_model",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,13 +1,22 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import functools
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Iterable, Union
|
from typing import Callable, Iterable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||||
|
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
|
BackwardPrefetch,
|
||||||
|
ShardingStrategy,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||||
from torch.utils.data import ConcatDataset, DataLoader
|
from torch.utils.data import ConcatDataset, DataLoader
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
|
@ -25,24 +34,29 @@ from internlm.data.packed_dataset import (
|
||||||
get_packed_dataset_without_short_length,
|
get_packed_dataset_without_short_length,
|
||||||
)
|
)
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||||
|
from internlm.model.embedding import Embedding1D
|
||||||
|
from internlm.model.linear import (
|
||||||
|
FeedForward,
|
||||||
|
RewardModelLinear,
|
||||||
|
ScaleColumnParallelLinear,
|
||||||
|
)
|
||||||
|
from internlm.model.multi_head_attention import MHA
|
||||||
|
from internlm.model.utils import try_import_RMSNorm
|
||||||
from internlm.monitor import send_heartbeat, set_env_var
|
from internlm.monitor import send_heartbeat, set_env_var
|
||||||
from internlm.monitor.monitor import monitor_manager as mm
|
from internlm.monitor.monitor import monitor_manager as mm
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
|
||||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||||
from internlm.train.utils import create_param_groups
|
from internlm.train.utils import create_param_groups
|
||||||
from internlm.utils.common import DummyProfile
|
from internlm.utils.common import DummyProfile
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.parallel import (
|
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
||||||
is_no_pp_or_last_stage,
|
|
||||||
sync_model_param,
|
|
||||||
sync_model_param_within_tp,
|
|
||||||
)
|
|
||||||
from internlm.utils.registry import MODEL_INITIALIZER
|
from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
|
RMSNorm = try_import_RMSNorm()
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,7 +86,7 @@ def initialize_model():
|
||||||
else:
|
else:
|
||||||
model = NaiveAMPModel(
|
model = NaiveAMPModel(
|
||||||
model=model,
|
model=model,
|
||||||
output_to_fp32=is_no_pp_or_last_stage(),
|
output_to_fp32=gpc.is_no_pp_or_last_stage(),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
sync_buffer=False,
|
sync_buffer=False,
|
||||||
)
|
)
|
||||||
|
@ -90,6 +104,42 @@ def initialize_model():
|
||||||
# state in the same dp group are all the same.
|
# state in the same dp group are all the same.
|
||||||
set_mode(ParallelMode.DATA)
|
set_mode(ParallelMode.DATA)
|
||||||
|
|
||||||
|
# if fsdp enabled, wrap the model
|
||||||
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
|
# set wrap_policy for fsdp wrap
|
||||||
|
transformer_wrap_policy = functools.partial(
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
|
transformer_layer_cls={
|
||||||
|
Embedding1D,
|
||||||
|
ParallelGPT2Embeddings,
|
||||||
|
MHA,
|
||||||
|
RMSNorm,
|
||||||
|
FeedForward,
|
||||||
|
ParallelFusedMLP,
|
||||||
|
RewardModelLinear,
|
||||||
|
ScaleColumnParallelLinear,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# wrap the model
|
||||||
|
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||||
|
model = FSDP(
|
||||||
|
module=model,
|
||||||
|
process_group=grp,
|
||||||
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
auto_wrap_policy=transformer_wrap_policy,
|
||||||
|
forward_prefetch=True,
|
||||||
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||||
|
limit_all_gathers=True,
|
||||||
|
use_orig_params=True,
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,12 +168,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||||
eps=adam_cfg.adam_eps,
|
eps=adam_cfg.adam_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not gpc.config.parallel.zero1.fsdp:
|
||||||
optimizer = HybridZeroOptimizer(
|
optimizer = HybridZeroOptimizer(
|
||||||
naive_optimizer,
|
naive_optimizer,
|
||||||
grad_scal_cfg=gpc.config.grad_scaler,
|
grad_scal_cfg=gpc.config.grad_scaler,
|
||||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||||
param_bcast_sync_handler=param_bcast_sync_handler,
|
param_bcast_sync_handler=param_bcast_sync_handler,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
optimizer = FSDPadaptOptimizer(
|
||||||
|
naive_optimizer,
|
||||||
|
grad_scal_cfg=gpc.config.grad_scaler,
|
||||||
|
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||||
|
)
|
||||||
|
|
||||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||||
|
|
||||||
|
@ -360,7 +417,7 @@ def record_current_batch_training_metrics(
|
||||||
timer.store_last_timers()
|
timer.store_last_timers()
|
||||||
if success_update in (0, True):
|
if success_update in (0, True):
|
||||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||||
if is_no_pp_or_last_stage():
|
if gpc.is_no_pp_or_last_stage():
|
||||||
acc_perplex = metric.get_metric()
|
acc_perplex = metric.get_metric()
|
||||||
|
|
||||||
if success_update and gpc.is_rank_for_log():
|
if success_update and gpc.is_rank_for_log():
|
||||||
|
|
|
@ -12,6 +12,9 @@ from enum import Enum
|
||||||
from typing import Callable, Dict, Union
|
from typing import Callable, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.distributed._shard.api import load_with_process_group
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp import StateDictType
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
@ -157,11 +160,38 @@ def get_model_topology(model):
|
||||||
return topos
|
return topos
|
||||||
|
|
||||||
|
|
||||||
|
def get_shard_state_dict(shard_model):
|
||||||
|
"""
|
||||||
|
Only used for FSDP module saving.
|
||||||
|
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
|
||||||
|
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
|
||||||
|
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# FSDP model can only save with sharded shape SHARDED_STATE_DICT when set use_orig_params=True
|
||||||
|
with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT):
|
||||||
|
shard_states = shard_model.state_dict()
|
||||||
|
|
||||||
|
return shard_states
|
||||||
|
|
||||||
|
|
||||||
|
def load_shard_state_dict(shard_model, shard_state, **kwargs):
|
||||||
|
"""
|
||||||
|
Only used for FSDP module loading.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT):
|
||||||
|
missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs)
|
||||||
|
|
||||||
|
return (missing_k, unexpected_keys)
|
||||||
|
|
||||||
|
|
||||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
||||||
load_content_str = ""
|
load_content_str = ""
|
||||||
load_ckpt_folder = load_info["path"]
|
load_ckpt_folder = load_info["path"]
|
||||||
load_content: CheckpointLoadMask = load_info["content"]
|
load_content: CheckpointLoadMask = load_info["content"]
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||||
|
|
||||||
|
@ -224,6 +254,10 @@ def save_model_checkpoint(folder, model):
|
||||||
- folder
|
- folder
|
||||||
- model_tp{tp_rank}_pp{pp_rank}.pt
|
- model_tp{tp_rank}_pp{pp_rank}.pt
|
||||||
|
|
||||||
|
If fsdp is activated, the saved weight is named:
|
||||||
|
- folder
|
||||||
|
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
|
||||||
|
|
||||||
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -231,7 +265,11 @@ def save_model_checkpoint(folder, model):
|
||||||
model: The model to be saved
|
model: The model to be saved
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
|
states = get_shard_state_dict(model)
|
||||||
|
else:
|
||||||
states = model.state_dict()
|
states = model.state_dict()
|
||||||
|
|
||||||
# get non-expert parameters
|
# get non-expert parameters
|
||||||
states = get_non_moe_state_dict(states)
|
states = get_non_moe_state_dict(states)
|
||||||
topo = get_model_topology(model)
|
topo = get_model_topology(model)
|
||||||
|
@ -247,12 +285,18 @@ def save_model_checkpoint(folder, model):
|
||||||
# even if pp is not considered, it will definitely not be written on the same machine.
|
# even if pp is not considered, it will definitely not be written on the same machine.
|
||||||
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
||||||
for i in range(tp_size):
|
for i in range(tp_size):
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
|
for j in range(dp_size):
|
||||||
|
should_save_rank_pair.add((i, j))
|
||||||
|
else:
|
||||||
should_save_rank_pair.add((i, i % dp_size))
|
should_save_rank_pair.add((i, i % dp_size))
|
||||||
|
|
||||||
if (tp_rank, dp_rank) in should_save_rank_pair:
|
if (tp_rank, dp_rank) in should_save_rank_pair:
|
||||||
fn = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
f_dp = f"_dp{dp_rank}" if gpc.config.parallel.zero1.fsdp else ""
|
||||||
|
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt"
|
||||||
fp = os.path.join(folder, fn)
|
fp = os.path.join(folder, fn)
|
||||||
llm_save(fp, saved_obj=states)
|
llm_save(fp, saved_obj=states)
|
||||||
|
if not gpc.config.parallel.zero1.fsdp or dp_rank == tp_rank % dp_size:
|
||||||
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
||||||
topo_fp = os.path.join(folder, topo_fn)
|
topo_fp = os.path.join(folder, topo_fn)
|
||||||
llm_save(topo_fp, saved_obj=topo)
|
llm_save(topo_fp, saved_obj=topo)
|
||||||
|
@ -276,19 +320,37 @@ def load_model_checkpoint(folder, model):
|
||||||
- folder
|
- folder
|
||||||
- model_tp{tp_rank}_pp{pp_rank}.pt
|
- model_tp{tp_rank}_pp{pp_rank}.pt
|
||||||
|
|
||||||
|
If fsdp is activated, the saved weight is named:
|
||||||
|
- folder
|
||||||
|
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
|
||||||
|
|
||||||
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
|
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||||
|
|
||||||
fns = get_fns(folder)
|
fns = get_fns(folder)
|
||||||
max_pp, max_tp = 0, 0
|
|
||||||
|
# avoid ckpt misuse between FSDP and no-FSDP
|
||||||
|
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
||||||
|
assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or (
|
||||||
|
"_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp
|
||||||
|
), "FSDP model wants to load no-FSDP ckpts or reverse"
|
||||||
|
|
||||||
|
max_pp, max_tp, max_zo = 0, 0, 0
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
||||||
segements = os.path.splitext(fn)[0].split("_")
|
segements = os.path.splitext(fn)[0].split("_")
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
|
max_zo = max(max_zo, int(segements[-1][2:]))
|
||||||
|
max_pp = max(max_pp, int(segements[-2][2:]))
|
||||||
|
max_tp = max(max_tp, int(segements[-3][2:]))
|
||||||
|
else:
|
||||||
max_pp = max(max_pp, int(segements[-1][2:]))
|
max_pp = max(max_pp, int(segements[-1][2:]))
|
||||||
max_tp = max(max_tp, int(segements[-2][2:]))
|
max_tp = max(max_tp, int(segements[-2][2:]))
|
||||||
|
|
||||||
|
@ -298,9 +360,19 @@ def load_model_checkpoint(folder, model):
|
||||||
assert (
|
assert (
|
||||||
tp_size == max_tp + 1
|
tp_size == max_tp + 1
|
||||||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
|
assert (
|
||||||
|
dp_size == max_zo + 1
|
||||||
|
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
|
||||||
|
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
|
||||||
|
else:
|
||||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||||
fp = os.path.join(folder, should_load_name)
|
fp = os.path.join(folder, should_load_name)
|
||||||
|
|
||||||
|
# for FSDP shards loading, we need to set process group
|
||||||
|
with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)):
|
||||||
states = llm_load(fp, map_location=get_current_device())
|
states = llm_load(fp, map_location=get_current_device())
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -316,6 +388,9 @@ def load_model_checkpoint(folder, model):
|
||||||
# try to load expert parameter to separate files if model have moe layer
|
# try to load expert parameter to separate files if model have moe layer
|
||||||
try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank)
|
try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank)
|
||||||
|
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
|
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
|
||||||
|
else:
|
||||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||||
if len(missing_k) != 0:
|
if len(missing_k) != 0:
|
||||||
logger.warning(f"Warning: missing keys {missing_k}")
|
logger.warning(f"Warning: missing keys {missing_k}")
|
||||||
|
|
|
@ -50,10 +50,6 @@ def sync_model_param_within_tp(model):
|
||||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||||
|
|
||||||
|
|
||||||
def is_no_pp_or_last_stage():
|
|
||||||
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parallel_log_file_name():
|
def get_parallel_log_file_name():
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
fn_prefix = "main_" # Indicates a rank with more output information
|
fn_prefix = "main_" # Indicates a rank with more output information
|
||||||
|
|
Loading…
Reference in New Issue