support model activation checkpoint

pull/456/head
yingtongxiong 2023-10-24 16:13:52 +08:00
parent 0996c47e49
commit 97dcefc389
6 changed files with 131 additions and 68 deletions

View File

@ -1,4 +1,4 @@
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
from .base_scheduler import BaseScheduler, SchedulerHook
from .no_pipeline_scheduler import NonPipelineScheduler
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
@ -8,5 +8,4 @@ __all__ = [
"InterleavedPipelineScheduler",
"PipelineScheduler",
"SchedulerHook",
"SchedulerMetricHook",
]

View File

@ -4,12 +4,11 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, Optional
from typing import Any, Callable, Iterable
import torch
from internlm.core.engine import Engine
from internlm.utils.megatron_timers import megatron_timer as timer
class BaseScheduler(ABC):
@ -147,41 +146,3 @@ class SchedulerHook(ABC):
@abstractmethod
def post_helper_func(self, scheduler, outputs, label) -> None:
"""A post helper function"""
class SchedulerMetricHook(SchedulerHook):
"""
Scheduler Metric Hook.
"""
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
self._post_func = metric
self._skip = skip
def before_forward(self, scheduler, inputs) -> None:
if not self._skip:
timer("fwd").start()
def after_forward(self, scheduler, outputs) -> None:
if not self._skip:
timer("fwd").stop()
def before_criterion(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("cal_loss").start()
def after_criterion(self, scheduler, loss) -> None:
if not self._skip:
timer("cal_loss").stop()
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if not self._skip:
timer("bwd").start()
def after_backward(self, scheduler, inputs_grad) -> None:
if not self._skip:
timer("bwd").stop()
def post_helper_func(self, scheduler, outputs, label) -> None:
if self._post_func is not None:
self._post_func(outputs, label)

View File

@ -1,4 +1,4 @@
from typing import List
from typing import Callable, List, Optional
import torch
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
@ -6,6 +6,8 @@ from torch_scatter import scatter
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerHook
from internlm.utils.megatron_timers import megatron_timer as timer
class AccPerplex:
@ -260,3 +262,41 @@ class LossWithTypeId:
self.ds_token_num.fill_(0.0)
return res
class SchedulerMetricHook(SchedulerHook):
"""
Scheduler Metric Hook.
"""
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
self._post_func = metric
self._skip = skip
def before_forward(self, scheduler, inputs) -> None:
if not self._skip:
timer("fwd").start()
def after_forward(self, scheduler, outputs) -> None:
if not self._skip:
timer("fwd").stop()
def before_criterion(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("cal_loss").start()
def after_criterion(self, scheduler, loss) -> None:
if not self._skip:
timer("cal_loss").stop()
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if not self._skip:
timer("bwd").start()
def after_backward(self, scheduler, inputs_grad) -> None:
if not self._skip:
timer("bwd").stop()
def post_helper_func(self, scheduler, outputs, label) -> None:
if self._post_func is not None:
self._post_func(outputs, label)

View File

@ -8,6 +8,7 @@ from torch import nn
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.scheduler import SchedulerHook
from internlm.model.embedding import Embedding1D
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
from internlm.model.utils import (
@ -33,6 +34,8 @@ class FSTPOverlapHandler:
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
self.head = []
self.embedding = []
self.model_checkpoint = gpc.config.model.checkpoint
self.is_forward = True
self.reduce_scatter_handlers = {}
self.zero_const_pool = {}
@ -81,6 +84,9 @@ class FSTPOverlapHandler:
return self.zero_const_pool[size]
def set_forward_mode(self, flag):
self.is_forward = flag
def _initialize_module_shape(self):
hidden_size = gpc.config.HIDDEN_SIZE
mlp_ratio = gpc.config.MLP_RATIO
@ -121,7 +127,6 @@ class FSTPOverlapHandler:
def get_bias_memory(self, module: nn.Module):
block_index = self.module_to_index[module]
# if the bias memory pool is empty or module has been not allocated memory
# import pdb; pdb.set_trace()
if len(self.all_gather_bias_memory_pool) == 0:
for _ in range(2):
weight = {}
@ -209,9 +214,13 @@ class FSTPOverlapHandler:
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
block_index = self.module_to_index[module]
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight_memory_pool(block_index + 1)
if self.model_checkpoint and self.is_forward is False:
if block_index - 1 >= 0:
self._all_gather_block_weight_memory_pool(block_index - 1)
else:
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight_memory_pool(block_index + 1)
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
handle = self.fstp_global_handle[module]
@ -234,6 +243,9 @@ class FSTPOverlapHandler:
)
self.fstp_global_handle[first_backward_module] = weight_handle
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
# wait handle for current module
weight_handle = self.fstp_global_handle[module]
@ -264,6 +276,10 @@ class FSTPOverlapHandler:
for embedding in self.embedding:
embedding.register_forward_hook(_post_forward_hook_for_embedding)
if self.model_checkpoint and self.is_forward is False:
for head in self.head:
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
for out_proj in self.fstp_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
@ -275,9 +291,42 @@ class FSTPOverlapHandler:
# 1. register post_backward_hook @head module to prefetch for the last block's last module
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
# 3. register post_backward_hook @fstp_module to release resource
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)
if gpc.config.model.checkpoint is False:
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)
for module in self.fstp_modules:
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
module.register_full_backward_hook(_post_backward_hook_for_module)
for module in self.fstp_modules:
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
module.register_full_backward_hook(_post_backward_hook_for_module)
class FSTPOverlapSchedulerHook(SchedulerHook):
"""
SchedulerHook for fstp overlap handler
"""
def __init__(self, overlap_handler: FSTPOverlapHandler) -> None:
super().__init__()
self._overlap_handler = overlap_handler
def before_forward(self, scheduler, inputs) -> None:
self._overlap_handler.set_forward_mode(True)
def after_forward(self, scheduler, outputs) -> None:
pass
def before_criterion(self, scheduler, outputs, label) -> None:
pass
def after_criterion(self, scheduler, loss) -> None:
pass
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
self._overlap_handler.set_forward_mode(False)
def after_backward(self, scheduler, inputs_grad) -> None:
pass
def post_helper_func(self, scheduler, outputs, label) -> None:
pass

View File

@ -6,8 +6,7 @@ from tqdm import tqdm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.model.metrics import AccPerplex
from internlm.model.metrics import AccPerplex, SchedulerMetricHook
@contextmanager

View File

@ -5,6 +5,7 @@ import socket
import time
import traceback
from functools import partial
from typing import List, Optional
import torch
import torch.distributed as dist
@ -12,11 +13,12 @@ import torch.distributed as dist
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.scheduler import SchedulerHook
from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.model.metrics import AccPerplex, SchedulerMetricHook
from internlm.model.overlap_handler import FSTPOverlapSchedulerHook
from internlm.monitor import initialize_monitor_manager, send_alert_message
from internlm.monitor.monitor import monitor_manager as mm
from internlm.train import (
@ -67,6 +69,30 @@ def initialize_llm_logger(start_time: str):
return uniscale_logger
def get_scheduler_hooks(
metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False
) -> List[SchedulerHook]:
scheduler_hooks: List[SchedulerHook] = []
if metric is not None:
scheduler_hooks.append(
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
)
if activation_checkpoint:
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler))
return scheduler_hooks
def main(args):
# init setting
skip_batches = gpc.config.data.skip_batches
@ -149,17 +175,6 @@ def main(args):
)
# initialize trainer
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
]
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
@ -168,7 +183,7 @@ def main(args):
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks,
scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint),
)
# initialize simple memory profiler