feat(fsdp): add training option for fsdp

pull/292/head
zaglc 2023-09-04 18:01:30 +08:00
parent c516602e9a
commit 85c6ed6473
11 changed files with 306 additions and 19 deletions

View File

@ -7,22 +7,23 @@ MLP_RATIO = 8 / 3
NUM_LAYER = 32 NUM_LAYER = 32
VOCAB_SIZE = 103168 VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" MODEL_ONLY_FOLDER = "local:llm_ckpts/20"
# Ckpt folder format: # Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX' # fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts" SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49" LOAD_CKPT_FOLDER = "local:llm_ckpts/20"
# boto3 Ckpt folder format: # boto3 Ckpt folder format:
# import os # import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint # BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" # SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" # LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50 CHECKPOINT_EVERY = 20
ckpt = dict( ckpt = dict(
enable_save_ckpt=False, # enable ckpt save. enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
load_given_ckpt = False,
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights. # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training. load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY, checkpoint_every=CHECKPOINT_EVERY,
@ -32,7 +33,7 @@ ckpt = dict(
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
) )
TRAIN_FOLDER = "/path/to/dataset" TRAIN_FOLDER = "../../train_data"#"/path/to/dataset"
VALID_FOLDER = "/path/to/dataset" VALID_FOLDER = "/path/to/dataset"
data = dict( data = dict(
seq_len=SEQ_LEN, seq_len=SEQ_LEN,
@ -50,7 +51,7 @@ data = dict(
rampup_batch_size="", rampup_batch_size="",
# Datasets with less than 50 rows will be discarded # Datasets with less than 50 rows will be discarded
min_length=50, min_length=50,
# train_folder=TRAIN_FOLDER, train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER, # valid_folder=VALID_FOLDER,
) )
@ -111,7 +112,7 @@ beta2_scheduler = dict(
) )
model = dict( model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] checkpoint=True, # The proportion of layers for activation checkpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD, num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True, embed_split_hidden=True,
vocab_size=VOCAB_SIZE, vocab_size=VOCAB_SIZE,
@ -121,7 +122,7 @@ model = dict(
num_layers=NUM_LAYER, num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO, mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False, apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm", norm_type="rmsnorm",
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
use_flash_attn=True, use_flash_attn=True,
@ -140,9 +141,11 @@ pipeline parallel (dict):
tensor parallel: tensor parallel size, usually the number of GPUs per node. tensor parallel: tensor parallel size, usually the number of GPUs per node.
""" """
parallel = dict( parallel = dict(
zero1=8, zero1=-1,
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
tensor=2,
sequence_parallel=False, sequence_parallel=False,
use_fsdp = False,
) )
cudnn_deterministic = False cudnn_deterministic = False

View File

@ -12,6 +12,7 @@ from .process_group_initializer import (
Initializer_Zero1, Initializer_Zero1,
ParallelMode, ParallelMode,
ProcessGroupInitializer, ProcessGroupInitializer,
Initializer_Zero3_dp,
) )
from .random import ( from .random import (
add_seed, add_seed,
@ -34,6 +35,7 @@ __all__ = [
"Initializer_Pipeline", "Initializer_Pipeline",
"Initializer_Data", "Initializer_Data",
"Initializer_Zero1", "Initializer_Zero1",
"Initializer_Zero3_dp",
"ProcessGroupInitializer", "ProcessGroupInitializer",
"Initializer_Model", "Initializer_Model",
"seed", "seed",

View File

@ -462,6 +462,8 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if self.config.parallel.use_fsdp:
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
if self.pipeline_parallel_size > 1: if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
for initializer in initializers: for initializer in initializers:

View File

@ -31,6 +31,11 @@ class ParallelMode(Enum):
# zero1 parallel # zero1 parallel
ZERO1 = "zero1" ZERO1 = "zero1"
# zero3-dp parallel
# if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size
# then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
ZERO3_DP = "zero3_dp"
class ProcessGroupInitializer(ABC): class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups. """An object, knowing the parallelism configuration, that initializes parallel groups.
@ -332,3 +337,62 @@ class Initializer_Zero1(ProcessGroupInitializer):
ranks_in_group = ranks ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_Zero3_dp(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.data_parallel_size % self.zero1_parallel_size == 0
# the only difference between this initializer and DP_initializer
# when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding
# eg: when zero=4 and dp=8
# no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group
# fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually
self.data_parallel_size //= self.zero1_parallel_size
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
assert self.world_size % self.data_parallel_size == 0
def init_dist_group(self, use_cpu: bool = False):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Data parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.ZERO3_DP
for i in range(self.rank_num_per_dp_group):
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
else:
group_cpu = None
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@ -61,6 +61,17 @@ def args_sanity_check():
if "tensor" not in gpc.config.parallel: if "tensor" not in gpc.config.parallel:
gpc.config.parallel._add_item("tensor", 1) gpc.config.parallel._add_item("tensor", 1)
if isinstance(gpc.config.parallel.pipeline, int):
pp = gpc.config.parallel.pipeline
else:
pp = gpc.config.parallel.pipeline.size
if "use_fsdp" not in gpc.config.parallel:
gpc.config.parallel._add_item("use_fsdp", False)
elif gpc.config.parallel.use_fsdp and pp > 1:
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
gpc.config.parallel._add_item("use_fsdp", False)
# processing the data config in gpc # processing the data config in gpc
data = gpc.config.data data = gpc.config.data

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from .hybrid_zero_optim import HybridZeroOptimizer from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer
__all__ = ["HybridZeroOptimizer"] __all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"]

View File

@ -80,6 +80,132 @@ class BaseOptimizer(Optimizer):
pass pass
class FSDPadaptOptimizer(BaseOptimizer):
'''
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
reserve some necessary components of hybird-optim:
grad_scaler;
grad_clip and unscale;
state_dict and load_state_dict
'''
def __init__(
self,
optimizer: Optimizer,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
):
super().__init__(optim=optimizer)
# gradient scaler
self.grad_scaler = DynamicGradScaler(
initial_scale=grad_scal_cfg.fp16.initial_scale,
min_scale=grad_scal_cfg.fp16.min_scale,
growth_factor=grad_scal_cfg.growth_factor,
backoff_factor=grad_scal_cfg.backoff_factor,
growth_interval=grad_scal_cfg.fp16.growth_interval,
hysteresis=grad_scal_cfg.hysteresis,
max_scale=grad_scal_cfg.max_scale,
)
# clip gradient
self._clip_grad_norm = zero_cfg.clip_grad_norm
self.use_fsdp = gpc.config.parallel.use_fsdp
@property
def loss_scale(self):
return self.grad_scaler.scale
def backward(self, loss, retain_graph=False):
loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph)
def step(self):
# in case that fsdp-zero3 size is not equal to dp size
# FSDP module will only reduce gradient within FSDP process group
# so manually reduce grad is essential between two parallel FSDP process group
for group_idx in range(len(self.param_groups)):
params = self.param_groups[group_idx]["params"]
for param in params:
if param.requires_grad:
reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
# compute norm
found_inf = False
norm_groups = []
for group_idx in range(len(self.param_groups)):
params = self.param_groups[group_idx]["params"]
gradients = [p.grad for p in params]
norm_group = compute_norm(
gradients=gradients,
parameters=params,
last_stage=True
)
if norm_group == -1:
found_inf = True
break
norm_groups.append(norm_group)
loss_scale = float(self.loss_scale.item()) # backup
self.grad_scaler.update(found_inf)
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
self.zero_grad()
return False, None
if self._clip_grad_norm > 0:
global_norm = sum(norm_groups) ** 0.5
# unscale
for group_idx in range(len(self.param_groups)):
params = self.param_groups[group_idx]["params"]
for p in params:
self._unscale_and_clip_grads(p.grad, global_norm, loss_scale)
self.optim.step()
self.zero_grad()
return True, [global_norm / loss_scale]
def clip_grad_norm(self, model, max_norm):
# will conduct in the step()
pass
#########################
# utils from hybirdzero #
#########################
def _unscale_and_clip_grads(self, grad, total_norm, loss_scale):
# compute combined scale factor for this group
combined_scale = loss_scale
if self._clip_grad_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1.0:
combined_scale = clip * loss_scale
# for grad in grad_groups_flat:
grad.data.mul_(1.0 / combined_scale)
def state_dict(self):
states = {}
grad_scaler = self.grad_scaler.state_dict()
states["grad_scaler"] = grad_scaler
optim_states = self.optim.state_dict()
states["base_optim_states"] = optim_states
return states
def load_state_dict(self, states):
assert "grad_scaler" in states, "Not found grad_scaler state!"
grad_scaler = states["grad_scaler"]
self.grad_scaler.load_state_dict(grad_scaler)
optim_states = states["base_optim_states"]
self.optim.load_state_dict(optim_states)
class HybridZeroOptimizer(BaseOptimizer): class HybridZeroOptimizer(BaseOptimizer):
""" """
Hybrid Zero Optimizer. Hybrid Zero Optimizer.

View File

@ -6,6 +6,7 @@ from .training_internlm import (
initialize_optimizer, initialize_optimizer,
load_new_batch, load_new_batch,
record_current_batch_training_metrics, record_current_batch_training_metrics,
warp_FSDP_model,
) )
__all__ = [ __all__ = [
@ -16,4 +17,5 @@ __all__ = [
"initialize_optimizer", "initialize_optimizer",
"load_new_batch", "load_new_batch",
"record_current_batch_training_metrics", "record_current_batch_training_metrics",
"warp_FSDP_model",
] ]

View File

@ -28,7 +28,7 @@ from internlm.monitor import set_env_var
from internlm.monitor.monitor import monitor_manager as mm from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
@ -40,6 +40,24 @@ from internlm.utils.parallel import (
) )
from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.registry import MODEL_INITIALIZER
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
ShardingStrategy,
MixedPrecision,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
import functools
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
logger = get_logger(__file__) logger = get_logger(__file__)
@ -83,6 +101,30 @@ def initialize_model():
return model return model
def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
if gpc.config.parallel.use_fsdp:
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D}
)
mx = MixedPrecision(
param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype,
buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(module=model,
process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=transformer_wrap_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
#cpu_offload=CPUOfload(offload_params=True)
#mixed_precision=mx,
#device_id=torch.cuda.current_device()
)
return model
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
""" """
Initialize optimizer. Initialize optimizer.
@ -105,12 +147,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
eps=adam_cfg.adam_eps, eps=adam_cfg.adam_eps,
) )
optimizer = HybridZeroOptimizer( if not gpc.config.parallel.use_fsdp:
naive_optimizer, optimizer = HybridZeroOptimizer(
grad_scal_cfg=gpc.config.grad_scaler, naive_optimizer,
zero_cfg=gpc.config.hybrid_zero_optimizer, grad_scal_cfg=gpc.config.grad_scaler,
param_bcast_sync_handler=param_bcast_sync_handler, zero_cfg=gpc.config.hybrid_zero_optimizer,
) param_bcast_sync_handler=param_bcast_sync_handler,
)
else:
optimizer = FSDPadaptOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)

View File

@ -55,6 +55,26 @@ def get_model_topology(model):
return topos return topos
def get_state_dict(model):
"""
Only used for FSDP module saving.
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
"""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy):
states = model.state_dict()
return states
def save_model_checkpoint(folder, model): def save_model_checkpoint(folder, model):
""" """
Save the model according to the relationship between tp and dp. The principle is that the data of each tp Save the model according to the relationship between tp and dp. The principle is that the data of each tp
@ -69,7 +89,11 @@ def save_model_checkpoint(folder, model):
model: The model to be saved model: The model to be saved
""" """
states = model.state_dict() if gpc.config.parallel.use_fsdp:
states = get_state_dict(model)
else:
states = model.state_dict()
topo = get_model_topology(model) topo = get_model_topology(model)
if folder is not None: if folder is not None:

View File

@ -28,6 +28,7 @@ from internlm.train import (
initialize_optimizer, initialize_optimizer,
load_new_batch, load_new_batch,
record_current_batch_training_metrics, record_current_batch_training_metrics,
warp_FSDP_model,
) )
from internlm.utils.common import ( from internlm.utils.common import (
BatchSkipper, BatchSkipper,
@ -123,6 +124,9 @@ def main(args):
# Loading model weights must be done before zero is initialized. # Loading model weights must be done before zero is initialized.
ckpt_manager.try_load_model(current_time) ckpt_manager.try_load_model(current_time)
# if fsdp enabled, warp the model
model = warp_FSDP_model(model)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
# Loading other persistent training states. # Loading other persistent training states.