mirror of https://github.com/InternLM/InternLM
feat(memory_profiler): improve memory profiler (#217)
parent
29779c75f0
commit
a48210f1f3
|
@ -1,15 +1,13 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial, reduce
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import pyecharts
|
import pyecharts
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
from internlm.core.context import global_context as gpc
|
|
||||||
from internlm.solver.pipeline_utils import partition_uniform
|
|
||||||
|
|
||||||
mb = 1024 * 1024
|
mb = 1024 * 1024
|
||||||
|
|
||||||
|
@ -107,6 +105,8 @@ class SimpleMemState:
|
||||||
"""
|
"""
|
||||||
Update the total memory usage of the model and sub-models.
|
Update the total memory usage of the model and sub-models.
|
||||||
"""
|
"""
|
||||||
|
self._total_mem = self._layer_mem
|
||||||
|
|
||||||
for stat in self.sub_model_stats.values():
|
for stat in self.sub_model_stats.values():
|
||||||
# Update sub-model status first.
|
# Update sub-model status first.
|
||||||
stat.update_total_memory()
|
stat.update_total_memory()
|
||||||
|
@ -169,6 +169,39 @@ class SimpleMemState:
|
||||||
return {"name": self.layer_name, "children": children}
|
return {"name": self.layer_name, "children": children}
|
||||||
|
|
||||||
|
|
||||||
|
class ActivationMemState:
|
||||||
|
"""
|
||||||
|
Activation Memory State
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_chunks: int) -> None:
|
||||||
|
self._num_chunks = num_chunks
|
||||||
|
|
||||||
|
self.inited: List[bool] = [False for _ in range(num_chunks)]
|
||||||
|
self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_mem(self) -> int:
|
||||||
|
return sum(state.total_mem for state in self.states)
|
||||||
|
|
||||||
|
def dump(self, prefix: str = "") -> str:
|
||||||
|
return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states])
|
||||||
|
|
||||||
|
def to_json(self, base: int = 1024 * 1024) -> List:
|
||||||
|
return [state.to_json(base) for state in self.states]
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]:
|
||||||
|
num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1
|
||||||
|
|
||||||
|
if num_chunks > 1:
|
||||||
|
model = torch.nn.ModuleList([_model.model if isinstance(_model, NaiveAMPModel) else _model for _model in model])
|
||||||
|
else:
|
||||||
|
model = model.model if isinstance(model, NaiveAMPModel) else model
|
||||||
|
|
||||||
|
return model, num_chunks
|
||||||
|
|
||||||
|
|
||||||
class SimpleMemoryProfiler:
|
class SimpleMemoryProfiler:
|
||||||
"""
|
"""
|
||||||
A memory profiler for a llm model.
|
A memory profiler for a llm model.
|
||||||
|
@ -177,7 +210,7 @@ class SimpleMemoryProfiler:
|
||||||
model (torch.nn.Module): The model to profile.
|
model (torch.nn.Module): The model to profile.
|
||||||
optimizer (torch.optim.Optimizer): The optimizer used for training the model.
|
optimizer (torch.optim.Optimizer): The optimizer used for training the model.
|
||||||
log_file (str): The file to write the memory state information to.
|
log_file (str): The file to write the memory state information to.
|
||||||
activation_config (List[str], optional): The list of activation layers to track. Defaults to None.
|
total_steps: number of steps to trace.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -186,9 +219,8 @@ class SimpleMemoryProfiler:
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
log_folder: str,
|
log_folder: str,
|
||||||
total_steps: int = 5,
|
total_steps: int = 5,
|
||||||
activation_config: List[str] = None,
|
|
||||||
):
|
):
|
||||||
self._model = model
|
self._model, self._num_model_chunks = _unpack_naive_wrapper(model)
|
||||||
self._optimizer = optimizer
|
self._optimizer = optimizer
|
||||||
self._log_folder = log_folder
|
self._log_folder = log_folder
|
||||||
self._remaining_steps = total_steps
|
self._remaining_steps = total_steps
|
||||||
|
@ -197,17 +229,20 @@ class SimpleMemoryProfiler:
|
||||||
self._record_start_time = time.time()
|
self._record_start_time = time.time()
|
||||||
|
|
||||||
# For activation memory state.
|
# For activation memory state.
|
||||||
self._activation_config = activation_config
|
|
||||||
self._activation_mem_inited: bool = False
|
|
||||||
self._activation_mem: int = 0
|
self._activation_mem: int = 0
|
||||||
self._activation_max_count = 0
|
self._activation_mem_max: int = 0
|
||||||
self._activation_base_mem: SimpleMemState = SimpleMemState("activations")
|
self._activation_base_mems = ActivationMemState(self._num_model_chunks)
|
||||||
|
|
||||||
# Check or create log folder
|
# Check or create log folder
|
||||||
os.makedirs(self._log_folder, exist_ok=True)
|
os.makedirs(self._log_folder, exist_ok=True)
|
||||||
|
|
||||||
# Register activation memory tracking hooks
|
# Register activation memory tracking hooks
|
||||||
self._register_activation_trace_hooks()
|
if self._num_model_chunks > 1:
|
||||||
|
for chunk_id in range(self._num_model_chunks):
|
||||||
|
self._register_activation_trace_hooks(chunk_id, self._model[chunk_id])
|
||||||
|
else:
|
||||||
|
self._register_activation_trace_hooks(0, self._model)
|
||||||
|
|
||||||
# Calculate static parameter cuda memory
|
# Calculate static parameter cuda memory
|
||||||
self._param_mem_state = SimpleMemState("param_mem")
|
self._param_mem_state = SimpleMemState("param_mem")
|
||||||
|
@ -221,7 +256,7 @@ class SimpleMemoryProfiler:
|
||||||
self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups)))
|
self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups)))
|
||||||
|
|
||||||
# Generate the first memory record
|
# Generate the first memory record
|
||||||
self.point(create=True)
|
self.point(with_options="params,grads,os_params", create=True)
|
||||||
|
|
||||||
def point(self, with_options: str = "", create: bool = False) -> None:
|
def point(self, with_options: str = "", create: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -272,7 +307,7 @@ class SimpleMemoryProfiler:
|
||||||
if "os_state" in options:
|
if "os_state" in options:
|
||||||
layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump()
|
layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump()
|
||||||
if "activation_base" in options:
|
if "activation_base" in options:
|
||||||
layout_info += "activation_base_layout:\n" + self._activation_base_mem.dump()
|
layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump()
|
||||||
|
|
||||||
# Write memory state information to log file
|
# Write memory state information to log file
|
||||||
file_mode = "w" if create else "a"
|
file_mode = "w" if create else "a"
|
||||||
|
@ -315,14 +350,14 @@ class SimpleMemoryProfiler:
|
||||||
[self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()],
|
[self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()],
|
||||||
"os_memory_sunburst",
|
"os_memory_sunburst",
|
||||||
)
|
)
|
||||||
self._render_sunburst_chart(self._activation_base_mem.to_json()["children"], "activation_memory_sunburst")
|
self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst")
|
||||||
# Generate summary sunburst chart
|
# Generate summary sunburst chart
|
||||||
summary_sunburst_data = [
|
summary_sunburst_data = [
|
||||||
{"name": "params", "value": self._param_mem_state.total_mem // mb},
|
{"name": "params", "value": self._param_mem_state.total_mem // mb},
|
||||||
{"name": "grads", "value": self._grad_mem_state.total_mem // mb},
|
{"name": "grads", "value": self._grad_mem_state.total_mem // mb},
|
||||||
{"name": "os_params", "value": self._os_params_mem_state.total_mem // mb},
|
{"name": "os_params", "value": self._os_params_mem_state.total_mem // mb},
|
||||||
{"name": "os_state", "value": self._os_state_mem_state.total_mem // mb},
|
{"name": "os_state", "value": self._os_state_mem_state.total_mem // mb},
|
||||||
{"name": "activation", "value": self._activation_base_mem.total_mem // mb},
|
{"name": "activation", "value": self._activation_mem_max // mb},
|
||||||
]
|
]
|
||||||
|
|
||||||
self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst")
|
self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst")
|
||||||
|
@ -337,12 +372,13 @@ class SimpleMemoryProfiler:
|
||||||
{},
|
{},
|
||||||
{
|
{
|
||||||
"r0": "10%",
|
"r0": "10%",
|
||||||
"r": "40%",
|
"r": "35%",
|
||||||
"itemStyle": {"borderWidth": 3},
|
"itemStyle": {"borderWidth": 3},
|
||||||
"label": {"align": "left"},
|
"label": {"align": "left"},
|
||||||
},
|
},
|
||||||
{"r0": "40%", "r": "65%", "label": {"align": "left"}},
|
{"r0": "35%", "r": "55%", "label": {"align": "left"}},
|
||||||
{"r0": "65%", "r": "80%", "label": {"align": "left"}},
|
{"r0": "55%", "r": "70%", "label": {"align": "left"}},
|
||||||
|
{"r0": "70%", "r": "80%", "label": {"align": "left"}},
|
||||||
{"r0": "80%", "r": "90%", "label": {"align": "left"}},
|
{"r0": "80%", "r": "90%", "label": {"align": "left"}},
|
||||||
{
|
{
|
||||||
"r0": "90%",
|
"r0": "90%",
|
||||||
|
@ -357,7 +393,14 @@ class SimpleMemoryProfiler:
|
||||||
f"{self._log_folder}/{name}.html"
|
f"{self._log_folder}/{name}.html"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _inner_activation_trace_hook(self, layer_name: str, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
def _inner_activation_trace_hook(
|
||||||
|
self,
|
||||||
|
chunk_id: int,
|
||||||
|
layer_name: str,
|
||||||
|
model: Any,
|
||||||
|
inputs: Any,
|
||||||
|
output: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Hook function to trace the activation memory usage for a inner layer.
|
Hook function to trace the activation memory usage for a inner layer.
|
||||||
|
|
||||||
|
@ -373,13 +416,15 @@ class SimpleMemoryProfiler:
|
||||||
del model, inputs
|
del model, inputs
|
||||||
assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}"
|
assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}"
|
||||||
|
|
||||||
if self._stoped or self._activation_mem_inited:
|
if self._stoped or self._activation_base_mems.inited[chunk_id]:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook.
|
# Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook.
|
||||||
self._activation_base_mem.add(layer_name, output.element_size() * output.nelement(), flush=False)
|
self._activation_base_mems.states[chunk_id].add(
|
||||||
|
layer_name, output.element_size() * output.nelement(), flush=False
|
||||||
|
)
|
||||||
|
|
||||||
def _activation_trace_hook_forward(self, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
||||||
"""
|
"""
|
||||||
Hook function to trace the activation memory usage for a forward pass.
|
Hook function to trace the activation memory usage for a forward pass.
|
||||||
|
|
||||||
|
@ -398,23 +443,24 @@ class SimpleMemoryProfiler:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if the activation memory has been initialized
|
# Check if the activation memory has been initialized
|
||||||
if self._activation_mem_inited is False:
|
if self._activation_base_mems.inited[chunk_id] is False:
|
||||||
|
self._activation_base_mems.inited[chunk_id] = True
|
||||||
# Update the total memory of the activation base memory state
|
# Update the total memory of the activation base memory state
|
||||||
self._activation_base_mem.update_total_memory()
|
self._activation_base_mems.states[chunk_id].update_total_memory()
|
||||||
# Set with_options to "activation_base" to include activation_base_layout in the memory dump
|
# Set with_options to "activation_base" to include activation_base_layout in the memory dump
|
||||||
self._activation_mem_inited = True
|
with_options = "activation_base"
|
||||||
|
else:
|
||||||
|
with_options = ""
|
||||||
|
|
||||||
# Accumulate activation memory usage for each forward pass
|
# Accumulate activation memory usage for each forward pass
|
||||||
self._activation_mem += self._activation_base_mem.total_mem
|
self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem
|
||||||
|
if self._activation_mem > self._activation_mem_max:
|
||||||
# Update activation max count
|
self._activation_mem_max = self._activation_mem
|
||||||
if self._activation_mem // self._activation_base_mem.total_mem > self._activation_max_count:
|
|
||||||
self._activation_max_count = self._activation_mem // self._activation_base_mem.total_mem
|
|
||||||
|
|
||||||
# Trigger a memory record
|
# Trigger a memory record
|
||||||
self.point()
|
self.point(with_options)
|
||||||
|
|
||||||
def _activation_tarce_hook_backward(self, model: Any, inputs: Any, grad_outputs: Any) -> None:
|
def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Hook function to trace the activation memory usage for a backward pass.
|
Hook function to trace the activation memory usage for a backward pass.
|
||||||
|
|
||||||
|
@ -432,37 +478,28 @@ class SimpleMemoryProfiler:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Release activation memory usage for each backward pass
|
# Release activation memory usage for each backward pass
|
||||||
self._activation_mem -= self._activation_base_mem.total_mem
|
self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem
|
||||||
|
|
||||||
# Trigger a memory record
|
# Trigger a memory record
|
||||||
self.point()
|
self.point()
|
||||||
|
|
||||||
def _register_activation_trace_hooks(self) -> None:
|
def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None:
|
||||||
"""
|
"""
|
||||||
Register activation trace hooks for the model and each submodule in the model.
|
Register activation trace hooks for the model and each submodule in the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Register inner activation trace hooks for each submodule in the model
|
# Register inner activation trace hooks for each submodule in the model
|
||||||
for layer_name in self._activation_config:
|
for layer_name, sub_model in model_chunk.named_modules():
|
||||||
# Register a hook for every activation
|
|
||||||
model = self._model
|
|
||||||
sub_models = layer_name.split(".")
|
|
||||||
# Get the target sub-model
|
|
||||||
for sub_model_name in sub_models:
|
|
||||||
try:
|
|
||||||
model = model.get_submodule(sub_model_name)
|
|
||||||
except AttributeError:
|
|
||||||
model = None
|
|
||||||
break
|
|
||||||
|
|
||||||
# Register the hook
|
# Register the hook
|
||||||
if model is not None:
|
if len(sub_model._modules) != 0:
|
||||||
model.register_forward_hook(partial(self._inner_activation_trace_hook, layer_name))
|
continue # TODO: in some special cases, we may need some additional configuration to correct
|
||||||
|
|
||||||
|
sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name))
|
||||||
|
|
||||||
# Register a forward hook for the main model to track activation memory usage
|
# Register a forward hook for the main model to track activation memory usage
|
||||||
self._model.register_forward_hook(self._activation_trace_hook_forward)
|
model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id))
|
||||||
# Register a backward hook for the main model to release activation memory usage
|
# Register a backward hook for the main model to release activation memory usage
|
||||||
self._model.register_full_backward_hook(self._activation_tarce_hook_backward)
|
model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id))
|
||||||
|
|
||||||
def _calc_tensor_memory(
|
def _calc_tensor_memory(
|
||||||
self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False
|
self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False
|
||||||
|
@ -554,48 +591,6 @@ class SimpleMemoryProfiler:
|
||||||
self._calc_tensor_memory(root_stat, named_tensors)
|
self._calc_tensor_memory(root_stat, named_tensors)
|
||||||
|
|
||||||
|
|
||||||
def build_activation_config(num_layers: int, num_chunks: int = 1) -> List[str]:
|
|
||||||
# TODO: support interleaved pipeline scheduling.
|
|
||||||
assert num_chunks == 1, "Only support num_chunks == 1"
|
|
||||||
|
|
||||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
|
||||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
||||||
else:
|
|
||||||
pipeline_size = 1
|
|
||||||
pipeline_rank = 0
|
|
||||||
|
|
||||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
|
||||||
parts = all_parts[pipeline_rank]
|
|
||||||
start, end = parts[0]
|
|
||||||
num_blocks = end - start
|
|
||||||
|
|
||||||
block_conf_tmpl = [
|
|
||||||
"mixer.rotary_emb",
|
|
||||||
"mixer.Wqkv",
|
|
||||||
"mixer.inner_attn",
|
|
||||||
"mixer.inner_cross_attn",
|
|
||||||
"mixer.out_proj",
|
|
||||||
# "dropout1", # skip when dropout_selective_checkpoint is True
|
|
||||||
# "dropout2", # skip when dropout_selective_checkpoint is True
|
|
||||||
"norm1",
|
|
||||||
"norm2",
|
|
||||||
"mlp.w1",
|
|
||||||
"mlp.w2",
|
|
||||||
"mlp.w3",
|
|
||||||
]
|
|
||||||
|
|
||||||
block_conf = []
|
|
||||||
for block_id in range(num_blocks):
|
|
||||||
block_conf += [f"blocks.{block_id}.{layer}" for layer in block_conf_tmpl]
|
|
||||||
|
|
||||||
# We don't need to care about whether the embedding, norm, and head layers exist in the model after partitioning.
|
|
||||||
# If they don't exist, they will be automatically ignored when registering activation trace hooks.
|
|
||||||
activation_conf = ["embedding", "norm", "head"] + block_conf
|
|
||||||
|
|
||||||
return activation_conf
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
class SimpleModel(torch.nn.Module):
|
class SimpleModel(torch.nn.Module):
|
||||||
|
@ -635,32 +630,39 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor:
|
||||||
|
if _num_chunks > 1:
|
||||||
|
_output = _input
|
||||||
|
for _model_chunk in _model_chunks:
|
||||||
|
_output = _model_chunk(_output)
|
||||||
|
else:
|
||||||
|
_output = _model_chunks(_input)
|
||||||
|
|
||||||
|
return _output
|
||||||
|
|
||||||
|
# num_chunks config
|
||||||
|
_num_chunks = 1
|
||||||
|
|
||||||
# init model and optimizer
|
# init model and optimizer
|
||||||
_model: torch.nn.Module = SimpleModel()
|
if _num_chunks > 1:
|
||||||
|
_chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)]
|
||||||
|
_model = torch.nn.ModuleList(_chunks).cuda()
|
||||||
|
else:
|
||||||
|
_model: torch.nn.Module = SimpleModel().cuda()
|
||||||
_optimizer = torch.optim.Adam(_model.parameters())
|
_optimizer = torch.optim.Adam(_model.parameters())
|
||||||
|
|
||||||
# create activation config for simple model layer by layer.
|
|
||||||
activation_configs = [
|
|
||||||
# model level 0
|
|
||||||
"layer1",
|
|
||||||
"layer2",
|
|
||||||
"layer3",
|
|
||||||
# model level 1
|
|
||||||
"layer2.layer1",
|
|
||||||
"layer2.layer3",
|
|
||||||
]
|
|
||||||
|
|
||||||
_model.modules()
|
|
||||||
|
|
||||||
# init profiler
|
# init profiler
|
||||||
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler.log", activation_configs)
|
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler", total_steps=1)
|
||||||
|
|
||||||
_optimizer.zero_grad()
|
_optimizer.zero_grad()
|
||||||
|
|
||||||
x1 = torch.randn((128, 5120))
|
# inputs
|
||||||
x2 = torch.randn((128, 5120))
|
x1 = torch.randn((128, 5120)).cuda()
|
||||||
out1 = _model(x1)
|
x2 = torch.randn((128, 5120)).cuda()
|
||||||
out2 = _model(x2)
|
# forward
|
||||||
|
out1 = _simple_schedule(_num_chunks, _model, x1)
|
||||||
|
out2 = _simple_schedule(_num_chunks, _model, x2)
|
||||||
|
# backward
|
||||||
out1.mean().backward()
|
out1.mean().backward()
|
||||||
out2.mean().backward()
|
out2.mean().backward()
|
||||||
|
|
||||||
|
|
8
train.py
8
train.py
|
@ -55,10 +55,7 @@ from internlm.utils.parallel import (
|
||||||
sync_model_param_within_tp,
|
sync_model_param_within_tp,
|
||||||
)
|
)
|
||||||
from internlm.utils.registry import MODEL_INITIALIZER
|
from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
from internlm.utils.simple_memory_profiler import (
|
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
|
||||||
SimpleMemoryProfiler,
|
|
||||||
build_activation_config,
|
|
||||||
)
|
|
||||||
from internlm.utils.writer import Writer
|
from internlm.utils.writer import Writer
|
||||||
|
|
||||||
# global llm logger
|
# global llm logger
|
||||||
|
@ -556,12 +553,11 @@ def main(args):
|
||||||
# initialize simple memory profiler
|
# initialize simple memory profiler
|
||||||
if args.profiling:
|
if args.profiling:
|
||||||
memory_profiler = SimpleMemoryProfiler(
|
memory_profiler = SimpleMemoryProfiler(
|
||||||
model.model,
|
model,
|
||||||
optimizer.optim,
|
optimizer.optim,
|
||||||
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
|
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
|
||||||
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||||
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
|
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
|
||||||
activation_config=build_activation_config(gpc.config.model.num_layers),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
memory_profiler = None
|
memory_profiler = None
|
||||||
|
|
Loading…
Reference in New Issue