mirror of https://github.com/InternLM/InternLM
				
				
				
			Merge branch 'feat/fstp_refactor' of https://github.com/yingtongxiong/InternLM into feat/fstp_refactor
merge originpull/456/head
						commit
						0d3592a53f
					
				| 
						 | 
				
			
			@ -244,7 +244,8 @@ class FSTPOverlapHandler:
 | 
			
		|||
            self.fstp_global_handle[first_backward_module] = weight_handle
 | 
			
		||||
 | 
			
		||||
        def _pre_backward_hook_for_head(module: nn.Module, grad_output):
 | 
			
		||||
            self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
 | 
			
		||||
            if self.is_forward is False:
 | 
			
		||||
                self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
 | 
			
		||||
 | 
			
		||||
        def _pre_backward_hook_for_module(module: nn.Module, grad_output):  # pylint: disable=W0613
 | 
			
		||||
            # wait handle for current module
 | 
			
		||||
| 
						 | 
				
			
			@ -276,7 +277,7 @@ class FSTPOverlapHandler:
 | 
			
		|||
        for embedding in self.embedding:
 | 
			
		||||
            embedding.register_forward_hook(_post_forward_hook_for_embedding)
 | 
			
		||||
 | 
			
		||||
        if self.model_checkpoint and self.is_forward is False:
 | 
			
		||||
        if self.model_checkpoint:
 | 
			
		||||
            for head in self.head:
 | 
			
		||||
                head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -291,7 +292,7 @@ class FSTPOverlapHandler:
 | 
			
		|||
        # 1. register post_backward_hook @head module to prefetch for the last block's last module
 | 
			
		||||
        # 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
 | 
			
		||||
        # 3. register post_backward_hook @fstp_module to release resource
 | 
			
		||||
        if gpc.config.model.checkpoint is False:
 | 
			
		||||
        if self.model_checkpoint is False:
 | 
			
		||||
            for head in self.head:
 | 
			
		||||
                head.register_full_backward_hook(_post_backward_hook_for_head)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue