mirror of https://github.com/InternLM/InternLM
support model activation checkpoint
parent
0996c47e49
commit
97dcefc389
|
@ -1,4 +1,4 @@
|
|||
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
from .no_pipeline_scheduler import NonPipelineScheduler
|
||||
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
|
||||
|
||||
|
@ -8,5 +8,4 @@ __all__ = [
|
|||
"InterleavedPipelineScheduler",
|
||||
"PipelineScheduler",
|
||||
"SchedulerHook",
|
||||
"SchedulerMetricHook",
|
||||
]
|
||||
|
|
|
@ -4,12 +4,11 @@
|
|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
|
||||
class BaseScheduler(ABC):
|
||||
|
@ -147,41 +146,3 @@ class SchedulerHook(ABC):
|
|||
@abstractmethod
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
"""A post helper function"""
|
||||
|
||||
|
||||
class SchedulerMetricHook(SchedulerHook):
|
||||
"""
|
||||
Scheduler Metric Hook.
|
||||
"""
|
||||
|
||||
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
|
||||
self._post_func = metric
|
||||
self._skip = skip
|
||||
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").start()
|
||||
|
||||
def after_forward(self, scheduler, outputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").stop()
|
||||
|
||||
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||
if not self._skip:
|
||||
timer("cal_loss").start()
|
||||
|
||||
def after_criterion(self, scheduler, loss) -> None:
|
||||
if not self._skip:
|
||||
timer("cal_loss").stop()
|
||||
|
||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||
if not self._skip:
|
||||
timer("bwd").start()
|
||||
|
||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||
if not self._skip:
|
||||
timer("bwd").stop()
|
||||
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
if self._post_func is not None:
|
||||
self._post_func(outputs, label)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
||||
|
@ -6,6 +6,8 @@ from torch_scatter import scatter
|
|||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.scheduler import SchedulerHook
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
|
||||
class AccPerplex:
|
||||
|
@ -260,3 +262,41 @@ class LossWithTypeId:
|
|||
self.ds_token_num.fill_(0.0)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class SchedulerMetricHook(SchedulerHook):
|
||||
"""
|
||||
Scheduler Metric Hook.
|
||||
"""
|
||||
|
||||
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
|
||||
self._post_func = metric
|
||||
self._skip = skip
|
||||
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").start()
|
||||
|
||||
def after_forward(self, scheduler, outputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").stop()
|
||||
|
||||
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||
if not self._skip:
|
||||
timer("cal_loss").start()
|
||||
|
||||
def after_criterion(self, scheduler, loss) -> None:
|
||||
if not self._skip:
|
||||
timer("cal_loss").stop()
|
||||
|
||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||
if not self._skip:
|
||||
timer("bwd").start()
|
||||
|
||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||
if not self._skip:
|
||||
timer("bwd").stop()
|
||||
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
if self._post_func is not None:
|
||||
self._post_func(outputs, label)
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch import nn
|
|||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.scheduler import SchedulerHook
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
||||
from internlm.model.utils import (
|
||||
|
@ -33,6 +34,8 @@ class FSTPOverlapHandler:
|
|||
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||
self.head = []
|
||||
self.embedding = []
|
||||
self.model_checkpoint = gpc.config.model.checkpoint
|
||||
self.is_forward = True
|
||||
|
||||
self.reduce_scatter_handlers = {}
|
||||
self.zero_const_pool = {}
|
||||
|
@ -81,6 +84,9 @@ class FSTPOverlapHandler:
|
|||
|
||||
return self.zero_const_pool[size]
|
||||
|
||||
def set_forward_mode(self, flag):
|
||||
self.is_forward = flag
|
||||
|
||||
def _initialize_module_shape(self):
|
||||
hidden_size = gpc.config.HIDDEN_SIZE
|
||||
mlp_ratio = gpc.config.MLP_RATIO
|
||||
|
@ -121,7 +127,6 @@ class FSTPOverlapHandler:
|
|||
def get_bias_memory(self, module: nn.Module):
|
||||
block_index = self.module_to_index[module]
|
||||
# if the bias memory pool is empty or module has been not allocated memory
|
||||
# import pdb; pdb.set_trace()
|
||||
if len(self.all_gather_bias_memory_pool) == 0:
|
||||
for _ in range(2):
|
||||
weight = {}
|
||||
|
@ -209,9 +214,13 @@ class FSTPOverlapHandler:
|
|||
|
||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
block_index = self.module_to_index[module]
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||
if self.model_checkpoint and self.is_forward is False:
|
||||
if block_index - 1 >= 0:
|
||||
self._all_gather_block_weight_memory_pool(block_index - 1)
|
||||
else:
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
handle = self.fstp_global_handle[module]
|
||||
|
@ -234,6 +243,9 @@ class FSTPOverlapHandler:
|
|||
)
|
||||
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||
|
||||
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
|
||||
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
|
||||
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
||||
# wait handle for current module
|
||||
weight_handle = self.fstp_global_handle[module]
|
||||
|
@ -264,6 +276,10 @@ class FSTPOverlapHandler:
|
|||
for embedding in self.embedding:
|
||||
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
||||
|
||||
if self.model_checkpoint and self.is_forward is False:
|
||||
for head in self.head:
|
||||
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
|
||||
|
||||
for out_proj in self.fstp_outs:
|
||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||
|
||||
|
@ -275,9 +291,42 @@ class FSTPOverlapHandler:
|
|||
# 1. register post_backward_hook @head module to prefetch for the last block's last module
|
||||
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
|
||||
# 3. register post_backward_hook @fstp_module to release resource
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
if gpc.config.model.checkpoint is False:
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
|
||||
for module in self.fstp_modules:
|
||||
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
||||
module.register_full_backward_hook(_post_backward_hook_for_module)
|
||||
for module in self.fstp_modules:
|
||||
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
||||
module.register_full_backward_hook(_post_backward_hook_for_module)
|
||||
|
||||
|
||||
class FSTPOverlapSchedulerHook(SchedulerHook):
|
||||
"""
|
||||
SchedulerHook for fstp overlap handler
|
||||
"""
|
||||
|
||||
def __init__(self, overlap_handler: FSTPOverlapHandler) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._overlap_handler = overlap_handler
|
||||
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
self._overlap_handler.set_forward_mode(True)
|
||||
|
||||
def after_forward(self, scheduler, outputs) -> None:
|
||||
pass
|
||||
|
||||
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||
pass
|
||||
|
||||
def after_criterion(self, scheduler, loss) -> None:
|
||||
pass
|
||||
|
||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||
self._overlap_handler.set_forward_mode(False)
|
||||
|
||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||
pass
|
||||
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
pass
|
||||
|
|
|
@ -6,8 +6,7 @@ from tqdm import tqdm
|
|||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.model.metrics import AccPerplex, SchedulerMetricHook
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
43
train.py
43
train.py
|
@ -5,6 +5,7 @@ import socket
|
|||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -12,11 +13,12 @@ import torch.distributed as dist
|
|||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.core.scheduler import SchedulerHook
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.model.metrics import AccPerplex, SchedulerMetricHook
|
||||
from internlm.model.overlap_handler import FSTPOverlapSchedulerHook
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.train import (
|
||||
|
@ -67,6 +69,30 @@ def initialize_llm_logger(start_time: str):
|
|||
return uniscale_logger
|
||||
|
||||
|
||||
def get_scheduler_hooks(
|
||||
metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False
|
||||
) -> List[SchedulerHook]:
|
||||
scheduler_hooks: List[SchedulerHook] = []
|
||||
|
||||
if metric is not None:
|
||||
scheduler_hooks.append(
|
||||
SchedulerMetricHook(
|
||||
metric=metric,
|
||||
skip=(
|
||||
gpc.is_using_pp()
|
||||
and hasattr(gpc.config.model, "num_chunks")
|
||||
and gpc.config.model.num_chunks > 1
|
||||
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if activation_checkpoint:
|
||||
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler))
|
||||
|
||||
return scheduler_hooks
|
||||
|
||||
|
||||
def main(args):
|
||||
# init setting
|
||||
skip_batches = gpc.config.data.skip_batches
|
||||
|
@ -149,17 +175,6 @@ def main(args):
|
|||
)
|
||||
|
||||
# initialize trainer
|
||||
scheduler_hooks = [
|
||||
SchedulerMetricHook(
|
||||
metric=metric,
|
||||
skip=(
|
||||
gpc.is_using_pp()
|
||||
and hasattr(gpc.config.model, "num_chunks")
|
||||
and gpc.config.model.num_chunks > 1
|
||||
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
||||
model=model,
|
||||
|
@ -168,7 +183,7 @@ def main(args):
|
|||
train_dataloader=train_dl,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint),
|
||||
)
|
||||
|
||||
# initialize simple memory profiler
|
||||
|
|
Loading…
Reference in New Issue