diff --git a/internlm/core/scheduler/__init__.py b/internlm/core/scheduler/__init__.py index a9bf013..ea6afcd 100644 --- a/internlm/core/scheduler/__init__.py +++ b/internlm/core/scheduler/__init__.py @@ -1,4 +1,4 @@ -from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook +from .base_scheduler import BaseScheduler, SchedulerHook from .no_pipeline_scheduler import NonPipelineScheduler from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler @@ -8,5 +8,4 @@ __all__ = [ "InterleavedPipelineScheduler", "PipelineScheduler", "SchedulerHook", - "SchedulerMetricHook", ] diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 20b4460..fbd878c 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -4,12 +4,11 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable import torch from internlm.core.engine import Engine -from internlm.utils.megatron_timers import megatron_timer as timer class BaseScheduler(ABC): @@ -147,41 +146,3 @@ class SchedulerHook(ABC): @abstractmethod def post_helper_func(self, scheduler, outputs, label) -> None: """A post helper function""" - - -class SchedulerMetricHook(SchedulerHook): - """ - Scheduler Metric Hook. - """ - - def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None: - self._post_func = metric - self._skip = skip - - def before_forward(self, scheduler, inputs) -> None: - if not self._skip: - timer("fwd").start() - - def after_forward(self, scheduler, outputs) -> None: - if not self._skip: - timer("fwd").stop() - - def before_criterion(self, scheduler, outputs, label) -> None: - if not self._skip: - timer("cal_loss").start() - - def after_criterion(self, scheduler, loss) -> None: - if not self._skip: - timer("cal_loss").stop() - - def before_backward(self, scheduler, outputs, outputs_grad) -> None: - if not self._skip: - timer("bwd").start() - - def after_backward(self, scheduler, inputs_grad) -> None: - if not self._skip: - timer("bwd").stop() - - def post_helper_func(self, scheduler, outputs, label) -> None: - if self._post_func is not None: - self._post_func(outputs, label) diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 3a77f8b..c32d829 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Callable, List, Optional import torch from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss @@ -6,6 +6,8 @@ from torch_scatter import scatter from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.scheduler import SchedulerHook +from internlm.utils.megatron_timers import megatron_timer as timer class AccPerplex: @@ -260,3 +262,41 @@ class LossWithTypeId: self.ds_token_num.fill_(0.0) return res + + +class SchedulerMetricHook(SchedulerHook): + """ + Scheduler Metric Hook. + """ + + def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None: + self._post_func = metric + self._skip = skip + + def before_forward(self, scheduler, inputs) -> None: + if not self._skip: + timer("fwd").start() + + def after_forward(self, scheduler, outputs) -> None: + if not self._skip: + timer("fwd").stop() + + def before_criterion(self, scheduler, outputs, label) -> None: + if not self._skip: + timer("cal_loss").start() + + def after_criterion(self, scheduler, loss) -> None: + if not self._skip: + timer("cal_loss").stop() + + def before_backward(self, scheduler, outputs, outputs_grad) -> None: + if not self._skip: + timer("bwd").start() + + def after_backward(self, scheduler, inputs_grad) -> None: + if not self._skip: + timer("bwd").stop() + + def post_helper_func(self, scheduler, outputs, label) -> None: + if self._post_func is not None: + self._post_func(outputs, label) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 6870fe6..098fc8c 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -8,6 +8,7 @@ from torch import nn from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel +from internlm.core.scheduler import SchedulerHook from internlm.model.embedding import Embedding1D from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear from internlm.model.utils import ( @@ -33,6 +34,8 @@ class FSTPOverlapHandler: self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules self.head = [] self.embedding = [] + self.model_checkpoint = gpc.config.model.checkpoint + self.is_forward = True self.reduce_scatter_handlers = {} self.zero_const_pool = {} @@ -81,6 +84,9 @@ class FSTPOverlapHandler: return self.zero_const_pool[size] + def set_forward_mode(self, flag): + self.is_forward = flag + def _initialize_module_shape(self): hidden_size = gpc.config.HIDDEN_SIZE mlp_ratio = gpc.config.MLP_RATIO @@ -121,7 +127,6 @@ class FSTPOverlapHandler: def get_bias_memory(self, module: nn.Module): block_index = self.module_to_index[module] # if the bias memory pool is empty or module has been not allocated memory - # import pdb; pdb.set_trace() if len(self.all_gather_bias_memory_pool) == 0: for _ in range(2): weight = {} @@ -209,9 +214,13 @@ class FSTPOverlapHandler: def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613 block_index = self.module_to_index[module] - # start the all-gather for next block - if block_index + 1 < gpc.config.NUM_LAYER: - self._all_gather_block_weight_memory_pool(block_index + 1) + if self.model_checkpoint and self.is_forward is False: + if block_index - 1 >= 0: + self._all_gather_block_weight_memory_pool(block_index - 1) + else: + # start the all-gather for next block + if block_index + 1 < gpc.config.NUM_LAYER: + self._all_gather_block_weight_memory_pool(block_index + 1) def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613 handle = self.fstp_global_handle[module] @@ -234,6 +243,9 @@ class FSTPOverlapHandler: ) self.fstp_global_handle[first_backward_module] = weight_handle + def _pre_backward_hook_for_head(module: nn.Module, grad_output): + self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1) + def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613 # wait handle for current module weight_handle = self.fstp_global_handle[module] @@ -264,6 +276,10 @@ class FSTPOverlapHandler: for embedding in self.embedding: embedding.register_forward_hook(_post_forward_hook_for_embedding) + if self.model_checkpoint and self.is_forward is False: + for head in self.head: + head.register_full_backward_pre_hook(_pre_backward_hook_for_head) + for out_proj in self.fstp_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) @@ -275,9 +291,42 @@ class FSTPOverlapHandler: # 1. register post_backward_hook @head module to prefetch for the last block's last module # 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module # 3. register post_backward_hook @fstp_module to release resource - for head in self.head: - head.register_full_backward_hook(_post_backward_hook_for_head) + if gpc.config.model.checkpoint is False: + for head in self.head: + head.register_full_backward_hook(_post_backward_hook_for_head) - for module in self.fstp_modules: - module.register_full_backward_pre_hook(_pre_backward_hook_for_module) - module.register_full_backward_hook(_post_backward_hook_for_module) + for module in self.fstp_modules: + module.register_full_backward_pre_hook(_pre_backward_hook_for_module) + module.register_full_backward_hook(_post_backward_hook_for_module) + + +class FSTPOverlapSchedulerHook(SchedulerHook): + """ + SchedulerHook for fstp overlap handler + """ + + def __init__(self, overlap_handler: FSTPOverlapHandler) -> None: + super().__init__() + + self._overlap_handler = overlap_handler + + def before_forward(self, scheduler, inputs) -> None: + self._overlap_handler.set_forward_mode(True) + + def after_forward(self, scheduler, outputs) -> None: + pass + + def before_criterion(self, scheduler, outputs, label) -> None: + pass + + def after_criterion(self, scheduler, loss) -> None: + pass + + def before_backward(self, scheduler, outputs, outputs_grad) -> None: + self._overlap_handler.set_forward_mode(False) + + def after_backward(self, scheduler, inputs_grad) -> None: + pass + + def post_helper_func(self, scheduler, outputs, label) -> None: + pass diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index f708fa7..c6e27a6 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -6,8 +6,7 @@ from tqdm import tqdm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.scheduler import SchedulerMetricHook -from internlm.model.metrics import AccPerplex +from internlm.model.metrics import AccPerplex, SchedulerMetricHook @contextmanager diff --git a/train.py b/train.py index b4f2a6d..ae86728 100644 --- a/train.py +++ b/train.py @@ -5,6 +5,7 @@ import socket import time import traceback from functools import partial +from typing import List, Optional import torch import torch.distributed as dist @@ -12,11 +13,12 @@ import torch.distributed as dist import internlm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.scheduler import SchedulerMetricHook +from internlm.core.scheduler import SchedulerHook from internlm.core.trainer import TrainState from internlm.initialize import initialize_distributed_env from internlm.model.loss import FlashGPTLMLoss -from internlm.model.metrics import AccPerplex +from internlm.model.metrics import AccPerplex, SchedulerMetricHook +from internlm.model.overlap_handler import FSTPOverlapSchedulerHook from internlm.monitor import initialize_monitor_manager, send_alert_message from internlm.monitor.monitor import monitor_manager as mm from internlm.train import ( @@ -67,6 +69,30 @@ def initialize_llm_logger(start_time: str): return uniscale_logger +def get_scheduler_hooks( + metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False +) -> List[SchedulerHook]: + scheduler_hooks: List[SchedulerHook] = [] + + if metric is not None: + scheduler_hooks.append( + SchedulerMetricHook( + metric=metric, + skip=( + gpc.is_using_pp() + and hasattr(gpc.config.model, "num_chunks") + and gpc.config.model.num_chunks > 1 + and gpc.config.parallel["pipeline"].get("interleaved_overlap", False) + ), + ), + ) + + if activation_checkpoint: + scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler)) + + return scheduler_hooks + + def main(args): # init setting skip_batches = gpc.config.data.skip_batches @@ -149,17 +175,6 @@ def main(args): ) # initialize trainer - scheduler_hooks = [ - SchedulerMetricHook( - metric=metric, - skip=( - gpc.is_using_pp() - and hasattr(gpc.config.model, "num_chunks") - and gpc.config.model.num_chunks > 1 - and gpc.config.parallel["pipeline"].get("interleaved_overlap", False) - ), - ), - ] trainer, train_dl, _, _ = internlm.initialize_trainer( model=model, @@ -168,7 +183,7 @@ def main(args): train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, - scheduler_hooks=scheduler_hooks, + scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint), ) # initialize simple memory profiler