mirror of https://github.com/hpcaitech/ColossalAI
improved allgather & reducescatter for 3d
parent
c719798abe
commit
e94c79f15b
|
@ -3,12 +3,17 @@
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
from torch import Tensor
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
_all_gather_func = dist._all_gather_base \
|
||||
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
|
||||
_reduce_scatter_func = dist._reduce_scatter_base \
|
||||
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
||||
r"""Gathers all tensors from the parallel group and concatenates them in a
|
||||
|
@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
|
|||
out = tensor
|
||||
work = None
|
||||
else:
|
||||
shape = list(tensor.shape)
|
||||
shape[0], shape[dim] = shape[dim], shape[0]
|
||||
shape[0] *= depth
|
||||
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
|
||||
temp = list(torch.chunk(out, depth, dim=0))
|
||||
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
|
||||
out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:]
|
||||
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||
work = dist.all_gather(tensor_list=temp,
|
||||
tensor=tensor.transpose(0, dim).contiguous(),
|
||||
group=group,
|
||||
async_op=async_op)
|
||||
out = torch.transpose(out, 0, dim)
|
||||
work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op)
|
||||
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
|
||||
if async_op:
|
||||
return out, work
|
||||
else:
|
||||
|
@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor,
|
|||
out = tensor
|
||||
work = None
|
||||
else:
|
||||
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
||||
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
|
||||
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
|
||||
out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:]
|
||||
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
|
||||
work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op)
|
||||
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
|
||||
if async_op:
|
||||
return out, work
|
||||
else:
|
||||
|
@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
|
|||
|
||||
|
||||
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
|
||||
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
||||
r"""Modified from `torch.distributed.scatter_object_list
|
||||
<https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
||||
"""
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
return
|
||||
|
|
|
@ -34,7 +34,7 @@ class _Linear3D(torch.autograd.Function):
|
|||
ctx.output_parallel_mode = output_parallel_mode
|
||||
|
||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||
weight = all_gather(weight, -1, weight_parallel_mode)
|
||||
weight = all_gather(weight, 0, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
|
@ -53,7 +53,7 @@ class _Linear3D(torch.autograd.Function):
|
|||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
input_op.wait()
|
||||
|
@ -205,7 +205,7 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
|
|||
ctx.weight_id = weight_id
|
||||
|
||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
|
||||
weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
|
|
|
@ -196,8 +196,8 @@ class Linear3D(ParallelLayer):
|
|||
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(out_features, self.depth**2)
|
||||
self.in_features_per_partition = divide(in_features, self.depth**2)
|
||||
self.out_features_per_partition = divide(out_features, self.depth)
|
||||
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
|
@ -287,7 +287,7 @@ class Linear3D(ParallelLayer):
|
|||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
|
@ -310,7 +310,7 @@ class Linear3D(ParallelLayer):
|
|||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
|
|
|
@ -4,12 +4,23 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.core import global_context
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D,
|
||||
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D,
|
||||
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D)
|
||||
from colossalai.nn import (
|
||||
Classifier3D,
|
||||
CrossEntropyLoss3D,
|
||||
Embedding3D,
|
||||
LayerNorm3D,
|
||||
Linear3D,
|
||||
PatchEmbedding3D,
|
||||
VanillaClassifier,
|
||||
VanillaPatchEmbedding,
|
||||
VocabParallelClassifier3D,
|
||||
VocabParallelCrossEntropyLoss3D,
|
||||
VocabParallelEmbedding3D,
|
||||
)
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
|
@ -40,7 +51,7 @@ def check_linear():
|
|||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
|
@ -93,7 +104,7 @@ def check_linear():
|
|||
B_grad = layer_master.weight.grad.transpose(0, 1)
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
|
@ -775,7 +786,7 @@ def check_loss():
|
|||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
|
@ -828,7 +839,7 @@ def check_vocab_parallel_loss():
|
|||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
|
|
Loading…
Reference in New Issue