mirror of https://github.com/InternLM/InternLM
				
				
				
			fix reduce scatter async bug
							parent
							
								
									229cc5c68c
								
							
						
					
					
						commit
						6682f5d92a
					
				| 
						 | 
				
			
			@ -371,12 +371,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
 | 
			
		|||
                grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
 | 
			
		||||
                assert hasattr(weight, "_fstp_reduce_scatter_str")
 | 
			
		||||
                all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
 | 
			
		||||
                grad_weight = torch.empty(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device)
 | 
			
		||||
                grad_weight = torch.zeros(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device)
 | 
			
		||||
                if grad_bias is not None:
 | 
			
		||||
                    grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
 | 
			
		||||
                    assert hasattr(bias, "_fstp_reduce_scatter_str")
 | 
			
		||||
                    all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
 | 
			
		||||
                    grad_bias = torch.empty(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device)
 | 
			
		||||
                    grad_bias = torch.zeros(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device)
 | 
			
		||||
        else:
 | 
			
		||||
            grad_weight = None
 | 
			
		||||
            grad_bias = grad_output if ctx.needs_input_grad[2] else None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -333,7 +333,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
                    key = getattr(_param, "_fstp_reduce_scatter_str")
 | 
			
		||||
                    comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
 | 
			
		||||
                    comm_handle.wait()
 | 
			
		||||
                    _param.grad = _grad
 | 
			
		||||
                    _param.grad += _grad
 | 
			
		||||
 | 
			
		||||
                bucket.reset_by_rank(rank)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -356,7 +356,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
                key = getattr(_param, "_fstp_reduce_scatter_str")
 | 
			
		||||
                comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
 | 
			
		||||
                comm_handle.wait()
 | 
			
		||||
                _param.grad = _grad
 | 
			
		||||
                _param.grad += _grad
 | 
			
		||||
 | 
			
		||||
            # reduce grad
 | 
			
		||||
            if self.skip_grad_reduce is False:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue