mirror of https://github.com/hpcaitech/ColossalAI
fix rebase
parent
457a0de79f
commit
5a310b9ee1
|
@ -1 +0,0 @@
|
||||||
to_cast = []
|
|
|
@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
|
||||||
return ret.to(ret_type)
|
return ret.to(ret_type)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None:
|
||||||
r"""
|
r"""
|
||||||
This is an in-place operation for compressed all_reduce using fp8.
|
This is an in-place operation for compressed all_reduce using fp8.
|
||||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||||
|
@ -66,7 +66,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size(group=group)
|
||||||
input_type = tensor.dtype
|
input_type = tensor.dtype
|
||||||
input_shape = tensor.shape
|
input_shape = tensor.shape
|
||||||
input_device = tensor.device
|
input_device = tensor.device
|
||||||
|
@ -83,19 +83,19 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||||
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||||
else:
|
else:
|
||||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||||
dist.all_to_all(output_chunks, input_chunks)
|
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||||
dist.all_gather(scale_list, scale)
|
dist.all_gather(scale_list, scale, group=group)
|
||||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||||
for scale, out in zip(scale_list, output_chunks):
|
for scale, out in zip(scale_list, output_chunks):
|
||||||
out = out.view(fp8_type)
|
out = out.view(fp8_type)
|
||||||
summed_out += cast_from_fp8(out, scale, input_type)
|
summed_out += cast_from_fp8(out, scale, input_type)
|
||||||
|
|
||||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||||
dist.all_gather(scale_list, scale)
|
dist.all_gather(scale_list, scale, group=group)
|
||||||
|
|
||||||
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
|
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
|
||||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8))
|
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
|
||||||
for i in range(world_size):
|
for i in range(world_size):
|
||||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||||
tensor_out = torch.cat(tensor_list, dim=0)
|
tensor_out = torch.cat(tensor_list, dim=0)
|
||||||
|
@ -169,8 +169,8 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||||
|
|
||||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None:
|
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None:
|
||||||
r"""
|
r"""
|
||||||
This is an in-place operation for compressed all_reduce using fp8.
|
This is an in-place operation for compressed reduce_scatter using fp8.
|
||||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||||
|
|
|
@ -94,7 +94,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
total_input = total_input.view(-1, total_input.shape[-1])
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
if fp8_communication and ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce and fp8_communication:
|
||||||
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
|
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
|
||||||
elif ctx.async_grad_allreduce:
|
elif ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
|
@ -117,12 +117,11 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||||
ctx.save_for_backward(input_, weight, bias)
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.async_grad_allreduce = async_grad_allreduce
|
ctx.async_grad_allreduce = async_grad_allreduce
|
||||||
ctx.fp8_communication = fp8_communication
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output = F.linear(input_, weight, bias)
|
output = F.linear(input_, weight, bias)
|
||||||
else:
|
else:
|
||||||
|
@ -134,7 +133,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight, bias = ctx.saved_tensors
|
input, weight, bias = ctx.saved_tensors
|
||||||
use_bias = ctx.use_bias
|
use_bias = ctx.use_bias
|
||||||
fp8_communication = ctx.fp8_communication
|
|
||||||
|
|
||||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||||
if use_bias:
|
if use_bias:
|
||||||
|
@ -150,10 +148,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
|
|
||||||
if ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
if fp8_communication:
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
all_reduce_fp8(grad_input, group=ctx.process_group)
|
|
||||||
else:
|
|
||||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
|
||||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
|
|
||||||
|
@ -172,7 +167,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
|
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
if ctx.async_grad_allreduce and not fp8_communication:
|
if ctx.async_grad_allreduce:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
@ -243,18 +238,16 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, process_group, dim, fp8_communication=False):
|
def forward(ctx, input_, process_group, dim):
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.fp8_communication = fp8_communication
|
|
||||||
|
|
||||||
return _gather(input_, dim, process_group, fp8_communication)
|
return _gather(input_, dim, process_group)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
dim = ctx.dim
|
dim = ctx.dim
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
fp8_communication = ctx.fp8_communication
|
|
||||||
# do reduce-scatter
|
# do reduce-scatter
|
||||||
new_shape = list(grad_output.shape)
|
new_shape = list(grad_output.shape)
|
||||||
assert (
|
assert (
|
||||||
|
@ -266,10 +259,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
]
|
]
|
||||||
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
|
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||||
|
|
||||||
if fp8_communication:
|
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||||
reduce_scatter_fp8(output, grad_list, group=process_group)
|
|
||||||
else:
|
|
||||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
|
||||||
|
|
||||||
return output, None, None, None
|
return output, None, None, None
|
||||||
|
|
||||||
|
@ -576,7 +566,6 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||||
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
||||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||||
if fp8_communication:
|
if fp8_communication:
|
||||||
# if False:
|
|
||||||
reduce_scatter_fp8(output, input_list, group=process_group)
|
reduce_scatter_fp8(output, input_list, group=process_group)
|
||||||
else:
|
else:
|
||||||
dist.reduce_scatter(output, input_list, group=process_group)
|
dist.reduce_scatter(output, input_list, group=process_group)
|
||||||
|
@ -588,8 +577,7 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||||
dim = ctx.dim
|
dim = ctx.dim
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
fp8_communication = ctx.fp8_communication
|
fp8_communication = ctx.fp8_communication
|
||||||
|
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None
|
||||||
return _gather(grad_output, dim, process_group, fp8_communication), None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
|
@ -793,12 +781,12 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_comm=False):
|
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.grad_scale = grad_scale
|
ctx.grad_scale = grad_scale
|
||||||
|
|
||||||
return _gather(input_, dim, process_group, fp8_comm=fp8_comm, fp8_format="e4m3")
|
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
@ -829,11 +817,23 @@ class _AllToAll(torch.autograd.Function):
|
||||||
# using all_to_all_single when batch size is 1
|
# using all_to_all_single when batch size is 1
|
||||||
if bsz == 1:
|
if bsz == 1:
|
||||||
return _all_to_all_single(
|
return _all_to_all_single(
|
||||||
input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2"
|
input_,
|
||||||
|
world_size,
|
||||||
|
process_group,
|
||||||
|
scatter_dim,
|
||||||
|
gather_dim,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
|
fp8_format="e5m2",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return _all_to_all(
|
return _all_to_all(
|
||||||
input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2"
|
input_,
|
||||||
|
world_size,
|
||||||
|
process_group,
|
||||||
|
scatter_dim,
|
||||||
|
gather_dim,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
|
fp8_format="e5m2",
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -841,17 +841,29 @@ class _AllToAll(torch.autograd.Function):
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
scatter_dim = ctx.gather_dim
|
scatter_dim = ctx.gather_dim
|
||||||
gather_dim = ctx.scatter_dim
|
gather_dim = ctx.scatter_dim
|
||||||
ctx.fp8_communication
|
fp8_communication = ctx.fp8_communication
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
bsz, _, _ = grad_output.shape
|
bsz, _, _ = grad_output.shape
|
||||||
|
|
||||||
if bsz == 1:
|
if bsz == 1:
|
||||||
return_grad = _all_to_all_single(
|
return_grad = _all_to_all_single(
|
||||||
grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2"
|
grad_output,
|
||||||
|
world_size,
|
||||||
|
process_group,
|
||||||
|
scatter_dim,
|
||||||
|
gather_dim,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
|
fp8_format="e5m2",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return_grad = _all_to_all(
|
return_grad = _all_to_all(
|
||||||
grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2"
|
grad_output,
|
||||||
|
world_size,
|
||||||
|
process_group,
|
||||||
|
scatter_dim,
|
||||||
|
gather_dim,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
|
fp8_format="e5m2",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (return_grad, None, None, None, None)
|
return (return_grad, None, None, None, None)
|
||||||
|
@ -912,10 +924,7 @@ def _split(input_, dim=-1, process_group=None):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
from colossalai.params import to_cast
|
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"):
|
||||||
|
|
||||||
|
|
||||||
def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3"):
|
|
||||||
# skip if only one rank involved
|
# skip if only one rank involved
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
|
@ -926,13 +935,12 @@ def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3
|
||||||
|
|
||||||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||||
|
|
||||||
if fp8_comm:
|
if fp8_communication:
|
||||||
# if False:
|
# if False:
|
||||||
if has_inf_or_nan(input_):
|
if has_inf_or_nan(input_):
|
||||||
print("input has nan")
|
print("input has nan")
|
||||||
exit(0)
|
exit(0)
|
||||||
input_type = input_.dtype
|
input_type = input_.dtype
|
||||||
to_cast.append(input_)
|
|
||||||
ret, scale = cast_to_fp8(input_, fp8_format="e5m2")
|
ret, scale = cast_to_fp8(input_, fp8_format="e5m2")
|
||||||
if has_inf_or_nan(ret):
|
if has_inf_or_nan(ret):
|
||||||
import pdb
|
import pdb
|
||||||
|
@ -1012,8 +1020,8 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"):
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
|
||||||
if fp8_comm:
|
if fp8_communication:
|
||||||
input_type = input_.dtype
|
input_type = input_.dtype
|
||||||
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
||||||
fp8_type = ret.dtype
|
fp8_type = ret.dtype
|
||||||
|
@ -1036,7 +1044,9 @@ def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=Fal
|
||||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||||
|
|
||||||
|
|
||||||
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"):
|
def _all_to_all_single(
|
||||||
|
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
|
||||||
|
):
|
||||||
inp_shape = list(input_.shape)
|
inp_shape = list(input_.shape)
|
||||||
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
||||||
if scatter_dim < 2:
|
if scatter_dim < 2:
|
||||||
|
@ -1048,7 +1058,7 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, f
|
||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
|
|
||||||
if fp8_comm:
|
if fp8_communication:
|
||||||
input_type = input_t.dtype
|
input_type = input_t.dtype
|
||||||
ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format)
|
ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format)
|
||||||
fp8_type = ret.dtype
|
fp8_type = ret.dtype
|
||||||
|
@ -1085,10 +1095,8 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||||
return LinearWithAsyncCommunication.apply(
|
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def linear_gather_forward_reducescatter_backward(
|
def linear_gather_forward_reducescatter_backward(
|
||||||
|
@ -1099,8 +1107,8 @@ def linear_gather_forward_reducescatter_backward(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False):
|
def gather_forward_reducescatter_backward(input_, process_group, dim):
|
||||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication)
|
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
|
||||||
|
|
||||||
|
|
||||||
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
|
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
|
||||||
|
@ -1132,8 +1140,8 @@ def reduce_forward(input_, process_group, fp8_communication=False):
|
||||||
|
|
||||||
|
|
||||||
def reduce_backward(input_, process_group, fp8_communication=False):
|
def reduce_backward(input_, process_group, fp8_communication=False):
|
||||||
return _ReduceBackward.apply(input_, process_group, fp8_communication=fp8_communication)
|
return _ReduceBackward.apply(input_, process_group, fp8_communication)
|
||||||
|
|
||||||
|
|
||||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_comm=False):
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
|
||||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_comm)
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||||
|
|
|
@ -510,9 +510,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
query_states = all_to_all_comm(query_states, sp_group, fp8_comm=shard_config.fp8_communication)
|
query_states = all_to_all_comm(query_states, sp_group)
|
||||||
key_states = all_to_all_comm(key_states, sp_group, fp8_comm=shard_config.fp8_communication)
|
key_states = all_to_all_comm(key_states, sp_group)
|
||||||
value_states = all_to_all_comm(value_states, sp_group, fp8_comm=shard_config.fp8_communication)
|
value_states = all_to_all_comm(value_states, sp_group)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
@ -660,11 +660,16 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
inputs_embeds = split_forward_gather_backward(
|
inputs_embeds = split_forward_gather_backward(
|
||||||
inputs_embeds, 1, sp_group, fp8_comm=shard_config.fp8_communication
|
inputs_embeds,
|
||||||
|
1,
|
||||||
|
sp_group,
|
||||||
)
|
)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
inputs_embeds = split_forward_gather_backward(
|
inputs_embeds = split_forward_gather_backward(
|
||||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_comm=shard_config.fp8_communication
|
inputs_embeds,
|
||||||
|
1,
|
||||||
|
sp_group,
|
||||||
|
1 / sp_size,
|
||||||
)
|
)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
@ -711,11 +716,16 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
||||||
|
|
||||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states, 1, sp_group, fp8_comm=shard_config.fp8_communication
|
hidden_states,
|
||||||
|
1,
|
||||||
|
sp_group,
|
||||||
)
|
)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_comm=shard_config.fp8_communication
|
hidden_states,
|
||||||
|
1,
|
||||||
|
sp_group,
|
||||||
|
grad_scale=sp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
|
|
|
@ -134,37 +134,37 @@ class LlamaPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.gate_proj",
|
suffix="mlp.gate_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.up_proj",
|
suffix="mlp.up_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.down_proj",
|
suffix="mlp.down_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -224,10 +224,7 @@ def main():
|
||||||
# modify the param accordingly for finetuning test cases
|
# modify the param accordingly for finetuning test cases
|
||||||
plugin = HybridParallelPlugin(
|
plugin = HybridParallelPlugin(
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
pp_size=1,
|
pp_size=2,
|
||||||
sp_size=2,
|
|
||||||
enable_sequence_parallelism=True,
|
|
||||||
sequence_parallelism_mode="all_to_all",
|
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
pp_style="interleaved",
|
pp_style="interleaved",
|
||||||
num_model_chunks=2,
|
num_model_chunks=2,
|
||||||
|
|
|
@ -5,7 +5,7 @@ pip install -r requirements.txt
|
||||||
|
|
||||||
FAIL_LIMIT=3
|
FAIL_LIMIT=3
|
||||||
|
|
||||||
for plugin in "hybrid_parallel"; do
|
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
|
||||||
for i in $(seq 1 $FAIL_LIMIT); do
|
for i in $(seq 1 $FAIL_LIMIT); do
|
||||||
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break
|
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break
|
||||||
echo "Failed $i times"
|
echo "Failed $i times"
|
||||||
|
|
|
@ -220,9 +220,9 @@ def main():
|
||||||
plugin = HybridParallelPlugin(
|
plugin = HybridParallelPlugin(
|
||||||
tp_size=2,
|
tp_size=2,
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
sp_size=2,
|
sp_size=1,
|
||||||
sequence_parallelism_mode="split_gather",
|
# sequence_parallelism_mode="split_gather",
|
||||||
enable_sequence_parallelism=True,
|
# enable_sequence_parallelism=True,
|
||||||
num_microbatches=None,
|
num_microbatches=None,
|
||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
enable_all_optimization=True,
|
enable_all_optimization=True,
|
||||||
|
@ -321,7 +321,3 @@ def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
if dist.get_rank() == 0:
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
pdb.set_trace()
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-4, 1e-3
|
atol, rtol = 1e-4, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-2, 5e-2
|
atol, rtol = 5e-3, 5e-3
|
||||||
col_layer_grads = get_grad_tensors_for_check(
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
gpt2,
|
gpt2,
|
||||||
sharded_gpt2,
|
sharded_gpt2,
|
||||||
|
@ -97,7 +97,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-5, 1e-3
|
atol, rtol = 1e-5, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-2, 5e-2
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
if org_model.__class__.__name__ == "GPT2Model":
|
if org_model.__class__.__name__ == "GPT2Model":
|
||||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
@ -131,47 +131,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
# {
|
{
|
||||||
# "tp_size": 4,
|
"tp_size": 4,
|
||||||
# "pp_size": 1,
|
"pp_size": 1,
|
||||||
# "num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
# "enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
# "sequence_parallelism_mode": "ring",
|
"sequence_parallelism_mode": "ring",
|
||||||
# "enable_flash_attention": False,
|
"enable_flash_attention": False,
|
||||||
# "use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
# "precision": "fp32",
|
"precision": "fp32",
|
||||||
# "initial_scale": 1,
|
"initial_scale": 1,
|
||||||
# },
|
},
|
||||||
# {
|
|
||||||
# "tp_size": 4,
|
|
||||||
# "pp_size": 1,
|
|
||||||
# "num_microbatches": 1,
|
|
||||||
# "enable_sequence_parallelism": True,
|
|
||||||
# "sequence_parallelism_mode": "split_gather",
|
|
||||||
# "enable_flash_attention": False,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tp_size": 2,
|
|
||||||
# "pp_size": 2,
|
|
||||||
# "num_microbatches": 4,
|
|
||||||
# "enable_all_optimization": True,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tp_size": 1,
|
|
||||||
# "pp_size": 2,
|
|
||||||
# "num_microbatches": 2,
|
|
||||||
# "enable_all_optimization": True,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "zero_stage": 1,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
{
|
{
|
||||||
"tp_size": 4,
|
"tp_size": 4,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
|
@ -182,7 +152,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
"fp8_communication": True,
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -284,4 +272,4 @@ def test_gpt2_3d():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_gpt2()
|
test_gpt2()
|
||||||
# test_gpt2_3d()
|
test_gpt2_3d()
|
||||||
|
|
|
@ -34,6 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
if enable_gradient_checkpointing:
|
if enable_gradient_checkpointing:
|
||||||
# org_model.gradient_checkpointing_enable()
|
# org_model.gradient_checkpointing_enable()
|
||||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||||
|
|
||||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
)
|
)
|
||||||
|
@ -70,7 +71,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
)
|
)
|
||||||
grad = grads[grad_index]
|
grad = grads[grad_index]
|
||||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-2, rtol=5e-2, check_dtype=False)
|
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
|
|
||||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
grads_to_check = {}
|
grads_to_check = {}
|
||||||
|
@ -108,7 +109,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-5, 1e-3
|
atol, rtol = 1e-5, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-2, 5e-2
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
if org_model.__class__.__name__ == "LlamaModel":
|
if org_model.__class__.__name__ == "LlamaModel":
|
||||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
@ -120,7 +121,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-4, 1e-3
|
atol, rtol = 1e-4, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-2, 5e-2
|
atol, rtol = 5e-3, 5e-3
|
||||||
try:
|
try:
|
||||||
check_weight(
|
check_weight(
|
||||||
llama_model,
|
llama_model,
|
||||||
|
@ -145,117 +146,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
# { # Test ring + Flash attention
|
{ # Test ring + Flash attention
|
||||||
# "tp_size": 2,
|
|
||||||
# "pp_size": 1,
|
|
||||||
# "sp_size": 2,
|
|
||||||
# "num_microbatches": 1,
|
|
||||||
# "enable_sequence_parallelism": True,
|
|
||||||
# "sequence_parallelism_mode": "ring",
|
|
||||||
# "enable_flash_attention": True,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "zero_stage": 2,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
# { # Ulysess + Flash attention
|
|
||||||
# "tp_size": 1,
|
|
||||||
# "pp_size": 2,
|
|
||||||
# "sp_size": 2,
|
|
||||||
# "num_microbatches": 2,
|
|
||||||
# "enable_sequence_parallelism": True,
|
|
||||||
# "sequence_parallelism_mode": "all_to_all",
|
|
||||||
# "enable_flash_attention": True,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "zero_stage": 1,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tp_size": 1,
|
|
||||||
# "pp_size": 1,
|
|
||||||
# "sp_size": 2,
|
|
||||||
# "num_microbatches": 1,
|
|
||||||
# "enable_sequence_parallelism": True,
|
|
||||||
# "sequence_parallelism_mode": "all_to_all",
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "zero_stage": 1,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tp_size": 4,
|
|
||||||
# "pp_size": 1,
|
|
||||||
# "num_microbatches": 1,
|
|
||||||
# "enable_sequence_parallelism": True,
|
|
||||||
# "sequence_parallelism_mode": "split_gather",
|
|
||||||
# "enable_flash_attention": False,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tp_size": 2,
|
|
||||||
# "pp_size": 2,
|
|
||||||
# "num_microbatches": 2,
|
|
||||||
# "enable_all_optimization": True,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# "enable_gradient_checkpointing": True,
|
|
||||||
# "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tp_size": 1,
|
|
||||||
# "pp_size": 2,
|
|
||||||
# "num_microbatches": 4,
|
|
||||||
# "use_lazy_init": False,
|
|
||||||
# "precision": "fp32",
|
|
||||||
# "enable_gradient_checkpointing": True,
|
|
||||||
# "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tp_size": 2,
|
|
||||||
# "pp_size": 1,
|
|
||||||
# "enable_all_optimization": True,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "zero_stage": 2,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tp_size": 1,
|
|
||||||
# "pp_size": 2,
|
|
||||||
# "num_microbatches": 2,
|
|
||||||
# "enable_all_optimization": True,
|
|
||||||
# "use_lazy_init": True,
|
|
||||||
# "zero_stage": 1,
|
|
||||||
# "precision": "fp16",
|
|
||||||
# "initial_scale": 1,
|
|
||||||
# },
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"sp_size": 2,
|
"sp_size": 2,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "split_gather",
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 1,
|
"zero_stage": 2,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
"fp8_communication": True,
|
|
||||||
},
|
},
|
||||||
{
|
{ # Ulysess + Flash attention
|
||||||
"tp_size": 2,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 2,
|
||||||
"num_microbatches": 1,
|
"sp_size": 2,
|
||||||
"enable_sequence_parallelism": False,
|
"num_microbatches": 2,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
"fp8_communication": True,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
|
@ -268,18 +183,67 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
"fp8_communication": True,
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_llama_test(test_config):
|
def run_llama_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_sequence_classification")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
try:
|
try:
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed config out: {test_config}")
|
print(f"Failed config: {test_config}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
@ -327,7 +291,7 @@ def run_llama_test(test_config):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_llama_3d_test(test_config):
|
def run_llama_3d_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_sequence_classification")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
try:
|
try:
|
||||||
|
@ -369,4 +333,4 @@ def test_llama_3d():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_llama()
|
test_llama()
|
||||||
# test_llama_3d()
|
test_llama_3d()
|
||||||
|
|
Loading…
Reference in New Issue