mirror of https://github.com/InternLM/InternLM
support bias
parent
e7f9f1d208
commit
f6a5086fe4
|
@ -10,7 +10,7 @@ from internlm.core.context import global_context as gpc
|
|||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
||||
from internlm.model.utils import all_gather_raw_memory_pool
|
||||
from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
|
||||
|
@ -25,6 +25,7 @@ class FSTPOverlapHandler:
|
|||
self.fstp_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
|
||||
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
|
||||
self.module_to_index = dict() # key: fstp module; value: transformer block index
|
||||
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||
self.head = []
|
||||
|
@ -76,49 +77,61 @@ class FSTPOverlapHandler:
|
|||
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
||||
|
||||
return self.zero_const_pool[size]
|
||||
|
||||
def _initialize_memory_pool(self) -> None:
|
||||
# allocate memory pool
|
||||
|
||||
def _initialize_module_shape(self):
|
||||
hidden_size = gpc.config.HIDDEN_SIZE
|
||||
mlp_ratio = gpc.config.MLP_RATIO
|
||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
||||
|
||||
self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size)
|
||||
self.module_shape["out_proj"] = (hidden_size, hidden_size)
|
||||
self.module_shape["w1"] = (mlp_hidden_size, hidden_size)
|
||||
self.module_shape["w2"] = (mlp_hidden_size, hidden_size)
|
||||
self.module_shape["w3"] = (hidden_size, mlp_hidden_size)
|
||||
|
||||
def _initialize_memory_pool(self) -> None:
|
||||
# allocate memory pool
|
||||
self.all_gather_memory_pool = []
|
||||
self.all_gather_bias_memory_pool = []
|
||||
self.reduce_scatter_memory_pool = {}
|
||||
self.module_shape = {}
|
||||
|
||||
self._initialize_module_shape()
|
||||
dtype = gpc.config.model.get("dtype", torch.half)
|
||||
device = get_current_device()
|
||||
|
||||
for _ in range(2):
|
||||
weight = {}
|
||||
for name in self.module_name:
|
||||
if name == "Wqkv":
|
||||
weight[name] = torch.zeros(
|
||||
(3 * hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
elif name == "out_proj":
|
||||
weight[name] = torch.zeros(
|
||||
(hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
elif name == "w1" or name == "w2":
|
||||
weight[name] = torch.zeros(
|
||||
(mlp_hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
else:
|
||||
weight[name] = torch.zeros(
|
||||
(hidden_size, mlp_hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
|
||||
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
|
||||
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
||||
|
||||
def get_all_gather_memory(self, module):
|
||||
block_index = self.module_to_index[module]
|
||||
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
def get_bias_memory(self, module: nn.Module):
|
||||
block_index = self.module_to_index[module]
|
||||
# if the bias memory pool is empty or module has been not allocated memory
|
||||
# import pdb; pdb.set_trace()
|
||||
if len(self.all_gather_bias_memory_pool) == 0:
|
||||
for _ in range(2):
|
||||
weight = {}
|
||||
weight[module._fstp_name] = torch.zeros(
|
||||
self.module_shape[module._fstp_name][0],
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous()
|
||||
self.all_gather_bias_memory_pool.append(weight)
|
||||
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
|
||||
for i in range(2):
|
||||
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
|
||||
self.module_shape[module._fstp_name][0],
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous()
|
||||
|
||||
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
|
||||
def get_reduce_scatter_memory(self, key):
|
||||
return_idx = 0
|
||||
|
@ -157,10 +170,19 @@ class FSTPOverlapHandler:
|
|||
|
||||
def release_reduce_scatter_memory(self, key, index):
|
||||
self.reduce_scatter_memory_pool[key][index].idle = True
|
||||
|
||||
|
||||
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
||||
fstp_modules = self.index_to_fstp_modules[block_index]
|
||||
for module in fstp_modules:
|
||||
if module.bias is not None:
|
||||
bias_handle = all_gather_raw_bias_memory_pool(
|
||||
module.bias,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.bias_global_handle[module] = bias_handle
|
||||
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
|
@ -186,6 +208,9 @@ class FSTPOverlapHandler:
|
|||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||
handle = self.fstp_global_handle[module]
|
||||
handle.wait()
|
||||
if module.bias is not None:
|
||||
bias_handle = self.bias_global_handle[module]
|
||||
bias_handle.wait()
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any):
|
||||
if module in self.fstp_global_handle:
|
||||
|
|
|
@ -140,6 +140,21 @@ def all_gather_raw_memory_pool(
|
|||
)
|
||||
return handle
|
||||
|
||||
def all_gather_raw_bias_memory_pool(
|
||||
input_: Tensor,
|
||||
process_group: ProcessGroup,
|
||||
async_op: bool = False,
|
||||
gather_dim: int = 0,
|
||||
module: nn.Module = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
gpc.fstp_handler.get_bias_memory(module=module),
|
||||
input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op,
|
||||
)
|
||||
return handle
|
||||
|
||||
|
||||
def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
|
||||
assert my_input.dtype == grad_output.dtype
|
||||
|
@ -486,8 +501,11 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
handle_weight.wait()
|
||||
# TODO memory pool for bias
|
||||
if bias is not None:
|
||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
handle_bias.wait()
|
||||
if overlap_handler is not None:
|
||||
total_bias = gpc.fstp_handler.get_bias_memory(module=module)
|
||||
else:
|
||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
handle_bias.wait()
|
||||
else:
|
||||
total_bias = bias
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue