[colotensor] add Tensor.view op and its unit test (#1343)

[colotensor] add megatron initialization for gpt2
pull/1352/head
HELSON 2022-07-21 10:53:15 +08:00 committed by GitHub
parent 6160a1d6a7
commit 7a8702c06d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 309 additions and 79 deletions

View File

@ -5,3 +5,4 @@ from .loss import colo_cross_entropy
from .embedding import colo_embedding
from .addmm import colo_addmm
from .embedding_bag import colo_embedding_bag
from .view import colo_view

View File

@ -69,7 +69,9 @@ def colo_addmm(input_tensor: GeneralTensor,
if not mat2.has_compute_spec(): # No Model Parallel Applied
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
ret_tensor = ColoTensor.from_torch_tensor(
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha),
spec=ColoTensorSpec(mat2.get_process_group()))
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate():
mode = 'row'

View File

@ -1,7 +1,8 @@
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \
ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
@ -110,17 +111,18 @@ def colo_embedding(input_tensor: GeneralTensor,
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
if not weight.has_compute_spec(): # No Model Parallel Applied
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
tensor=F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse),
spec=ColoTensorSpec(weight.get_process_group()))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1drow():
mode = 'row'
elif weight.is_shard_1dcol():

View File

@ -2,7 +2,8 @@ import torch.nn.functional as F
from typing import Optional
from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \
ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor
@ -89,21 +90,22 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
# Handle differen parallel actions.
if not weight.has_compute_spec(): # No Model Parallel Applied
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
tensor=F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx),
spec=ColoTensorSpec(weight.get_process_group()))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1dcol():
tp_mode = 'col'
else:

View File

@ -19,5 +19,9 @@ def colo_layernorm(
input_tensor = input_tensor.redistribute(ReplicaSpec())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
output = ColoTensor.from_torch_tensor(
tensor=output,
spec=ColoTensorSpec(
pg=input_tensor.get_process_group(),
dist_attr=input_tensor.dist_spec))
return output

View File

@ -0,0 +1,97 @@
import math
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
from typing import Optional, Union
def _all_int(my_iter):
return all(isinstance(i, int) for i in my_iter)
def _get_valid_shape(shape):
if isinstance(shape, list):
if _all_int(shape):
return tuple(shape)
else:
raise RuntimeError("expects type(int) but finds an other type")
elif isinstance(shape, tuple):
if _all_int(shape):
return shape
else:
return _get_valid_shape(shape[0])
else:
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
def _shape_infer(org_sp, tgt_sp):
cnt = 0
pos = 0
for idx, dim in enumerate(tgt_sp):
if dim < -1:
raise RuntimeError("invalid shape dimension {}".format(dim))
elif dim == -1:
cnt += 1
pos = idx
if cnt > 1:
raise RuntimeError("only one dimension can be inferred")
org_prod = math.prod(org_sp)
tgt_prod = math.prod(tgt_sp)
if cnt == 0:
if org_prod != tgt_prod:
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
else:
return tgt_sp
elif org_prod % tgt_prod != 0:
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
infer_dim = -(org_prod // tgt_prod)
return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:]
@colo_op_impl(torch.Tensor.view)
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
Changes the shape of the current tensor.
"""
assert isinstance(self, ColoTensor)
# apply original `view` function for replicated colo tensors
if self.is_replicate():
return self.view(*shape)
cur_sp = self.size()
org_sp = self.size_global()
# parse the passed arguments
tgt_sp = _get_valid_shape(shape)
# get the correct shape from inference
inf_sp = _shape_infer(org_sp, tgt_sp)
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
new_shape = (cur_sp[0],) + tgt_sp[1:]
res = self.view(*new_shape)
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
new_shape = tgt_sp[:-1] + (cur_sp[-1],)
res = self.view(*new_shape)
else:
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return ColoTensor.from_torch_tensor(
tensor=replicated_t.view(*shape),
spec=ColoTensorSpec(self.get_process_group()))
return ColoTensor.from_torch_tensor(
tensor=res,
spec=ColoTensorSpec(
pg=self.get_process_group(),
dist_attr=self.dist_spec))
@colo_op_impl(torch.Tensor.size)
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
size = self.size_global()
if dim is None:
return size
else:
return size[dim]

View File

@ -22,28 +22,30 @@ def _get_my_nowrap_functions() -> Set[Callable]:
}
def _convert_output(output, pg: ProcessGroup):
def _convert_output(output, colo_spec: ColoTensorSpec):
if type(output) == torch.Tensor:
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
return ColoTensor.from_torch_tensor(output, colo_spec)
elif isinstance(output, (list, tuple)):
return type(output)(_convert_output(o, pg) for o in output)
return type(output)(_convert_output(o, colo_spec) for o in output)
else:
return output
def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
for elem in args:
if isinstance(elem, ColoTensor):
pg = elem.get_process_group()
return pg
dp = elem.dist_spec
return ColoTensorSpec(pg, dp)
elif isinstance(elem, (list, tuple)):
pg = _scan_for_pg_from_args(elem, {})
if pg is not None:
return pg
spec = _get_spec_from_args(elem, {})
if spec is not None:
return spec
for k, v in kwargs.items():
if isinstance(v, ColoTensor):
pg = v.get_process_group()
return pg
dp = v.dist_spec
return ColoTensorSpec(pg, dp)
return None
@ -170,11 +172,11 @@ class ColoTensor(torch.Tensor):
if func in _get_my_nowrap_functions():
return ret
else:
pg = _scan_for_pg_from_args(args, kwargs)
return _convert_output(ret, pg)
colo_spec = _get_spec_from_args(args, kwargs)
return _convert_output(ret, colo_spec)
def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
@ -243,50 +245,32 @@ class ColoTensor(torch.Tensor):
memo[id(self)] = tensor
return tensor
##### override builtin functions which must use tensor in replicate placement ####
# override builtin functions which must use tensor in replicate placement #
def view_local(self, *args) -> 'ColoTensor':
return super().view(*args)
def size_local(self, *args) -> torch.Size:
with torch._C.DisableTorchFunction():
return super().size(*args)
def size_local(self, *args, **kwargs) -> torch.Size:
return super().size(*args, **kwargs)
def view_global(self, *args) -> 'ColoTensor':
"""override the torch buildin view()
the args passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
"""
if self.is_replicate():
return super().view(*args)
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None) -> torch.Size:
def size_global(self, *args) -> torch.Size:
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
"""
if self.is_replicate():
if args is not None:
return super().size(args)
else:
return super().size()
return self.size_local(*args)
spec = self.dist_spec
dims = spec.dims
num_partitions = spec.num_partitions
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
size_list = list(super().size())
size_list = list(self.size_local())
for dim, num_partition in zip(dims, num_partitions):
size_list[dim] *= num_partition
if args is not None:
return size_list[args]
else:
if args == ():
return torch.Size(size_list)
else:
return size_list[args[0]]
# Some API for dist spec check

View File

@ -22,4 +22,7 @@ class ComputeSpec(object):
self.output_replicate = True
def __repr__(self):
return f'compute pattern: {self.compute_pattern}'
return f'Compute pattern: {self.compute_pattern}'
def set_output_replicate(self, flag: bool = True):
self.output_replicate = flag

View File

@ -1,7 +1,7 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
def gather_tensor(colo_tensor: ColoTensor) -> None:
@ -26,7 +26,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
if dist_spec.placement == 'r':
if dist_spec.placement == DistPlacementPattern.REPLICATE:
dist.broadcast(colo_tensor.data, 0)
else:
global_size = colo_tensor.size_global()

View File

@ -73,3 +73,9 @@ def split_param_row_tp1d(param, pg):
def split_param_col_tp1d(param, pg):
split_param_single_dim_tp1d(-1, param, pg)
def debug_print(ranks, *args):
if dist.get_rank() in ranks:
print(*args)
dist.barrier()

View File

@ -75,7 +75,7 @@ def _run_view(world_size):
assert t.size_global(1) == 5
assert t.size_global() == torch.Size([4 * world_size, 5])
t = t.view_global(4 * 5 * world_size)
t = t.view(4 * 5 * world_size)
assert t.shape == torch.Size([4 * 5 * world_size])
@ -129,9 +129,9 @@ def _run_set_tensor_spec(world_size):
spec1 = ColoTensorSpec(pg)
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
dist_spec2 = (ShardSpec([-1], [pg.tp_world_size()]), None)
dist_spec2 = ShardSpec([-1], [pg.tp_world_size()])
assert t1.is_replicate()
t1.set_dist_spec(*dist_spec2)
t1.set_dist_spec(dist_spec2)
assert t1.is_shard_1dcol()

View File

@ -15,6 +15,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.data_parallel import ColoDDP
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, debug_print
def init_1d_row_spec(model, pg: ProcessGroup):
@ -34,6 +35,32 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p.set_tensor_spec(*spec)
def init_megatron_spec(model, pg: ProcessGroup):
for mn, module in model.named_modules():
# debug_print([0], mn)
for pn, param in module.named_parameters(recurse=False):
# debug_print([0], '\t', pn, param.compute_spec, param.shape)
param.set_process_group(pg)
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg)
param.compute_spec.set_output_replicate(False)
else:
raise RuntimeError
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg)
else:
assert 'bias' in pn
elif 'wte' in mn or 'wpe' in mn:
assert 'weight' in pn
split_param_col_tp1d(param, pg)
elif 'c_fc' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg)
# debug_print([0], '\t', param.compute_spec, param.shape)
def check_param_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
@ -102,8 +129,10 @@ def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1:
return
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt(init_1d_row_spec, use_ddp)
run_gpt(init_1d_col_spec, use_ddp)
# Comments below tests for speed concern
# run_gpt(init_1d_row_spec, use_ddp)
# run_gpt(init_1d_col_spec, use_ddp)
run_gpt(init_megatron_spec, use_ddp)
@pytest.mark.dist
@ -116,4 +145,4 @@ def test_gpt(world_size, use_ddp):
if __name__ == '__main__':
test_gpt(4, use_ddp=True)
test_gpt(4, use_ddp=False)

View File

@ -0,0 +1,100 @@
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec
from colossalai.tensor.distspec import DistPlacementPattern
from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print
def exam_view_core(pg):
# the case of replicated ColoTensors
x = torch.randn(4, 4).cuda()
x_colo = ColoTensor(x, ColoTensorSpec(pg))
y = x.view(2, -1, 2)
y_colo = x_colo.view(2, -1, 2)
assert torch.all(y == y_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
# the perfect case of col-sliced ColoTensors
split_param_col_tp1d(x_colo, pg)
z = x.view(torch.Size((2, 1, 2, -1)))
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
if dist.get_rank() == 0:
z = z[:, :, :, 0:2]
else:
z = z[:, :, :, 2:]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the perfect case of row-sliced ColoTensors
split_param_row_tp1d(x_colo, pg)
z = x.view(torch.Size((-1, 2, 2)))
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
if dist.get_rank() == 0:
z = z[0:2, :, :]
else:
z = z[2:, :, :]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the normal case of row-sliced ColoTensors
z = x.view(-1, 2, 2, 2)
z_colo = x_colo.view(-1, 2, 2, 2)
assert torch.all(z == z_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
def exam_view_autograd(pg):
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
with torch.no_grad():
y.copy_(x)
y = ColoTensor(y, ColoTensorSpec(pg))
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
xx = x.view(2, 2, -1)
yy_slice = y_slice.view(2, 2, -1)
yy = yy_slice.to_replicate()
grad = torch.randn(2, 2, 4, device=get_current_device())
xx.backward(grad)
yy.backward(grad)
assert torch.all(x.grad == y.grad)
def exam_view_errors(pg):
x = torch.randn(8, 2, device=get_current_device())
x = ColoTensor(x, ColoTensorSpec(pg))
split_param_row_tp1d(x, pg)
x.view('a', 'b', 'c')
x.view(8, -1)
x.view([-2, -2, -2])
x.view((-1, -1, -1))
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
exam_view_core(pg)
exam_view_autograd(pg)
# exam_view_errors(pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_view(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_view(2)

View File

@ -11,7 +11,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from tests.test_tensor._utils import tensor_shard_equal
from tests.test_tensor.common_utils import tensor_shard_equal
def run_dist(rank, world_size, port, dp_degree, tp_degree):
@ -24,7 +24,7 @@ def run_dist(rank, world_size, port, dp_degree, tp_degree):
gather_tensor(param)
if dist.get_rank() == 0:
assert torch.allclose(x, param.data, rtol=0, atol=0)
assert torch.all(x == param)
else:
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
dist.barrier()

View File

@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from functools import partial
from tests.test_tensor._utils import set_seed
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize
from colossalai.nn.optimizer import HybridAdam

View File

@ -9,7 +9,7 @@ from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.core import global_context as gpc
from functools import partial
from tests.test_tensor._utils import set_seed
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.gemini import ChunkManager, GeminiManager