mirror of https://github.com/hpcaitech/ColossalAI
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from .base_store import BaseStore
|
|
|
|
|
|
class BucketStore(BaseStore):
|
|
|
|
def __init__(self, dp_parallel_mode):
|
|
super().__init__(dp_parallel_mode)
|
|
self._grads = dict()
|
|
self._params = dict()
|
|
self._num_elements_in_bucket = dict()
|
|
|
|
self.reset()
|
|
|
|
def num_elements_in_bucket(self, reduce_rank: int = None):
|
|
return self._num_elements_in_bucket[reduce_rank]
|
|
|
|
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
|
self._num_elements_in_bucket[reduce_rank] += num_elements
|
|
|
|
def add_grad(self, tensor, reduce_rank: int = None):
|
|
self._grads[reduce_rank].append(tensor)
|
|
|
|
def add_param(self, tensor, reduce_rank: int = None):
|
|
self._params[reduce_rank].append(tensor)
|
|
|
|
def reset(self):
|
|
keys = [None] + list(range(self._world_size))
|
|
self._grads = {rank: [] for rank in keys}
|
|
self._params = {rank: [] for rank in keys}
|
|
self._num_elements_in_bucket = {rank: 0 for rank in keys}
|
|
|
|
def reset_by_rank(self, reduce_rank=None):
|
|
self._grads[reduce_rank] = []
|
|
self._params[reduce_rank] = []
|
|
self._num_elements_in_bucket[reduce_rank] = 0
|
|
|
|
def get_grad(self, reduce_rank: int = None):
|
|
return self._grads[reduce_rank]
|
|
|
|
def get_param(self, reduce_rank: int = None):
|
|
return self._params[reduce_rank]
|