mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): fix lint error
parent
f6a5086fe4
commit
0d693cf3a1
|
@ -53,7 +53,6 @@ class MoE(torch.nn.Module):
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -10,7 +10,10 @@ from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
||||||
from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool
|
from internlm.model.utils import (
|
||||||
|
all_gather_raw_bias_memory_pool,
|
||||||
|
all_gather_raw_memory_pool,
|
||||||
|
)
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +28,7 @@ class FSTPOverlapHandler:
|
||||||
self.fstp_modules = []
|
self.fstp_modules = []
|
||||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||||
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
|
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
|
||||||
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
|
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
|
||||||
self.module_to_index = dict() # key: fstp module; value: transformer block index
|
self.module_to_index = dict() # key: fstp module; value: transformer block index
|
||||||
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||||
self.head = []
|
self.head = []
|
||||||
|
@ -77,13 +80,13 @@ class FSTPOverlapHandler:
|
||||||
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
||||||
|
|
||||||
return self.zero_const_pool[size]
|
return self.zero_const_pool[size]
|
||||||
|
|
||||||
def _initialize_module_shape(self):
|
def _initialize_module_shape(self):
|
||||||
hidden_size = gpc.config.HIDDEN_SIZE
|
hidden_size = gpc.config.HIDDEN_SIZE
|
||||||
mlp_ratio = gpc.config.MLP_RATIO
|
mlp_ratio = gpc.config.MLP_RATIO
|
||||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||||
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
||||||
|
|
||||||
self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size)
|
self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size)
|
||||||
self.module_shape["out_proj"] = (hidden_size, hidden_size)
|
self.module_shape["out_proj"] = (hidden_size, hidden_size)
|
||||||
self.module_shape["w1"] = (mlp_hidden_size, hidden_size)
|
self.module_shape["w1"] = (mlp_hidden_size, hidden_size)
|
||||||
|
@ -96,7 +99,7 @@ class FSTPOverlapHandler:
|
||||||
self.all_gather_bias_memory_pool = []
|
self.all_gather_bias_memory_pool = []
|
||||||
self.reduce_scatter_memory_pool = {}
|
self.reduce_scatter_memory_pool = {}
|
||||||
self.module_shape = {}
|
self.module_shape = {}
|
||||||
|
|
||||||
self._initialize_module_shape()
|
self._initialize_module_shape()
|
||||||
dtype = gpc.config.model.get("dtype", torch.half)
|
dtype = gpc.config.model.get("dtype", torch.half)
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
@ -107,10 +110,14 @@ class FSTPOverlapHandler:
|
||||||
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
|
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
|
||||||
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
||||||
|
|
||||||
|
def clear_memory_pool(self) -> None:
|
||||||
|
self.zero_const_pool = {}
|
||||||
|
self.reduce_scatter_memory_pool = {}
|
||||||
|
|
||||||
def get_all_gather_memory(self, module):
|
def get_all_gather_memory(self, module):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
||||||
|
|
||||||
def get_bias_memory(self, module: nn.Module):
|
def get_bias_memory(self, module: nn.Module):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
# if the bias memory pool is empty or module has been not allocated memory
|
# if the bias memory pool is empty or module has been not allocated memory
|
||||||
|
@ -119,19 +126,20 @@ class FSTPOverlapHandler:
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
weight = {}
|
weight = {}
|
||||||
weight[module._fstp_name] = torch.zeros(
|
weight[module._fstp_name] = torch.zeros(
|
||||||
self.module_shape[module._fstp_name][0],
|
self.module_shape[module._fstp_name][0],
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device=get_current_device()).contiguous()
|
device=get_current_device(),
|
||||||
|
).contiguous()
|
||||||
self.all_gather_bias_memory_pool.append(weight)
|
self.all_gather_bias_memory_pool.append(weight)
|
||||||
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
|
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
|
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
|
||||||
self.module_shape[module._fstp_name][0],
|
self.module_shape[module._fstp_name][0],
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device=get_current_device()).contiguous()
|
device=get_current_device(),
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
|
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
|
||||||
|
|
||||||
|
|
||||||
def get_reduce_scatter_memory(self, key):
|
def get_reduce_scatter_memory(self, key):
|
||||||
return_idx = 0
|
return_idx = 0
|
||||||
|
@ -170,7 +178,7 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
def release_reduce_scatter_memory(self, key, index):
|
def release_reduce_scatter_memory(self, key, index):
|
||||||
self.reduce_scatter_memory_pool[key][index].idle = True
|
self.reduce_scatter_memory_pool[key][index].idle = True
|
||||||
|
|
||||||
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
||||||
fstp_modules = self.index_to_fstp_modules[block_index]
|
fstp_modules = self.index_to_fstp_modules[block_index]
|
||||||
for module in fstp_modules:
|
for module in fstp_modules:
|
||||||
|
@ -182,7 +190,7 @@ class FSTPOverlapHandler:
|
||||||
module=module,
|
module=module,
|
||||||
)
|
)
|
||||||
self.bias_global_handle[module] = bias_handle
|
self.bias_global_handle[module] = bias_handle
|
||||||
|
|
||||||
weight_handle = all_gather_raw_memory_pool(
|
weight_handle = all_gather_raw_memory_pool(
|
||||||
module.weight,
|
module.weight,
|
||||||
self.process_group,
|
self.process_group,
|
||||||
|
|
|
@ -140,6 +140,7 @@ def all_gather_raw_memory_pool(
|
||||||
)
|
)
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
|
||||||
def all_gather_raw_bias_memory_pool(
|
def all_gather_raw_bias_memory_pool(
|
||||||
input_: Tensor,
|
input_: Tensor,
|
||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
|
|
3
train.py
3
train.py
|
@ -298,8 +298,7 @@ def main(args):
|
||||||
prof.step()
|
prof.step()
|
||||||
|
|
||||||
if gpc.fstp_handler is not None:
|
if gpc.fstp_handler is not None:
|
||||||
gpc.fstp_handler.zero_const_pool = {}
|
gpc.fstp_handler.clear_memory_pool()
|
||||||
gpc.fstp_handler.reduce_scatter_memory_pool = {}
|
|
||||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue