feat(memory_profiler): improve memory profiler (#217)

pull/218/head^2
cx 2023-08-23 14:18:33 +08:00 committed by GitHub
parent 29779c75f0
commit a48210f1f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 116 additions and 118 deletions

View File

@ -1,15 +1,13 @@
import os import os
import time import time
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial, reduce
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import pyecharts import pyecharts
import torch import torch
from internlm.core.context import ParallelMode from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.context import global_context as gpc
from internlm.solver.pipeline_utils import partition_uniform
mb = 1024 * 1024 mb = 1024 * 1024
@ -107,6 +105,8 @@ class SimpleMemState:
""" """
Update the total memory usage of the model and sub-models. Update the total memory usage of the model and sub-models.
""" """
self._total_mem = self._layer_mem
for stat in self.sub_model_stats.values(): for stat in self.sub_model_stats.values():
# Update sub-model status first. # Update sub-model status first.
stat.update_total_memory() stat.update_total_memory()
@ -169,6 +169,39 @@ class SimpleMemState:
return {"name": self.layer_name, "children": children} return {"name": self.layer_name, "children": children}
class ActivationMemState:
"""
Activation Memory State
"""
def __init__(self, num_chunks: int) -> None:
self._num_chunks = num_chunks
self.inited: List[bool] = [False for _ in range(num_chunks)]
self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)]
@property
def total_mem(self) -> int:
return sum(state.total_mem for state in self.states)
def dump(self, prefix: str = "") -> str:
return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states])
def to_json(self, base: int = 1024 * 1024) -> List:
return [state.to_json(base) for state in self.states]
def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]:
num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1
if num_chunks > 1:
model = torch.nn.ModuleList([_model.model if isinstance(_model, NaiveAMPModel) else _model for _model in model])
else:
model = model.model if isinstance(model, NaiveAMPModel) else model
return model, num_chunks
class SimpleMemoryProfiler: class SimpleMemoryProfiler:
""" """
A memory profiler for a llm model. A memory profiler for a llm model.
@ -177,7 +210,7 @@ class SimpleMemoryProfiler:
model (torch.nn.Module): The model to profile. model (torch.nn.Module): The model to profile.
optimizer (torch.optim.Optimizer): The optimizer used for training the model. optimizer (torch.optim.Optimizer): The optimizer used for training the model.
log_file (str): The file to write the memory state information to. log_file (str): The file to write the memory state information to.
activation_config (List[str], optional): The list of activation layers to track. Defaults to None. total_steps: number of steps to trace.
""" """
def __init__( def __init__(
@ -186,9 +219,8 @@ class SimpleMemoryProfiler:
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
log_folder: str, log_folder: str,
total_steps: int = 5, total_steps: int = 5,
activation_config: List[str] = None,
): ):
self._model = model self._model, self._num_model_chunks = _unpack_naive_wrapper(model)
self._optimizer = optimizer self._optimizer = optimizer
self._log_folder = log_folder self._log_folder = log_folder
self._remaining_steps = total_steps self._remaining_steps = total_steps
@ -197,17 +229,20 @@ class SimpleMemoryProfiler:
self._record_start_time = time.time() self._record_start_time = time.time()
# For activation memory state. # For activation memory state.
self._activation_config = activation_config
self._activation_mem_inited: bool = False
self._activation_mem: int = 0 self._activation_mem: int = 0
self._activation_max_count = 0 self._activation_mem_max: int = 0
self._activation_base_mem: SimpleMemState = SimpleMemState("activations") self._activation_base_mems = ActivationMemState(self._num_model_chunks)
# Check or create log folder # Check or create log folder
os.makedirs(self._log_folder, exist_ok=True) os.makedirs(self._log_folder, exist_ok=True)
# Register activation memory tracking hooks # Register activation memory tracking hooks
self._register_activation_trace_hooks() if self._num_model_chunks > 1:
for chunk_id in range(self._num_model_chunks):
self._register_activation_trace_hooks(chunk_id, self._model[chunk_id])
else:
self._register_activation_trace_hooks(0, self._model)
# Calculate static parameter cuda memory # Calculate static parameter cuda memory
self._param_mem_state = SimpleMemState("param_mem") self._param_mem_state = SimpleMemState("param_mem")
@ -221,7 +256,7 @@ class SimpleMemoryProfiler:
self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups))) self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups)))
# Generate the first memory record # Generate the first memory record
self.point(create=True) self.point(with_options="params,grads,os_params", create=True)
def point(self, with_options: str = "", create: bool = False) -> None: def point(self, with_options: str = "", create: bool = False) -> None:
""" """
@ -272,7 +307,7 @@ class SimpleMemoryProfiler:
if "os_state" in options: if "os_state" in options:
layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump() layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump()
if "activation_base" in options: if "activation_base" in options:
layout_info += "activation_base_layout:\n" + self._activation_base_mem.dump() layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump()
# Write memory state information to log file # Write memory state information to log file
file_mode = "w" if create else "a" file_mode = "w" if create else "a"
@ -315,14 +350,14 @@ class SimpleMemoryProfiler:
[self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()], [self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()],
"os_memory_sunburst", "os_memory_sunburst",
) )
self._render_sunburst_chart(self._activation_base_mem.to_json()["children"], "activation_memory_sunburst") self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst")
# Generate summary sunburst chart # Generate summary sunburst chart
summary_sunburst_data = [ summary_sunburst_data = [
{"name": "params", "value": self._param_mem_state.total_mem // mb}, {"name": "params", "value": self._param_mem_state.total_mem // mb},
{"name": "grads", "value": self._grad_mem_state.total_mem // mb}, {"name": "grads", "value": self._grad_mem_state.total_mem // mb},
{"name": "os_params", "value": self._os_params_mem_state.total_mem // mb}, {"name": "os_params", "value": self._os_params_mem_state.total_mem // mb},
{"name": "os_state", "value": self._os_state_mem_state.total_mem // mb}, {"name": "os_state", "value": self._os_state_mem_state.total_mem // mb},
{"name": "activation", "value": self._activation_base_mem.total_mem // mb}, {"name": "activation", "value": self._activation_mem_max // mb},
] ]
self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst") self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst")
@ -337,12 +372,13 @@ class SimpleMemoryProfiler:
{}, {},
{ {
"r0": "10%", "r0": "10%",
"r": "40%", "r": "35%",
"itemStyle": {"borderWidth": 3}, "itemStyle": {"borderWidth": 3},
"label": {"align": "left"}, "label": {"align": "left"},
}, },
{"r0": "40%", "r": "65%", "label": {"align": "left"}}, {"r0": "35%", "r": "55%", "label": {"align": "left"}},
{"r0": "65%", "r": "80%", "label": {"align": "left"}}, {"r0": "55%", "r": "70%", "label": {"align": "left"}},
{"r0": "70%", "r": "80%", "label": {"align": "left"}},
{"r0": "80%", "r": "90%", "label": {"align": "left"}}, {"r0": "80%", "r": "90%", "label": {"align": "left"}},
{ {
"r0": "90%", "r0": "90%",
@ -357,7 +393,14 @@ class SimpleMemoryProfiler:
f"{self._log_folder}/{name}.html" f"{self._log_folder}/{name}.html"
) )
def _inner_activation_trace_hook(self, layer_name: str, model: Any, inputs: Any, output: torch.Tensor) -> None: def _inner_activation_trace_hook(
self,
chunk_id: int,
layer_name: str,
model: Any,
inputs: Any,
output: torch.Tensor,
) -> None:
""" """
Hook function to trace the activation memory usage for a inner layer. Hook function to trace the activation memory usage for a inner layer.
@ -373,13 +416,15 @@ class SimpleMemoryProfiler:
del model, inputs del model, inputs
assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}" assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}"
if self._stoped or self._activation_mem_inited: if self._stoped or self._activation_base_mems.inited[chunk_id]:
return return
# Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook. # Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook.
self._activation_base_mem.add(layer_name, output.element_size() * output.nelement(), flush=False) self._activation_base_mems.states[chunk_id].add(
layer_name, output.element_size() * output.nelement(), flush=False
)
def _activation_trace_hook_forward(self, model: Any, inputs: Any, output: torch.Tensor) -> None: def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
""" """
Hook function to trace the activation memory usage for a forward pass. Hook function to trace the activation memory usage for a forward pass.
@ -398,23 +443,24 @@ class SimpleMemoryProfiler:
return return
# Check if the activation memory has been initialized # Check if the activation memory has been initialized
if self._activation_mem_inited is False: if self._activation_base_mems.inited[chunk_id] is False:
self._activation_base_mems.inited[chunk_id] = True
# Update the total memory of the activation base memory state # Update the total memory of the activation base memory state
self._activation_base_mem.update_total_memory() self._activation_base_mems.states[chunk_id].update_total_memory()
# Set with_options to "activation_base" to include activation_base_layout in the memory dump # Set with_options to "activation_base" to include activation_base_layout in the memory dump
self._activation_mem_inited = True with_options = "activation_base"
else:
with_options = ""
# Accumulate activation memory usage for each forward pass # Accumulate activation memory usage for each forward pass
self._activation_mem += self._activation_base_mem.total_mem self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem
if self._activation_mem > self._activation_mem_max:
# Update activation max count self._activation_mem_max = self._activation_mem
if self._activation_mem // self._activation_base_mem.total_mem > self._activation_max_count:
self._activation_max_count = self._activation_mem // self._activation_base_mem.total_mem
# Trigger a memory record # Trigger a memory record
self.point() self.point(with_options)
def _activation_tarce_hook_backward(self, model: Any, inputs: Any, grad_outputs: Any) -> None: def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None:
""" """
Hook function to trace the activation memory usage for a backward pass. Hook function to trace the activation memory usage for a backward pass.
@ -432,37 +478,28 @@ class SimpleMemoryProfiler:
return return
# Release activation memory usage for each backward pass # Release activation memory usage for each backward pass
self._activation_mem -= self._activation_base_mem.total_mem self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem
# Trigger a memory record # Trigger a memory record
self.point() self.point()
def _register_activation_trace_hooks(self) -> None: def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None:
""" """
Register activation trace hooks for the model and each submodule in the model. Register activation trace hooks for the model and each submodule in the model.
""" """
# Register inner activation trace hooks for each submodule in the model # Register inner activation trace hooks for each submodule in the model
for layer_name in self._activation_config: for layer_name, sub_model in model_chunk.named_modules():
# Register a hook for every activation
model = self._model
sub_models = layer_name.split(".")
# Get the target sub-model
for sub_model_name in sub_models:
try:
model = model.get_submodule(sub_model_name)
except AttributeError:
model = None
break
# Register the hook # Register the hook
if model is not None: if len(sub_model._modules) != 0:
model.register_forward_hook(partial(self._inner_activation_trace_hook, layer_name)) continue # TODO: in some special cases, we may need some additional configuration to correct
sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name))
# Register a forward hook for the main model to track activation memory usage # Register a forward hook for the main model to track activation memory usage
self._model.register_forward_hook(self._activation_trace_hook_forward) model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id))
# Register a backward hook for the main model to release activation memory usage # Register a backward hook for the main model to release activation memory usage
self._model.register_full_backward_hook(self._activation_tarce_hook_backward) model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id))
def _calc_tensor_memory( def _calc_tensor_memory(
self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False
@ -554,48 +591,6 @@ class SimpleMemoryProfiler:
self._calc_tensor_memory(root_stat, named_tensors) self._calc_tensor_memory(root_stat, named_tensors)
def build_activation_config(num_layers: int, num_chunks: int = 1) -> List[str]:
# TODO: support interleaved pipeline scheduling.
assert num_chunks == 1, "Only support num_chunks == 1"
if gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
else:
pipeline_size = 1
pipeline_rank = 0
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank]
start, end = parts[0]
num_blocks = end - start
block_conf_tmpl = [
"mixer.rotary_emb",
"mixer.Wqkv",
"mixer.inner_attn",
"mixer.inner_cross_attn",
"mixer.out_proj",
# "dropout1", # skip when dropout_selective_checkpoint is True
# "dropout2", # skip when dropout_selective_checkpoint is True
"norm1",
"norm2",
"mlp.w1",
"mlp.w2",
"mlp.w3",
]
block_conf = []
for block_id in range(num_blocks):
block_conf += [f"blocks.{block_id}.{layer}" for layer in block_conf_tmpl]
# We don't need to care about whether the embedding, norm, and head layers exist in the model after partitioning.
# If they don't exist, they will be automatically ignored when registering activation trace hooks.
activation_conf = ["embedding", "norm", "head"] + block_conf
return activation_conf
if __name__ == "__main__": if __name__ == "__main__":
class SimpleModel(torch.nn.Module): class SimpleModel(torch.nn.Module):
@ -635,32 +630,39 @@ if __name__ == "__main__":
return output return output
def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor:
if _num_chunks > 1:
_output = _input
for _model_chunk in _model_chunks:
_output = _model_chunk(_output)
else:
_output = _model_chunks(_input)
return _output
# num_chunks config
_num_chunks = 1
# init model and optimizer # init model and optimizer
_model: torch.nn.Module = SimpleModel() if _num_chunks > 1:
_chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)]
_model = torch.nn.ModuleList(_chunks).cuda()
else:
_model: torch.nn.Module = SimpleModel().cuda()
_optimizer = torch.optim.Adam(_model.parameters()) _optimizer = torch.optim.Adam(_model.parameters())
# create activation config for simple model layer by layer.
activation_configs = [
# model level 0
"layer1",
"layer2",
"layer3",
# model level 1
"layer2.layer1",
"layer2.layer3",
]
_model.modules()
# init profiler # init profiler
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler.log", activation_configs) profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler", total_steps=1)
_optimizer.zero_grad() _optimizer.zero_grad()
x1 = torch.randn((128, 5120)) # inputs
x2 = torch.randn((128, 5120)) x1 = torch.randn((128, 5120)).cuda()
out1 = _model(x1) x2 = torch.randn((128, 5120)).cuda()
out2 = _model(x2) # forward
out1 = _simple_schedule(_num_chunks, _model, x1)
out2 = _simple_schedule(_num_chunks, _model, x2)
# backward
out1.mean().backward() out1.mean().backward()
out2.mean().backward() out2.mean().backward()

View File

@ -55,10 +55,7 @@ from internlm.utils.parallel import (
sync_model_param_within_tp, sync_model_param_within_tp,
) )
from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.simple_memory_profiler import ( from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
SimpleMemoryProfiler,
build_activation_config,
)
from internlm.utils.writer import Writer from internlm.utils.writer import Writer
# global llm logger # global llm logger
@ -556,12 +553,11 @@ def main(args):
# initialize simple memory profiler # initialize simple memory profiler
if args.profiling: if args.profiling:
memory_profiler = SimpleMemoryProfiler( memory_profiler = SimpleMemoryProfiler(
model.model, model,
optimizer.optim, optimizer.optim,
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_" log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
activation_config=build_activation_config(gpc.config.model.num_layers),
) )
else: else:
memory_profiler = None memory_profiler = None