mirror of https://github.com/InternLM/InternLM
add adaptation in model/utils.py
parent
e551b4dffc
commit
0666aa97e4
|
@ -80,6 +80,7 @@ def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||||
|
|
||||||
|
|
||||||
|
# the following communicators are adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/distributed.py
|
||||||
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device)
|
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device)
|
||||||
|
@ -147,6 +148,7 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
|
||||||
return grad_weight, grad_bias
|
return grad_weight, grad_bias
|
||||||
|
|
||||||
|
|
||||||
|
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||||
class FusedDenseFuncTorch(torch.autograd.Function):
|
class FusedDenseFuncTorch(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
|
@ -223,6 +225,7 @@ class FusedDenseFuncTorch(torch.autograd.Function):
|
||||||
assert ctx.compute_weight_gradient
|
assert ctx.compute_weight_gradient
|
||||||
if process_group is not None and sequence_parallel:
|
if process_group is not None and sequence_parallel:
|
||||||
handle_x.wait()
|
handle_x.wait()
|
||||||
|
# we remove the cuda independence, which is different from flash_attn.
|
||||||
grad_weight, grad_bias = linear_bias_wgrad_torch(
|
grad_weight, grad_bias = linear_bias_wgrad_torch(
|
||||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue