mirror of https://github.com/InternLM/InternLM
fix moe compatibility for fsdp and memory profiling
parent
b3645b0244
commit
74d6c71ad9
|
@ -349,7 +349,7 @@ def args_sanity_check():
|
||||||
assert (
|
assert (
|
||||||
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||||
), "not support overlap and moe at the same time"
|
), "not support overlap and moe at the same time"
|
||||||
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this"
|
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=-1 can fix this"
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
|
|
|
@ -424,7 +424,9 @@ class SimpleMemoryProfiler:
|
||||||
layer_name, output.element_size() * output.nelement(), flush=False
|
layer_name, output.element_size() * output.nelement(), flush=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
def _activation_trace_hook_forward(
|
||||||
|
self, chunk_id: int, model: Any, inputs: Any, output: Any # pylint: disable=W0613
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Hook function to trace the activation memory usage for a forward pass.
|
Hook function to trace the activation memory usage for a forward pass.
|
||||||
|
|
||||||
|
@ -437,7 +439,6 @@ class SimpleMemoryProfiler:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
del model, inputs
|
del model, inputs
|
||||||
assert isinstance(output, torch.Tensor), f"invalid output type: {type(output)}"
|
|
||||||
|
|
||||||
if self._stoped:
|
if self._stoped:
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue