mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
17 lines
618 B
17 lines
618 B
3 years ago
|
from colossalai.registry import GRADIENT_HANDLER
|
||
|
from ._base_gradient_handler import BaseGradientHandler
|
||
|
|
||
|
|
||
|
@GRADIENT_HANDLER.register_module
|
||
|
class ZeROGradientHandler(BaseGradientHandler):
|
||
|
"""A helper class to handle all-reduce operations in a data parallel group.
|
||
|
A all-reduce collective communication will be operated in
|
||
|
:func:`handle_gradient` among a data parallel group.
|
||
|
This class is specialized with ZeRO optimization.
|
||
|
"""
|
||
|
|
||
|
def handle_gradient(self):
|
||
|
"""A method running a all-reduce operation in a data parallel group.
|
||
|
"""
|
||
|
self._optimizer.allreduce_gradients()
|