mirror of https://github.com/InternLM/InternLM
support model activation checkpoint
parent
0996c47e49
commit
97dcefc389
|
@ -1,4 +1,4 @@
|
||||||
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
|
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||||
from .no_pipeline_scheduler import NonPipelineScheduler
|
from .no_pipeline_scheduler import NonPipelineScheduler
|
||||||
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
|
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
|
||||||
|
|
||||||
|
@ -8,5 +8,4 @@ __all__ = [
|
||||||
"InterleavedPipelineScheduler",
|
"InterleavedPipelineScheduler",
|
||||||
"PipelineScheduler",
|
"PipelineScheduler",
|
||||||
"SchedulerHook",
|
"SchedulerHook",
|
||||||
"SchedulerMetricHook",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,12 +4,11 @@
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Iterable, Optional
|
from typing import Any, Callable, Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
|
||||||
|
|
||||||
|
|
||||||
class BaseScheduler(ABC):
|
class BaseScheduler(ABC):
|
||||||
|
@ -147,41 +146,3 @@ class SchedulerHook(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||||
"""A post helper function"""
|
"""A post helper function"""
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMetricHook(SchedulerHook):
|
|
||||||
"""
|
|
||||||
Scheduler Metric Hook.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
|
|
||||||
self._post_func = metric
|
|
||||||
self._skip = skip
|
|
||||||
|
|
||||||
def before_forward(self, scheduler, inputs) -> None:
|
|
||||||
if not self._skip:
|
|
||||||
timer("fwd").start()
|
|
||||||
|
|
||||||
def after_forward(self, scheduler, outputs) -> None:
|
|
||||||
if not self._skip:
|
|
||||||
timer("fwd").stop()
|
|
||||||
|
|
||||||
def before_criterion(self, scheduler, outputs, label) -> None:
|
|
||||||
if not self._skip:
|
|
||||||
timer("cal_loss").start()
|
|
||||||
|
|
||||||
def after_criterion(self, scheduler, loss) -> None:
|
|
||||||
if not self._skip:
|
|
||||||
timer("cal_loss").stop()
|
|
||||||
|
|
||||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
|
||||||
if not self._skip:
|
|
||||||
timer("bwd").start()
|
|
||||||
|
|
||||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
|
||||||
if not self._skip:
|
|
||||||
timer("bwd").stop()
|
|
||||||
|
|
||||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
|
||||||
if self._post_func is not None:
|
|
||||||
self._post_func(outputs, label)
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
||||||
|
@ -6,6 +6,8 @@ from torch_scatter import scatter
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.core.scheduler import SchedulerHook
|
||||||
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
|
||||||
|
|
||||||
class AccPerplex:
|
class AccPerplex:
|
||||||
|
@ -260,3 +262,41 @@ class LossWithTypeId:
|
||||||
self.ds_token_num.fill_(0.0)
|
self.ds_token_num.fill_(0.0)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMetricHook(SchedulerHook):
|
||||||
|
"""
|
||||||
|
Scheduler Metric Hook.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
|
||||||
|
self._post_func = metric
|
||||||
|
self._skip = skip
|
||||||
|
|
||||||
|
def before_forward(self, scheduler, inputs) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("fwd").start()
|
||||||
|
|
||||||
|
def after_forward(self, scheduler, outputs) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("fwd").stop()
|
||||||
|
|
||||||
|
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("cal_loss").start()
|
||||||
|
|
||||||
|
def after_criterion(self, scheduler, loss) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("cal_loss").stop()
|
||||||
|
|
||||||
|
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("bwd").start()
|
||||||
|
|
||||||
|
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("bwd").stop()
|
||||||
|
|
||||||
|
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||||
|
if self._post_func is not None:
|
||||||
|
self._post_func(outputs, label)
|
||||||
|
|
|
@ -8,6 +8,7 @@ from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
|
from internlm.core.scheduler import SchedulerHook
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
||||||
from internlm.model.utils import (
|
from internlm.model.utils import (
|
||||||
|
@ -33,6 +34,8 @@ class FSTPOverlapHandler:
|
||||||
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||||
self.head = []
|
self.head = []
|
||||||
self.embedding = []
|
self.embedding = []
|
||||||
|
self.model_checkpoint = gpc.config.model.checkpoint
|
||||||
|
self.is_forward = True
|
||||||
|
|
||||||
self.reduce_scatter_handlers = {}
|
self.reduce_scatter_handlers = {}
|
||||||
self.zero_const_pool = {}
|
self.zero_const_pool = {}
|
||||||
|
@ -81,6 +84,9 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
return self.zero_const_pool[size]
|
return self.zero_const_pool[size]
|
||||||
|
|
||||||
|
def set_forward_mode(self, flag):
|
||||||
|
self.is_forward = flag
|
||||||
|
|
||||||
def _initialize_module_shape(self):
|
def _initialize_module_shape(self):
|
||||||
hidden_size = gpc.config.HIDDEN_SIZE
|
hidden_size = gpc.config.HIDDEN_SIZE
|
||||||
mlp_ratio = gpc.config.MLP_RATIO
|
mlp_ratio = gpc.config.MLP_RATIO
|
||||||
|
@ -121,7 +127,6 @@ class FSTPOverlapHandler:
|
||||||
def get_bias_memory(self, module: nn.Module):
|
def get_bias_memory(self, module: nn.Module):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
# if the bias memory pool is empty or module has been not allocated memory
|
# if the bias memory pool is empty or module has been not allocated memory
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
if len(self.all_gather_bias_memory_pool) == 0:
|
if len(self.all_gather_bias_memory_pool) == 0:
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
weight = {}
|
weight = {}
|
||||||
|
@ -209,9 +214,13 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
# start the all-gather for next block
|
if self.model_checkpoint and self.is_forward is False:
|
||||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
if block_index - 1 >= 0:
|
||||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
self._all_gather_block_weight_memory_pool(block_index - 1)
|
||||||
|
else:
|
||||||
|
# start the all-gather for next block
|
||||||
|
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||||
|
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||||
|
|
||||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||||
handle = self.fstp_global_handle[module]
|
handle = self.fstp_global_handle[module]
|
||||||
|
@ -234,6 +243,9 @@ class FSTPOverlapHandler:
|
||||||
)
|
)
|
||||||
self.fstp_global_handle[first_backward_module] = weight_handle
|
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||||
|
|
||||||
|
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
|
||||||
|
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
|
||||||
|
|
||||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
||||||
# wait handle for current module
|
# wait handle for current module
|
||||||
weight_handle = self.fstp_global_handle[module]
|
weight_handle = self.fstp_global_handle[module]
|
||||||
|
@ -264,6 +276,10 @@ class FSTPOverlapHandler:
|
||||||
for embedding in self.embedding:
|
for embedding in self.embedding:
|
||||||
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
||||||
|
|
||||||
|
if self.model_checkpoint and self.is_forward is False:
|
||||||
|
for head in self.head:
|
||||||
|
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
|
||||||
|
|
||||||
for out_proj in self.fstp_outs:
|
for out_proj in self.fstp_outs:
|
||||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||||
|
|
||||||
|
@ -275,9 +291,42 @@ class FSTPOverlapHandler:
|
||||||
# 1. register post_backward_hook @head module to prefetch for the last block's last module
|
# 1. register post_backward_hook @head module to prefetch for the last block's last module
|
||||||
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
|
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
|
||||||
# 3. register post_backward_hook @fstp_module to release resource
|
# 3. register post_backward_hook @fstp_module to release resource
|
||||||
for head in self.head:
|
if gpc.config.model.checkpoint is False:
|
||||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
for head in self.head:
|
||||||
|
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||||
|
|
||||||
for module in self.fstp_modules:
|
for module in self.fstp_modules:
|
||||||
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
||||||
module.register_full_backward_hook(_post_backward_hook_for_module)
|
module.register_full_backward_hook(_post_backward_hook_for_module)
|
||||||
|
|
||||||
|
|
||||||
|
class FSTPOverlapSchedulerHook(SchedulerHook):
|
||||||
|
"""
|
||||||
|
SchedulerHook for fstp overlap handler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, overlap_handler: FSTPOverlapHandler) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._overlap_handler = overlap_handler
|
||||||
|
|
||||||
|
def before_forward(self, scheduler, inputs) -> None:
|
||||||
|
self._overlap_handler.set_forward_mode(True)
|
||||||
|
|
||||||
|
def after_forward(self, scheduler, outputs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_criterion(self, scheduler, loss) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||||
|
self._overlap_handler.set_forward_mode(False)
|
||||||
|
|
||||||
|
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||||
|
pass
|
||||||
|
|
|
@ -6,8 +6,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.scheduler import SchedulerMetricHook
|
from internlm.model.metrics import AccPerplex, SchedulerMetricHook
|
||||||
from internlm.model.metrics import AccPerplex
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
43
train.py
43
train.py
|
@ -5,6 +5,7 @@ import socket
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -12,11 +13,12 @@ import torch.distributed as dist
|
||||||
import internlm
|
import internlm
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.scheduler import SchedulerMetricHook
|
from internlm.core.scheduler import SchedulerHook
|
||||||
from internlm.core.trainer import TrainState
|
from internlm.core.trainer import TrainState
|
||||||
from internlm.initialize import initialize_distributed_env
|
from internlm.initialize import initialize_distributed_env
|
||||||
from internlm.model.loss import FlashGPTLMLoss
|
from internlm.model.loss import FlashGPTLMLoss
|
||||||
from internlm.model.metrics import AccPerplex
|
from internlm.model.metrics import AccPerplex, SchedulerMetricHook
|
||||||
|
from internlm.model.overlap_handler import FSTPOverlapSchedulerHook
|
||||||
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
||||||
from internlm.monitor.monitor import monitor_manager as mm
|
from internlm.monitor.monitor import monitor_manager as mm
|
||||||
from internlm.train import (
|
from internlm.train import (
|
||||||
|
@ -67,6 +69,30 @@ def initialize_llm_logger(start_time: str):
|
||||||
return uniscale_logger
|
return uniscale_logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler_hooks(
|
||||||
|
metric: Optional[AccPerplex] = None, activation_checkpoint: bool = False
|
||||||
|
) -> List[SchedulerHook]:
|
||||||
|
scheduler_hooks: List[SchedulerHook] = []
|
||||||
|
|
||||||
|
if metric is not None:
|
||||||
|
scheduler_hooks.append(
|
||||||
|
SchedulerMetricHook(
|
||||||
|
metric=metric,
|
||||||
|
skip=(
|
||||||
|
gpc.is_using_pp()
|
||||||
|
and hasattr(gpc.config.model, "num_chunks")
|
||||||
|
and gpc.config.model.num_chunks > 1
|
||||||
|
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if activation_checkpoint:
|
||||||
|
scheduler_hooks.append(FSTPOverlapSchedulerHook(gpc.fstp_handler))
|
||||||
|
|
||||||
|
return scheduler_hooks
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# init setting
|
# init setting
|
||||||
skip_batches = gpc.config.data.skip_batches
|
skip_batches = gpc.config.data.skip_batches
|
||||||
|
@ -149,17 +175,6 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize trainer
|
# initialize trainer
|
||||||
scheduler_hooks = [
|
|
||||||
SchedulerMetricHook(
|
|
||||||
metric=metric,
|
|
||||||
skip=(
|
|
||||||
gpc.is_using_pp()
|
|
||||||
and hasattr(gpc.config.model, "num_chunks")
|
|
||||||
and gpc.config.model.num_chunks > 1
|
|
||||||
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -168,7 +183,7 @@ def main(args):
|
||||||
train_dataloader=train_dl,
|
train_dataloader=train_dl,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
beta2_scheduler=beta2_scheduler,
|
beta2_scheduler=beta2_scheduler,
|
||||||
scheduler_hooks=scheduler_hooks,
|
scheduler_hooks=get_scheduler_hooks(metric, gpc.config.model.checkpoint),
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize simple memory profiler
|
# initialize simple memory profiler
|
||||||
|
|
Loading…
Reference in New Issue