mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			74 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			74 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from internlm.core.context import global_context as gpc
 | |
| 
 | |
| 
 | |
| def _split(input_, parallel_mode, dim=-1):
 | |
|     # skip if only one rank involved
 | |
|     world_size = gpc.get_world_size(parallel_mode)
 | |
|     if world_size == 1:
 | |
|         return input_
 | |
| 
 | |
|     # Split along last dimension.
 | |
|     dim_size = input_.size(dim)
 | |
|     assert dim_size % world_size == 0, (
 | |
|         f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
 | |
|         f"cannot split tensor evenly"
 | |
|     )
 | |
| 
 | |
|     tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
 | |
|     rank = gpc.get_local_rank(parallel_mode)
 | |
|     output = tensor_list[rank].contiguous()
 | |
| 
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def _gather(input_, parallel_mode, dim=-1):
 | |
|     # skip if only one rank involved
 | |
|     world_size = gpc.get_world_size(parallel_mode)
 | |
|     if world_size == 1:
 | |
|         return input_
 | |
| 
 | |
|     # all gather
 | |
|     rank = gpc.get_local_rank(parallel_mode)
 | |
|     tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
 | |
|     tensor_list[rank] = input_
 | |
|     group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
 | |
|     torch.distributed.all_gather(tensor_list, input_, group=group)
 | |
| 
 | |
|     # concat
 | |
|     output = torch.cat(tensor_list, dim=dim).contiguous()
 | |
| 
 | |
|     return output
 | |
| 
 | |
| 
 | |
| class _GatherForwardSplitBackward(torch.autograd.Function):
 | |
|     """Gather the input from model parallel region and concatenate.
 | |
| 
 | |
|     Args:
 | |
|         input_: input matrix.
 | |
|         parallel_mode: parallel mode.
 | |
|         dim: dimension
 | |
|     """
 | |
| 
 | |
|     @staticmethod
 | |
|     def symbolic(input_):
 | |
|         return _gather(input_, parallel_mode=None)
 | |
| 
 | |
|     @staticmethod
 | |
|     def forward(ctx, input_, parallel_mode, dim):
 | |
|         ctx.mode = parallel_mode
 | |
|         ctx.dim = dim
 | |
|         return _gather(input_, parallel_mode, dim)
 | |
| 
 | |
|     @staticmethod
 | |
|     def backward(ctx, grad_output):
 | |
|         return _split(grad_output, ctx.mode, ctx.dim), None, None
 | |
| 
 | |
| 
 | |
| def gather_forward_split_backward(input_, parallel_mode, dim):
 | |
|     return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
 |