add adaptation in model/utils.py

pull/155/head
yingtongxiong 2023-08-02 16:43:49 +08:00
parent e551b4dffc
commit 0666aa97e4
1 changed files with 3 additions and 0 deletions

View File

@ -80,6 +80,7 @@ def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
# the following communicators are adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/distributed.py
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device)
@ -147,6 +148,7 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
return grad_weight, grad_bias
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(torch.autograd.Function):
@staticmethod
@custom_fwd
@ -223,6 +225,7 @@ class FusedDenseFuncTorch(torch.autograd.Function):
assert ctx.compute_weight_gradient
if process_group is not None and sequence_parallel:
handle_x.wait()
# we remove the cuda independence, which is different from flash_attn.
grad_weight, grad_bias = linear_bias_wgrad_torch(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)