feat(fsdp): add training option for fsdp

pull/292/head
zaglc 2023-09-04 18:01:30 +08:00
parent c516602e9a
commit 85c6ed6473
11 changed files with 306 additions and 19 deletions

View File

@ -7,22 +7,23 @@ MLP_RATIO = 8 / 3
NUM_LAYER = 32
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
MODEL_ONLY_FOLDER = "local:llm_ckpts/20"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
LOAD_CKPT_FOLDER = "local:llm_ckpts/20"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
CHECKPOINT_EVERY = 20
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
load_given_ckpt = False,
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY,
@ -32,7 +33,7 @@ ckpt = dict(
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/path/to/dataset"
TRAIN_FOLDER = "../../train_data"#"/path/to/dataset"
VALID_FOLDER = "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
@ -50,7 +51,7 @@ data = dict(
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
)
@ -111,7 +112,7 @@ beta2_scheduler = dict(
)
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
checkpoint=True, # The proportion of layers for activation checkpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
@ -121,7 +122,7 @@ model = dict(
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
@ -140,9 +141,11 @@ pipeline parallel (dict):
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=8,
zero1=-1,
pipeline=dict(size=1, interleaved_overlap=True),
tensor=2,
sequence_parallel=False,
use_fsdp = False,
)
cudnn_deterministic = False

View File

@ -12,6 +12,7 @@ from .process_group_initializer import (
Initializer_Zero1,
ParallelMode,
ProcessGroupInitializer,
Initializer_Zero3_dp,
)
from .random import (
add_seed,
@ -34,6 +35,7 @@ __all__ = [
"Initializer_Pipeline",
"Initializer_Data",
"Initializer_Zero1",
"Initializer_Zero3_dp",
"ProcessGroupInitializer",
"Initializer_Model",
"seed",

View File

@ -462,6 +462,8 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if self.config.parallel.use_fsdp:
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
for initializer in initializers:

View File

@ -31,6 +31,11 @@ class ParallelMode(Enum):
# zero1 parallel
ZERO1 = "zero1"
# zero3-dp parallel
# if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size
# then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
ZERO3_DP = "zero3_dp"
class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.
@ -332,3 +337,62 @@ class Initializer_Zero1(ProcessGroupInitializer):
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_Zero3_dp(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.data_parallel_size % self.zero1_parallel_size == 0
# the only difference between this initializer and DP_initializer
# when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding
# eg: when zero=4 and dp=8
# no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group
# fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually
self.data_parallel_size //= self.zero1_parallel_size
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
assert self.world_size % self.data_parallel_size == 0
def init_dist_group(self, use_cpu: bool = False):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Data parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.ZERO3_DP
for i in range(self.rank_num_per_dp_group):
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
else:
group_cpu = None
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@ -61,6 +61,17 @@ def args_sanity_check():
if "tensor" not in gpc.config.parallel:
gpc.config.parallel._add_item("tensor", 1)
if isinstance(gpc.config.parallel.pipeline, int):
pp = gpc.config.parallel.pipeline
else:
pp = gpc.config.parallel.pipeline.size
if "use_fsdp" not in gpc.config.parallel:
gpc.config.parallel._add_item("use_fsdp", False)
elif gpc.config.parallel.use_fsdp and pp > 1:
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
gpc.config.parallel._add_item("use_fsdp", False)
# processing the data config in gpc
data = gpc.config.data

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .hybrid_zero_optim import HybridZeroOptimizer
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer
__all__ = ["HybridZeroOptimizer"]
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"]

View File

@ -80,6 +80,132 @@ class BaseOptimizer(Optimizer):
pass
class FSDPadaptOptimizer(BaseOptimizer):
'''
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
reserve some necessary components of hybird-optim:
grad_scaler;
grad_clip and unscale;
state_dict and load_state_dict
'''
def __init__(
self,
optimizer: Optimizer,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
):
super().__init__(optim=optimizer)
# gradient scaler
self.grad_scaler = DynamicGradScaler(
initial_scale=grad_scal_cfg.fp16.initial_scale,
min_scale=grad_scal_cfg.fp16.min_scale,
growth_factor=grad_scal_cfg.growth_factor,
backoff_factor=grad_scal_cfg.backoff_factor,
growth_interval=grad_scal_cfg.fp16.growth_interval,
hysteresis=grad_scal_cfg.hysteresis,
max_scale=grad_scal_cfg.max_scale,
)
# clip gradient
self._clip_grad_norm = zero_cfg.clip_grad_norm
self.use_fsdp = gpc.config.parallel.use_fsdp
@property
def loss_scale(self):
return self.grad_scaler.scale
def backward(self, loss, retain_graph=False):
loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph)
def step(self):
# in case that fsdp-zero3 size is not equal to dp size
# FSDP module will only reduce gradient within FSDP process group
# so manually reduce grad is essential between two parallel FSDP process group
for group_idx in range(len(self.param_groups)):
params = self.param_groups[group_idx]["params"]
for param in params:
if param.requires_grad:
reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
# compute norm
found_inf = False
norm_groups = []
for group_idx in range(len(self.param_groups)):
params = self.param_groups[group_idx]["params"]
gradients = [p.grad for p in params]
norm_group = compute_norm(
gradients=gradients,
parameters=params,
last_stage=True
)
if norm_group == -1:
found_inf = True
break
norm_groups.append(norm_group)
loss_scale = float(self.loss_scale.item()) # backup
self.grad_scaler.update(found_inf)
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
self.zero_grad()
return False, None
if self._clip_grad_norm > 0:
global_norm = sum(norm_groups) ** 0.5
# unscale
for group_idx in range(len(self.param_groups)):
params = self.param_groups[group_idx]["params"]
for p in params:
self._unscale_and_clip_grads(p.grad, global_norm, loss_scale)
self.optim.step()
self.zero_grad()
return True, [global_norm / loss_scale]
def clip_grad_norm(self, model, max_norm):
# will conduct in the step()
pass
#########################
# utils from hybirdzero #
#########################
def _unscale_and_clip_grads(self, grad, total_norm, loss_scale):
# compute combined scale factor for this group
combined_scale = loss_scale
if self._clip_grad_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1.0:
combined_scale = clip * loss_scale
# for grad in grad_groups_flat:
grad.data.mul_(1.0 / combined_scale)
def state_dict(self):
states = {}
grad_scaler = self.grad_scaler.state_dict()
states["grad_scaler"] = grad_scaler
optim_states = self.optim.state_dict()
states["base_optim_states"] = optim_states
return states
def load_state_dict(self, states):
assert "grad_scaler" in states, "Not found grad_scaler state!"
grad_scaler = states["grad_scaler"]
self.grad_scaler.load_state_dict(grad_scaler)
optim_states = states["base_optim_states"]
self.optim.load_state_dict(optim_states)
class HybridZeroOptimizer(BaseOptimizer):
"""
Hybrid Zero Optimizer.

View File

@ -6,6 +6,7 @@ from .training_internlm import (
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
warp_FSDP_model,
)
__all__ = [
@ -16,4 +17,5 @@ __all__ = [
"initialize_optimizer",
"load_new_batch",
"record_current_batch_training_metrics",
"warp_FSDP_model",
]

View File

@ -28,7 +28,7 @@ from internlm.monitor import set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger
@ -40,6 +40,24 @@ from internlm.utils.parallel import (
)
from internlm.utils.registry import MODEL_INITIALIZER
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
ShardingStrategy,
MixedPrecision,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
import functools
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
logger = get_logger(__file__)
@ -83,6 +101,30 @@ def initialize_model():
return model
def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
if gpc.config.parallel.use_fsdp:
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D}
)
mx = MixedPrecision(
param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype,
buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(module=model,
process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=transformer_wrap_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
#cpu_offload=CPUOfload(offload_params=True)
#mixed_precision=mx,
#device_id=torch.cuda.current_device()
)
return model
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
"""
Initialize optimizer.
@ -105,12 +147,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
eps=adam_cfg.adam_eps,
)
optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
if not gpc.config.parallel.use_fsdp:
optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
else:
optimizer = FSDPadaptOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)

View File

@ -55,6 +55,26 @@ def get_model_topology(model):
return topos
def get_state_dict(model):
"""
Only used for FSDP module saving.
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
"""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy):
states = model.state_dict()
return states
def save_model_checkpoint(folder, model):
"""
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
@ -69,7 +89,11 @@ def save_model_checkpoint(folder, model):
model: The model to be saved
"""
states = model.state_dict()
if gpc.config.parallel.use_fsdp:
states = get_state_dict(model)
else:
states = model.state_dict()
topo = get_model_topology(model)
if folder is not None:

View File

@ -28,6 +28,7 @@ from internlm.train import (
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
warp_FSDP_model,
)
from internlm.utils.common import (
BatchSkipper,
@ -123,6 +124,9 @@ def main(args):
# Loading model weights must be done before zero is initialized.
ckpt_manager.try_load_model(current_time)
# if fsdp enabled, warp the model
model = warp_FSDP_model(model)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
# Loading other persistent training states.