mirror of https://github.com/hpcaitech/ColossalAI
remove all to all
parent
5a310b9ee1
commit
6a20f07b80
|
@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
|
||||||
return ret.to(ret_type)
|
return ret.to(ret_type)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None:
|
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
|
||||||
r"""
|
r"""
|
||||||
This is an in-place operation for compressed all_reduce using fp8.
|
This is an in-place operation for compressed all_reduce using fp8.
|
||||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||||
|
@ -167,7 +167,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||||
del inp["fp8_scale"]
|
del inp["fp8_scale"]
|
||||||
|
|
||||||
|
|
||||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None:
|
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
|
||||||
r"""
|
r"""
|
||||||
This is an in-place operation for compressed reduce_scatter using fp8.
|
This is an in-place operation for compressed reduce_scatter using fp8.
|
||||||
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
||||||
|
|
|
@ -170,7 +170,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
if ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
||||||
|
@ -261,7 +261,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
|
|
||||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||||
|
|
||||||
return output, None, None, None
|
return output, None, None
|
||||||
|
|
||||||
|
|
||||||
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
|
@ -729,7 +729,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
grad_output = grad_output * ctx.grad_scale
|
grad_output = grad_output * ctx.grad_scale
|
||||||
|
|
||||||
# to_cast.append(grad_output.cpu().detach().numpy())
|
# to_cast.append(grad_output.cpu().detach().numpy())
|
||||||
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, "e4m3"), None, None, None, None
|
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class _ReduceForward(torch.autograd.Function):
|
class _ReduceForward(torch.autograd.Function):
|
||||||
|
@ -786,7 +786,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.grad_scale = grad_scale
|
ctx.grad_scale = grad_scale
|
||||||
|
|
||||||
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
|
return _gather(input_, dim, process_group, fp8_communication=fp8_communication)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
@ -806,67 +806,26 @@ class _AllToAll(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication):
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.scatter_dim = scatter_dim
|
ctx.scatter_dim = scatter_dim
|
||||||
ctx.gather_dim = gather_dim
|
ctx.gather_dim = gather_dim
|
||||||
ctx.fp8_communication = fp8_communication
|
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
bsz, _, _ = input_.shape
|
bsz, _, _ = input_.shape
|
||||||
|
|
||||||
# using all_to_all_single when batch size is 1
|
# using all_to_all_single when batch size is 1
|
||||||
if bsz == 1:
|
if bsz == 1:
|
||||||
return _all_to_all_single(
|
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||||
input_,
|
|
||||||
world_size,
|
|
||||||
process_group,
|
|
||||||
scatter_dim,
|
|
||||||
gather_dim,
|
|
||||||
fp8_communication=fp8_communication,
|
|
||||||
fp8_format="e5m2",
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return _all_to_all(
|
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||||
input_,
|
|
||||||
world_size,
|
|
||||||
process_group,
|
|
||||||
scatter_dim,
|
|
||||||
gather_dim,
|
|
||||||
fp8_communication=fp8_communication,
|
|
||||||
fp8_format="e5m2",
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, *grad_output):
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
scatter_dim = ctx.gather_dim
|
scatter_dim = ctx.gather_dim
|
||||||
gather_dim = ctx.scatter_dim
|
gather_dim = ctx.scatter_dim
|
||||||
fp8_communication = ctx.fp8_communication
|
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||||
world_size = dist.get_world_size(process_group)
|
return (return_grad, None, None, None)
|
||||||
bsz, _, _ = grad_output.shape
|
|
||||||
|
|
||||||
if bsz == 1:
|
|
||||||
return_grad = _all_to_all_single(
|
|
||||||
grad_output,
|
|
||||||
world_size,
|
|
||||||
process_group,
|
|
||||||
scatter_dim,
|
|
||||||
gather_dim,
|
|
||||||
fp8_communication=fp8_communication,
|
|
||||||
fp8_format="e5m2",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return_grad = _all_to_all(
|
|
||||||
grad_output,
|
|
||||||
world_size,
|
|
||||||
process_group,
|
|
||||||
scatter_dim,
|
|
||||||
gather_dim,
|
|
||||||
fp8_communication=fp8_communication,
|
|
||||||
fp8_format="e5m2",
|
|
||||||
)
|
|
||||||
|
|
||||||
return (return_grad, None, None, None, None)
|
|
||||||
|
|
||||||
|
|
||||||
class HookParameter(torch.autograd.Function):
|
class HookParameter(torch.autograd.Function):
|
||||||
|
@ -924,41 +883,20 @@ def _split(input_, dim=-1, process_group=None):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"):
|
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"):
|
||||||
# skip if only one rank involved
|
# skip if only one rank involved
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
# all gather
|
|
||||||
import torch.distributed as dista
|
|
||||||
|
|
||||||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
|
||||||
|
|
||||||
if fp8_communication:
|
if fp8_communication:
|
||||||
# if False:
|
|
||||||
if has_inf_or_nan(input_):
|
|
||||||
print("input has nan")
|
|
||||||
exit(0)
|
|
||||||
input_type = input_.dtype
|
input_type = input_.dtype
|
||||||
ret, scale = cast_to_fp8(input_, fp8_format="e5m2")
|
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
||||||
if has_inf_or_nan(ret):
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
pdb.set_trace()
|
|
||||||
print("cast has nan")
|
|
||||||
# exit(0)
|
|
||||||
dista.barrier()
|
|
||||||
fp8_type = ret.dtype
|
fp8_type = ret.dtype
|
||||||
input_ = ret.view(torch.uint8)
|
input_ = ret.view(torch.uint8)
|
||||||
input_ = input_.contiguous()
|
input_ = input_.contiguous()
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
scale = torch.tensor(scale, dtype=torch.float32).to(input_.device)
|
scale = torch.tensor(scale, dtype=torch.float32).to(input_.device)
|
||||||
# import torch.distributed as dista
|
|
||||||
# if dista.get_rank()==0:
|
|
||||||
# import pdb
|
|
||||||
# pdb.set_trace()
|
|
||||||
# dista.barrier()
|
|
||||||
scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)]
|
scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)]
|
||||||
|
|
||||||
scale = torch.tensor(scale).to(input_.device)
|
scale = torch.tensor(scale).to(input_.device)
|
||||||
|
@ -969,24 +907,10 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for
|
||||||
for output, scale in zip(tensor_list, scale_list):
|
for output, scale in zip(tensor_list, scale_list):
|
||||||
output = output.view(fp8_type)
|
output = output.view(fp8_type)
|
||||||
output = cast_from_fp8(output, scale, input_type)
|
output = cast_from_fp8(output, scale, input_type)
|
||||||
if has_inf_or_nan(output) and dista.get_rank() == 0:
|
|
||||||
print("casted_output has nan")
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
pdb.set_trace()
|
|
||||||
dista.barrier()
|
|
||||||
|
|
||||||
cast_tensor_list.append(output)
|
cast_tensor_list.append(output)
|
||||||
|
|
||||||
output = torch.cat(cast_tensor_list, dim=dim).contiguous()
|
output = torch.cat(cast_tensor_list, dim=dim).contiguous()
|
||||||
|
|
||||||
if has_inf_or_nan(output):
|
|
||||||
print("output has nan")
|
|
||||||
exit(0)
|
|
||||||
# import pdb
|
|
||||||
# pdb.set_trace()
|
|
||||||
dista.barrier()
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
input_ = input_.contiguous()
|
input_ = input_.contiguous()
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
|
@ -1020,33 +944,14 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
|
||||||
if fp8_communication:
|
|
||||||
input_type = input_.dtype
|
|
||||||
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
|
||||||
fp8_type = ret.dtype
|
|
||||||
input_ = ret.view(torch.uint8)
|
|
||||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
|
||||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
|
||||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
|
||||||
dist.all_to_all(output_list, input_list, group=group)
|
|
||||||
dist.all_gather(scale_list, scale, group=group)
|
|
||||||
cast_tensor_list = []
|
|
||||||
for output, scale in zip(output_list, scale_list):
|
|
||||||
output = output.view(fp8_type)
|
|
||||||
output = cast_from_fp8(output, scale, input_type)
|
|
||||||
cast_tensor_list.append(output)
|
|
||||||
output_list = cast_tensor_list
|
|
||||||
else:
|
|
||||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||||
dist.all_to_all(output_list, input_list, group=group)
|
dist.all_to_all(output_list, input_list, group=group)
|
||||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||||
|
|
||||||
|
|
||||||
def _all_to_all_single(
|
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
||||||
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
|
|
||||||
):
|
|
||||||
inp_shape = list(input_.shape)
|
inp_shape = list(input_.shape)
|
||||||
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
||||||
if scatter_dim < 2:
|
if scatter_dim < 2:
|
||||||
|
@ -1058,22 +963,6 @@ def _all_to_all_single(
|
||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
|
|
||||||
if fp8_communication:
|
|
||||||
input_type = input_t.dtype
|
|
||||||
ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format)
|
|
||||||
fp8_type = ret.dtype
|
|
||||||
input_t = ret.view(torch.uint8)
|
|
||||||
output = torch.empty_like(input_t)
|
|
||||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)]
|
|
||||||
dist.all_to_all_single(output, input_t, group=group)
|
|
||||||
dist.all_gather(scale_list, scale, group=group)
|
|
||||||
cast_tensor_list = []
|
|
||||||
for output_part, scale in zip(output, scale_list):
|
|
||||||
output_part = output_part.view(fp8_type)
|
|
||||||
output_part = cast_from_fp8(output_part, scale, input_type)
|
|
||||||
cast_tensor_list.append(output_part)
|
|
||||||
output = torch.stack(cast_tensor_list, dim=0)
|
|
||||||
else:
|
|
||||||
output = torch.empty_like(input_t)
|
output = torch.empty_like(input_t)
|
||||||
dist.all_to_all_single(output, input_t, group=group)
|
dist.all_to_all_single(output, input_t, group=group)
|
||||||
|
|
||||||
|
@ -1143,5 +1032,5 @@ def reduce_backward(input_, process_group, fp8_communication=False):
|
||||||
return _ReduceBackward.apply(input_, process_group, fp8_communication)
|
return _ReduceBackward.apply(input_, process_group, fp8_communication)
|
||||||
|
|
||||||
|
|
||||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||||
|
|
|
@ -84,7 +84,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
fp8_communication: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||||
|
@ -99,7 +98,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.fp8_communication = fp8_communication
|
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError("cannot skip bias addition if bias is None")
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
@ -203,12 +201,10 @@ class Linear1D_Col(ParallelModule):
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode is None:
|
||||||
output_parallel = linear_with_async_comm(
|
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
|
||||||
)
|
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
input_parallel = gather_forward_reducescatter_backward(
|
input_parallel = gather_forward_reducescatter_backward(
|
||||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
input_parallel, self.process_group, self.seq_parallel_dim
|
||||||
)
|
)
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
||||||
elif self.seq_parallel_mode == "ring":
|
elif self.seq_parallel_mode == "ring":
|
||||||
|
@ -268,7 +264,6 @@ class Linear1D_Row(ParallelModule):
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
stream_chunk_num: int = 1,
|
stream_chunk_num: int = 1,
|
||||||
fp8_communication: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -283,7 +278,6 @@ class Linear1D_Row(ParallelModule):
|
||||||
self.seq_parallel_mode = seq_parallel_mode
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
self.seq_parallel_dim = seq_parallel_dim
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
self.num_partitions = dist.get_world_size(self.process_group)
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
self.fp8_communication = fp8_communication
|
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError("cannot skip bias addition if bias is None")
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
@ -404,9 +398,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||||
)
|
)
|
||||||
input_ = split_forward_gather_backward(
|
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||||
input_, dim=-1, process_group=self.process_group, fp8_comm=self.fp8_communication
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.stream_chunk_num > 1:
|
if self.stream_chunk_num > 1:
|
||||||
if self.training:
|
if self.training:
|
||||||
|
@ -426,11 +418,11 @@ class Linear1D_Row(ParallelModule):
|
||||||
else:
|
else:
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode is None:
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
output = reduce_forward(output_parallel, self.process_group)
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||||
output = reducescatter_forward_gather_backward(
|
output = reducescatter_forward_gather_backward(
|
||||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
output_parallel, self.process_group, self.seq_parallel_dim
|
||||||
)
|
)
|
||||||
elif self.seq_parallel_mode == "ring":
|
elif self.seq_parallel_mode == "ring":
|
||||||
output = linear_reducescatter_forward_gather_backward(
|
output = linear_reducescatter_forward_gather_backward(
|
||||||
|
|
|
@ -460,7 +460,7 @@ class LlamaPipelineForwards:
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -592,7 +592,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -659,18 +659,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
||||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
inputs_embeds = split_forward_gather_backward(
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||||
inputs_embeds,
|
|
||||||
1,
|
|
||||||
sp_group,
|
|
||||||
)
|
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
inputs_embeds = split_forward_gather_backward(
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||||
inputs_embeds,
|
|
||||||
1,
|
|
||||||
sp_group,
|
|
||||||
1 / sp_size,
|
|
||||||
)
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
|
@ -715,18 +706,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||||
hidden_states,
|
|
||||||
1,
|
|
||||||
sp_group,
|
|
||||||
)
|
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||||
hidden_states,
|
|
||||||
1,
|
|
||||||
sp_group,
|
|
||||||
grad_scale=sp_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
|
|
@ -218,11 +218,8 @@ def main():
|
||||||
elif args.plugin == "hybrid_parallel":
|
elif args.plugin == "hybrid_parallel":
|
||||||
# modify the param accordingly for finetuning test cases
|
# modify the param accordingly for finetuning test cases
|
||||||
plugin = HybridParallelPlugin(
|
plugin = HybridParallelPlugin(
|
||||||
tp_size=2,
|
tp_size=1,
|
||||||
pp_size=1,
|
pp_size=2,
|
||||||
sp_size=1,
|
|
||||||
# sequence_parallelism_mode="split_gather",
|
|
||||||
# enable_sequence_parallelism=True,
|
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
enable_all_optimization=True,
|
enable_all_optimization=True,
|
||||||
|
|
Loading…
Reference in New Issue