feat(train): add fsdp training option (#293)

* feat(fsdp): add training option for fsdp

* fix(fsdp): add mix-precision training

* fix failure in lint-check

* fix format problem

* restore 7B_sft

* fix load ckpt bug

* fix load ckpt bug2

* feat(solver/optimizer): add new file fsdp_optimizer.py

* fix(train.py): fix ci lint error

* fix(fsdp_optimizer.py): wait grad async

* fix bug for loading ckpts when zero1 < dp_size

* fix(context/parallel_context.py): only log warning for fsdp

* change ckpt name

* fix(model/modeling_internlm.py): fix checkpoint=False runtime error

* more wrap

* add support for FSDP with tp

* modify args_sanity_check for fsdp with pipeline and fsdp with moe

* fix(internlm/utils/parallel.py): fix circular import

* fix(internlm/train/training_internlm.py): remove set IS_TENSOR_PARALLEL attr

* fix(internlm/train/training_internlm.py): update wrap class and fix lint error

* fix(internlm/model): reset dropout_selective_checkpoint=True

* feat(configs/7B_sft.py): move fsdp config to parallel zero1

* feat(configs/7B_sft.py): adapt to old version config

---------

Co-authored-by: huangting4201 <1538303371@qq.com>
pull/407/head
zaglc 2023-10-09 18:59:31 +08:00 committed by GitHub
parent 582ee000bd
commit a075153adf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 546 additions and 83 deletions

View File

@ -154,7 +154,7 @@ pipeline parallel (dict):
tensor parallel: tensor parallel size, usually the number of GPUs per node. tensor parallel: tensor parallel size, usually the number of GPUs per node.
""" """
parallel = dict( parallel = dict(
zero1=8, zero1=dict(size=8, fsdp=False),
tensor=1, tensor=1,
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False, sequence_parallel=False,

View File

@ -11,6 +11,7 @@ from .process_group_initializer import (
Initializer_Pipeline, Initializer_Pipeline,
Initializer_Tensor, Initializer_Tensor,
Initializer_Zero1, Initializer_Zero1,
Initializer_Zero3_dp,
ParallelMode, ParallelMode,
ProcessGroupInitializer, ProcessGroupInitializer,
) )
@ -36,6 +37,7 @@ __all__ = [
"Initializer_Data", "Initializer_Data",
"Initializer_Zero1", "Initializer_Zero1",
"Initializer_Nettest", "Initializer_Nettest",
"Initializer_Zero3_dp",
"ProcessGroupInitializer", "ProcessGroupInitializer",
"Initializer_Model", "Initializer_Model",
"seed", "seed",

View File

@ -328,6 +328,9 @@ class ParallelContext(metaclass=SingletonMeta):
return False return False
return self.is_last_rank(ParallelMode.PIPELINE) return self.is_last_rank(ParallelMode.PIPELINE)
def is_no_pp_or_last_stage(self):
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
def get_world_size(self, parallel_mode: ParallelMode): def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`. """Returns the world size for `parallel_mode`.
@ -429,6 +432,16 @@ class ParallelContext(metaclass=SingletonMeta):
assert self.zero1_parallel_size > 0 assert self.zero1_parallel_size > 0
assert self.data_parallel_size % self.zero1_parallel_size == 0 assert self.data_parallel_size % self.zero1_parallel_size == 0
# check for fsdp:
# if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights
# because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank
# pytorch vision: 1.13.1+cu117
if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.zero1.get("fsdp", False):
logger.warning(
f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, "
"will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value"
)
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config: if key in config:
ele = config[key] ele = config[key]
@ -495,6 +508,8 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if isinstance(self.config.parallel.zero1, dict) and self.config.parallel.zero1.get("fsdp", False):
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
if self.pipeline_parallel_size > 1: if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))

View File

@ -37,6 +37,11 @@ class ParallelMode(Enum):
# runntime network test # runntime network test
NETTEST = "nettest" NETTEST = "nettest"
# zero3-dp parallel
# if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size
# then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
ZERO3_DP = "zero3_dp"
# expert parallel # expert parallel
EXPERT = "expert" EXPERT = "expert"
@ -594,3 +599,62 @@ class Initializer_Expert_Data(ProcessGroupInitializer):
) )
return groups return groups
class Initializer_Zero3_dp(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.data_parallel_size % self.zero1_parallel_size == 0
# the only difference between this initializer and DP_initializer
# when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding
# eg: when zero=4 and dp=8
# no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group
# fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually
self.data_parallel_size //= self.zero1_parallel_size
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
assert self.world_size % self.data_parallel_size == 0
def init_dist_group(self, use_cpu: bool = False):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Data parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.ZERO3_DP
for i in range(self.rank_num_per_dp_group):
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
else:
group_cpu = None
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@ -16,7 +16,7 @@ from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout from internlm.utils.timeout import llm_timeout
# check pacakge # check package
try: try:
import numa import numa
from numa import memory, schedule from numa import memory, schedule
@ -65,14 +65,36 @@ def args_sanity_check():
# procssing the parallel config in gpc # procssing the parallel config in gpc
if "zero1" not in gpc.config.parallel: if "zero1" not in gpc.config.parallel:
gpc.config.parallel._add_item("zero1", -1) gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False))
if isinstance(gpc.config.parallel.zero1, int):
zero1_size = gpc.config.parallel.zero1
gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False))
if "pipeline" not in gpc.config.parallel: if "pipeline" not in gpc.config.parallel:
gpc.config.parallel._add_item("pipeline", 1) gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False))
if "tensor" not in gpc.config.parallel: if "tensor" not in gpc.config.parallel:
gpc.config.parallel._add_item("tensor", 1) gpc.config.parallel._add_item("tensor", 1)
if isinstance(gpc.config.parallel.pipeline, int):
pp = gpc.config.parallel.pipeline
else:
pp = gpc.config.parallel.pipeline.size
# check fsdp config
if "fsdp" not in gpc.config.parallel.zero1:
gpc.config.parallel.zero1._add_item("fsdp", False)
assert not (
gpc.config.parallel.zero1.fsdp and pp > 1
), "FSDP is not supportted when pipeline size > 1, please set pipeline size to 1 or disabled FSDP"
if gpc.config.parallel.zero1.fsdp:
assert (
torch.__version__ >= "2.0.1"
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
# processing the data config in gpc # processing the data config in gpc
data = gpc.config.data data = gpc.config.data
@ -271,6 +293,9 @@ def args_sanity_check():
model._add_item("moe_use_residual", False) model._add_item("moe_use_residual", False)
if "moe_gate_k" not in model: if "moe_gate_k" not in model:
model._add_item("moe_gate_k", 2) model._add_item("moe_gate_k", 2)
assert not (
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
), "FSDP does not support num_experts > 1"
# process the parallel config # process the parallel config
if "sequence_parallel" not in gpc.config.parallel: if "sequence_parallel" not in gpc.config.parallel:

View File

@ -6,7 +6,6 @@ from torch_scatter import scatter
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.utils.parallel import is_no_pp_or_last_stage
class AccPerplex: class AccPerplex:
@ -138,7 +137,7 @@ class AccPerplex:
self.total_log_probs += total_log_probs self.total_log_probs += total_log_probs
def get_metric(self, reset=True): def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None: if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
@ -236,7 +235,7 @@ class LossWithTypeId:
self.ds_token_num += token_num_type self.ds_token_num += token_num_type
def get_metric(self, reset=True): def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None: if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if hasattr(self, "total_type_count"): if hasattr(self, "total_type_count"):

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from .fsdp_optimizer import FSDPadaptOptimizer
from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff
__all__ = ["HybridZeroOptimizer", "reload_zero_fp32_buff"] __all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "reload_zero_fp32_buff"]

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch.optim import Optimizer
class BaseOptimizer(Optimizer):
"""
Base Optimizer.
"""
def __init__(self, optim: Optimizer): # pylint: disable=W0231
self.optim = optim
@property
def param_groups(self):
return self.optim.param_groups
@property
def defaults(self):
return self.optim.defaults
def add_param_group(self, *args, **kwargs):
return self.optim.add_param_group(*args, **kwargs)
def step(self, *args, **kwargs):
return self.optim.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
self.optim.zero_grad(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
self.optim.load_state_dict(*args, **kwargs)
def state_dict(self):
return self.optim.state_dict()
def backward(self, loss):
loss.backward()
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
def clip_grad_norm(self):
pass

View File

@ -0,0 +1,221 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.solver.optimizer.utils import (
DynamicGradScaler,
reduce_tensor,
release_param_grad,
)
from internlm.utils.logger import get_logger
from .base_optimizer import BaseOptimizer
from .utils import compute_norm
logger = get_logger(__file__)
class FSDPadaptOptimizer(BaseOptimizer):
"""
optimizer for Pytorch FSDP if 'parallel.zero1.fsdp' is True in config file
reserve some necessary components of hybird-optim:
grad_scaler;
grad_clip and unscale;
state_dict and load_state_dict
"""
def __init__(
self,
optimizer: Optimizer,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
):
super().__init__(optim=optimizer)
# gradient scaler
self.grad_scaler = DynamicGradScaler(
initial_scale=grad_scal_cfg.fp16.initial_scale,
min_scale=grad_scal_cfg.fp16.min_scale,
growth_factor=grad_scal_cfg.growth_factor,
backoff_factor=grad_scal_cfg.backoff_factor,
growth_interval=grad_scal_cfg.fp16.growth_interval,
hysteresis=grad_scal_cfg.hysteresis,
max_scale=grad_scal_cfg.max_scale,
)
# clip gradient
self._clip_grad_norm = zero_cfg.clip_grad_norm
# fp16 and fp32 params
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
self._fp16_param_groups = dict()
self._fp32_param_tensor_groups = dict()
# init fp16 and fp32 params
for group_idx, param_group in enumerate(self.optim.param_groups):
group_params = param_group["params"]
# fp16 FlatParam storage
self._fp16_param_groups[group_idx] = group_params
# create copy of fp32 weight
fp32_tensor_param = [param.data.float() for param in group_params]
self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param
# replace
param_group["params"] = fp32_tensor_param
@property
def loss_scale(self):
return self.grad_scaler.scale
def backward(self, loss, retain_graph=False):
loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph)
def _compute_norm_with_fsdp_flatten(self, group_id):
params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0]
gradients = [p.grad for p in params if p.untyped_storage().size() != 0]
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
return norm_group
def zero_grad(self):
for _, param_group in self._fp16_param_groups.items():
for param in param_group:
param.grad = None
def step(self):
# in case that fsdp-zero3 size is not equal to dp size
# FSDP module will only reduce gradient within FSDP process group
# so manually reduce grad is essential between two parallel FSDP process group
for group_idx in range(len(self.param_groups)):
params = self._fp16_param_groups[group_idx]
for param in params:
if param.requires_grad and param.grad is not None:
handle = reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
handle.wait()
# compute norm
found_inf = False
norm_groups = {}
for group_idx in range(len(self.param_groups)):
group_name = self.param_groups[group_idx]["name"] if "name" in self.param_groups[group_idx] else "default"
group_name = f"{group_idx}_{group_name}"
norm_group = self._compute_norm_with_fsdp_flatten(group_idx)
if norm_group == -1:
found_inf = True
norm_groups[group_name] = norm_group
loss_scale = float(self.loss_scale.item()) # backup
self.grad_scaler.update(found_inf)
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
self.zero_grad()
return False, norm_groups
# get the global norm
global_norm_groups = {}
if self._clip_grad_norm > 0:
for group_name, norm in norm_groups.items():
global_norm_groups[group_name] = norm**0.5
# create gradient for fp32 params
for group_idx in range(len(self.param_groups)):
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
device = self._fp32_param_tensor_groups[group_idx][0].device
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0]
for p, g in zip(nonzero_fp32, grad_fp32):
p.grad = g.to(device)
# unscale
self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale)
self.optim.step()
self.zero_grad()
for group_idx in range(len(self._fp16_param_groups)):
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
fp32_tensor_params = [
p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0
]
# release fp32 grad
release_param_grad(fp32_tensor_params)
# update fp16 param
for p, q in zip(fp16_params, fp32_tensor_params):
p.data.copy_(q)
for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups
def clip_grad_norm(self, model, max_norm):
# will conduct in the step()
pass
#########################
# utils from hybirdzero #
#########################
def _unscale_and_clip_grads(self, total_norm_groups, loss_scale):
# compute combined scale factor for this group
combined_scale_groups = []
if self._clip_grad_norm > 0.0:
# norm is in fact norm*scale
for group_id, total_norm in enumerate(total_norm_groups):
combined_scale_groups.append(loss_scale)
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1.0:
combined_scale_groups[group_id] = clip * loss_scale
for group_id, param in self._fp32_param_tensor_groups.items():
for p in param:
if p.untyped_storage().size() != 0:
p.grad.data.mul_(1.0 / combined_scale_groups[group_id])
def state_dict(self):
states = {}
grad_scaler = self.grad_scaler.state_dict()
states["grad_scaler"] = grad_scaler
optim_states = self.optim.state_dict()
states["base_optim_states"] = optim_states
flat_fp32_weights = {}
for group_idx, param in self._fp32_param_tensor_groups.items():
flat_fp32_weights[group_idx] = param
states["flat_fp32_weights"] = flat_fp32_weights
return states
def load_state_dict(self, states):
assert "grad_scaler" in states, "Not found grad_scaler state!"
grad_scaler = states["grad_scaler"]
self.grad_scaler.load_state_dict(grad_scaler)
optim_states = states["base_optim_states"]
self.optim.load_state_dict(optim_states)
# load fp32 optimizer weight
flat_fp32_weights = states["flat_fp32_weights"]
assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups)
for group_idx, param in flat_fp32_weights.items():
self_param = self._fp32_param_tensor_groups[group_idx]
assert len(self_param) == len(
param
), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}"
for p, q in zip(self_param, param):
p.data.copy_(q.data)
# load fp16 model weight
for group_idx, param in flat_fp32_weights.items():
fp16_param = self._fp16_param_groups[group_idx]
fp32_param = self._fp32_param_tensor_groups[group_idx]
for p, q in zip(fp16_param, fp32_param):
p.data.copy_(q.data)

View File

@ -33,53 +33,13 @@ from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.timeout import llm_timeout from internlm.utils.timeout import llm_timeout
from .base_optimizer import BaseOptimizer
from .utils import compute_norm from .utils import compute_norm
inf = math.inf inf = math.inf
logger = get_logger(__file__) logger = get_logger(__file__)
class BaseOptimizer(Optimizer):
"""
Base Optimizer.
"""
def __init__(self, optim: Optimizer): # pylint: disable=W0231
self.optim = optim
@property
def param_groups(self):
return self.optim.param_groups
@property
def defaults(self):
return self.optim.defaults
def add_param_group(self, *args, **kwargs):
return self.optim.add_param_group(*args, **kwargs)
def step(self, *args, **kwargs):
return self.optim.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
self.optim.zero_grad(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
self.optim.load_state_dict(*args, **kwargs)
def state_dict(self):
return self.optim.state_dict()
def backward(self, loss):
loss.backward()
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
def clip_grad_norm(self):
pass
class HybridZeroOptimizer(BaseOptimizer): class HybridZeroOptimizer(BaseOptimizer):
""" """
Hybrid Zero Optimizer. Hybrid Zero Optimizer.

View File

@ -6,6 +6,7 @@ from .training_internlm import (
initialize_optimizer, initialize_optimizer,
load_new_batch, load_new_batch,
record_current_batch_training_metrics, record_current_batch_training_metrics,
wrap_FSDP_model,
) )
__all__ = [ __all__ = [
@ -16,4 +17,5 @@ __all__ = [
"initialize_optimizer", "initialize_optimizer",
"load_new_batch", "load_new_batch",
"record_current_batch_training_metrics", "record_current_batch_training_metrics",
"wrap_FSDP_model",
] ]

View File

@ -1,13 +1,22 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import functools
import time import time
from functools import partial from functools import partial
from typing import Callable, Iterable, Union from typing import Callable, Iterable, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import ConcatDataset, DataLoader from torch.utils.data import ConcatDataset, DataLoader
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
@ -25,24 +34,29 @@ from internlm.data.packed_dataset import (
get_packed_dataset_without_short_length, get_packed_dataset_without_short_length,
) )
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
FeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import try_import_RMSNorm
from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor import send_heartbeat, set_env_var
from internlm.monitor.monitor import monitor_manager as mm from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.train.utils import create_param_groups from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import ( from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
is_no_pp_or_last_stage,
sync_model_param,
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout from internlm.utils.timeout import llm_timeout
RMSNorm = try_import_RMSNorm()
logger = get_logger(__file__) logger = get_logger(__file__)
@ -72,7 +86,7 @@ def initialize_model():
else: else:
model = NaiveAMPModel( model = NaiveAMPModel(
model=model, model=model,
output_to_fp32=is_no_pp_or_last_stage(), output_to_fp32=gpc.is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half), dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False, sync_buffer=False,
) )
@ -90,6 +104,42 @@ def initialize_model():
# state in the same dp group are all the same. # state in the same dp group are all the same.
set_mode(ParallelMode.DATA) set_mode(ParallelMode.DATA)
# if fsdp enabled, wrap the model
model = wrap_FSDP_model(model)
return model
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
if gpc.config.parallel.zero1.fsdp:
# set wrap_policy for fsdp wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
Embedding1D,
ParallelGPT2Embeddings,
MHA,
RMSNorm,
FeedForward,
ParallelFusedMLP,
RewardModelLinear,
ScaleColumnParallelLinear,
},
)
# wrap the model
grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(
module=model,
process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=transformer_wrap_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
use_orig_params=True,
)
return model return model
@ -118,12 +168,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
eps=adam_cfg.adam_eps, eps=adam_cfg.adam_eps,
) )
if not gpc.config.parallel.zero1.fsdp:
optimizer = HybridZeroOptimizer( optimizer = HybridZeroOptimizer(
naive_optimizer, naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler, grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer, zero_cfg=gpc.config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler, param_bcast_sync_handler=param_bcast_sync_handler,
) )
else:
optimizer = FSDPadaptOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
@ -360,7 +417,7 @@ def record_current_batch_training_metrics(
timer.store_last_timers() timer.store_last_timers()
if success_update in (0, True): if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage(): if gpc.is_no_pp_or_last_stage():
acc_perplex = metric.get_metric() acc_perplex = metric.get_metric()
if success_update and gpc.is_rank_for_log(): if success_update and gpc.is_rank_for_log():

View File

@ -12,6 +12,9 @@ from enum import Enum
from typing import Callable, Dict, Union from typing import Callable, Dict, Union
import torch import torch
from torch.distributed._shard.api import load_with_process_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
@ -157,11 +160,38 @@ def get_model_topology(model):
return topos return topos
def get_shard_state_dict(shard_model):
"""
Only used for FSDP module saving.
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
"""
# FSDP model can only save with sharded shape SHARDED_STATE_DICT when set use_orig_params=True
with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT):
shard_states = shard_model.state_dict()
return shard_states
def load_shard_state_dict(shard_model, shard_state, **kwargs):
"""
Only used for FSDP module loading.
"""
with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT):
missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs)
return (missing_k, unexpected_keys)
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
load_content_str = "" load_content_str = ""
load_ckpt_folder = load_info["path"] load_ckpt_folder = load_info["path"]
load_content: CheckpointLoadMask = load_info["content"] load_content: CheckpointLoadMask = load_info["content"]
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
@ -224,6 +254,10 @@ def save_model_checkpoint(folder, model):
- folder - folder
- model_tp{tp_rank}_pp{pp_rank}.pt - model_tp{tp_rank}_pp{pp_rank}.pt
If fsdp is activated, the saved weight is named:
- folder
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
Args: Args:
@ -231,7 +265,11 @@ def save_model_checkpoint(folder, model):
model: The model to be saved model: The model to be saved
""" """
if gpc.config.parallel.zero1.fsdp:
states = get_shard_state_dict(model)
else:
states = model.state_dict() states = model.state_dict()
# get non-expert parameters # get non-expert parameters
states = get_non_moe_state_dict(states) states = get_non_moe_state_dict(states)
topo = get_model_topology(model) topo = get_model_topology(model)
@ -247,12 +285,18 @@ def save_model_checkpoint(folder, model):
# even if pp is not considered, it will definitely not be written on the same machine. # even if pp is not considered, it will definitely not be written on the same machine.
should_save_rank_pair = set() # (tp_rank, dp_rank) should_save_rank_pair = set() # (tp_rank, dp_rank)
for i in range(tp_size): for i in range(tp_size):
if gpc.config.parallel.zero1.fsdp:
for j in range(dp_size):
should_save_rank_pair.add((i, j))
else:
should_save_rank_pair.add((i, i % dp_size)) should_save_rank_pair.add((i, i % dp_size))
if (tp_rank, dp_rank) in should_save_rank_pair: if (tp_rank, dp_rank) in should_save_rank_pair:
fn = f"model_tp{tp_rank}_pp{pp_rank}.pt" f_dp = f"_dp{dp_rank}" if gpc.config.parallel.zero1.fsdp else ""
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt"
fp = os.path.join(folder, fn) fp = os.path.join(folder, fn)
llm_save(fp, saved_obj=states) llm_save(fp, saved_obj=states)
if not gpc.config.parallel.zero1.fsdp or dp_rank == tp_rank % dp_size:
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
topo_fp = os.path.join(folder, topo_fn) topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo) llm_save(topo_fp, saved_obj=topo)
@ -276,19 +320,37 @@ def load_model_checkpoint(folder, model):
- folder - folder
- model_tp{tp_rank}_pp{pp_rank}.pt - model_tp{tp_rank}_pp{pp_rank}.pt
If fsdp is activated, the saved weight is named:
- folder
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
""" """
tp_size = gpc.get_world_size(ParallelMode.TENSOR) tp_size = gpc.get_world_size(ParallelMode.TENSOR)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE) pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
dp_size = gpc.get_world_size(ParallelMode.DATA)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
fns = get_fns(folder) fns = get_fns(folder)
max_pp, max_tp = 0, 0
# avoid ckpt misuse between FSDP and no-FSDP
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or (
"_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp
), "FSDP model wants to load no-FSDP ckpts or reverse"
max_pp, max_tp, max_zo = 0, 0, 0
for fn in fns: for fn in fns:
if fn.startswith("model_t") and not fn.endswith(".md5"): if fn.startswith("model_t") and not fn.endswith(".md5"):
segements = os.path.splitext(fn)[0].split("_") segements = os.path.splitext(fn)[0].split("_")
if gpc.config.parallel.zero1.fsdp:
max_zo = max(max_zo, int(segements[-1][2:]))
max_pp = max(max_pp, int(segements[-2][2:]))
max_tp = max(max_tp, int(segements[-3][2:]))
else:
max_pp = max(max_pp, int(segements[-1][2:])) max_pp = max(max_pp, int(segements[-1][2:]))
max_tp = max(max_tp, int(segements[-2][2:])) max_tp = max(max_tp, int(segements[-2][2:]))
@ -298,9 +360,19 @@ def load_model_checkpoint(folder, model):
assert ( assert (
tp_size == max_tp + 1 tp_size == max_tp + 1
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
if gpc.config.parallel.zero1.fsdp:
assert (
dp_size == max_zo + 1
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
if gpc.config.parallel.zero1.fsdp:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
else:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
fp = os.path.join(folder, should_load_name) fp = os.path.join(folder, should_load_name)
# for FSDP shards loading, we need to set process group
with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)):
states = llm_load(fp, map_location=get_current_device()) states = llm_load(fp, map_location=get_current_device())
""" """
@ -316,6 +388,9 @@ def load_model_checkpoint(folder, model):
# try to load expert parameter to separate files if model have moe layer # try to load expert parameter to separate files if model have moe layer
try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank) try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank)
if gpc.config.parallel.zero1.fsdp:
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
else:
missing_k, unexpected_keys = model.load_state_dict(states, strict=False) missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
if len(missing_k) != 0: if len(missing_k) != 0:
logger.warning(f"Warning: missing keys {missing_k}") logger.warning(f"Warning: missing keys {missing_k}")

View File

@ -50,10 +50,6 @@ def sync_model_param_within_tp(model):
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
def is_no_pp_or_last_stage():
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
def get_parallel_log_file_name(): def get_parallel_log_file_name():
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
fn_prefix = "main_" # Indicates a rank with more output information fn_prefix = "main_" # Indicates a rank with more output information