2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
|
|
|
|
class BaseGradientHandler(ABC):
|
2023-04-05 15:24:43 +00:00
|
|
|
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
2021-10-28 16:21:23 +00:00
|
|
|
before optimization.
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
model (Module): Model where the gradients accumulate.
|
|
|
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
2022-03-31 13:37:17 +00:00
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
def __init__(self, model, optimizer):
|
|
|
|
self._model = model
|
|
|
|
self._optimizer = optimizer
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def handle_gradient(self):
|
|
|
|
"""A method to accumulate gradients across different parallel groups. Users should
|
|
|
|
write their own functions or just use the functions in pre-defined subclasses.
|
|
|
|
"""
|
|
|
|
pass
|