mirror of https://github.com/InternLM/InternLM
				
				
				
			feat(fsdp): add training option for fsdp
							parent
							
								
									c516602e9a
								
							
						
					
					
						commit
						85c6ed6473
					
				| 
						 | 
				
			
			@ -7,22 +7,23 @@ MLP_RATIO = 8 / 3
 | 
			
		|||
NUM_LAYER = 32
 | 
			
		||||
VOCAB_SIZE = 103168
 | 
			
		||||
 | 
			
		||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
 | 
			
		||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/20"
 | 
			
		||||
# Ckpt folder format:
 | 
			
		||||
# fs: 'local:/mnt/nfs/XXX'
 | 
			
		||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
 | 
			
		||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
 | 
			
		||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/20"
 | 
			
		||||
 | 
			
		||||
# boto3 Ckpt folder format:
 | 
			
		||||
# import os
 | 
			
		||||
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
 | 
			
		||||
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
 | 
			
		||||
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
 | 
			
		||||
CHECKPOINT_EVERY = 50
 | 
			
		||||
CHECKPOINT_EVERY = 20
 | 
			
		||||
ckpt = dict(
 | 
			
		||||
    enable_save_ckpt=False,  # enable ckpt save.
 | 
			
		||||
    save_ckpt_folder=SAVE_CKPT_FOLDER,  # Path to save training ckpt.
 | 
			
		||||
    # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
 | 
			
		||||
    load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
 | 
			
		||||
    load_given_ckpt = False,
 | 
			
		||||
    # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
 | 
			
		||||
    load_optimizer=True,  # Wheter to load optimizer states when continuing training.
 | 
			
		||||
    checkpoint_every=CHECKPOINT_EVERY,
 | 
			
		||||
| 
						 | 
				
			
			@ -32,7 +33,7 @@ ckpt = dict(
 | 
			
		|||
    oss_snapshot_freq=int(CHECKPOINT_EVERY / 2),  # snapshot ckpt save frequency.
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
TRAIN_FOLDER = "/path/to/dataset"
 | 
			
		||||
TRAIN_FOLDER = "../../train_data"#"/path/to/dataset"
 | 
			
		||||
VALID_FOLDER = "/path/to/dataset"
 | 
			
		||||
data = dict(
 | 
			
		||||
    seq_len=SEQ_LEN,
 | 
			
		||||
| 
						 | 
				
			
			@ -50,7 +51,7 @@ data = dict(
 | 
			
		|||
    rampup_batch_size="",
 | 
			
		||||
    # Datasets with less than 50 rows will be discarded
 | 
			
		||||
    min_length=50,
 | 
			
		||||
    # train_folder=TRAIN_FOLDER,
 | 
			
		||||
    train_folder=TRAIN_FOLDER,
 | 
			
		||||
    # valid_folder=VALID_FOLDER,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -111,7 +112,7 @@ beta2_scheduler = dict(
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
model = dict(
 | 
			
		||||
    checkpoint=False,  # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
 | 
			
		||||
    checkpoint=True,  # The proportion of layers for activation checkpointing, the optional value are True/False/[0-1]
 | 
			
		||||
    num_attention_heads=NUM_ATTENTION_HEAD,
 | 
			
		||||
    embed_split_hidden=True,
 | 
			
		||||
    vocab_size=VOCAB_SIZE,
 | 
			
		||||
| 
						 | 
				
			
			@ -121,7 +122,7 @@ model = dict(
 | 
			
		|||
    num_layers=NUM_LAYER,
 | 
			
		||||
    mlp_ratio=MLP_RATIO,
 | 
			
		||||
    apply_post_layer_norm=False,
 | 
			
		||||
    dtype="torch.float16",  # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
 | 
			
		||||
    dtype="torch.bfloat16",  # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
 | 
			
		||||
    norm_type="rmsnorm",
 | 
			
		||||
    layer_norm_epsilon=1e-5,
 | 
			
		||||
    use_flash_attn=True,
 | 
			
		||||
| 
						 | 
				
			
			@ -140,9 +141,11 @@ pipeline parallel (dict):
 | 
			
		|||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
 | 
			
		||||
"""
 | 
			
		||||
parallel = dict(
 | 
			
		||||
    zero1=8,
 | 
			
		||||
    zero1=-1,
 | 
			
		||||
    pipeline=dict(size=1, interleaved_overlap=True),
 | 
			
		||||
    tensor=2,
 | 
			
		||||
    sequence_parallel=False,
 | 
			
		||||
    use_fsdp = False,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cudnn_deterministic = False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,7 @@ from .process_group_initializer import (
 | 
			
		|||
    Initializer_Zero1,
 | 
			
		||||
    ParallelMode,
 | 
			
		||||
    ProcessGroupInitializer,
 | 
			
		||||
    Initializer_Zero3_dp,
 | 
			
		||||
)
 | 
			
		||||
from .random import (
 | 
			
		||||
    add_seed,
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +35,7 @@ __all__ = [
 | 
			
		|||
    "Initializer_Pipeline",
 | 
			
		||||
    "Initializer_Data",
 | 
			
		||||
    "Initializer_Zero1",
 | 
			
		||||
    "Initializer_Zero3_dp",
 | 
			
		||||
    "ProcessGroupInitializer",
 | 
			
		||||
    "Initializer_Model",
 | 
			
		||||
    "seed",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -462,6 +462,8 @@ class ParallelContext(metaclass=SingletonMeta):
 | 
			
		|||
        initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
 | 
			
		||||
        initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
 | 
			
		||||
        initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
 | 
			
		||||
        if self.config.parallel.use_fsdp:
 | 
			
		||||
            initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
 | 
			
		||||
        if self.pipeline_parallel_size > 1:
 | 
			
		||||
            initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
 | 
			
		||||
        for initializer in initializers:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,6 +31,11 @@ class ParallelMode(Enum):
 | 
			
		|||
    # zero1 parallel
 | 
			
		||||
    ZERO1 = "zero1"
 | 
			
		||||
 | 
			
		||||
    # zero3-dp parallel
 | 
			
		||||
    # if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size
 | 
			
		||||
    # then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
 | 
			
		||||
    ZERO3_DP = "zero3_dp"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProcessGroupInitializer(ABC):
 | 
			
		||||
    """An object, knowing the parallelism configuration, that initializes parallel groups.
 | 
			
		||||
| 
						 | 
				
			
			@ -332,3 +337,62 @@ class Initializer_Zero1(ProcessGroupInitializer):
 | 
			
		|||
                    ranks_in_group = ranks
 | 
			
		||||
 | 
			
		||||
        return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
 | 
			
		||||
 | 
			
		||||
class Initializer_Zero3_dp(ProcessGroupInitializer):
 | 
			
		||||
    """A ProcessGroupInitializer for data parallelism.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        rank (int): The rank of current process.
 | 
			
		||||
        world_size (int): Size of whole communication world.
 | 
			
		||||
        data_parallel_size (int): Size of data parallel.
 | 
			
		||||
        pipeline_parallel_size (int): Size of pipeline parallel.
 | 
			
		||||
        tensor_parallel_size (int): Size of tensor parallel.
 | 
			
		||||
        zero1_parallel_size (int): Size of zero1 parallel.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        assert self.data_parallel_size % self.zero1_parallel_size == 0
 | 
			
		||||
 | 
			
		||||
        # the only difference between this initializer and DP_initializer
 | 
			
		||||
        # when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding
 | 
			
		||||
        # eg: when zero=4 and dp=8
 | 
			
		||||
        #     no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group
 | 
			
		||||
        #        fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually
 | 
			
		||||
 | 
			
		||||
        self.data_parallel_size //= self.zero1_parallel_size
 | 
			
		||||
        self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
 | 
			
		||||
 | 
			
		||||
        assert self.world_size % self.data_parallel_size == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def init_dist_group(self, use_cpu: bool = False):
 | 
			
		||||
        """Initialize data parallel groups, and assign local_ranks and groups to each gpu.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
 | 
			
		||||
                A Data parallelism's information tuple.
 | 
			
		||||
        """
 | 
			
		||||
        local_rank = None
 | 
			
		||||
        ranks_in_group = None
 | 
			
		||||
        process_group = None
 | 
			
		||||
        cpu_group = None
 | 
			
		||||
        group_world_size = None
 | 
			
		||||
        mode = ParallelMode.ZERO3_DP
 | 
			
		||||
 | 
			
		||||
        for i in range(self.rank_num_per_dp_group):
 | 
			
		||||
            ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
 | 
			
		||||
            group = dist.new_group(ranks)
 | 
			
		||||
            if use_cpu:
 | 
			
		||||
                group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
 | 
			
		||||
            else:
 | 
			
		||||
                group_cpu = None
 | 
			
		||||
 | 
			
		||||
            if self.rank in ranks:
 | 
			
		||||
                local_rank = ranks.index(self.rank)
 | 
			
		||||
                group_world_size = len(ranks)
 | 
			
		||||
                process_group = group
 | 
			
		||||
                cpu_group = group_cpu
 | 
			
		||||
                ranks_in_group = ranks
 | 
			
		||||
 | 
			
		||||
        return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
 | 
			
		||||
| 
						 | 
				
			
			@ -61,6 +61,17 @@ def args_sanity_check():
 | 
			
		|||
    if "tensor" not in gpc.config.parallel:
 | 
			
		||||
        gpc.config.parallel._add_item("tensor", 1)
 | 
			
		||||
 | 
			
		||||
    if isinstance(gpc.config.parallel.pipeline, int):
 | 
			
		||||
        pp = gpc.config.parallel.pipeline
 | 
			
		||||
    else:
 | 
			
		||||
        pp = gpc.config.parallel.pipeline.size
 | 
			
		||||
 | 
			
		||||
    if "use_fsdp" not in gpc.config.parallel:
 | 
			
		||||
        gpc.config.parallel._add_item("use_fsdp", False)
 | 
			
		||||
    elif gpc.config.parallel.use_fsdp and pp > 1:
 | 
			
		||||
        logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
 | 
			
		||||
        gpc.config.parallel._add_item("use_fsdp", False)
 | 
			
		||||
 | 
			
		||||
    # processing the data config in gpc
 | 
			
		||||
    data = gpc.config.data
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,6 @@
 | 
			
		|||
#!/usr/bin/env python
 | 
			
		||||
# -*- encoding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .hybrid_zero_optim import HybridZeroOptimizer
 | 
			
		||||
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer
 | 
			
		||||
 | 
			
		||||
__all__ = ["HybridZeroOptimizer"]
 | 
			
		||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -80,6 +80,132 @@ class BaseOptimizer(Optimizer):
 | 
			
		|||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FSDPadaptOptimizer(BaseOptimizer):
 | 
			
		||||
    '''
 | 
			
		||||
    optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
 | 
			
		||||
    reserve some necessary components of hybird-optim:
 | 
			
		||||
        grad_scaler;
 | 
			
		||||
        grad_clip and unscale;
 | 
			
		||||
        state_dict and load_state_dict
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self, 
 | 
			
		||||
            optimizer: Optimizer,
 | 
			
		||||
            grad_scal_cfg: Config = None,   
 | 
			
		||||
            zero_cfg: Config = None, 
 | 
			
		||||
        ):
 | 
			
		||||
        super().__init__(optim=optimizer)
 | 
			
		||||
        
 | 
			
		||||
        # gradient scaler
 | 
			
		||||
        self.grad_scaler = DynamicGradScaler(
 | 
			
		||||
            initial_scale=grad_scal_cfg.fp16.initial_scale,
 | 
			
		||||
            min_scale=grad_scal_cfg.fp16.min_scale,
 | 
			
		||||
            growth_factor=grad_scal_cfg.growth_factor,
 | 
			
		||||
            backoff_factor=grad_scal_cfg.backoff_factor,
 | 
			
		||||
            growth_interval=grad_scal_cfg.fp16.growth_interval,
 | 
			
		||||
            hysteresis=grad_scal_cfg.hysteresis,
 | 
			
		||||
            max_scale=grad_scal_cfg.max_scale,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # clip gradient
 | 
			
		||||
        self._clip_grad_norm = zero_cfg.clip_grad_norm
 | 
			
		||||
        self.use_fsdp = gpc.config.parallel.use_fsdp
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def loss_scale(self):
 | 
			
		||||
        return self.grad_scaler.scale
 | 
			
		||||
 | 
			
		||||
    def backward(self, loss, retain_graph=False):
 | 
			
		||||
        loss = self.loss_scale * loss
 | 
			
		||||
        loss.backward(retain_graph=retain_graph)
 | 
			
		||||
 | 
			
		||||
    def step(self):
 | 
			
		||||
        # in case that fsdp-zero3 size is not equal to dp size
 | 
			
		||||
        # FSDP module will only reduce gradient within FSDP process group
 | 
			
		||||
        # so manually reduce grad is essential between two parallel FSDP process group
 | 
			
		||||
        for group_idx in range(len(self.param_groups)):
 | 
			
		||||
            params = self.param_groups[group_idx]["params"]
 | 
			
		||||
            for param in params:
 | 
			
		||||
                if param.requires_grad:
 | 
			
		||||
                    reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
 | 
			
		||||
 | 
			
		||||
        # compute norm
 | 
			
		||||
        found_inf = False
 | 
			
		||||
        norm_groups = []
 | 
			
		||||
        for group_idx in range(len(self.param_groups)):
 | 
			
		||||
            params = self.param_groups[group_idx]["params"]
 | 
			
		||||
            gradients = [p.grad for p in params]
 | 
			
		||||
            norm_group = compute_norm(
 | 
			
		||||
                gradients=gradients,
 | 
			
		||||
                parameters=params,
 | 
			
		||||
                last_stage=True
 | 
			
		||||
            )
 | 
			
		||||
            if norm_group == -1:
 | 
			
		||||
                found_inf = True
 | 
			
		||||
                break
 | 
			
		||||
            norm_groups.append(norm_group)
 | 
			
		||||
 | 
			
		||||
        loss_scale = float(self.loss_scale.item())  # backup
 | 
			
		||||
        self.grad_scaler.update(found_inf)
 | 
			
		||||
        if found_inf:
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.warning("Overflow occurs, please check it.")
 | 
			
		||||
            self.zero_grad()
 | 
			
		||||
            return False, None
 | 
			
		||||
        
 | 
			
		||||
        if self._clip_grad_norm > 0:
 | 
			
		||||
            global_norm = sum(norm_groups) ** 0.5       
 | 
			
		||||
 | 
			
		||||
        # unscale
 | 
			
		||||
        for group_idx in range(len(self.param_groups)):
 | 
			
		||||
            params = self.param_groups[group_idx]["params"]
 | 
			
		||||
            for p in params:
 | 
			
		||||
                self._unscale_and_clip_grads(p.grad, global_norm, loss_scale)
 | 
			
		||||
 | 
			
		||||
        self.optim.step() 
 | 
			
		||||
        self.zero_grad()
 | 
			
		||||
 | 
			
		||||
        return True, [global_norm / loss_scale]
 | 
			
		||||
 | 
			
		||||
    def clip_grad_norm(self, model, max_norm):
 | 
			
		||||
        # will conduct in the step()
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    #########################
 | 
			
		||||
    # utils from hybirdzero #
 | 
			
		||||
    #########################
 | 
			
		||||
 | 
			
		||||
    def _unscale_and_clip_grads(self, grad, total_norm, loss_scale):
 | 
			
		||||
        # compute combined scale factor for this group
 | 
			
		||||
        combined_scale = loss_scale
 | 
			
		||||
 | 
			
		||||
        if self._clip_grad_norm > 0.0:
 | 
			
		||||
            # norm is in fact norm*scale
 | 
			
		||||
            clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
 | 
			
		||||
            if clip > 1.0:
 | 
			
		||||
                combined_scale = clip * loss_scale
 | 
			
		||||
 | 
			
		||||
        # for grad in grad_groups_flat:
 | 
			
		||||
        grad.data.mul_(1.0 / combined_scale)
 | 
			
		||||
 | 
			
		||||
    def state_dict(self):
 | 
			
		||||
        states = {}
 | 
			
		||||
        grad_scaler = self.grad_scaler.state_dict()
 | 
			
		||||
        states["grad_scaler"] = grad_scaler
 | 
			
		||||
        optim_states = self.optim.state_dict()
 | 
			
		||||
        states["base_optim_states"] = optim_states
 | 
			
		||||
 | 
			
		||||
        return states
 | 
			
		||||
 | 
			
		||||
    def load_state_dict(self, states):
 | 
			
		||||
        assert "grad_scaler" in states, "Not found grad_scaler state!"
 | 
			
		||||
        grad_scaler = states["grad_scaler"]
 | 
			
		||||
        self.grad_scaler.load_state_dict(grad_scaler)
 | 
			
		||||
        optim_states = states["base_optim_states"]
 | 
			
		||||
        self.optim.load_state_dict(optim_states)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		||||
    """
 | 
			
		||||
    Hybrid Zero Optimizer.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,6 +6,7 @@ from .training_internlm import (
 | 
			
		|||
    initialize_optimizer,
 | 
			
		||||
    load_new_batch,
 | 
			
		||||
    record_current_batch_training_metrics,
 | 
			
		||||
    warp_FSDP_model,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
| 
						 | 
				
			
			@ -16,4 +17,5 @@ __all__ = [
 | 
			
		|||
    "initialize_optimizer",
 | 
			
		||||
    "load_new_batch",
 | 
			
		||||
    "record_current_batch_training_metrics",
 | 
			
		||||
    "warp_FSDP_model",
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,7 @@ from internlm.monitor import set_env_var
 | 
			
		|||
from internlm.monitor.monitor import monitor_manager as mm
 | 
			
		||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
 | 
			
		||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
 | 
			
		||||
from internlm.solver.optimizer import HybridZeroOptimizer
 | 
			
		||||
from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer
 | 
			
		||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
 | 
			
		||||
from internlm.utils.common import DummyProfile
 | 
			
		||||
from internlm.utils.logger import get_logger
 | 
			
		||||
| 
						 | 
				
			
			@ -40,6 +40,24 @@ from internlm.utils.parallel import (
 | 
			
		|||
)
 | 
			
		||||
from internlm.utils.registry import MODEL_INITIALIZER
 | 
			
		||||
 | 
			
		||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 | 
			
		||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
 | 
			
		||||
    CPUOffload,
 | 
			
		||||
    BackwardPrefetch,
 | 
			
		||||
    ShardingStrategy,
 | 
			
		||||
    MixedPrecision,
 | 
			
		||||
    BackwardPrefetch,
 | 
			
		||||
)
 | 
			
		||||
from torch.distributed.fsdp.wrap import (
 | 
			
		||||
    size_based_auto_wrap_policy,
 | 
			
		||||
    transformer_auto_wrap_policy,
 | 
			
		||||
    enable_wrap,
 | 
			
		||||
    wrap,
 | 
			
		||||
)
 | 
			
		||||
import functools
 | 
			
		||||
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__file__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -83,6 +101,30 @@ def initialize_model():
 | 
			
		|||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
 | 
			
		||||
    if gpc.config.parallel.use_fsdp:
 | 
			
		||||
        transformer_wrap_policy = functools.partial(
 | 
			
		||||
            transformer_auto_wrap_policy,
 | 
			
		||||
            transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D}
 | 
			
		||||
        )
 | 
			
		||||
        mx = MixedPrecision(
 | 
			
		||||
            param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype, 
 | 
			
		||||
            buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
 | 
			
		||||
        grp = gpc.get_group(ParallelMode.ZERO1)
 | 
			
		||||
        model = FSDP(module=model, 
 | 
			
		||||
                     process_group=grp,
 | 
			
		||||
                     sharding_strategy=ShardingStrategy.FULL_SHARD,
 | 
			
		||||
                     auto_wrap_policy=transformer_wrap_policy,
 | 
			
		||||
                     forward_prefetch=True,
 | 
			
		||||
                     backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
 | 
			
		||||
                     #cpu_offload=CPUOfload(offload_params=True)
 | 
			
		||||
                     #mixed_precision=mx, 
 | 
			
		||||
                     #device_id=torch.cuda.current_device()
 | 
			
		||||
                     )
 | 
			
		||||
        
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
 | 
			
		||||
    """
 | 
			
		||||
    Initialize optimizer.
 | 
			
		||||
| 
						 | 
				
			
			@ -105,12 +147,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
 | 
			
		|||
        eps=adam_cfg.adam_eps,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    optimizer = HybridZeroOptimizer(
 | 
			
		||||
        naive_optimizer,
 | 
			
		||||
        grad_scal_cfg=gpc.config.grad_scaler,
 | 
			
		||||
        zero_cfg=gpc.config.hybrid_zero_optimizer,
 | 
			
		||||
        param_bcast_sync_handler=param_bcast_sync_handler,
 | 
			
		||||
    )
 | 
			
		||||
    if not gpc.config.parallel.use_fsdp:
 | 
			
		||||
        optimizer = HybridZeroOptimizer(
 | 
			
		||||
            naive_optimizer,
 | 
			
		||||
            grad_scal_cfg=gpc.config.grad_scaler,
 | 
			
		||||
            zero_cfg=gpc.config.hybrid_zero_optimizer,
 | 
			
		||||
            param_bcast_sync_handler=param_bcast_sync_handler,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        optimizer = FSDPadaptOptimizer(
 | 
			
		||||
            naive_optimizer, 
 | 
			
		||||
            grad_scal_cfg=gpc.config.grad_scaler, 
 | 
			
		||||
            zero_cfg=gpc.config.hybrid_zero_optimizer, 
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -55,6 +55,26 @@ def get_model_topology(model):
 | 
			
		|||
    return topos
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_state_dict(model):
 | 
			
		||||
    """
 | 
			
		||||
    Only used for FSDP module saving. 
 | 
			
		||||
    It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter 
 | 
			
		||||
    (saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
 | 
			
		||||
    'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
 | 
			
		||||
    
 | 
			
		||||
    """
 | 
			
		||||
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 | 
			
		||||
    from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig
 | 
			
		||||
 | 
			
		||||
    # TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
 | 
			
		||||
    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
 | 
			
		||||
    with FSDP.state_dict_type(
 | 
			
		||||
            model, StateDictType.FULL_STATE_DICT, save_policy):
 | 
			
		||||
        states = model.state_dict()
 | 
			
		||||
 | 
			
		||||
    return states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_model_checkpoint(folder, model):
 | 
			
		||||
    """
 | 
			
		||||
    Save the model according to the relationship between tp and dp. The principle is that the data of each tp
 | 
			
		||||
| 
						 | 
				
			
			@ -69,7 +89,11 @@ def save_model_checkpoint(folder, model):
 | 
			
		|||
        model: The model to be saved
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    states = model.state_dict()
 | 
			
		||||
    if gpc.config.parallel.use_fsdp:
 | 
			
		||||
        states = get_state_dict(model)
 | 
			
		||||
    else: 
 | 
			
		||||
        states = model.state_dict()
 | 
			
		||||
        
 | 
			
		||||
    topo = get_model_topology(model)
 | 
			
		||||
 | 
			
		||||
    if folder is not None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										4
									
								
								train.py
								
								
								
								
							
							
						
						
									
										4
									
								
								train.py
								
								
								
								
							| 
						 | 
				
			
			@ -28,6 +28,7 @@ from internlm.train import (
 | 
			
		|||
    initialize_optimizer,
 | 
			
		||||
    load_new_batch,
 | 
			
		||||
    record_current_batch_training_metrics,
 | 
			
		||||
    warp_FSDP_model,
 | 
			
		||||
)
 | 
			
		||||
from internlm.utils.common import (
 | 
			
		||||
    BatchSkipper,
 | 
			
		||||
| 
						 | 
				
			
			@ -123,6 +124,9 @@ def main(args):
 | 
			
		|||
    # Loading model weights must be done before zero is initialized.
 | 
			
		||||
    ckpt_manager.try_load_model(current_time)
 | 
			
		||||
 | 
			
		||||
    # if fsdp enabled, warp the model        
 | 
			
		||||
    model = warp_FSDP_model(model)
 | 
			
		||||
 | 
			
		||||
    optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
 | 
			
		||||
 | 
			
		||||
    # Loading other persistent training states.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue