From 74d6c71ad95f01bb65ac9f069d884a111b9017ce Mon Sep 17 00:00:00 2001 From: Qu Wenwen Date: Tue, 17 Oct 2023 11:26:29 +0800 Subject: [PATCH] fix moe compatibility for fsdp and memory profiling --- internlm/initialize/launch.py | 2 +- internlm/utils/simple_memory_profiler.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index fead575..2087ae4 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -349,7 +349,7 @@ def args_sanity_check(): assert ( not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param ), "not support overlap and moe at the same time" - assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this" + assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=-1 can fix this" def launch( diff --git a/internlm/utils/simple_memory_profiler.py b/internlm/utils/simple_memory_profiler.py index 9caf0a2..8a688ed 100644 --- a/internlm/utils/simple_memory_profiler.py +++ b/internlm/utils/simple_memory_profiler.py @@ -424,7 +424,9 @@ class SimpleMemoryProfiler: layer_name, output.element_size() * output.nelement(), flush=False ) - def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None: + def _activation_trace_hook_forward( + self, chunk_id: int, model: Any, inputs: Any, output: Any # pylint: disable=W0613 + ) -> None: """ Hook function to trace the activation memory usage for a forward pass. @@ -437,7 +439,6 @@ class SimpleMemoryProfiler: None """ del model, inputs - assert isinstance(output, torch.Tensor), f"invalid output type: {type(output)}" if self._stoped: return