[Tensor] TP Linear 1D row (#843)

pull/847/head
Ziyue Jiang 3 years ago committed by GitHub
parent cf6d1c9284
commit 05023ecfee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,10 @@
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.context import ParallelMode
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input
from packaging import version
@colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
@ -19,12 +20,31 @@ def colo_linear(types, args, kwargs, pg):
bias = None
else:
bias = kwargs.get('bias', None)
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
if weight.shard_spec == None:
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
elif weight.shard_spec == '1Drow':
"""
Input:S[1] x Weight:S[0] = Output:P
All-Reduce(Output) + bias = res
"""
# Input:S[1]
input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1)
# Output:P
partial_output = torch.nn.functional.linear(input_per_partition, weight.torch_tensor())
# Reduce(Output)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias
if bias is not None:
output = output + bias
return output
else:
raise NotImplementedError
else:
return torch.nn.functional.linear(input_tensor, weight, bias)

@ -4,7 +4,6 @@ from typing import Tuple
import numpy
from .op_wrapper import _COLOSSAL_OPS
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
@ -24,6 +23,7 @@ class ColoTensor(object):
pin_memory=False,
device=None,
torch_tensor=torch.empty(0),
shard_spec: str = None,
):
self._size = size
self._dtype = dtype
@ -31,11 +31,29 @@ class ColoTensor(object):
self._pin_memory = pin_memory
self._device = device
self._torch_tensor = torch_tensor
self._shard_spec = shard_spec
@property
def shard_spec(self) -> Optional[str]:
return self._shard_spec
@property
def data(self):
return self._torch_tensor.data
@property
def grad(self):
return self._torch_tensor.grad
@property
def size(self):
return self._size
def numel(self):
return product(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,

@ -0,0 +1,91 @@
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ColoTensor
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import torch.distributed as dist
from test_tensor_utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
def run_linear_tp1d_row_test():
device = get_current_device()
dtype = torch.float32
DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D)
in_features = 4
out_features = 5
local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer_master = torch.nn.Linear(in_features, out_features)
layer = torch.nn.Linear(in_features, out_features)
A_shape = (2, in_features)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A = broadcast_tensor_chunk(A_master, chunk_size=1)
A.requires_grad = True
W_shape = (out_features, in_features)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
W = broadcast_tensor_chunk(W_master, chunk_size=DEPTH, local_rank=local_rank)
W.requires_grad = True
B_shape = (out_features)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
B = broadcast_tensor_chunk(B_master, chunk_size=1)
B.requires_grad = True
# replace the torch nn.Parameters with ColoTensor
sharded_weight = ColoTensor.init_from_torch_tensor(W)
sharded_weight._shard_spec = "1Drow"
sharded_bias = ColoTensor.init_from_torch_tensor(B)
replace_parameter_add_grad(layer, sharded_weight, sharded_bias)
out = layer(A)
replace_parameter_add_grad(layer_master, W_master, B_master)
A_master.requires_grad = True
#C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C_master = layer_master(A_master)
C = C_master.clone()
check_equal(out, C)
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad = broadcast_tensor_chunk(grad_master, chunk_size=1)
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[local_rank]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
check_equal(B_grad, layer.bias.grad)
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_linear_tp1d_row_test()
@pytest.mark.dist
@parameterize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_linear_1d()

@ -0,0 +1,20 @@
import torch
import torch.distributed as dist
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True
def replace_parameter_add_grad(layer, weight=None, bias=None):
if weight is not None:
delattr(layer, 'weight')
setattr(layer, 'weight', weight)
layer.weight.requires_grad = True
if bias is not None:
delattr(layer, 'bias')
setattr(layer, 'bias', bias)
layer.bias.requires_grad = True
def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
dist.broadcast(tensor, src=0)
tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank]
return tensor_chunk.clone()
Loading…
Cancel
Save