feat(model/overlap_handler.py): fix lint error

pull/456/head
huangting4201 2023-10-23 15:22:03 +08:00
parent f6a5086fe4
commit 0d693cf3a1
4 changed files with 26 additions and 19 deletions

View File

@ -53,7 +53,6 @@ class MoE(torch.nn.Module):
device=None, device=None,
dtype=None, dtype=None,
): ):
super().__init__() super().__init__()
assert ( assert (

View File

@ -10,7 +10,10 @@ from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool from internlm.model.utils import (
all_gather_raw_bias_memory_pool,
all_gather_raw_memory_pool,
)
from internlm.utils.common import get_current_device from internlm.utils.common import get_current_device
@ -25,7 +28,7 @@ class FSTPOverlapHandler:
self.fstp_modules = [] self.fstp_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
self.module_to_index = dict() # key: fstp module; value: transformer block index self.module_to_index = dict() # key: fstp module; value: transformer block index
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
self.head = [] self.head = []
@ -107,6 +110,10 @@ class FSTPOverlapHandler:
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous() weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
self.all_gather_memory_pool.append(weight) # containing two groups of block weight self.all_gather_memory_pool.append(weight) # containing two groups of block weight
def clear_memory_pool(self) -> None:
self.zero_const_pool = {}
self.reduce_scatter_memory_pool = {}
def get_all_gather_memory(self, module): def get_all_gather_memory(self, module):
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
return self.all_gather_memory_pool[block_index % 2][module._fstp_name] return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
@ -119,20 +126,21 @@ class FSTPOverlapHandler:
for _ in range(2): for _ in range(2):
weight = {} weight = {}
weight[module._fstp_name] = torch.zeros( weight[module._fstp_name] = torch.zeros(
self.module_shape[module._fstp_name][0], self.module_shape[module._fstp_name][0],
dtype=gpc.config.model.get("dtype", torch.half), dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device()).contiguous() device=get_current_device(),
).contiguous()
self.all_gather_bias_memory_pool.append(weight) self.all_gather_bias_memory_pool.append(weight)
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]: elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
for i in range(2): for i in range(2):
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros( self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
self.module_shape[module._fstp_name][0], self.module_shape[module._fstp_name][0],
dtype=gpc.config.model.get("dtype", torch.half), dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device()).contiguous() device=get_current_device(),
).contiguous()
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name] return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
def get_reduce_scatter_memory(self, key): def get_reduce_scatter_memory(self, key):
return_idx = 0 return_idx = 0

View File

@ -140,6 +140,7 @@ def all_gather_raw_memory_pool(
) )
return handle return handle
def all_gather_raw_bias_memory_pool( def all_gather_raw_bias_memory_pool(
input_: Tensor, input_: Tensor,
process_group: ProcessGroup, process_group: ProcessGroup,

View File

@ -298,8 +298,7 @@ def main(args):
prof.step() prof.step()
if gpc.fstp_handler is not None: if gpc.fstp_handler is not None:
gpc.fstp_handler.zero_const_pool = {} gpc.fstp_handler.clear_memory_pool()
gpc.fstp_handler.reduce_scatter_memory_pool = {}
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()