mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): fix lint error
parent
f6a5086fe4
commit
0d693cf3a1
|
@ -53,7 +53,6 @@ class MoE(torch.nn.Module):
|
|||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
|
|
|
@ -10,7 +10,10 @@ from internlm.core.context import global_context as gpc
|
|||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
||||
from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool
|
||||
from internlm.model.utils import (
|
||||
all_gather_raw_bias_memory_pool,
|
||||
all_gather_raw_memory_pool,
|
||||
)
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
|
||||
|
@ -107,6 +110,10 @@ class FSTPOverlapHandler:
|
|||
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
|
||||
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
||||
|
||||
def clear_memory_pool(self) -> None:
|
||||
self.zero_const_pool = {}
|
||||
self.reduce_scatter_memory_pool = {}
|
||||
|
||||
def get_all_gather_memory(self, module):
|
||||
block_index = self.module_to_index[module]
|
||||
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
@ -121,18 +128,19 @@ class FSTPOverlapHandler:
|
|||
weight[module._fstp_name] = torch.zeros(
|
||||
self.module_shape[module._fstp_name][0],
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous()
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
self.all_gather_bias_memory_pool.append(weight)
|
||||
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
|
||||
for i in range(2):
|
||||
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
|
||||
self.module_shape[module._fstp_name][0],
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous()
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
|
||||
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
|
||||
def get_reduce_scatter_memory(self, key):
|
||||
return_idx = 0
|
||||
|
||||
|
|
|
@ -140,6 +140,7 @@ def all_gather_raw_memory_pool(
|
|||
)
|
||||
return handle
|
||||
|
||||
|
||||
def all_gather_raw_bias_memory_pool(
|
||||
input_: Tensor,
|
||||
process_group: ProcessGroup,
|
||||
|
|
3
train.py
3
train.py
|
@ -298,8 +298,7 @@ def main(args):
|
|||
prof.step()
|
||||
|
||||
if gpc.fstp_handler is not None:
|
||||
gpc.fstp_handler.zero_const_pool = {}
|
||||
gpc.fstp_handler.reduce_scatter_memory_pool = {}
|
||||
gpc.fstp_handler.clear_memory_pool()
|
||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
|
Loading…
Reference in New Issue