mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			285 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			285 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
#!/usr/bin/env python
 | 
						|
# -*- encoding: utf-8 -*-
 | 
						|
 | 
						|
from typing import List
 | 
						|
 | 
						|
from torch import Tensor
 | 
						|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
 | 
						|
 | 
						|
from internlm.core.context import ParallelMode
 | 
						|
from internlm.core.context import global_context as gpc
 | 
						|
 | 
						|
 | 
						|
class BaseStore:
 | 
						|
    """
 | 
						|
    Base Store
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, dp_parallel_mode=ParallelMode.DATA):
 | 
						|
        self._world_size = gpc.get_world_size(dp_parallel_mode)
 | 
						|
        self._local_rank = gpc.get_local_rank(dp_parallel_mode)
 | 
						|
 | 
						|
    @property
 | 
						|
    def world_size(self):
 | 
						|
        return self._world_size
 | 
						|
 | 
						|
    @property
 | 
						|
    def local_rank(self):
 | 
						|
        return self._local_rank
 | 
						|
 | 
						|
 | 
						|
class BucketStore(BaseStore):
 | 
						|
    """
 | 
						|
    Bucket Store
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, dp_parallel_mode):
 | 
						|
        super().__init__(dp_parallel_mode)
 | 
						|
        self._grads = dict()
 | 
						|
        self._params = dict()
 | 
						|
        self._num_elements_in_bucket = dict()
 | 
						|
 | 
						|
        self.reset()
 | 
						|
 | 
						|
    def num_elements_in_bucket(self, reduce_rank: int = None):
 | 
						|
        return self._num_elements_in_bucket[reduce_rank]
 | 
						|
 | 
						|
    def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
 | 
						|
        self._num_elements_in_bucket[reduce_rank] += num_elements
 | 
						|
 | 
						|
    def add_grad(self, tensor, reduce_rank: int = None):
 | 
						|
        self._grads[reduce_rank].append(tensor)
 | 
						|
 | 
						|
    def add_param(self, tensor, reduce_rank: int = None):
 | 
						|
        self._params[reduce_rank].append(tensor)
 | 
						|
 | 
						|
    def reset(self):
 | 
						|
        keys = [None] + list(range(self._world_size))
 | 
						|
        self._grads = {rank: [] for rank in keys}
 | 
						|
        self._params = {rank: [] for rank in keys}
 | 
						|
        self._num_elements_in_bucket = {rank: 0 for rank in keys}
 | 
						|
 | 
						|
    def reset_by_rank(self, reduce_rank=None):
 | 
						|
        self._grads[reduce_rank] = []
 | 
						|
        self._params[reduce_rank] = []
 | 
						|
        self._num_elements_in_bucket[reduce_rank] = 0
 | 
						|
 | 
						|
    def get_grad(self, reduce_rank: int = None):
 | 
						|
        return self._grads[reduce_rank]
 | 
						|
 | 
						|
    def get_param(self, reduce_rank: int = None):
 | 
						|
        return self._params[reduce_rank]
 | 
						|
 | 
						|
 | 
						|
class GradientStore(BaseStore):
 | 
						|
    """
 | 
						|
    Gradient Store
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, *args):
 | 
						|
        super().__init__(*args)
 | 
						|
        # bookkeeping data structures
 | 
						|
        self._averaged_gradients = dict()
 | 
						|
 | 
						|
        # for backward reduction hooks
 | 
						|
        self._grad_acc_objs = []
 | 
						|
 | 
						|
    def add_accumulate_grad_object(self, obj):
 | 
						|
        """
 | 
						|
        Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
 | 
						|
        be attached successfully.
 | 
						|
 | 
						|
        :param obj: An object of :class:`AccumulateGrad` class
 | 
						|
        :type obj: :class:`AccumulateGrad`
 | 
						|
        """
 | 
						|
 | 
						|
        self._grad_acc_objs.append(obj)
 | 
						|
 | 
						|
    def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
 | 
						|
        """
 | 
						|
        Return average gradients of a parameter group
 | 
						|
 | 
						|
        :param group_id: The index of parameter group
 | 
						|
        :type group_id: int
 | 
						|
 | 
						|
        :return: Return the list of averaged gradients of a parameter group. Each element is a gradient,
 | 
						|
            not a parameter.
 | 
						|
        :rtype: List[torch.Tensor]
 | 
						|
        """
 | 
						|
 | 
						|
        return self._averaged_gradients[group_id]
 | 
						|
 | 
						|
    def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
 | 
						|
        """
 | 
						|
        Append an average gradient to the list of averaged gradients of a parameter group
 | 
						|
 | 
						|
        :param group_id: The index of a parameter group
 | 
						|
        :param tensor: A :class:`torch.Tensor` object
 | 
						|
        :type group_id: int
 | 
						|
        :type tensor: torch.Tensor
 | 
						|
 | 
						|
        """
 | 
						|
 | 
						|
        if group_id in self._averaged_gradients:
 | 
						|
            self._averaged_gradients[group_id].append(tensor)
 | 
						|
        else:
 | 
						|
            self._averaged_gradients[group_id] = [tensor]
 | 
						|
 | 
						|
    def reset_average_gradients_by_group(self, group_id: int) -> None:
 | 
						|
        """
 | 
						|
        Reset the bookkeeping data structure for averaged gradients to an empty list
 | 
						|
 | 
						|
        :param group_id: The index of a parameter group
 | 
						|
        :type group_id: int
 | 
						|
        """
 | 
						|
 | 
						|
        self._averaged_gradients[group_id] = []
 | 
						|
 | 
						|
 | 
						|
class ParameterStore(BaseStore):
 | 
						|
    """
 | 
						|
    Parameter Store
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, dp_paralle_mode):
 | 
						|
        super().__init__(dp_paralle_mode)
 | 
						|
        # param partitioning data structures
 | 
						|
        self._fp16_param_to_rank = dict()
 | 
						|
        self._rank_groupid_to_fp16_param_list = dict()
 | 
						|
        self._rank_group_id_to_flat_fp16_param = dict()
 | 
						|
 | 
						|
        # param reduction data structures
 | 
						|
        self._is_param_reduced = dict()
 | 
						|
        self._reduced_param = []
 | 
						|
 | 
						|
    def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
 | 
						|
        """
 | 
						|
        Set the mapping between parameter to rank, each parameter should be owned by a rank.
 | 
						|
 | 
						|
        :param tensor: A :class:`torch.Tensor` object
 | 
						|
        :type tensor: torch.Tensor
 | 
						|
        :param rank: The rank of which the process is responsible for updating the parameter
 | 
						|
        :type rank: int
 | 
						|
        """
 | 
						|
 | 
						|
        self._fp16_param_to_rank[tensor] = rank
 | 
						|
 | 
						|
    def get_param_rank(self, tensor: Tensor) -> int:
 | 
						|
        """
 | 
						|
        Gives the rank which the parameter belongs to
 | 
						|
 | 
						|
        :param tensor: A :class:`torch.Tensor` object
 | 
						|
        :type tensor: torch.Tensor
 | 
						|
        """
 | 
						|
        return self._fp16_param_to_rank[tensor]
 | 
						|
 | 
						|
    def belongs_to_current_rank(self, tensor) -> bool:
 | 
						|
        """
 | 
						|
        Check whether a parameter is supposed to be updated by the process of the current rank
 | 
						|
 | 
						|
        :param tensor: A :class:`torch.Tensor` object
 | 
						|
        :type tensor: torch.Tensor
 | 
						|
 | 
						|
        :return: True if the parameter should be updated by the current rank. Otherwise false.
 | 
						|
        :rtype: bool
 | 
						|
        """
 | 
						|
 | 
						|
        tensor_rank = self._fp16_param_to_rank[tensor]
 | 
						|
        return tensor_rank == self._local_rank
 | 
						|
 | 
						|
    def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
 | 
						|
        if rank not in self._rank_groupid_to_fp16_param_list:
 | 
						|
            self._rank_groupid_to_fp16_param_list[rank] = dict()
 | 
						|
 | 
						|
        if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
 | 
						|
            self._rank_groupid_to_fp16_param_list[rank][group_id] = []
 | 
						|
 | 
						|
        self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list)
 | 
						|
 | 
						|
    def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
 | 
						|
        return self._rank_groupid_to_fp16_param_list[rank][group_id]
 | 
						|
 | 
						|
    def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
 | 
						|
        if rank not in self._rank_group_id_to_flat_fp16_param:
 | 
						|
            self._rank_group_id_to_flat_fp16_param[rank] = dict()
 | 
						|
 | 
						|
        self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
 | 
						|
 | 
						|
    def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
 | 
						|
        return self._rank_group_id_to_flat_fp16_param[rank][group_id]
 | 
						|
 | 
						|
    def is_param_reduced(self, tensor):
 | 
						|
        return self._is_param_reduced[tensor]
 | 
						|
 | 
						|
    def set_param_reduction_state(self, tensor, state):
 | 
						|
        self._is_param_reduced[tensor] = state
 | 
						|
 | 
						|
    def get_param_reduction_states(self):
 | 
						|
        return self._is_param_reduced
 | 
						|
 | 
						|
    def reset_previous_reduced_params(self):
 | 
						|
        self._reduced_param = []
 | 
						|
 | 
						|
    def add_previous_reduced_param(self, tensor):
 | 
						|
        self._reduced_param.append(tensor)
 | 
						|
 | 
						|
    def clear_grads_of_previous_reduced_params(self):
 | 
						|
        if len(self._reduced_param) > 0:
 | 
						|
            for param in self._reduced_param:
 | 
						|
                param.grad = None
 | 
						|
            self.reset_previous_reduced_params()
 | 
						|
 | 
						|
 | 
						|
class TensorBucket:
 | 
						|
    """
 | 
						|
    Tensor Bucket
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, size):
 | 
						|
        self._max_size = size
 | 
						|
        self._current_size = 0
 | 
						|
        self._bucket = []
 | 
						|
 | 
						|
    @property
 | 
						|
    def max_size(self):
 | 
						|
        return self._max_size
 | 
						|
 | 
						|
    @property
 | 
						|
    def current_size(self):
 | 
						|
        return self._current_size
 | 
						|
 | 
						|
    def is_full_or_oversized(self):
 | 
						|
        return self._current_size >= self._max_size
 | 
						|
 | 
						|
    def is_empty(self):
 | 
						|
        return len(self._bucket) == 0
 | 
						|
 | 
						|
    def add_to_bucket(self, tensor, allow_oversize=False):
 | 
						|
        tensor_size = tensor.numel()
 | 
						|
 | 
						|
        if not allow_oversize and self.will_exceed_max_size(tensor_size):
 | 
						|
            msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})"
 | 
						|
            raise RuntimeError(msg)
 | 
						|
 | 
						|
        self._bucket.append(tensor)
 | 
						|
        self._current_size += tensor_size
 | 
						|
 | 
						|
    def will_exceed_max_size(self, tensor_size):
 | 
						|
        expected_size = self._current_size + tensor_size
 | 
						|
        return expected_size > self._max_size
 | 
						|
 | 
						|
    def get_bucket(self):
 | 
						|
        return self._bucket
 | 
						|
 | 
						|
    def empty(self):
 | 
						|
        self._bucket = []
 | 
						|
        self._size = 0
 | 
						|
 | 
						|
    def flatten(self):
 | 
						|
        return _flatten_dense_tensors(self._bucket)
 | 
						|
 | 
						|
    def unflatten_and_copy(self, flat_tensor):
 | 
						|
        unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
 | 
						|
        for old, new in zip(self._bucket, unflattened_tensor_list):
 | 
						|
            old.copy_(new)
 |