mirror of https://github.com/InternLM/InternLM
add adaptation in model/utils.py
parent
e551b4dffc
commit
0666aa97e4
|
@ -80,6 +80,7 @@ def gather_forward_split_backward(input_, parallel_mode, dim):
|
|||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
# the following communicators are adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/distributed.py
|
||||
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device)
|
||||
|
@ -147,6 +148,7 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
|||
return grad_weight, grad_bias
|
||||
|
||||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFuncTorch(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
|
@ -223,6 +225,7 @@ class FusedDenseFuncTorch(torch.autograd.Function):
|
|||
assert ctx.compute_weight_gradient
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
# we remove the cuda independence, which is different from flash_attn.
|
||||
grad_weight, grad_bias = linear_bias_wgrad_torch(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue