mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
25 lines
791 B
25 lines
791 B
#!/usr/bin/env python |
|
# -*- encoding: utf-8 -*- |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
class BaseGradientHandler(ABC): |
|
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups |
|
before optimization. |
|
|
|
:param model: Model where the gradients accumulate |
|
:param optimizer: Optimizer for updating the parameters |
|
:type model: Module |
|
:type optimizer: Optimizer |
|
""" |
|
def __init__(self, model, optimizer): |
|
self._model = model |
|
self._optimizer = optimizer |
|
|
|
@abstractmethod |
|
def handle_gradient(self): |
|
"""A method to accumulate gradients across different parallel groups. Users should |
|
write their own functions or just use the functions in pre-defined subclasses. |
|
""" |
|
pass
|
|
|