mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.7 KiB
61 lines
1.7 KiB
from typing import Dict
|
|
|
|
from torch import Tensor
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from .base_store import BaseStore
|
|
|
|
|
|
class ParameterStore(BaseStore):
|
|
def __init__(self, torch_pg: ProcessGroup):
|
|
super().__init__(torch_pg)
|
|
|
|
# record the padding size of each param
|
|
self._padding_map = dict()
|
|
|
|
# mapping working param and master param
|
|
self.master_to_working_param = dict()
|
|
self.working_to_master_param = dict()
|
|
|
|
def record_param_padding_size(self, param: Tensor, padding_size: int):
|
|
"""Record the padding size of a param
|
|
|
|
Args:
|
|
param (Tensor): The parameter
|
|
padding_size (int): The padding size of the parameter
|
|
"""
|
|
|
|
self._padding_map[id(param)] = padding_size
|
|
|
|
def get_param_padding_size(self, param: Tensor) -> int:
|
|
"""Return the padding size of the parameter
|
|
|
|
Args:
|
|
param (Tensor): The parameter
|
|
|
|
Returns:
|
|
int: the padding size of the parameter
|
|
"""
|
|
|
|
return self._padding_map[id(param)]
|
|
|
|
def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
|
|
"""Mapping master parameter and working parameter
|
|
|
|
Args:
|
|
master_param (Tensor): The parameter copy in optimizer
|
|
working_param (Tensor): The parameter of the model
|
|
"""
|
|
|
|
self.master_to_working_param[id(master_param)] = working_param
|
|
self.working_to_master_param[id(working_param)] = master_param
|
|
|
|
def get_padding_map(self) -> Dict[int, Tensor]:
|
|
"""Return the padding map
|
|
|
|
Returns:
|
|
Dict[int, Tensor]: The padding map
|
|
"""
|
|
|
|
return self._padding_map
|