mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): move handler to gpc
parent
85ad917ae4
commit
b20f47a1fe
|
@ -352,16 +352,13 @@ class MegatronFeedForward(BaseFeedForward):
|
||||||
|
|
||||||
class FSTPLinear(ColumnParallelLinear):
|
class FSTPLinear(ColumnParallelLinear):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
block_index = gpc.config.fstp_handler.module_to_index[self]
|
|
||||||
return fstp_fused_dense_func(
|
return fstp_fused_dense_func(
|
||||||
x,
|
x,
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
process_group=self.process_group,
|
process_group=self.process_group,
|
||||||
module=self,
|
module=self,
|
||||||
handler=gpc.config.fstp_handler,
|
handler=gpc.fstp_handler,
|
||||||
block_index=block_index,
|
|
||||||
module_name=self._fstp_name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -116,8 +116,9 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
||||||
|
|
||||||
def get_all_gather_memory(self, index, module_name):
|
def get_all_gather_memory(self, module):
|
||||||
return self.all_gather_memory_pool[index % 2][module_name]
|
block_index = self.module_to_index[module]
|
||||||
|
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
||||||
|
|
||||||
def get_reduce_scatter_memory(self, key):
|
def get_reduce_scatter_memory(self, key):
|
||||||
return_idx = 0
|
return_idx = 0
|
||||||
|
@ -163,8 +164,7 @@ class FSTPOverlapHandler:
|
||||||
module.weight,
|
module.weight,
|
||||||
self.process_group,
|
self.process_group,
|
||||||
async_op=True,
|
async_op=True,
|
||||||
block_index=block_index,
|
module=module,
|
||||||
module_name=getattr(module, "_fstp_name"),
|
|
||||||
)
|
)
|
||||||
self.fstp_global_handle[module] = weight_handle
|
self.fstp_global_handle[module] = weight_handle
|
||||||
|
|
||||||
|
@ -192,13 +192,11 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||||
first_backward_module = self.fstp_modules[-1]
|
first_backward_module = self.fstp_modules[-1]
|
||||||
block_index = self.module_to_index[first_backward_module]
|
|
||||||
weight_handle = all_gather_raw_memory_pool(
|
weight_handle = all_gather_raw_memory_pool(
|
||||||
first_backward_module.weight,
|
first_backward_module.weight,
|
||||||
self.process_group,
|
self.process_group,
|
||||||
async_op=True,
|
async_op=True,
|
||||||
block_index=block_index,
|
module=first_backward_module,
|
||||||
module_name=getattr(first_backward_module, "_fstp_name"),
|
|
||||||
)
|
)
|
||||||
self.fstp_global_handle[first_backward_module] = weight_handle
|
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||||
|
|
||||||
|
@ -211,13 +209,11 @@ class FSTPOverlapHandler:
|
||||||
module_index = self.fstp_modules.index(module)
|
module_index = self.fstp_modules.index(module)
|
||||||
if module_index - 1 >= 0:
|
if module_index - 1 >= 0:
|
||||||
next_module = self.fstp_modules[module_index - 1]
|
next_module = self.fstp_modules[module_index - 1]
|
||||||
block_index = self.module_to_index[next_module]
|
|
||||||
weight_handle = all_gather_raw_memory_pool(
|
weight_handle = all_gather_raw_memory_pool(
|
||||||
next_module.weight,
|
next_module.weight,
|
||||||
self.process_group,
|
self.process_group,
|
||||||
async_op=True,
|
async_op=True,
|
||||||
block_index=block_index,
|
module=next_module,
|
||||||
module_name=getattr(next_module, "_fstp_name"),
|
|
||||||
)
|
)
|
||||||
self.fstp_global_handle[next_module] = weight_handle
|
self.fstp_global_handle[next_module] = weight_handle
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,12 @@ import fused_dense_lib as fused_dense_cuda
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from flash_attn.utils.distributed import all_reduce_raw
|
from flash_attn.utils.distributed import all_reduce_raw
|
||||||
from torch import Tensor
|
from torch import Tensor, nn
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.common import get_current_device
|
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
@ -131,11 +130,10 @@ def all_gather_raw_memory_pool(
|
||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
async_op: bool = False,
|
async_op: bool = False,
|
||||||
gather_dim: int = 0,
|
gather_dim: int = 0,
|
||||||
block_index: int = None,
|
module: nn.Module = None,
|
||||||
module_name: str = None,
|
|
||||||
):
|
):
|
||||||
handle = torch.distributed.all_gather_into_tensor(
|
handle = torch.distributed.all_gather_into_tensor(
|
||||||
gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name),
|
gpc.fstp_handler.get_all_gather_memory(module=module),
|
||||||
input_.contiguous(),
|
input_.contiguous(),
|
||||||
group=process_group,
|
group=process_group,
|
||||||
async_op=async_op,
|
async_op=async_op,
|
||||||
|
@ -166,8 +164,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
assert input_.shape[0] % world_size == 0
|
assert input_.shape[0] % world_size == 0
|
||||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||||
index = gpc.config.fstp_handler.get_reduce_scatter_memory(size)
|
index = gpc.fstp_handler.get_reduce_scatter_memory(size)
|
||||||
output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
|
output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
|
||||||
setattr(output, "index", index)
|
setattr(output, "index", index)
|
||||||
handle = torch.distributed.reduce_scatter_tensor(
|
handle = torch.distributed.reduce_scatter_tensor(
|
||||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||||
|
@ -469,16 +467,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
process_group=None,
|
process_group=None,
|
||||||
module=None,
|
module=None,
|
||||||
overlap_handler=None,
|
overlap_handler=None,
|
||||||
block_index=None,
|
|
||||||
module_name=None,
|
|
||||||
):
|
):
|
||||||
ctx.compute_weight_gradient = weight.requires_grad
|
ctx.compute_weight_gradient = weight.requires_grad
|
||||||
ctx.return_residual = return_residual
|
ctx.return_residual = return_residual
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.overlap_handler = overlap_handler
|
ctx.overlap_handler = overlap_handler
|
||||||
ctx.module = module
|
ctx.module = module
|
||||||
ctx.block_index = block_index
|
|
||||||
ctx.module_name = module_name
|
|
||||||
|
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
|
@ -488,7 +482,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
# do all_gather for weight and bias before actual computation
|
# do all_gather for weight and bias before actual computation
|
||||||
if overlap_handler is not None:
|
if overlap_handler is not None:
|
||||||
total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name)
|
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||||
else:
|
else:
|
||||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
handle_weight.wait()
|
handle_weight.wait()
|
||||||
|
@ -531,8 +525,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
grad_input = grad_input.contiguous()
|
grad_input = grad_input.contiguous()
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
overlap_handler = ctx.overlap_handler
|
overlap_handler = ctx.overlap_handler
|
||||||
block_index = ctx.block_index
|
module = ctx.module
|
||||||
module_name = ctx.module_name
|
|
||||||
|
|
||||||
if ctx.compute_weight_gradient:
|
if ctx.compute_weight_gradient:
|
||||||
x, weight, bias = ctx.saved_tensors
|
x, weight, bias = ctx.saved_tensors
|
||||||
|
@ -547,7 +540,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
if overlap_handler is not None:
|
if overlap_handler is not None:
|
||||||
total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name)
|
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||||
else:
|
else:
|
||||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
handle_weight.wait()
|
handle_weight.wait()
|
||||||
|
@ -669,16 +662,12 @@ def fstp_fused_dense_func(
|
||||||
process_group=None,
|
process_group=None,
|
||||||
module=None,
|
module=None,
|
||||||
handler=None,
|
handler=None,
|
||||||
block_index=None,
|
|
||||||
module_name=None,
|
|
||||||
):
|
):
|
||||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||||
)
|
)
|
||||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||||
return FSTPFusedDenseFunc.apply(
|
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
|
||||||
x, weight, bias, return_residual, process_group, module, handler, block_index, module_name
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert process_group is None
|
assert process_group is None
|
||||||
out = F.linear(x, weight, bias)
|
out = F.linear(x, weight, bias)
|
||||||
|
|
|
@ -68,7 +68,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
self._fstp_handler = None
|
self._fstp_handler = None
|
||||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||||
self._fstp_handler = gpc.config.fstp_handler
|
self._fstp_handler = gpc.fstp_handler
|
||||||
|
|
||||||
# Zero related args
|
# Zero related args
|
||||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||||
|
@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
_param.grad.add_(_grad)
|
_param.grad.add_(_grad)
|
||||||
|
|
||||||
# release cuda memory.
|
# release cuda memory.
|
||||||
gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
|
gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
|
||||||
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
||||||
|
|
||||||
bucket.reset_by_rank(reduce_rank)
|
bucket.reset_by_rank(reduce_rank)
|
||||||
|
|
|
@ -108,9 +108,9 @@ def initialize_model():
|
||||||
# if fsdp enabled, wrap the model
|
# if fsdp enabled, wrap the model
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
gpc.config.fstp_handler = None
|
gpc.fstp_handler = None
|
||||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||||
gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
gpc.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
6
train.py
6
train.py
|
@ -297,9 +297,9 @@ def main(args):
|
||||||
|
|
||||||
prof.step()
|
prof.step()
|
||||||
|
|
||||||
if gpc.config.fstp_handler is not None:
|
if gpc.fstp_handler is not None:
|
||||||
gpc.config.fstp_handler.zero_const_pool = {}
|
gpc.fstp_handler.zero_const_pool = {}
|
||||||
gpc.config.fstp_handler.reduce_scatter_memory_pool = {}
|
gpc.fstp_handler.reduce_scatter_memory_pool = {}
|
||||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue