support model activation checkpoint

pull/456/head
yingtongxiong 2023-10-24 16:13:52 +08:00
parent 0996c47e49
commit 97dcefc389
6 changed files with 131 additions and 68 deletions

View File

@ -1,4 +1,4 @@
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook from .base_scheduler import BaseScheduler, SchedulerHook
from .no_pipeline_scheduler import NonPipelineScheduler from .no_pipeline_scheduler import NonPipelineScheduler
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
@ -8,5 +8,4 @@ __all__ = [
"InterleavedPipelineScheduler", "InterleavedPipelineScheduler",
"PipelineScheduler", "PipelineScheduler",
"SchedulerHook", "SchedulerHook",
"SchedulerMetricHook",
] ]

View File

@ -4,12 +4,11 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, Optional from typing import Any, Callable, Iterable
import torch import torch
from internlm.core.engine import Engine from internlm.core.engine import Engine
from internlm.utils.megatron_timers import megatron_timer as timer
class BaseScheduler(ABC): class BaseScheduler(ABC):
@ -147,41 +146,3 @@ class SchedulerHook(ABC):
@abstractmethod @abstractmethod
def post_helper_func(self, scheduler, outputs, label) -> None: def post_helper_func(self, scheduler, outputs, label) -> None:
"""A post helper function""" """A post helper function"""
class SchedulerMetricHook(SchedulerHook):
"""
Scheduler Metric Hook.
"""
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
self._post_func = metric
self._skip = skip
def before_forward(self, scheduler, inputs) -> None:
if not self._skip:
timer("fwd").start()
def after_forward(self, scheduler, outputs) -> None:
if not self._skip:
timer("fwd").stop()
def before_criterion(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("cal_loss").start()
def after_criterion(self, scheduler, loss) -> None:
if not self._skip:
timer("cal_loss").stop()
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if not self._skip:
timer("bwd").start()
def after_backward(self, scheduler, inputs_grad) -> None:
if not self._skip:
timer("bwd").stop()
def post_helper_func(self, scheduler, outputs, label) -> None:
if self._post_func is not None:
self._post_func(outputs, label)

View File

@ -1,4 +1,4 @@
from typing import List from typing import Callable, List, Optional
import torch import torch
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
@ -6,6 +6,8 @@ from torch_scatter import scatter
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerHook
from internlm.utils.megatron_timers import megatron_timer as timer
class AccPerplex: class AccPerplex:
@ -260,3 +262,41 @@ class LossWithTypeId:
self.ds_token_num.fill_(0.0) self.ds_token_num.fill_(0.0)
return res return res
class SchedulerMetricHook(SchedulerHook):
"""
Scheduler Metric Hook.
"""
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
self._post_func = metric
self._skip = skip
def before_forward(self, scheduler, inputs) -> None:
if not self._skip:
timer("fwd").start()
def after_forward(self, scheduler, outputs) -> None:
if not self._skip:
timer("fwd").stop()
def before_criterion(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("cal_loss").start()
def after_criterion(self, scheduler, loss) -> None:
if not self._skip:
timer("cal_loss").stop()
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if not self._skip:
timer("bwd").start()
def after_backward(self, scheduler, inputs_grad) -> None:
if not self._skip:
timer("bwd").stop()
def post_helper_func(self, scheduler, outputs, label) -> None:
if self._post_func is not None:
self._post_func(outputs, label)

View File

@ -8,6 +8,7 @@ from torch import nn
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.scheduler import SchedulerHook
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
from internlm.model.utils import ( from internlm.model.utils import (
@ -33,6 +34,8 @@ class FSTPOverlapHandler:
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
self.head = [] self.head = []
self.embedding = [] self.embedding = []
self.model_checkpoint = gpc.config.model.checkpoint
self.is_forward = True
self.reduce_scatter_handlers = {} self.reduce_scatter_handlers = {}
self.zero_const_pool = {} self.zero_const_pool = {}
@ -81,6 +84,9 @@ class FSTPOverlapHandler:
return self.zero_const_pool[size] return self.zero_const_pool[size]
def set_forward_mode(self, flag):
self.is_forward = flag
def _initialize_module_shape(self): def _initialize_module_shape(self):
hidden_size = gpc.config.HIDDEN_SIZE hidden_size = gpc.config.HIDDEN_SIZE
mlp_ratio = gpc.config.MLP_RATIO mlp_ratio = gpc.config.MLP_RATIO
@ -121,7 +127,6 @@ class FSTPOverlapHandler:
def get_bias_memory(self, module: nn.Module): def get_bias_memory(self, module: nn.Module):
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
# if the bias memory pool is empty or module has been not allocated memory # if the bias memory pool is empty or module has been not allocated memory
# import pdb; pdb.set_trace()
if len(self.all_gather_bias_memory_pool) == 0: if len(self.all_gather_bias_memory_pool) == 0:
for _ in range(2): for _ in range(2):
weight = {} weight = {}
@ -209,9 +214,13 @@ class FSTPOverlapHandler:
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613 def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
# start the all-gather for next block if self.model_checkpoint and self.is_forward is False:
if block_index + 1 < gpc.config.NUM_LAYER: if block_index - 1 >= 0:
self._all_gather_block_weight_memory_pool(block_index + 1) self._all_gather_block_weight_memory_pool(block_index - 1)
else:
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight_memory_pool(block_index + 1)
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613 def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
handle = self.fstp_global_handle[module] handle = self.fstp_global_handle[module]
@ -234,6 +243,9 @@ class FSTPOverlapHandler:
) )
self.fstp_global_handle[first_backward_module] = weight_handle self.fstp_global_handle[first_backward_module] = weight_handle
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613 def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
# wait handle for current module # wait handle for current module
weight_handle = self.fstp_global_handle[module] weight_handle = self.fstp_global_handle[module]
@ -264,6 +276,10 @@ class FSTPOverlapHandler:
for embedding in self.embedding: for embedding in self.embedding:
embedding.register_forward_hook(_post_forward_hook_for_embedding) embedding.register_forward_hook(_post_forward_hook_for_embedding)
if self.model_checkpoint and self.is_forward is False:
for head in self.head:
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
for out_proj in self.fstp_outs: for out_proj in self.fstp_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
@ -275,9 +291,42 @@ class FSTPOverlapHandler:
# 1. register post_backward_hook @head module to prefetch for the last block's last module # 1. register post_backward_hook @head module to prefetch for the last block's last module
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module # 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
# 3. register post_backward_hook @fstp_module to release resource # 3. register post_backward_hook @fstp_module to release resource
for head in self.head: if gpc.config.model.checkpoint is False:
head.register_full_backward_hook(_post_backward_hook_for_head) for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)
for module in self.fstp_modules: for module in self.fstp_modules:
module.register_full_backward_pre_hook(_pre_backward_hook_for_module) module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
module.register_full_backward_hook(_post_backward_hook_for_module) module.register_full_backward_hook(_post_backward_hook_for_module)
class FSTPOverlapSchedulerHook(SchedulerHook):
"""
SchedulerHook for fstp overlap handler
"""
def __init__(self, overlap_handler: FSTPOverlapHandler) -> None:
super().__init__()
self._overlap_handler = overlap_handler
def before_forward(self, scheduler, inputs) -> None:
self._overlap_handler.set_forward_mode(True)
def after_forward(self, scheduler, outputs) -> None:
pass
def before_criterion(self, scheduler, outputs, label) -> None:
pass
def after_criterion(self, scheduler, loss) -> None:
pass
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
self._overlap_handler.set_forward_mode(False)
def after_backward(self, scheduler, inputs_grad) -> None:
pass
def post_helper_func(self, scheduler, outputs, label) -> None:
pass

View File

@ -6,8 +6,7 @@ from tqdm import tqdm
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook from internlm.model.metrics import AccPerplex, SchedulerMetricHook
from internlm.model.metrics import AccPerplex
@contextmanager @contextmanager

View File

@ -5,6 +5,7 @@ import socket
import time import time
import traceback import traceback
from functools import partial from functools import partial
from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -12,11 +13,12 @@ import torch.distributed as dist
import internlm import internlm
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook from internlm.core.scheduler import SchedulerHook
from internlm.core.trainer import TrainState from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex from internlm.model.metrics import AccPerplex, SchedulerMetricHook
from internlm.model.overlap_handler import FSTPOverlapSchedulerHook
from internlm.monitor import initialize_monitor_manager, send_alert_message from internlm.monitor import initialize_monitor_manager, send_alert_message
from internlm.monitor.monitor import monitor_manager as mm from internlm.monitor.monitor import monitor_manager as mm
from internlm.train import ( from internlm.train import (
@ -67,6 +69,30 @@ def initialize_llm_logger(start_time: str):
return uniscale_logger return uniscale_logger
def get_scheduler_hooks(
metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False
) -> List[SchedulerHook]:
scheduler_hooks: List[SchedulerHook] = []
if metric is not None:
scheduler_hooks.append(
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
)
if activation_checkpoint:
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler))
return scheduler_hooks
def main(args): def main(args):
# init setting # init setting
skip_batches = gpc.config.data.skip_batches skip_batches = gpc.config.data.skip_batches
@ -149,17 +175,6 @@ def main(args):
) )
# initialize trainer # initialize trainer
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
]
trainer, train_dl, _, _ = internlm.initialize_trainer( trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model, model=model,
@ -168,7 +183,7 @@ def main(args):
train_dataloader=train_dl, train_dataloader=train_dl,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler, beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks, scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint),
) )
# initialize simple memory profiler # initialize simple memory profiler