mirror of https://github.com/InternLM/InternLM
				
				
				
			feat(model/overlap_handler.py): fix lint error
							parent
							
								
									f6a5086fe4
								
							
						
					
					
						commit
						0d693cf3a1
					
				| 
						 | 
				
			
			@ -53,7 +53,6 @@ class MoE(torch.nn.Module):
 | 
			
		|||
        device=None,
 | 
			
		||||
        dtype=None,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        assert (
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,10 @@ from internlm.core.context import global_context as gpc
 | 
			
		|||
from internlm.core.naive_amp import NaiveAMPModel
 | 
			
		||||
from internlm.model.embedding import Embedding1D
 | 
			
		||||
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
 | 
			
		||||
from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool
 | 
			
		||||
from internlm.model.utils import (
 | 
			
		||||
    all_gather_raw_bias_memory_pool,
 | 
			
		||||
    all_gather_raw_memory_pool,
 | 
			
		||||
)
 | 
			
		||||
from internlm.utils.common import get_current_device
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -25,7 +28,7 @@ class FSTPOverlapHandler:
 | 
			
		|||
        self.fstp_modules = []
 | 
			
		||||
        self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
 | 
			
		||||
        self.fstp_global_handle = dict()  # key: fstp module; value: module global all-gather op handle
 | 
			
		||||
        self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
 | 
			
		||||
        self.bias_global_handle = dict()  # key: fstp module; value: module bias global all-gather op handle
 | 
			
		||||
        self.module_to_index = dict()  # key: fstp module; value: transformer block index
 | 
			
		||||
        self.index_to_fstp_modules = dict()  # key: transformer block index; value: fsdp modules
 | 
			
		||||
        self.head = []
 | 
			
		||||
| 
						 | 
				
			
			@ -77,13 +80,13 @@ class FSTPOverlapHandler:
 | 
			
		|||
            self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
 | 
			
		||||
 | 
			
		||||
        return self.zero_const_pool[size]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _initialize_module_shape(self):
 | 
			
		||||
        hidden_size = gpc.config.HIDDEN_SIZE
 | 
			
		||||
        mlp_ratio = gpc.config.MLP_RATIO
 | 
			
		||||
        mlp_hidden_size = int(hidden_size * mlp_ratio)
 | 
			
		||||
        mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size)
 | 
			
		||||
        self.module_shape["out_proj"] = (hidden_size, hidden_size)
 | 
			
		||||
        self.module_shape["w1"] = (mlp_hidden_size, hidden_size)
 | 
			
		||||
| 
						 | 
				
			
			@ -96,7 +99,7 @@ class FSTPOverlapHandler:
 | 
			
		|||
        self.all_gather_bias_memory_pool = []
 | 
			
		||||
        self.reduce_scatter_memory_pool = {}
 | 
			
		||||
        self.module_shape = {}
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        self._initialize_module_shape()
 | 
			
		||||
        dtype = gpc.config.model.get("dtype", torch.half)
 | 
			
		||||
        device = get_current_device()
 | 
			
		||||
| 
						 | 
				
			
			@ -107,10 +110,14 @@ class FSTPOverlapHandler:
 | 
			
		|||
                weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
 | 
			
		||||
            self.all_gather_memory_pool.append(weight)  # containing two groups of block weight
 | 
			
		||||
 | 
			
		||||
    def clear_memory_pool(self) -> None:
 | 
			
		||||
        self.zero_const_pool = {}
 | 
			
		||||
        self.reduce_scatter_memory_pool = {}
 | 
			
		||||
 | 
			
		||||
    def get_all_gather_memory(self, module):
 | 
			
		||||
        block_index = self.module_to_index[module]
 | 
			
		||||
        return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def get_bias_memory(self, module: nn.Module):
 | 
			
		||||
        block_index = self.module_to_index[module]
 | 
			
		||||
        # if the bias memory pool is empty or module has been not allocated memory
 | 
			
		||||
| 
						 | 
				
			
			@ -119,19 +126,20 @@ class FSTPOverlapHandler:
 | 
			
		|||
            for _ in range(2):
 | 
			
		||||
                weight = {}
 | 
			
		||||
                weight[module._fstp_name] = torch.zeros(
 | 
			
		||||
                                                    self.module_shape[module._fstp_name][0], 
 | 
			
		||||
                                                    dtype=gpc.config.model.get("dtype", torch.half),
 | 
			
		||||
                                                    device=get_current_device()).contiguous()
 | 
			
		||||
                    self.module_shape[module._fstp_name][0],
 | 
			
		||||
                    dtype=gpc.config.model.get("dtype", torch.half),
 | 
			
		||||
                    device=get_current_device(),
 | 
			
		||||
                ).contiguous()
 | 
			
		||||
                self.all_gather_bias_memory_pool.append(weight)
 | 
			
		||||
        elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
 | 
			
		||||
            for i in range(2):
 | 
			
		||||
                self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
 | 
			
		||||
                                                    self.module_shape[module._fstp_name][0], 
 | 
			
		||||
                                                    dtype=gpc.config.model.get("dtype", torch.half),
 | 
			
		||||
                                                    device=get_current_device()).contiguous()
 | 
			
		||||
        
 | 
			
		||||
                    self.module_shape[module._fstp_name][0],
 | 
			
		||||
                    dtype=gpc.config.model.get("dtype", torch.half),
 | 
			
		||||
                    device=get_current_device(),
 | 
			
		||||
                ).contiguous()
 | 
			
		||||
 | 
			
		||||
        return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
 | 
			
		||||
                
 | 
			
		||||
 | 
			
		||||
    def get_reduce_scatter_memory(self, key):
 | 
			
		||||
        return_idx = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -170,7 +178,7 @@ class FSTPOverlapHandler:
 | 
			
		|||
 | 
			
		||||
    def release_reduce_scatter_memory(self, key, index):
 | 
			
		||||
        self.reduce_scatter_memory_pool[key][index].idle = True
 | 
			
		||||
   
 | 
			
		||||
 | 
			
		||||
    def _all_gather_block_weight_memory_pool(self, block_index: int):
 | 
			
		||||
        fstp_modules = self.index_to_fstp_modules[block_index]
 | 
			
		||||
        for module in fstp_modules:
 | 
			
		||||
| 
						 | 
				
			
			@ -182,7 +190,7 @@ class FSTPOverlapHandler:
 | 
			
		|||
                    module=module,
 | 
			
		||||
                )
 | 
			
		||||
                self.bias_global_handle[module] = bias_handle
 | 
			
		||||
                
 | 
			
		||||
 | 
			
		||||
            weight_handle = all_gather_raw_memory_pool(
 | 
			
		||||
                module.weight,
 | 
			
		||||
                self.process_group,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -140,6 +140,7 @@ def all_gather_raw_memory_pool(
 | 
			
		|||
    )
 | 
			
		||||
    return handle
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def all_gather_raw_bias_memory_pool(
 | 
			
		||||
    input_: Tensor,
 | 
			
		||||
    process_group: ProcessGroup,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										3
									
								
								train.py
								
								
								
								
							
							
						
						
									
										3
									
								
								train.py
								
								
								
								
							| 
						 | 
				
			
			@ -298,8 +298,7 @@ def main(args):
 | 
			
		|||
            prof.step()
 | 
			
		||||
 | 
			
		||||
            if gpc.fstp_handler is not None:
 | 
			
		||||
                gpc.fstp_handler.zero_const_pool = {}
 | 
			
		||||
                gpc.fstp_handler.reduce_scatter_memory_pool = {}
 | 
			
		||||
                gpc.fstp_handler.clear_memory_pool()
 | 
			
		||||
            # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
 | 
			
		||||
            torch.cuda.reset_peak_memory_stats()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue