mirror of https://github.com/hpcaitech/ColossalAI
[Feature] llama shardformer fp8 support (#5938)
* add llama shardformer fp8 * Llama Shardformer Parity * fix typo * fix all reduce * fix pytest failure * fix reduce op and move function to fp8.py * fix typopull/5963/head
parent
c297e21bea
commit
53cb9606bd
|
@ -3,9 +3,11 @@ from typing import Any
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
Args:
|
||||
|
@ -23,7 +25,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
|
|||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
|
||||
if inp.dim() == 2:
|
||||
if per_channel_scale:
|
||||
per_channel_max = inp.abs().max(dim=-1).values.float()
|
||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||
scale = fp8_max / per_channel_max[:, None]
|
||||
|
@ -37,7 +39,9 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
|
|||
return ret, scale_inv
|
||||
|
||||
|
||||
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
|
||||
def cast_from_fp8(
|
||||
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
|
||||
|
@ -49,20 +53,23 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
|
|||
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
||||
|
||||
if inp.dim() == 2:
|
||||
if per_channel_scale:
|
||||
ret = scale_inv[:, None] * inp.float()
|
||||
else:
|
||||
ret = scale_inv * inp.float()
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
op: ReduceOp.SUM or ReduceOp.AVG
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
@ -72,18 +79,20 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
|
|||
input_shape = tensor.shape
|
||||
input_device = tensor.device
|
||||
input_size = tensor.numel()
|
||||
tensor = tensor.flatten()
|
||||
flat_padded_x = tensor.flatten()
|
||||
|
||||
assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG"
|
||||
|
||||
if flat_padded_x.size(0) % world_size != 0:
|
||||
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format)
|
||||
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||
if dist.get_rank() == world_size - 1:
|
||||
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||
else:
|
||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||
output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0))
|
||||
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
@ -92,15 +101,18 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
|
|||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
|
||||
if op == ReduceOp.AVG:
|
||||
summed_out.div_(world_size)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
||||
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
|
||||
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
tensor_out = torch.cat(tensor_list, dim=0)
|
||||
tensor.data = tensor_out.view(input_shape).to(input_type)
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
|
@ -276,5 +288,74 @@ def all_gather_into_tensor_flat_fp8(
|
|||
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
|
||||
numel = np.prod(output_shape)
|
||||
valid_buffer = buffer[:numel].reshape(output_shape)
|
||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type)
|
||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
|
||||
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
||||
|
||||
|
||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
input_type = input_list[0].dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
scale_list = []
|
||||
tensor_list = []
|
||||
|
||||
for i in range(world_size):
|
||||
input_tensor = input_list[i]
|
||||
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
||||
scale_list.append(scale)
|
||||
ret = ret.view(torch.uint8)
|
||||
tensor_list.append(ret)
|
||||
|
||||
output_scale_list = [torch.empty_like(x) for x in scale_list]
|
||||
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
|
||||
dist.all_to_all(output_tensor_list, tensor_list, group=group)
|
||||
dist.all_to_all(output_scale_list, scale_list, group=group)
|
||||
|
||||
for i in range(world_size):
|
||||
scale = output_scale_list[i]
|
||||
tensor = output_tensor_list[i]
|
||||
tensor = tensor.view(fp8_type)
|
||||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||
|
||||
|
||||
def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
per_slice_len = input_tensor.size(0) // world_size
|
||||
input_type = input_tensor.dtype
|
||||
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_tensor = ret.view(torch.uint8)
|
||||
tensor = torch.empty_like(input_tensor)
|
||||
scale_list = [torch.empty_like(scale) for _ in range(world_size)]
|
||||
dist.all_to_all_single(tensor, input_tensor, group=group)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
cast_tensor_list = []
|
||||
|
||||
for i in range(world_size):
|
||||
output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type)
|
||||
output_part = cast_from_fp8(output_part, scale_list[i], input_type)
|
||||
cast_tensor_list.append(output_part)
|
||||
output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0))
|
||||
|
||||
|
||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
input_type = input_.dtype
|
||||
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_ = ret.view(torch.uint8)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, input_, group=group)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
||||
for i in range(world_size):
|
||||
output = tensor_list[i].view(fp8_type)
|
||||
scale = scale_list[i]
|
||||
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
||||
|
|
|
@ -14,7 +14,13 @@ try:
|
|||
except ImportError:
|
||||
_grad_accum_fusion_available = False
|
||||
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8, cast_from_fp8, cast_to_fp8, reduce_scatter_fp8
|
||||
from colossalai.quantization.fp8 import (
|
||||
all_reduce_fp8,
|
||||
all_to_all_fp8,
|
||||
all_to_all_single_fp8,
|
||||
gather_fp8,
|
||||
reduce_scatter_fp8,
|
||||
)
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
|
@ -117,11 +123,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
|
@ -133,6 +140,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
|
@ -148,7 +156,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
if fp8_communication:
|
||||
all_reduce_fp8(grad_input, group=ctx.process_group)
|
||||
else:
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||
|
||||
|
@ -167,10 +178,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
||||
|
@ -238,16 +249,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
def forward(ctx, input_, process_group, dim, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
return _gather(input_, dim, process_group)
|
||||
return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
fp8_communication = ctx.fp8_communication
|
||||
# do reduce-scatter
|
||||
new_shape = list(grad_output.shape)
|
||||
assert (
|
||||
|
@ -259,9 +272,12 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
]
|
||||
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
|
||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||
if fp8_communication:
|
||||
reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2")
|
||||
else:
|
||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||
|
||||
return output, None, None
|
||||
return output, None, None, None
|
||||
|
||||
|
||||
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
|
@ -577,12 +593,8 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
fp8_communication = ctx.fp8_communication
|
||||
return (
|
||||
_gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None
|
||||
|
||||
|
||||
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
|
@ -816,26 +828,67 @@ class _AllToAll(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = input_.shape
|
||||
|
||||
# using all_to_all_single when batch size is 1
|
||||
if bsz == 1:
|
||||
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
return _all_to_all_single(
|
||||
input_,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
else:
|
||||
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
return _all_to_all(
|
||||
input_,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
def backward(ctx, grad_output):
|
||||
process_group = ctx.process_group
|
||||
scatter_dim = ctx.gather_dim
|
||||
gather_dim = ctx.scatter_dim
|
||||
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||
return (return_grad, None, None, None)
|
||||
fp8_communication = ctx.fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = grad_output.shape
|
||||
|
||||
if bsz == 1:
|
||||
return_grad = _all_to_all_single(
|
||||
grad_output,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e5m2",
|
||||
)
|
||||
else:
|
||||
return_grad = _all_to_all(
|
||||
grad_output,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e5m2",
|
||||
)
|
||||
|
||||
return (return_grad, None, None, None, None)
|
||||
|
||||
|
||||
class HookParameter(torch.autograd.Function):
|
||||
|
@ -899,33 +952,14 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for
|
|||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
input_ = input_.contiguous()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
if fp8_communication:
|
||||
input_type = input_.dtype
|
||||
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_ = ret.view(torch.uint8)
|
||||
input_ = input_.contiguous()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
scale = torch.tensor(scale, dtype=torch.float32).to(input_.device)
|
||||
scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)]
|
||||
|
||||
scale = torch.tensor(scale).to(input_.device)
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
torch.distributed.all_gather(scale_list, scale, group=process_group)
|
||||
|
||||
cast_tensor_list = []
|
||||
for output, scale in zip(tensor_list, scale_list):
|
||||
output = output.view(fp8_type)
|
||||
output = cast_from_fp8(output, scale, input_type)
|
||||
cast_tensor_list.append(output)
|
||||
|
||||
output = torch.cat(cast_tensor_list, dim=dim).contiguous()
|
||||
|
||||
gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
|
||||
else:
|
||||
input_ = input_.contiguous()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
dist.all_gather(tensor_list, input_, group=process_group)
|
||||
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
@ -954,14 +988,19 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
|||
return output
|
||||
|
||||
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
if fp8_communication:
|
||||
all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format)
|
||||
else:
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
||||
def _all_to_all_single(
|
||||
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
|
||||
):
|
||||
inp_shape = list(input_.shape)
|
||||
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
||||
if scatter_dim < 2:
|
||||
|
@ -974,7 +1013,11 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
|||
)
|
||||
|
||||
output = torch.empty_like(input_t)
|
||||
dist.all_to_all_single(output, input_t, group=group)
|
||||
if fp8_communication:
|
||||
all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format)
|
||||
else:
|
||||
|
||||
dist.all_to_all_single(output, input_t, group=group)
|
||||
|
||||
if scatter_dim < 2:
|
||||
output = output.transpose(0, 1).contiguous()
|
||||
|
@ -994,8 +1037,10 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
|||
)
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
return LinearWithAsyncCommunication.apply(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
|
@ -1006,8 +1051,8 @@ def linear_gather_forward_reducescatter_backward(
|
|||
)
|
||||
|
||||
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication)
|
||||
|
||||
|
||||
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
|
||||
|
@ -1042,5 +1087,5 @@ def reduce_backward(input_, process_group, fp8_communication=False):
|
|||
return _ReduceBackward.apply(input_, process_group, fp8_communication)
|
||||
|
||||
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||
|
|
|
@ -84,6 +84,7 @@ class Linear1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
|
@ -98,6 +99,7 @@ class Linear1D_Col(ParallelModule):
|
|||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -201,10 +203,12 @@ class Linear1D_Col(ParallelModule):
|
|||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
|
@ -214,7 +218,9 @@ class Linear1D_Col(ParallelModule):
|
|||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
@ -264,6 +270,7 @@ class Linear1D_Row(ParallelModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -278,6 +285,7 @@ class Linear1D_Row(ParallelModule):
|
|||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -398,7 +406,9 @@ class Linear1D_Row(ParallelModule):
|
|||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
input_ = split_forward_gather_backward(
|
||||
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
|
@ -418,11 +428,11 @@ class Linear1D_Row(ParallelModule):
|
|||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
|
|
|
@ -460,7 +460,7 @@ class LlamaPipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -510,9 +510,9 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -574,7 +574,9 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -592,7 +594,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
|||
return forward
|
||||
|
||||
|
||||
def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
|
@ -659,9 +661,13 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
|||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
|
@ -706,9 +712,13 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -65,7 +65,6 @@ class LlamaPolicy(Policy):
|
|||
norm_cls = FusedRMSNorm
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
self.shard_config.enable_sequence_overlap = False
|
||||
|
@ -134,37 +133,37 @@ class LlamaPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_to_all_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize("shape", [(16, 8, 4)])
|
||||
@parameterize("scatter_dim", [0, 1, 2])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def check_4gpu(shape, scatter_dim, dtype, fp8_format):
|
||||
world_size = dist.get_world_size()
|
||||
input_tensor = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_tensor_list = list(torch.chunk(input_tensor, world_size, scatter_dim))
|
||||
input_tensor_list = [x.contiguous() for x in input_tensor_list]
|
||||
output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list]
|
||||
output_tensor_list = [torch.empty_like(x) for x in input_tensor_list]
|
||||
all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group())
|
||||
assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_to_all():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_to_all()
|
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
dist.all_to_all_single
|
||||
|
||||
|
||||
@parameterize("shape", [(4), (8, 7), (4, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def check_4gpu(shape, dtype, fp8_format):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
output = torch.empty_like(x)
|
||||
output_fp8 = torch.empty_like(x)
|
||||
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), fp8_format=fp8_format)
|
||||
dist.all_to_all_single(output, x, group=_get_default_group())
|
||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_to_all_single():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_to_all_single()
|
|
@ -32,9 +32,9 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_gather():
|
||||
def test_all_gather_flat():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_gather()
|
||||
test_all_gather_flat()
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize(
|
||||
"shape",
|
||||
[
|
||||
(3, 7),
|
||||
(4, 7),
|
||||
(7, 4),
|
||||
(8, 9),
|
||||
(3),
|
||||
(7,),
|
||||
(8,),
|
||||
],
|
||||
)
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def check_4gpu(shape, dtype, fp8_format):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
x_fp8 = x.clone()
|
||||
dist.all_reduce(x)
|
||||
all_reduce_fp8(x_fp8, fp8_format=fp8_format)
|
||||
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
dist.all_reduce(x, op=dist.ReduceOp.AVG)
|
||||
all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format)
|
||||
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_reduce():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_reduce()
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import gather_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize(
|
||||
"shape",
|
||||
[
|
||||
(3, 7),
|
||||
(2, 1),
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
(4, 2),
|
||||
(5,),
|
||||
(4,),
|
||||
(2,),
|
||||
],
|
||||
)
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def check_4gpu(shape, dtype, fp8_format):
|
||||
world_size = dist.get_world_size()
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
output_list = [torch.empty_like(x) for _ in range(world_size)]
|
||||
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
|
||||
gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format)
|
||||
dist.all_gather(output_list, x, group=_get_default_group())
|
||||
assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_gather():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_gather()
|
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
from torch.distributed import reduce_scatter
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import reduce_scatter_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize("shape", [(16, 8, 4)])
|
||||
@parameterize("scatter_dim", [0, 1, 2])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def check_4gpu(shape, scatter_dim, dtype, fp8_format):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4))
|
||||
input_list = [t.contiguous() for t in input_list]
|
||||
output_origin = torch.empty_like(input_list[0])
|
||||
output_fp8 = torch.empty_like(input_list[0])
|
||||
reduce_scatter(output_origin, input_list, group=_get_default_group())
|
||||
reduce_scatter_fp8(output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format)
|
||||
assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_reduce_scatter():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_reduce_scatter()
|
Loading…
Reference in New Issue