improved allgather & reducescatter for 3d

pull/2295/head
zbian 2023-01-03 15:26:47 +08:00 committed by アマデウス
parent c719798abe
commit e94c79f15b
4 changed files with 43 additions and 29 deletions

View File

@ -3,12 +3,17 @@
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch import Tensor
from torch.distributed import ReduceOp
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
_all_gather_func = dist._all_gather_base \
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
_reduce_scatter_func = dist._reduce_scatter_base \
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
r"""Gathers all tensors from the parallel group and concatenates them in a
@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
out = tensor
work = None
else:
shape = list(tensor.shape)
shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
temp = list(torch.chunk(out, depth, dim=0))
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:]
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.all_gather(tensor_list=temp,
tensor=tensor.transpose(0, dim).contiguous(),
group=group,
async_op=async_op)
out = torch.transpose(out, 0, dim)
work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op)
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
if async_op:
return out, work
else:
@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor,
out = tensor
work = None
else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:]
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op)
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
if async_op:
return out, work
else:
@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
r"""Modified from `torch.distributed.scatter_object_list
<https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
"""
if dist.distributed_c10d._rank_not_in_group(group):
return

View File

@ -34,7 +34,7 @@ class _Linear3D(torch.autograd.Function):
ctx.output_parallel_mode = output_parallel_mode
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight, -1, weight_parallel_mode)
weight = all_gather(weight, 0, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
@ -53,7 +53,7 @@ class _Linear3D(torch.autograd.Function):
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_op.wait()
@ -205,7 +205,7 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
ctx.weight_id = weight_id
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)

View File

@ -196,8 +196,8 @@ class Linear3D(ParallelLayer):
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.skip_bias_add = skip_bias_add
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2)
self.in_features_per_partition = divide(in_features, self.depth**2)
self.out_features_per_partition = divide(out_features, self.depth)
self.bias_features_per_partition = divide(out_features, self.depth)
self.weight = Parameter(
@ -287,7 +287,7 @@ class Linear3D(ParallelLayer):
local_state,
self.weight_parallel_mode,
dims={
weight_key: -1,
weight_key: 0,
bias_key: 0
},
partition_states={
@ -310,7 +310,7 @@ class Linear3D(ParallelLayer):
local_state,
self.weight_parallel_mode,
dims={
weight_key: -1,
weight_key: 0,
bias_key: 0
},
partition_states={

View File

@ -4,12 +4,23 @@
import time
import torch
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context
from colossalai.logging import get_dist_logger
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D,
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D,
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D)
from colossalai.nn import (
Classifier3D,
CrossEntropyLoss3D,
Embedding3D,
LayerNorm3D,
Linear3D,
PatchEmbedding3D,
VanillaClassifier,
VanillaPatchEmbedding,
VocabParallelClassifier3D,
VocabParallelCrossEntropyLoss3D,
VocabParallelEmbedding3D,
)
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.utils import get_current_device, print_rank_0
@ -40,7 +51,7 @@ def check_linear():
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
@ -93,7 +104,7 @@ def check_linear():
B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
@ -775,7 +786,7 @@ def check_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
@ -828,7 +839,7 @@ def check_vocab_parallel_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]