ColossalAI/colossalai/engine/gradient_handler/_zero_gradient_handler.py

22 lines
751 B
Python
Raw Normal View History

2021-10-28 16:21:23 +00:00
from colossalai.registry import GRADIENT_HANDLER
2021-10-28 16:21:23 +00:00
from ._base_gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module
class ZeROGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
2022-03-09 07:17:01 +00:00
A all-reduce collective communication will be operated in
2021-10-28 16:21:23 +00:00
:func:`handle_gradient` among a data parallel group.
This class is specialized with ZeRO optimization.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
2021-10-28 16:21:23 +00:00
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
self._optimizer.sync_grad()