mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): move handler to gpc
parent
85ad917ae4
commit
b20f47a1fe
|
@ -352,16 +352,13 @@ class MegatronFeedForward(BaseFeedForward):
|
|||
|
||||
class FSTPLinear(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
block_index = gpc.config.fstp_handler.module_to_index[self]
|
||||
return fstp_fused_dense_func(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
module=self,
|
||||
handler=gpc.config.fstp_handler,
|
||||
block_index=block_index,
|
||||
module_name=self._fstp_name,
|
||||
handler=gpc.fstp_handler,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -116,8 +116,9 @@ class FSTPOverlapHandler:
|
|||
|
||||
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
||||
|
||||
def get_all_gather_memory(self, index, module_name):
|
||||
return self.all_gather_memory_pool[index % 2][module_name]
|
||||
def get_all_gather_memory(self, module):
|
||||
block_index = self.module_to_index[module]
|
||||
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
def get_reduce_scatter_memory(self, key):
|
||||
return_idx = 0
|
||||
|
@ -163,8 +164,7 @@ class FSTPOverlapHandler:
|
|||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
block_index=block_index,
|
||||
module_name=getattr(module, "_fstp_name"),
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
|
||||
|
@ -192,13 +192,11 @@ class FSTPOverlapHandler:
|
|||
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||
first_backward_module = self.fstp_modules[-1]
|
||||
block_index = self.module_to_index[first_backward_module]
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
first_backward_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
block_index=block_index,
|
||||
module_name=getattr(first_backward_module, "_fstp_name"),
|
||||
module=first_backward_module,
|
||||
)
|
||||
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||
|
||||
|
@ -211,13 +209,11 @@ class FSTPOverlapHandler:
|
|||
module_index = self.fstp_modules.index(module)
|
||||
if module_index - 1 >= 0:
|
||||
next_module = self.fstp_modules[module_index - 1]
|
||||
block_index = self.module_to_index[next_module]
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
next_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
block_index=block_index,
|
||||
module_name=getattr(next_module, "_fstp_name"),
|
||||
module=next_module,
|
||||
)
|
||||
self.fstp_global_handle[next_module] = weight_handle
|
||||
|
||||
|
|
|
@ -7,13 +7,12 @@ import fused_dense_lib as fused_dense_cuda
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.utils.distributed import all_reduce_raw
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
@ -131,11 +130,10 @@ def all_gather_raw_memory_pool(
|
|||
process_group: ProcessGroup,
|
||||
async_op: bool = False,
|
||||
gather_dim: int = 0,
|
||||
block_index: int = None,
|
||||
module_name: str = None,
|
||||
module: nn.Module = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name),
|
||||
gpc.fstp_handler.get_all_gather_memory(module=module),
|
||||
input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op,
|
||||
|
@ -166,8 +164,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
|
|||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||
index = gpc.config.fstp_handler.get_reduce_scatter_memory(size)
|
||||
output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
|
||||
index = gpc.fstp_handler.get_reduce_scatter_memory(size)
|
||||
output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
|
||||
setattr(output, "index", index)
|
||||
handle = torch.distributed.reduce_scatter_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
|
@ -469,16 +467,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
process_group=None,
|
||||
module=None,
|
||||
overlap_handler=None,
|
||||
block_index=None,
|
||||
module_name=None,
|
||||
):
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.overlap_handler = overlap_handler
|
||||
ctx.module = module
|
||||
ctx.block_index = block_index
|
||||
ctx.module_name = module_name
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
|
@ -488,7 +482,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
if world_size > 1:
|
||||
# do all_gather for weight and bias before actual computation
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name)
|
||||
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
|
@ -531,8 +525,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
overlap_handler = ctx.overlap_handler
|
||||
block_index = ctx.block_index
|
||||
module_name = ctx.module_name
|
||||
module = ctx.module
|
||||
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight, bias = ctx.saved_tensors
|
||||
|
@ -547,7 +540,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name)
|
||||
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
|
@ -669,16 +662,12 @@ def fstp_fused_dense_func(
|
|||
process_group=None,
|
||||
module=None,
|
||||
handler=None,
|
||||
block_index=None,
|
||||
module_name=None,
|
||||
):
|
||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
)
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FSTPFusedDenseFunc.apply(
|
||||
x, weight, bias, return_residual, process_group, module, handler, block_index, module_name
|
||||
)
|
||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
|
||||
else:
|
||||
assert process_group is None
|
||||
out = F.linear(x, weight, bias)
|
||||
|
|
|
@ -68,7 +68,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
self._fstp_handler = None
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
self._fstp_handler = gpc.config.fstp_handler
|
||||
self._fstp_handler = gpc.fstp_handler
|
||||
|
||||
# Zero related args
|
||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||
|
@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
_param.grad.add_(_grad)
|
||||
|
||||
# release cuda memory.
|
||||
gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
|
||||
gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
|
||||
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
||||
|
||||
bucket.reset_by_rank(reduce_rank)
|
||||
|
|
|
@ -108,9 +108,9 @@ def initialize_model():
|
|||
# if fsdp enabled, wrap the model
|
||||
model = wrap_FSDP_model(model)
|
||||
|
||||
gpc.config.fstp_handler = None
|
||||
gpc.fstp_handler = None
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
gpc.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
|
||||
return model
|
||||
|
||||
|
|
6
train.py
6
train.py
|
@ -297,9 +297,9 @@ def main(args):
|
|||
|
||||
prof.step()
|
||||
|
||||
if gpc.config.fstp_handler is not None:
|
||||
gpc.config.fstp_handler.zero_const_pool = {}
|
||||
gpc.config.fstp_handler.reduce_scatter_memory_pool = {}
|
||||
if gpc.fstp_handler is not None:
|
||||
gpc.fstp_handler.zero_const_pool = {}
|
||||
gpc.fstp_handler.reduce_scatter_memory_pool = {}
|
||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
|
Loading…
Reference in New Issue