feat(model/overlap_handler.py): move handler to gpc

pull/456/head
huangting4201 2023-10-23 12:02:32 +08:00
parent 85ad917ae4
commit b20f47a1fe
6 changed files with 23 additions and 41 deletions

View File

@ -352,16 +352,13 @@ class MegatronFeedForward(BaseFeedForward):
class FSTPLinear(ColumnParallelLinear): class FSTPLinear(ColumnParallelLinear):
def forward(self, x): def forward(self, x):
block_index = gpc.config.fstp_handler.module_to_index[self]
return fstp_fused_dense_func( return fstp_fused_dense_func(
x, x,
self.weight, self.weight,
self.bias, self.bias,
process_group=self.process_group, process_group=self.process_group,
module=self, module=self,
handler=gpc.config.fstp_handler, handler=gpc.fstp_handler,
block_index=block_index,
module_name=self._fstp_name,
) )

View File

@ -116,8 +116,9 @@ class FSTPOverlapHandler:
self.all_gather_memory_pool.append(weight) # containing two groups of block weight self.all_gather_memory_pool.append(weight) # containing two groups of block weight
def get_all_gather_memory(self, index, module_name): def get_all_gather_memory(self, module):
return self.all_gather_memory_pool[index % 2][module_name] block_index = self.module_to_index[module]
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
def get_reduce_scatter_memory(self, key): def get_reduce_scatter_memory(self, key):
return_idx = 0 return_idx = 0
@ -163,8 +164,7 @@ class FSTPOverlapHandler:
module.weight, module.weight,
self.process_group, self.process_group,
async_op=True, async_op=True,
block_index=block_index, module=module,
module_name=getattr(module, "_fstp_name"),
) )
self.fstp_global_handle[module] = weight_handle self.fstp_global_handle[module] = weight_handle
@ -192,13 +192,11 @@ class FSTPOverlapHandler:
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
first_backward_module = self.fstp_modules[-1] first_backward_module = self.fstp_modules[-1]
block_index = self.module_to_index[first_backward_module]
weight_handle = all_gather_raw_memory_pool( weight_handle = all_gather_raw_memory_pool(
first_backward_module.weight, first_backward_module.weight,
self.process_group, self.process_group,
async_op=True, async_op=True,
block_index=block_index, module=first_backward_module,
module_name=getattr(first_backward_module, "_fstp_name"),
) )
self.fstp_global_handle[first_backward_module] = weight_handle self.fstp_global_handle[first_backward_module] = weight_handle
@ -211,13 +209,11 @@ class FSTPOverlapHandler:
module_index = self.fstp_modules.index(module) module_index = self.fstp_modules.index(module)
if module_index - 1 >= 0: if module_index - 1 >= 0:
next_module = self.fstp_modules[module_index - 1] next_module = self.fstp_modules[module_index - 1]
block_index = self.module_to_index[next_module]
weight_handle = all_gather_raw_memory_pool( weight_handle = all_gather_raw_memory_pool(
next_module.weight, next_module.weight,
self.process_group, self.process_group,
async_op=True, async_op=True,
block_index=block_index, module=next_module,
module_name=getattr(next_module, "_fstp_name"),
) )
self.fstp_global_handle[next_module] = weight_handle self.fstp_global_handle[next_module] = weight_handle

View File

@ -7,13 +7,12 @@ import fused_dense_lib as fused_dense_cuda
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flash_attn.utils.distributed import all_reduce_raw from flash_attn.utils.distributed import all_reduce_raw
from torch import Tensor from torch import Tensor, nn
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
@ -131,11 +130,10 @@ def all_gather_raw_memory_pool(
process_group: ProcessGroup, process_group: ProcessGroup,
async_op: bool = False, async_op: bool = False,
gather_dim: int = 0, gather_dim: int = 0,
block_index: int = None, module: nn.Module = None,
module_name: str = None,
): ):
handle = torch.distributed.all_gather_into_tensor( handle = torch.distributed.all_gather_into_tensor(
gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name), gpc.fstp_handler.get_all_gather_memory(module=module),
input_.contiguous(), input_.contiguous(),
group=process_group, group=process_group,
async_op=async_op, async_op=async_op,
@ -166,8 +164,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0 assert input_.shape[0] % world_size == 0
size = (input_.shape[0] // world_size, *input_.shape[1:]) size = (input_.shape[0] // world_size, *input_.shape[1:])
index = gpc.config.fstp_handler.get_reduce_scatter_memory(size) index = gpc.fstp_handler.get_reduce_scatter_memory(size)
output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
setattr(output, "index", index) setattr(output, "index", index)
handle = torch.distributed.reduce_scatter_tensor( handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op output, input_.contiguous(), group=process_group, async_op=async_op
@ -469,16 +467,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
process_group=None, process_group=None,
module=None, module=None,
overlap_handler=None, overlap_handler=None,
block_index=None,
module_name=None,
): ):
ctx.compute_weight_gradient = weight.requires_grad ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual ctx.return_residual = return_residual
ctx.process_group = process_group ctx.process_group = process_group
ctx.overlap_handler = overlap_handler ctx.overlap_handler = overlap_handler
ctx.module = module ctx.module = module
ctx.block_index = block_index
ctx.module_name = module_name
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.to(dtype=torch.get_autocast_gpu_dtype())
@ -488,7 +482,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
if world_size > 1: if world_size > 1:
# do all_gather for weight and bias before actual computation # do all_gather for weight and bias before actual computation
if overlap_handler is not None: if overlap_handler is not None:
total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
else: else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait() handle_weight.wait()
@ -531,8 +525,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
process_group = ctx.process_group process_group = ctx.process_group
overlap_handler = ctx.overlap_handler overlap_handler = ctx.overlap_handler
block_index = ctx.block_index module = ctx.module
module_name = ctx.module_name
if ctx.compute_weight_gradient: if ctx.compute_weight_gradient:
x, weight, bias = ctx.saved_tensors x, weight, bias = ctx.saved_tensors
@ -547,7 +540,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
world_size = gpc.get_world_size(ParallelMode.TENSOR) world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1: if world_size > 1:
if overlap_handler is not None: if overlap_handler is not None:
total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name) total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
else: else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait() handle_weight.wait()
@ -669,16 +662,12 @@ def fstp_fused_dense_func(
process_group=None, process_group=None,
module=None, module=None,
handler=None, handler=None,
block_index=None,
module_name=None,
): ):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled() x.dtype == torch.float32 and torch.is_autocast_enabled()
) )
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FSTPFusedDenseFunc.apply( return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
x, weight, bias, return_residual, process_group, module, handler, block_index, module_name
)
else: else:
assert process_group is None assert process_group is None
out = F.linear(x, weight, bias) out = F.linear(x, weight, bias)

View File

@ -68,7 +68,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._fstp_handler = None self._fstp_handler = None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.config.fstp_handler self._fstp_handler = gpc.fstp_handler
# Zero related args # Zero related args
reduce_bucket_size = zero_cfg.reduce_bucket_size reduce_bucket_size = zero_cfg.reduce_bucket_size
@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
_param.grad.add_(_grad) _param.grad.add_(_grad)
# release cuda memory. # release cuda memory.
gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None self._fstp_handler.reduce_scatter_handlers[_key] = None
bucket.reset_by_rank(reduce_rank) bucket.reset_by_rank(reduce_rank)

View File

@ -108,9 +108,9 @@ def initialize_model():
# if fsdp enabled, wrap the model # if fsdp enabled, wrap the model
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
gpc.config.fstp_handler = None gpc.fstp_handler = None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR)) gpc.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))
return model return model

View File

@ -297,9 +297,9 @@ def main(args):
prof.step() prof.step()
if gpc.config.fstp_handler is not None: if gpc.fstp_handler is not None:
gpc.config.fstp_handler.zero_const_pool = {} gpc.fstp_handler.zero_const_pool = {}
gpc.config.fstp_handler.reduce_scatter_memory_pool = {} gpc.fstp_handler.reduce_scatter_memory_pool = {}
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()