[WIP] Applying ColoTensor on TP-1D-row Linear. (#831)

* revert zero tensors back

* [tensor] init row 1d linear
pull/835/head
Jiarui Fang 2022-04-22 14:03:26 +08:00 committed by GitHub
parent 595bedf767
commit ac88de6dfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 8 deletions

View File

@ -19,12 +19,18 @@ def colo_linear(types, args, kwargs, pg):
bias = None
else:
bias = kwargs.get('bias', None)
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
if weight.shard_spec == None:
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
elif weight.shard_spec == '1Drow':
# TODO(jzy): implement 1Drow TP linear here.
raise NotImplementedError
else:
raise NotImplementedError
else:
return torch.nn.functional.linear(input_tensor, weight, bias)

View File

@ -1,6 +1,6 @@
import torch
from .op_wrapper import _COLOSSAL_OPS
from typing import Tuple
from typing import Tuple, Optional
class ColoTensor(object):
@ -21,20 +21,35 @@ class ColoTensor(object):
requires_grad=False,
pin_memory=False,
torch_tensor=torch.empty(0),
shard_spec: str = None,
):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._torch_tensor = torch_tensor
self._shard_spec = shard_spec
@property
def shard_spec(self) -> Optional[str]:
return self._shard_spec
@property
def data(self):
return self._torch_tensor.data
@property
def grad(self):
return self._torch_tensor.grad
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor):
def init_from_torch_tensor(tensor: torch.Tensor, shard_spec: str = None) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.pin_memory,
torch_tensor=tensor)
torch_tensor=tensor,
shard_spec=shard_spec)
return colo_t
def del_torch_tensor(self) -> None:
@ -67,7 +82,5 @@ class ColoTensor(object):
if kwargs is None:
kwargs = {}
kwargs = {
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)

View File

@ -0,0 +1,74 @@
from joblib import Parallel
from numpy import allclose, require
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ColoTensor
from copy import deepcopy
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
def run_linear_tp1d_row_test():
in_dim = 4
out_dim = 5
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
fc_ref = deepcopy(fc)
input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone()
# sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight, "1Drow")
# shard weight at begiin
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
sharded_weight = ColoTensor(in_dim / world_size, out_dim, shard_spec="1Drow")
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight')
setattr(fc, 'weight', sharded_weight)
delattr(fc, 'bias')
setattr(fc, 'bias', sharded_bias)
fc.weight.requires_grad = True
fc.bias.requires_grad = True
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
out = fc(input_tensor)
loss = out.sum()
loss.backward()
out_ref = fc_ref(input_ref)
loss_ref = out_ref.sum()
loss_ref.backward()
assert (loss_ref == loss)
assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_linear_tp1d_row_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_linear_1d(4)