mirror of https://github.com/InternLM/InternLM
feat(fsdp): add training option for fsdp
parent
c516602e9a
commit
85c6ed6473
|
@ -7,22 +7,23 @@ MLP_RATIO = 8 / 3
|
||||||
NUM_LAYER = 32
|
NUM_LAYER = 32
|
||||||
VOCAB_SIZE = 103168
|
VOCAB_SIZE = 103168
|
||||||
|
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/20"
|
||||||
# Ckpt folder format:
|
# Ckpt folder format:
|
||||||
# fs: 'local:/mnt/nfs/XXX'
|
# fs: 'local:/mnt/nfs/XXX'
|
||||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
LOAD_CKPT_FOLDER = "local:llm_ckpts/20"
|
||||||
|
|
||||||
# boto3 Ckpt folder format:
|
# boto3 Ckpt folder format:
|
||||||
# import os
|
# import os
|
||||||
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
||||||
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
||||||
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
||||||
CHECKPOINT_EVERY = 50
|
CHECKPOINT_EVERY = 20
|
||||||
ckpt = dict(
|
ckpt = dict(
|
||||||
enable_save_ckpt=False, # enable ckpt save.
|
enable_save_ckpt=False, # enable ckpt save.
|
||||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||||
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||||
|
load_given_ckpt = False,
|
||||||
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
||||||
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
||||||
checkpoint_every=CHECKPOINT_EVERY,
|
checkpoint_every=CHECKPOINT_EVERY,
|
||||||
|
@ -32,7 +33,7 @@ ckpt = dict(
|
||||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||||
)
|
)
|
||||||
|
|
||||||
TRAIN_FOLDER = "/path/to/dataset"
|
TRAIN_FOLDER = "../../train_data"#"/path/to/dataset"
|
||||||
VALID_FOLDER = "/path/to/dataset"
|
VALID_FOLDER = "/path/to/dataset"
|
||||||
data = dict(
|
data = dict(
|
||||||
seq_len=SEQ_LEN,
|
seq_len=SEQ_LEN,
|
||||||
|
@ -50,7 +51,7 @@ data = dict(
|
||||||
rampup_batch_size="",
|
rampup_batch_size="",
|
||||||
# Datasets with less than 50 rows will be discarded
|
# Datasets with less than 50 rows will be discarded
|
||||||
min_length=50,
|
min_length=50,
|
||||||
# train_folder=TRAIN_FOLDER,
|
train_folder=TRAIN_FOLDER,
|
||||||
# valid_folder=VALID_FOLDER,
|
# valid_folder=VALID_FOLDER,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -111,7 +112,7 @@ beta2_scheduler = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
checkpoint=True, # The proportion of layers for activation checkpointing, the optional value are True/False/[0-1]
|
||||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||||
embed_split_hidden=True,
|
embed_split_hidden=True,
|
||||||
vocab_size=VOCAB_SIZE,
|
vocab_size=VOCAB_SIZE,
|
||||||
|
@ -121,7 +122,7 @@ model = dict(
|
||||||
num_layers=NUM_LAYER,
|
num_layers=NUM_LAYER,
|
||||||
mlp_ratio=MLP_RATIO,
|
mlp_ratio=MLP_RATIO,
|
||||||
apply_post_layer_norm=False,
|
apply_post_layer_norm=False,
|
||||||
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||||
norm_type="rmsnorm",
|
norm_type="rmsnorm",
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
use_flash_attn=True,
|
use_flash_attn=True,
|
||||||
|
@ -140,9 +141,11 @@ pipeline parallel (dict):
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=-1,
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
|
tensor=2,
|
||||||
sequence_parallel=False,
|
sequence_parallel=False,
|
||||||
|
use_fsdp = False,
|
||||||
)
|
)
|
||||||
|
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
@ -12,6 +12,7 @@ from .process_group_initializer import (
|
||||||
Initializer_Zero1,
|
Initializer_Zero1,
|
||||||
ParallelMode,
|
ParallelMode,
|
||||||
ProcessGroupInitializer,
|
ProcessGroupInitializer,
|
||||||
|
Initializer_Zero3_dp,
|
||||||
)
|
)
|
||||||
from .random import (
|
from .random import (
|
||||||
add_seed,
|
add_seed,
|
||||||
|
@ -34,6 +35,7 @@ __all__ = [
|
||||||
"Initializer_Pipeline",
|
"Initializer_Pipeline",
|
||||||
"Initializer_Data",
|
"Initializer_Data",
|
||||||
"Initializer_Zero1",
|
"Initializer_Zero1",
|
||||||
|
"Initializer_Zero3_dp",
|
||||||
"ProcessGroupInitializer",
|
"ProcessGroupInitializer",
|
||||||
"Initializer_Model",
|
"Initializer_Model",
|
||||||
"seed",
|
"seed",
|
||||||
|
|
|
@ -462,6 +462,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||||
|
if self.config.parallel.use_fsdp:
|
||||||
|
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||||
for initializer in initializers:
|
for initializer in initializers:
|
||||||
|
|
|
@ -31,6 +31,11 @@ class ParallelMode(Enum):
|
||||||
# zero1 parallel
|
# zero1 parallel
|
||||||
ZERO1 = "zero1"
|
ZERO1 = "zero1"
|
||||||
|
|
||||||
|
# zero3-dp parallel
|
||||||
|
# if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size
|
||||||
|
# then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
|
||||||
|
ZERO3_DP = "zero3_dp"
|
||||||
|
|
||||||
|
|
||||||
class ProcessGroupInitializer(ABC):
|
class ProcessGroupInitializer(ABC):
|
||||||
"""An object, knowing the parallelism configuration, that initializes parallel groups.
|
"""An object, knowing the parallelism configuration, that initializes parallel groups.
|
||||||
|
@ -332,3 +337,62 @@ class Initializer_Zero1(ProcessGroupInitializer):
|
||||||
ranks_in_group = ranks
|
ranks_in_group = ranks
|
||||||
|
|
||||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||||
|
|
||||||
|
class Initializer_Zero3_dp(ProcessGroupInitializer):
|
||||||
|
"""A ProcessGroupInitializer for data parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rank (int): The rank of current process.
|
||||||
|
world_size (int): Size of whole communication world.
|
||||||
|
data_parallel_size (int): Size of data parallel.
|
||||||
|
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||||
|
tensor_parallel_size (int): Size of tensor parallel.
|
||||||
|
zero1_parallel_size (int): Size of zero1 parallel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
||||||
|
|
||||||
|
# the only difference between this initializer and DP_initializer
|
||||||
|
# when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding
|
||||||
|
# eg: when zero=4 and dp=8
|
||||||
|
# no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group
|
||||||
|
# fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually
|
||||||
|
|
||||||
|
self.data_parallel_size //= self.zero1_parallel_size
|
||||||
|
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
||||||
|
|
||||||
|
assert self.world_size % self.data_parallel_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
def init_dist_group(self, use_cpu: bool = False):
|
||||||
|
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||||
|
A Data parallelism's information tuple.
|
||||||
|
"""
|
||||||
|
local_rank = None
|
||||||
|
ranks_in_group = None
|
||||||
|
process_group = None
|
||||||
|
cpu_group = None
|
||||||
|
group_world_size = None
|
||||||
|
mode = ParallelMode.ZERO3_DP
|
||||||
|
|
||||||
|
for i in range(self.rank_num_per_dp_group):
|
||||||
|
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
if use_cpu:
|
||||||
|
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||||
|
else:
|
||||||
|
group_cpu = None
|
||||||
|
|
||||||
|
if self.rank in ranks:
|
||||||
|
local_rank = ranks.index(self.rank)
|
||||||
|
group_world_size = len(ranks)
|
||||||
|
process_group = group
|
||||||
|
cpu_group = group_cpu
|
||||||
|
ranks_in_group = ranks
|
||||||
|
|
||||||
|
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
|
@ -61,6 +61,17 @@ def args_sanity_check():
|
||||||
if "tensor" not in gpc.config.parallel:
|
if "tensor" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("tensor", 1)
|
gpc.config.parallel._add_item("tensor", 1)
|
||||||
|
|
||||||
|
if isinstance(gpc.config.parallel.pipeline, int):
|
||||||
|
pp = gpc.config.parallel.pipeline
|
||||||
|
else:
|
||||||
|
pp = gpc.config.parallel.pipeline.size
|
||||||
|
|
||||||
|
if "use_fsdp" not in gpc.config.parallel:
|
||||||
|
gpc.config.parallel._add_item("use_fsdp", False)
|
||||||
|
elif gpc.config.parallel.use_fsdp and pp > 1:
|
||||||
|
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
|
||||||
|
gpc.config.parallel._add_item("use_fsdp", False)
|
||||||
|
|
||||||
# processing the data config in gpc
|
# processing the data config in gpc
|
||||||
data = gpc.config.data
|
data = gpc.config.data
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from .hybrid_zero_optim import HybridZeroOptimizer
|
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer
|
||||||
|
|
||||||
__all__ = ["HybridZeroOptimizer"]
|
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"]
|
||||||
|
|
|
@ -80,6 +80,132 @@ class BaseOptimizer(Optimizer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
|
'''
|
||||||
|
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
|
||||||
|
reserve some necessary components of hybird-optim:
|
||||||
|
grad_scaler;
|
||||||
|
grad_clip and unscale;
|
||||||
|
state_dict and load_state_dict
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
grad_scal_cfg: Config = None,
|
||||||
|
zero_cfg: Config = None,
|
||||||
|
):
|
||||||
|
super().__init__(optim=optimizer)
|
||||||
|
|
||||||
|
# gradient scaler
|
||||||
|
self.grad_scaler = DynamicGradScaler(
|
||||||
|
initial_scale=grad_scal_cfg.fp16.initial_scale,
|
||||||
|
min_scale=grad_scal_cfg.fp16.min_scale,
|
||||||
|
growth_factor=grad_scal_cfg.growth_factor,
|
||||||
|
backoff_factor=grad_scal_cfg.backoff_factor,
|
||||||
|
growth_interval=grad_scal_cfg.fp16.growth_interval,
|
||||||
|
hysteresis=grad_scal_cfg.hysteresis,
|
||||||
|
max_scale=grad_scal_cfg.max_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# clip gradient
|
||||||
|
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
||||||
|
self.use_fsdp = gpc.config.parallel.use_fsdp
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_scale(self):
|
||||||
|
return self.grad_scaler.scale
|
||||||
|
|
||||||
|
def backward(self, loss, retain_graph=False):
|
||||||
|
loss = self.loss_scale * loss
|
||||||
|
loss.backward(retain_graph=retain_graph)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
# in case that fsdp-zero3 size is not equal to dp size
|
||||||
|
# FSDP module will only reduce gradient within FSDP process group
|
||||||
|
# so manually reduce grad is essential between two parallel FSDP process group
|
||||||
|
for group_idx in range(len(self.param_groups)):
|
||||||
|
params = self.param_groups[group_idx]["params"]
|
||||||
|
for param in params:
|
||||||
|
if param.requires_grad:
|
||||||
|
reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
|
||||||
|
|
||||||
|
# compute norm
|
||||||
|
found_inf = False
|
||||||
|
norm_groups = []
|
||||||
|
for group_idx in range(len(self.param_groups)):
|
||||||
|
params = self.param_groups[group_idx]["params"]
|
||||||
|
gradients = [p.grad for p in params]
|
||||||
|
norm_group = compute_norm(
|
||||||
|
gradients=gradients,
|
||||||
|
parameters=params,
|
||||||
|
last_stage=True
|
||||||
|
)
|
||||||
|
if norm_group == -1:
|
||||||
|
found_inf = True
|
||||||
|
break
|
||||||
|
norm_groups.append(norm_group)
|
||||||
|
|
||||||
|
loss_scale = float(self.loss_scale.item()) # backup
|
||||||
|
self.grad_scaler.update(found_inf)
|
||||||
|
if found_inf:
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.warning("Overflow occurs, please check it.")
|
||||||
|
self.zero_grad()
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if self._clip_grad_norm > 0:
|
||||||
|
global_norm = sum(norm_groups) ** 0.5
|
||||||
|
|
||||||
|
# unscale
|
||||||
|
for group_idx in range(len(self.param_groups)):
|
||||||
|
params = self.param_groups[group_idx]["params"]
|
||||||
|
for p in params:
|
||||||
|
self._unscale_and_clip_grads(p.grad, global_norm, loss_scale)
|
||||||
|
|
||||||
|
self.optim.step()
|
||||||
|
self.zero_grad()
|
||||||
|
|
||||||
|
return True, [global_norm / loss_scale]
|
||||||
|
|
||||||
|
def clip_grad_norm(self, model, max_norm):
|
||||||
|
# will conduct in the step()
|
||||||
|
pass
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# utils from hybirdzero #
|
||||||
|
#########################
|
||||||
|
|
||||||
|
def _unscale_and_clip_grads(self, grad, total_norm, loss_scale):
|
||||||
|
# compute combined scale factor for this group
|
||||||
|
combined_scale = loss_scale
|
||||||
|
|
||||||
|
if self._clip_grad_norm > 0.0:
|
||||||
|
# norm is in fact norm*scale
|
||||||
|
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
|
||||||
|
if clip > 1.0:
|
||||||
|
combined_scale = clip * loss_scale
|
||||||
|
|
||||||
|
# for grad in grad_groups_flat:
|
||||||
|
grad.data.mul_(1.0 / combined_scale)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
states = {}
|
||||||
|
grad_scaler = self.grad_scaler.state_dict()
|
||||||
|
states["grad_scaler"] = grad_scaler
|
||||||
|
optim_states = self.optim.state_dict()
|
||||||
|
states["base_optim_states"] = optim_states
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
def load_state_dict(self, states):
|
||||||
|
assert "grad_scaler" in states, "Not found grad_scaler state!"
|
||||||
|
grad_scaler = states["grad_scaler"]
|
||||||
|
self.grad_scaler.load_state_dict(grad_scaler)
|
||||||
|
optim_states = states["base_optim_states"]
|
||||||
|
self.optim.load_state_dict(optim_states)
|
||||||
|
|
||||||
|
|
||||||
class HybridZeroOptimizer(BaseOptimizer):
|
class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
Hybrid Zero Optimizer.
|
Hybrid Zero Optimizer.
|
||||||
|
|
|
@ -6,6 +6,7 @@ from .training_internlm import (
|
||||||
initialize_optimizer,
|
initialize_optimizer,
|
||||||
load_new_batch,
|
load_new_batch,
|
||||||
record_current_batch_training_metrics,
|
record_current_batch_training_metrics,
|
||||||
|
warp_FSDP_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -16,4 +17,5 @@ __all__ = [
|
||||||
"initialize_optimizer",
|
"initialize_optimizer",
|
||||||
"load_new_batch",
|
"load_new_batch",
|
||||||
"record_current_batch_training_metrics",
|
"record_current_batch_training_metrics",
|
||||||
|
"warp_FSDP_model",
|
||||||
]
|
]
|
||||||
|
|
|
@ -28,7 +28,7 @@ from internlm.monitor import set_env_var
|
||||||
from internlm.monitor.monitor import monitor_manager as mm
|
from internlm.monitor.monitor import monitor_manager as mm
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer
|
||||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||||
from internlm.utils.common import DummyProfile
|
from internlm.utils.common import DummyProfile
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
@ -40,6 +40,24 @@ from internlm.utils.parallel import (
|
||||||
)
|
)
|
||||||
from internlm.utils.registry import MODEL_INITIALIZER
|
from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
|
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
|
CPUOffload,
|
||||||
|
BackwardPrefetch,
|
||||||
|
ShardingStrategy,
|
||||||
|
MixedPrecision,
|
||||||
|
BackwardPrefetch,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp.wrap import (
|
||||||
|
size_based_auto_wrap_policy,
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
|
enable_wrap,
|
||||||
|
wrap,
|
||||||
|
)
|
||||||
|
import functools
|
||||||
|
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,6 +101,30 @@ def initialize_model():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
|
if gpc.config.parallel.use_fsdp:
|
||||||
|
transformer_wrap_policy = functools.partial(
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
|
transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||||
|
)
|
||||||
|
mx = MixedPrecision(
|
||||||
|
param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype,
|
||||||
|
buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
|
||||||
|
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||||
|
model = FSDP(module=model,
|
||||||
|
process_group=grp,
|
||||||
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
auto_wrap_policy=transformer_wrap_policy,
|
||||||
|
forward_prefetch=True,
|
||||||
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||||
|
#cpu_offload=CPUOfload(offload_params=True)
|
||||||
|
#mixed_precision=mx,
|
||||||
|
#device_id=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||||
"""
|
"""
|
||||||
Initialize optimizer.
|
Initialize optimizer.
|
||||||
|
@ -105,12 +147,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||||
eps=adam_cfg.adam_eps,
|
eps=adam_cfg.adam_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = HybridZeroOptimizer(
|
if not gpc.config.parallel.use_fsdp:
|
||||||
naive_optimizer,
|
optimizer = HybridZeroOptimizer(
|
||||||
grad_scal_cfg=gpc.config.grad_scaler,
|
naive_optimizer,
|
||||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
grad_scal_cfg=gpc.config.grad_scaler,
|
||||||
param_bcast_sync_handler=param_bcast_sync_handler,
|
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||||
)
|
param_bcast_sync_handler=param_bcast_sync_handler,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer = FSDPadaptOptimizer(
|
||||||
|
naive_optimizer,
|
||||||
|
grad_scal_cfg=gpc.config.grad_scaler,
|
||||||
|
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||||
|
)
|
||||||
|
|
||||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,26 @@ def get_model_topology(model):
|
||||||
return topos
|
return topos
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict(model):
|
||||||
|
"""
|
||||||
|
Only used for FSDP module saving.
|
||||||
|
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
|
||||||
|
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
|
||||||
|
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
|
||||||
|
|
||||||
|
"""
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig
|
||||||
|
|
||||||
|
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
||||||
|
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||||
|
with FSDP.state_dict_type(
|
||||||
|
model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||||
|
states = model.state_dict()
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
def save_model_checkpoint(folder, model):
|
def save_model_checkpoint(folder, model):
|
||||||
"""
|
"""
|
||||||
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
|
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
|
||||||
|
@ -69,7 +89,11 @@ def save_model_checkpoint(folder, model):
|
||||||
model: The model to be saved
|
model: The model to be saved
|
||||||
"""
|
"""
|
||||||
|
|
||||||
states = model.state_dict()
|
if gpc.config.parallel.use_fsdp:
|
||||||
|
states = get_state_dict(model)
|
||||||
|
else:
|
||||||
|
states = model.state_dict()
|
||||||
|
|
||||||
topo = get_model_topology(model)
|
topo = get_model_topology(model)
|
||||||
|
|
||||||
if folder is not None:
|
if folder is not None:
|
||||||
|
|
4
train.py
4
train.py
|
@ -28,6 +28,7 @@ from internlm.train import (
|
||||||
initialize_optimizer,
|
initialize_optimizer,
|
||||||
load_new_batch,
|
load_new_batch,
|
||||||
record_current_batch_training_metrics,
|
record_current_batch_training_metrics,
|
||||||
|
warp_FSDP_model,
|
||||||
)
|
)
|
||||||
from internlm.utils.common import (
|
from internlm.utils.common import (
|
||||||
BatchSkipper,
|
BatchSkipper,
|
||||||
|
@ -123,6 +124,9 @@ def main(args):
|
||||||
# Loading model weights must be done before zero is initialized.
|
# Loading model weights must be done before zero is initialized.
|
||||||
ckpt_manager.try_load_model(current_time)
|
ckpt_manager.try_load_model(current_time)
|
||||||
|
|
||||||
|
# if fsdp enabled, warp the model
|
||||||
|
model = warp_FSDP_model(model)
|
||||||
|
|
||||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||||
|
|
||||||
# Loading other persistent training states.
|
# Loading other persistent training states.
|
||||||
|
|
Loading…
Reference in New Issue