InternLM/internlm/core/gradient_handler.py

101 lines
4.2 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from collections import defaultdict
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
before optimization.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer):
self._model = model
self._optimizer = optimizer
@abstractmethod
def handle_gradient(self):
"""A method to accumulate gradients across different parallel groups. Users should
write their own functions or just use the functions in pre-defined subclasses.
"""
pass
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in sub parallel groups.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among all sub pipeline parallel groups.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in sub pipeline parallel groups."""
if gpc.pipeline_parallel_size > 1:
# bucketize and all-reduce
buckets = defaultdict(lambda: defaultdict(list))
# Pack the buckets.
for param in self._model.parameters():
group = getattr(param, "pipeline_shared_module_pg", None)
if (
param.requires_grad
and group is not None
and (
(hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null())
or param.grad is not None
)
):
tp = param.data.type()
buckets[group][tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for group, group_buckets in buckets.items():
for tp, bucket in group_buckets.items():
grads = [
param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data
for param in bucket
]
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
class EmbeddingSharedModuleGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in embedding share groups.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among the first pipeline stage and the last pipeline stage.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in sub pipeline parallel groups."""
if gpc.is_pipeline_first_stage() or gpc.is_pipeline_last_stage():
weight = self._model.model.shared_embedding_weight()
grad = weight.grad
# enable zero will cause grad to be None
if grad is None:
grad = torch.zeros_like(weight)
torch.distributed.all_reduce(grad, group=gpc.get_group(parallel_mode=ParallelMode.EMBEDDING))