2022-11-11 01:26:40 +00:00
|
|
|
from torch import Tensor
|
2023-01-13 06:56:17 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
2023-01-13 02:05:58 +00:00
|
|
|
|
2022-11-11 01:26:40 +00:00
|
|
|
from .base_store import BaseStore
|
|
|
|
|
|
|
|
|
|
|
|
class ParameterStore(BaseStore):
|
2023-01-13 06:56:17 +00:00
|
|
|
def __init__(self, torch_pg: ProcessGroup):
|
|
|
|
super().__init__(torch_pg)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# record the padding size of each param
|
|
|
|
self._padding_map = dict()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# mapping working param and master param
|
|
|
|
self.master_to_working_param = dict()
|
|
|
|
self.working_to_master_param = dict()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def record_param_padding_size(self, param: Tensor, padding_size: int):
|
|
|
|
"""Record the padding size of a param
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
param (Tensor): The parameter
|
|
|
|
padding_size (int): The padding size of the parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
self._padding_map[id(param)] = padding_size
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def get_param_padding_size(self, param: Tensor) -> int:
|
|
|
|
"""Return the padding size of the parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
param (Tensor): The parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Returns:
|
|
|
|
int: the padding size of the parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
return self._padding_map[id(param)]
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
|
|
|
|
"""Mapping master parameter and working parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
master_param (Tensor): The parameter copy in optimizer
|
|
|
|
working_param (Tensor): The parameter of the model
|
|
|
|
"""
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
self.master_to_working_param[id(master_param)] = working_param
|
|
|
|
self.working_to_master_param[id(working_param)] = master_param
|