mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			126 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			126 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
| # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
 | |
| 
 | |
| from typing import List, Tuple, Union
 | |
| 
 | |
| import torch
 | |
| import torch.distributed as dist
 | |
| 
 | |
| from internlm.core.context import ParallelMode
 | |
| from internlm.core.context import global_context as gpc
 | |
| from internlm.utils.common import get_current_device
 | |
| 
 | |
| TensorShape = Union[torch.Size, List[int], Tuple[int]]
 | |
| 
 | |
| 
 | |
| def send_meta_helper(obj, next_rank, tensor_kwargs):
 | |
|     send_shape = torch.tensor(obj.size(), **tensor_kwargs)
 | |
|     send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
 | |
|     dist.send(send_ndims, next_rank)
 | |
|     dist.send(send_shape, next_rank)
 | |
| 
 | |
| 
 | |
| def send_obj_meta(obj, next_rank=None):
 | |
|     """Sends obj meta information before sending a specific obj.
 | |
|     Since the recipient must know the shape of the obj in p2p communications,
 | |
|     meta information of the obj should be sent before communications. This function
 | |
|     synchronizes with :func:`recv_obj_meta`.
 | |
| 
 | |
|     Args:
 | |
|         obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
 | |
|         need_meta (bool, optional): If False, meta information won't be sent.
 | |
|         next_rank (int): The rank of the next member in pipeline parallel group.
 | |
| 
 | |
|     Returns:
 | |
|         bool: False
 | |
|     """
 | |
|     if next_rank is None:
 | |
|         next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
 | |
| 
 | |
|     tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
 | |
|     if isinstance(obj, torch.Tensor):
 | |
|         send_obj_nums = torch.tensor(1, **tensor_kwargs)
 | |
|         dist.send(send_obj_nums, next_rank)
 | |
|         send_meta_helper(obj, next_rank, tensor_kwargs)
 | |
|     else:
 | |
|         send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
 | |
|         dist.send(send_obj_nums, next_rank)
 | |
|         for tensor_to_send in obj:
 | |
|             send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
 | |
| 
 | |
| 
 | |
| def recv_meta_helper(prev_rank, tensor_kwargs):
 | |
|     recv_ndims = torch.empty((), **tensor_kwargs)
 | |
|     dist.recv(recv_ndims, prev_rank)
 | |
|     recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
 | |
|     dist.recv(recv_shape, prev_rank)
 | |
|     return recv_shape
 | |
| 
 | |
| 
 | |
| def recv_obj_meta(prev_rank=None) -> torch.Size:
 | |
|     """Receives obj meta information before receiving a specific obj.
 | |
|     Since the recipient must know the shape of the obj in p2p communications,
 | |
|     meta information of the obj should be received before communications. This function
 | |
|     synchronizes with :func:`send_obj_meta`.
 | |
| 
 | |
|     Args:
 | |
|         obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
 | |
|         prev_rank (int): The rank of the source of the obj.
 | |
| 
 | |
|     Returns:
 | |
|         Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
 | |
|     """
 | |
|     if prev_rank is None:
 | |
|         prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
 | |
| 
 | |
|     tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
 | |
|     recv_obj_nums = torch.empty((), **tensor_kwargs)
 | |
|     dist.recv(recv_obj_nums, prev_rank)
 | |
|     if recv_obj_nums.item() == 1:
 | |
|         recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
 | |
|         obj_shape = torch.Size(recv_shape)
 | |
|     else:
 | |
|         obj_shape = []
 | |
|         for _ in range(recv_obj_nums.item()):
 | |
|             recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
 | |
|             obj_shape.append(torch.Size(recv_shape))
 | |
| 
 | |
|     return obj_shape
 | |
| 
 | |
| 
 | |
| def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
 | |
|     """Break a tensor into equal 1D chunks.
 | |
| 
 | |
|     Args:
 | |
|         tensor (:class:`torch.Tensor`): Tensor to be split before communication.
 | |
|         new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
 | |
| 
 | |
|     Returns:
 | |
|         :class:`torch.Tensor`: The split tensor
 | |
|     """
 | |
|     partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR)
 | |
|     start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR)
 | |
|     end_index = start_index + partition_size
 | |
|     if new_buffer:
 | |
|         data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
 | |
|         data.copy_(tensor.view(-1)[start_index:end_index])
 | |
|     else:
 | |
|         data = tensor.view(-1)[start_index:end_index]
 | |
|     return data
 | |
| 
 | |
| 
 | |
| def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
 | |
|     """Opposite of above function, gather values from model parallel ranks.
 | |
| 
 | |
|     Args:
 | |
|         tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
 | |
|     Returns:
 | |
|         :class:`torch.Tensor`: The gathered tensor.
 | |
|     """
 | |
|     world_size = gpc.get_world_size(ParallelMode.TENSOR)
 | |
|     numel = torch.numel(tensor)
 | |
|     numel_gathered = world_size * numel
 | |
|     gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
 | |
|     chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
 | |
|     dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
 | |
|     return gathered
 |