mirror of https://github.com/InternLM/InternLM
feat(fsdp): add training option for fsdp
parent
c516602e9a
commit
85c6ed6473
|
@ -7,22 +7,23 @@ MLP_RATIO = 8 / 3
|
|||
NUM_LAYER = 32
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/20"
|
||||
# Ckpt folder format:
|
||||
# fs: 'local:/mnt/nfs/XXX'
|
||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/20"
|
||||
|
||||
# boto3 Ckpt folder format:
|
||||
# import os
|
||||
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
||||
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
||||
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
||||
CHECKPOINT_EVERY = 50
|
||||
CHECKPOINT_EVERY = 20
|
||||
ckpt = dict(
|
||||
enable_save_ckpt=False, # enable ckpt save.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||
load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||
load_given_ckpt = False,
|
||||
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
||||
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
||||
checkpoint_every=CHECKPOINT_EVERY,
|
||||
|
@ -32,7 +33,7 @@ ckpt = dict(
|
|||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
TRAIN_FOLDER = "../../train_data"#"/path/to/dataset"
|
||||
VALID_FOLDER = "/path/to/dataset"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
|
@ -50,7 +51,7 @@ data = dict(
|
|||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
train_folder=TRAIN_FOLDER,
|
||||
# valid_folder=VALID_FOLDER,
|
||||
)
|
||||
|
||||
|
@ -111,7 +112,7 @@ beta2_scheduler = dict(
|
|||
)
|
||||
|
||||
model = dict(
|
||||
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
||||
checkpoint=True, # The proportion of layers for activation checkpointing, the optional value are True/False/[0-1]
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
|
@ -121,7 +122,7 @@ model = dict(
|
|||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
|
@ -140,9 +141,11 @@ pipeline parallel (dict):
|
|||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
zero1=-1,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
tensor=2,
|
||||
sequence_parallel=False,
|
||||
use_fsdp = False,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -12,6 +12,7 @@ from .process_group_initializer import (
|
|||
Initializer_Zero1,
|
||||
ParallelMode,
|
||||
ProcessGroupInitializer,
|
||||
Initializer_Zero3_dp,
|
||||
)
|
||||
from .random import (
|
||||
add_seed,
|
||||
|
@ -34,6 +35,7 @@ __all__ = [
|
|||
"Initializer_Pipeline",
|
||||
"Initializer_Data",
|
||||
"Initializer_Zero1",
|
||||
"Initializer_Zero3_dp",
|
||||
"ProcessGroupInitializer",
|
||||
"Initializer_Model",
|
||||
"seed",
|
||||
|
|
|
@ -462,6 +462,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||
if self.config.parallel.use_fsdp:
|
||||
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
||||
if self.pipeline_parallel_size > 1:
|
||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||
for initializer in initializers:
|
||||
|
|
|
@ -31,6 +31,11 @@ class ParallelMode(Enum):
|
|||
# zero1 parallel
|
||||
ZERO1 = "zero1"
|
||||
|
||||
# zero3-dp parallel
|
||||
# if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size
|
||||
# then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
|
||||
ZERO3_DP = "zero3_dp"
|
||||
|
||||
|
||||
class ProcessGroupInitializer(ABC):
|
||||
"""An object, knowing the parallelism configuration, that initializes parallel groups.
|
||||
|
@ -332,3 +337,62 @@ class Initializer_Zero1(ProcessGroupInitializer):
|
|||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
class Initializer_Zero3_dp(ProcessGroupInitializer):
|
||||
"""A ProcessGroupInitializer for data parallelism.
|
||||
|
||||
Args:
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
||||
|
||||
# the only difference between this initializer and DP_initializer
|
||||
# when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding
|
||||
# eg: when zero=4 and dp=8
|
||||
# no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group
|
||||
# fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually
|
||||
|
||||
self.data_parallel_size //= self.zero1_parallel_size
|
||||
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
||||
|
||||
assert self.world_size % self.data_parallel_size == 0
|
||||
|
||||
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Returns:
|
||||
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||
A Data parallelism's information tuple.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.ZERO3_DP
|
||||
|
||||
for i in range(self.rank_num_per_dp_group):
|
||||
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
|
@ -61,6 +61,17 @@ def args_sanity_check():
|
|||
if "tensor" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("tensor", 1)
|
||||
|
||||
if isinstance(gpc.config.parallel.pipeline, int):
|
||||
pp = gpc.config.parallel.pipeline
|
||||
else:
|
||||
pp = gpc.config.parallel.pipeline.size
|
||||
|
||||
if "use_fsdp" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
elif gpc.config.parallel.use_fsdp and pp > 1:
|
||||
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
|
||||
# processing the data config in gpc
|
||||
data = gpc.config.data
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from .hybrid_zero_optim import HybridZeroOptimizer
|
||||
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer
|
||||
|
||||
__all__ = ["HybridZeroOptimizer"]
|
||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"]
|
||||
|
|
|
@ -80,6 +80,132 @@ class BaseOptimizer(Optimizer):
|
|||
pass
|
||||
|
||||
|
||||
class FSDPadaptOptimizer(BaseOptimizer):
|
||||
'''
|
||||
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
|
||||
reserve some necessary components of hybird-optim:
|
||||
grad_scaler;
|
||||
grad_clip and unscale;
|
||||
state_dict and load_state_dict
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
grad_scal_cfg: Config = None,
|
||||
zero_cfg: Config = None,
|
||||
):
|
||||
super().__init__(optim=optimizer)
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=grad_scal_cfg.fp16.initial_scale,
|
||||
min_scale=grad_scal_cfg.fp16.min_scale,
|
||||
growth_factor=grad_scal_cfg.growth_factor,
|
||||
backoff_factor=grad_scal_cfg.backoff_factor,
|
||||
growth_interval=grad_scal_cfg.fp16.growth_interval,
|
||||
hysteresis=grad_scal_cfg.hysteresis,
|
||||
max_scale=grad_scal_cfg.max_scale,
|
||||
)
|
||||
|
||||
# clip gradient
|
||||
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
||||
self.use_fsdp = gpc.config.parallel.use_fsdp
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
loss = self.loss_scale * loss
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
def step(self):
|
||||
# in case that fsdp-zero3 size is not equal to dp size
|
||||
# FSDP module will only reduce gradient within FSDP process group
|
||||
# so manually reduce grad is essential between two parallel FSDP process group
|
||||
for group_idx in range(len(self.param_groups)):
|
||||
params = self.param_groups[group_idx]["params"]
|
||||
for param in params:
|
||||
if param.requires_grad:
|
||||
reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
|
||||
|
||||
# compute norm
|
||||
found_inf = False
|
||||
norm_groups = []
|
||||
for group_idx in range(len(self.param_groups)):
|
||||
params = self.param_groups[group_idx]["params"]
|
||||
gradients = [p.grad for p in params]
|
||||
norm_group = compute_norm(
|
||||
gradients=gradients,
|
||||
parameters=params,
|
||||
last_stage=True
|
||||
)
|
||||
if norm_group == -1:
|
||||
found_inf = True
|
||||
break
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
loss_scale = float(self.loss_scale.item()) # backup
|
||||
self.grad_scaler.update(found_inf)
|
||||
if found_inf:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("Overflow occurs, please check it.")
|
||||
self.zero_grad()
|
||||
return False, None
|
||||
|
||||
if self._clip_grad_norm > 0:
|
||||
global_norm = sum(norm_groups) ** 0.5
|
||||
|
||||
# unscale
|
||||
for group_idx in range(len(self.param_groups)):
|
||||
params = self.param_groups[group_idx]["params"]
|
||||
for p in params:
|
||||
self._unscale_and_clip_grads(p.grad, global_norm, loss_scale)
|
||||
|
||||
self.optim.step()
|
||||
self.zero_grad()
|
||||
|
||||
return True, [global_norm / loss_scale]
|
||||
|
||||
def clip_grad_norm(self, model, max_norm):
|
||||
# will conduct in the step()
|
||||
pass
|
||||
|
||||
#########################
|
||||
# utils from hybirdzero #
|
||||
#########################
|
||||
|
||||
def _unscale_and_clip_grads(self, grad, total_norm, loss_scale):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = loss_scale
|
||||
|
||||
if self._clip_grad_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1.0:
|
||||
combined_scale = clip * loss_scale
|
||||
|
||||
# for grad in grad_groups_flat:
|
||||
grad.data.mul_(1.0 / combined_scale)
|
||||
|
||||
def state_dict(self):
|
||||
states = {}
|
||||
grad_scaler = self.grad_scaler.state_dict()
|
||||
states["grad_scaler"] = grad_scaler
|
||||
optim_states = self.optim.state_dict()
|
||||
states["base_optim_states"] = optim_states
|
||||
|
||||
return states
|
||||
|
||||
def load_state_dict(self, states):
|
||||
assert "grad_scaler" in states, "Not found grad_scaler state!"
|
||||
grad_scaler = states["grad_scaler"]
|
||||
self.grad_scaler.load_state_dict(grad_scaler)
|
||||
optim_states = states["base_optim_states"]
|
||||
self.optim.load_state_dict(optim_states)
|
||||
|
||||
|
||||
class HybridZeroOptimizer(BaseOptimizer):
|
||||
"""
|
||||
Hybrid Zero Optimizer.
|
||||
|
|
|
@ -6,6 +6,7 @@ from .training_internlm import (
|
|||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
warp_FSDP_model,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -16,4 +17,5 @@ __all__ = [
|
|||
"initialize_optimizer",
|
||||
"load_new_batch",
|
||||
"record_current_batch_training_metrics",
|
||||
"warp_FSDP_model",
|
||||
]
|
||||
|
|
|
@ -28,7 +28,7 @@ from internlm.monitor import set_env_var
|
|||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.logger import get_logger
|
||||
|
@ -40,6 +40,24 @@ from internlm.utils.parallel import (
|
|||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
CPUOffload,
|
||||
BackwardPrefetch,
|
||||
ShardingStrategy,
|
||||
MixedPrecision,
|
||||
BackwardPrefetch,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
enable_wrap,
|
||||
wrap,
|
||||
)
|
||||
import functools
|
||||
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
|
@ -83,6 +101,30 @@ def initialize_model():
|
|||
return model
|
||||
|
||||
|
||||
def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||
)
|
||||
mx = MixedPrecision(
|
||||
param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype,
|
||||
buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
|
||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||
model = FSDP(module=model,
|
||||
process_group=grp,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
auto_wrap_policy=transformer_wrap_policy,
|
||||
forward_prefetch=True,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
#cpu_offload=CPUOfload(offload_params=True)
|
||||
#mixed_precision=mx,
|
||||
#device_id=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
@ -105,12 +147,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
eps=adam_cfg.adam_eps,
|
||||
)
|
||||
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||
param_bcast_sync_handler=param_bcast_sync_handler,
|
||||
)
|
||||
if not gpc.config.parallel.use_fsdp:
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||
param_bcast_sync_handler=param_bcast_sync_handler,
|
||||
)
|
||||
else:
|
||||
optimizer = FSDPadaptOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||
)
|
||||
|
||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||
|
||||
|
|
|
@ -55,6 +55,26 @@ def get_model_topology(model):
|
|||
return topos
|
||||
|
||||
|
||||
def get_state_dict(model):
|
||||
"""
|
||||
Only used for FSDP module saving.
|
||||
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
|
||||
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
|
||||
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
|
||||
|
||||
"""
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig
|
||||
|
||||
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
||||
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||
with FSDP.state_dict_type(
|
||||
model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||
states = model.state_dict()
|
||||
|
||||
return states
|
||||
|
||||
|
||||
def save_model_checkpoint(folder, model):
|
||||
"""
|
||||
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
|
||||
|
@ -69,7 +89,11 @@ def save_model_checkpoint(folder, model):
|
|||
model: The model to be saved
|
||||
"""
|
||||
|
||||
states = model.state_dict()
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
states = get_state_dict(model)
|
||||
else:
|
||||
states = model.state_dict()
|
||||
|
||||
topo = get_model_topology(model)
|
||||
|
||||
if folder is not None:
|
||||
|
|
4
train.py
4
train.py
|
@ -28,6 +28,7 @@ from internlm.train import (
|
|||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
warp_FSDP_model,
|
||||
)
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
|
@ -123,6 +124,9 @@ def main(args):
|
|||
# Loading model weights must be done before zero is initialized.
|
||||
ckpt_manager.try_load_model(current_time)
|
||||
|
||||
# if fsdp enabled, warp the model
|
||||
model = warp_FSDP_model(model)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
# Loading other persistent training states.
|
||||
|
|
Loading…
Reference in New Issue