diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 8fddf1e..3c0dd9b 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -80,6 +80,7 @@ def gather_forward_split_backward(input_, parallel_mode, dim): return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) +# the following communicators are adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/distributed.py def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device) @@ -147,6 +148,7 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias): return grad_weight, grad_bias +# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFuncTorch(torch.autograd.Function): @staticmethod @custom_fwd @@ -223,6 +225,7 @@ class FusedDenseFuncTorch(torch.autograd.Function): assert ctx.compute_weight_gradient if process_group is not None and sequence_parallel: handle_x.wait() + # we remove the cuda independence, which is different from flash_attn. grad_weight, grad_bias = linear_bias_wgrad_torch( total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] )