mirror of https://github.com/hpcaitech/ColossalAI
[unit test] refactor test tensor (#1005)
* polish test_gpt * update op unit tests * update test modelpull/1003/head
parent
ad536e308e
commit
8e3d0ad8f1
|
@ -1 +1 @@
|
||||||
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net
|
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net, gpt
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .registry import non_distributed_component_funcs
|
||||||
|
from transformers import GPT2Config, GPT2LMHeadModel
|
||||||
|
from .utils.dummy_data_generator import DummyDataGenerator
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataLoader(DummyDataGenerator):
|
||||||
|
vocab_size = 50304
|
||||||
|
batch_size = 4
|
||||||
|
seq_len = 1024
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
input_ids = torch.randint(0,
|
||||||
|
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
|
||||||
|
device=get_current_device())
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
return input_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
class GPTLMModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size=768,
|
||||||
|
num_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
max_seq_len=1024,
|
||||||
|
vocab_size=50304,
|
||||||
|
checkpoint=False):
|
||||||
|
super().__init__()
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
self.model = GPT2LMHeadModel(
|
||||||
|
GPT2Config(n_embd=hidden_size,
|
||||||
|
n_layer=num_layers,
|
||||||
|
n_head=num_attention_heads,
|
||||||
|
n_positions=max_seq_len,
|
||||||
|
n_ctx=max_seq_len,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
resid_pdrop=0.0,
|
||||||
|
embd_pdrop=0.0,
|
||||||
|
attn_pdrop=0.0))
|
||||||
|
if checkpoint:
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
# Only return lm_logits
|
||||||
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_s(checkpoint=True):
|
||||||
|
return GPTLMModel(checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_m(checkpoint=True):
|
||||||
|
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTLMLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def forward(self, logits, labels):
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
|
||||||
|
@non_distributed_component_funcs.register(name='gpt2')
|
||||||
|
def get_training_components():
|
||||||
|
|
||||||
|
trainloader = DummyDataLoader()
|
||||||
|
testloader = DummyDataLoader()
|
||||||
|
|
||||||
|
criterion = GPTLMLoss()
|
||||||
|
return gpt2_s, trainloader, testloader, torch.optim.Adam, criterion
|
|
@ -1,5 +1,19 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
def check_equal(A, B):
|
def check_equal(A, B):
|
||||||
|
@ -25,3 +39,19 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
|
||||||
|
|
||||||
def tensor_equal(A, B):
|
def tensor_equal(A, B):
|
||||||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
||||||
|
assert tensor.ndim == shard.ndim
|
||||||
|
if tensor.shape == shard.shape:
|
||||||
|
return tensor_equal(tensor, shard)
|
||||||
|
else:
|
||||||
|
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
||||||
|
if dims_not_eq.numel() == 1:
|
||||||
|
# 1D shard
|
||||||
|
dim = dims_not_eq.item()
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
|
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from _utils import tensor_shard_equal, tensor_equal
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
|
@ -45,13 +46,6 @@ def init_1d_row(weight, bias):
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def check_grad_1d_row(model: torch.nn.Module, weight, bias):
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
|
|
||||||
assert torch.allclose(model.bias.grad, bias.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col(weight, bias):
|
def init_1d_col(weight, bias):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
@ -61,14 +55,7 @@ def init_1d_col(weight, bias):
|
||||||
bias.set_spec(spec)
|
bias.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def check_grad_1d_col(model: torch.nn.Module, weight, bias):
|
def run_with_spec(spec_init_func):
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
|
|
||||||
assert torch.allclose(model.bias.grad.chunk(size, -1)[rank], bias.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def run_with_spec(spec_init_func, check_grad_func):
|
|
||||||
model = Conv1D(4, 16).cuda()
|
model = Conv1D(4, 16).cuda()
|
||||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
||||||
|
@ -76,18 +63,19 @@ def run_with_spec(spec_init_func, check_grad_func):
|
||||||
x = torch.rand(2, 16).cuda()
|
x = torch.rand(2, 16).cuda()
|
||||||
out = model(x)
|
out = model(x)
|
||||||
colo_out = torch.addmm(bias, x, weight)
|
colo_out = torch.addmm(bias, x, weight)
|
||||||
assert torch.allclose(out, colo_out)
|
assert tensor_equal(out, colo_out)
|
||||||
grad = torch.rand_like(out)
|
grad = torch.rand_like(out)
|
||||||
out.backward(grad)
|
out.backward(grad)
|
||||||
colo_out.backward(grad)
|
colo_out.backward(grad)
|
||||||
check_grad_func(model, weight, bias)
|
tensor_shard_equal(model.weight.grad, weight.grad)
|
||||||
|
tensor_shard_equal(model.bias.grad, bias.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_with_spec(init_1d_row, check_grad_1d_row)
|
run_with_spec(init_1d_row)
|
||||||
run_with_spec(init_1d_col, check_grad_1d_col)
|
run_with_spec(init_1d_col)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
|
@ -12,6 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
||||||
|
from _utils import tensor_equal, tensor_shard_equal
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row(weight):
|
def init_1d_row(weight):
|
||||||
|
@ -22,12 +23,6 @@ def init_1d_row(weight):
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def check_grad_1d_row(model: torch.nn.Module, weight):
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col(weight):
|
def init_1d_col(weight):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
@ -36,31 +31,25 @@ def init_1d_col(weight):
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def check_grad_1d_col(model: torch.nn.Module, weight):
|
def run_with_spec(spec_init_func):
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def run_with_spec(spec_init_func, check_grad_func):
|
|
||||||
model = torch.nn.Embedding(12, 32).cuda()
|
model = torch.nn.Embedding(12, 32).cuda()
|
||||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||||
spec_init_func(weight)
|
spec_init_func(weight)
|
||||||
x = torch.tensor((0, 3, 6, 9)).cuda()
|
x = torch.tensor((0, 3, 6, 9)).cuda()
|
||||||
out = model(x)
|
out = model(x)
|
||||||
colo_out = F.embedding(x, weight)
|
colo_out = F.embedding(x, weight)
|
||||||
assert torch.allclose(out, colo_out)
|
assert tensor_equal(out, colo_out)
|
||||||
grad = torch.rand_like(out)
|
grad = torch.rand_like(out)
|
||||||
out.backward(grad)
|
out.backward(grad)
|
||||||
colo_out.backward(grad)
|
colo_out.backward(grad)
|
||||||
check_grad_func(model, weight)
|
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_with_spec(init_1d_row, check_grad_1d_row)
|
run_with_spec(init_1d_row)
|
||||||
run_with_spec(init_1d_col, check_grad_1d_col)
|
run_with_spec(init_1d_col)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
|
@ -1,142 +1,16 @@
|
||||||
import pytest
|
import pytest
|
||||||
import colossalai
|
import colossalai
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from transformers import GPT2Config, GPT2LMHeadModel
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils import ColoInitContext
|
from colossalai.utils import ColoInitContext
|
||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager, distspec
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from functools import partial
|
from functools import partial
|
||||||
# Hack huggingface Bert ModelOutput
|
from _utils import tensor_equal, tensor_shard_equal
|
||||||
# Make it available to our ColoTensor
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from transformers.file_utils import ModelOutput
|
|
||||||
from dataclasses import fields
|
|
||||||
from tests.test_tensor._utils import tensor_equal
|
|
||||||
|
|
||||||
|
|
||||||
def _post_init_colotensor(self):
|
|
||||||
class_fields = fields(self)
|
|
||||||
# Safety and consistency checks
|
|
||||||
if len(class_fields) == 0:
|
|
||||||
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
|
||||||
if not all(field.default is None for field in class_fields[1:]):
|
|
||||||
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
|
|
||||||
|
|
||||||
first_field = getattr(self, class_fields[0].name)
|
|
||||||
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
|
||||||
|
|
||||||
def is_tensor_with_colo(x):
|
|
||||||
"""
|
|
||||||
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
|
|
||||||
"""
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return isinstance(x, ColoTensor)
|
|
||||||
|
|
||||||
if other_fields_are_none and not is_tensor_with_colo(first_field):
|
|
||||||
if isinstance(first_field, dict):
|
|
||||||
iterator = first_field.items()
|
|
||||||
first_field_iterator = True
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
iterator = iter(first_field)
|
|
||||||
first_field_iterator = True
|
|
||||||
except TypeError:
|
|
||||||
first_field_iterator = False
|
|
||||||
|
|
||||||
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
|
||||||
# set the associated fields
|
|
||||||
if first_field_iterator:
|
|
||||||
for element in iterator:
|
|
||||||
if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
|
|
||||||
break
|
|
||||||
setattr(self, element[0], element[1])
|
|
||||||
if element[1] is not None:
|
|
||||||
self[element[0]] = element[1]
|
|
||||||
elif first_field is not None:
|
|
||||||
self[class_fields[0].name] = first_field
|
|
||||||
else:
|
|
||||||
for field in class_fields:
|
|
||||||
v = getattr(self, field.name)
|
|
||||||
if v is not None:
|
|
||||||
self[field.name] = v
|
|
||||||
|
|
||||||
|
|
||||||
ModelOutput.__post_init__ = _post_init_colotensor
|
|
||||||
|
|
||||||
|
|
||||||
class GPTLMModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
hidden_size=768,
|
|
||||||
num_layers=12,
|
|
||||||
num_attention_heads=12,
|
|
||||||
max_seq_len=1024,
|
|
||||||
vocab_size=50304,
|
|
||||||
checkpoint=False):
|
|
||||||
super().__init__()
|
|
||||||
self.checkpoint = checkpoint
|
|
||||||
self.model = GPT2LMHeadModel(
|
|
||||||
GPT2Config(n_embd=hidden_size,
|
|
||||||
n_layer=num_layers,
|
|
||||||
n_head=num_attention_heads,
|
|
||||||
n_positions=max_seq_len,
|
|
||||||
n_ctx=max_seq_len,
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
resid_pdrop=0.0,
|
|
||||||
embd_pdrop=0.0,
|
|
||||||
attn_pdrop=0.0))
|
|
||||||
if checkpoint:
|
|
||||||
self.model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask):
|
|
||||||
# Only return lm_logits
|
|
||||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
|
||||||
|
|
||||||
|
|
||||||
def gpt2_s(checkpoint=True):
|
|
||||||
return GPTLMModel(checkpoint=checkpoint)
|
|
||||||
|
|
||||||
|
|
||||||
def gpt2_m(checkpoint=True):
|
|
||||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTLMLoss(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.loss_fn = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def forward(self, logits, labels):
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
|
|
||||||
|
|
||||||
def get_data(batch_size, seq_len, vocab_size):
|
|
||||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
return input_ids, attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_spec(model):
|
def init_1d_row_spec(model):
|
||||||
|
@ -159,30 +33,6 @@ def init_1d_col_spec(model):
|
||||||
p.set_spec(spec)
|
p.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_equal_1d(tensor: torch.Tensor, shard: ColoTensor):
|
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
assert len(shard.spec.dist_spec.dims) == 1
|
|
||||||
dim = shard.spec.dist_spec.dims[0]
|
|
||||||
assert torch.equal(tensor.chunk(world_size, dim)[rank], shard.torch_tensor())
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
|
||||||
assert tensor.ndim == shard.ndim
|
|
||||||
if tensor.shape == shard.shape:
|
|
||||||
return tensor_equal(tensor, shard)
|
|
||||||
else:
|
|
||||||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
|
||||||
if dims_not_eq.numel() == 1:
|
|
||||||
# 1D shard
|
|
||||||
dim = dims_not_eq.item()
|
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def check_param_equal(model, torch_model):
|
def check_param_equal(model, torch_model):
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
assert tensor_shard_equal(torch_p, p)
|
assert tensor_shard_equal(torch_p, p)
|
||||||
|
@ -194,23 +44,20 @@ def check_grad_equal(model, torch_model):
|
||||||
|
|
||||||
|
|
||||||
def run_gpt(init_spec_func):
|
def run_gpt(init_spec_func):
|
||||||
BATCH_SIZE = 4
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
SEQ_LEN = 1024
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
VOCAB_SIZE = 50304
|
|
||||||
NUM_STEPS = 1
|
|
||||||
criterion = GPTLMLoss()
|
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = gpt2_s()
|
model = model_builder()
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
torch_model = gpt2_s().cuda()
|
torch_model = model_builder().cuda()
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
torch_p.data.copy_(p)
|
torch_p.data.copy_(p)
|
||||||
init_spec_func(model)
|
init_spec_func(model)
|
||||||
check_param_equal(model, torch_model)
|
check_param_equal(model, torch_model)
|
||||||
model.train()
|
model.train()
|
||||||
torch_model.train()
|
torch_model.train()
|
||||||
for i in range(NUM_STEPS):
|
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
|
||||||
logits = model(input_ids, attn_mask)
|
logits = model(input_ids, attn_mask)
|
||||||
torch_logits = torch_model(input_ids, attn_mask)
|
torch_logits = torch_model(input_ids, attn_mask)
|
||||||
assert tensor_equal(torch_logits, logits)
|
assert tensor_equal(torch_logits, logits)
|
||||||
|
@ -219,6 +66,8 @@ def run_gpt(init_spec_func):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch_loss.backward()
|
torch_loss.backward()
|
||||||
check_grad_equal(model, torch_model)
|
check_grad_equal(model, torch_model)
|
||||||
|
if i > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
@ -237,4 +86,4 @@ def test_gpt(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_gpt(1)
|
test_gpt(4)
|
||||||
|
|
|
@ -13,6 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
||||||
|
from _utils import tensor_equal, tensor_shard_equal
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row(weight, bias):
|
def init_1d_row(weight, bias):
|
||||||
|
@ -23,13 +24,6 @@ def init_1d_row(weight, bias):
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def check_grad_1d_row(model: torch.nn.Module, weight, bias):
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
|
|
||||||
assert torch.allclose(model.bias.grad, bias.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col(weight, bias):
|
def init_1d_col(weight, bias):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
@ -39,14 +33,7 @@ def init_1d_col(weight, bias):
|
||||||
bias.set_spec(spec)
|
bias.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def check_grad_1d_col(model: torch.nn.Module, weight, bias):
|
def run_with_spec(spec_init_func):
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
|
|
||||||
assert torch.allclose(model.bias.grad.chunk(size, 0)[rank], bias.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def run_with_spec(spec_init_func, check_grad_func):
|
|
||||||
model = torch.nn.Linear(4, 8).cuda()
|
model = torch.nn.Linear(4, 8).cuda()
|
||||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
||||||
|
@ -54,18 +41,19 @@ def run_with_spec(spec_init_func, check_grad_func):
|
||||||
x = torch.rand(2, 4).cuda()
|
x = torch.rand(2, 4).cuda()
|
||||||
out = model(x)
|
out = model(x)
|
||||||
colo_out = F.linear(x, weight, bias)
|
colo_out = F.linear(x, weight, bias)
|
||||||
assert torch.allclose(out, colo_out)
|
assert tensor_equal(out, colo_out)
|
||||||
grad = torch.rand_like(out)
|
grad = torch.rand_like(out)
|
||||||
out.backward(grad)
|
out.backward(grad)
|
||||||
colo_out.backward(grad)
|
colo_out.backward(grad)
|
||||||
check_grad_func(model, weight, bias)
|
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||||
|
assert tensor_shard_equal(model.bias.grad, bias.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_with_spec(init_1d_row, check_grad_1d_row)
|
run_with_spec(init_1d_row)
|
||||||
run_with_spec(init_1d_col, check_grad_1d_col)
|
run_with_spec(init_1d_col)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
|
@ -13,78 +13,8 @@ from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec
|
||||||
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager
|
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import random
|
from _utils import set_seed
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Hack huggingface Bert ModelOutput
|
|
||||||
# Make it available to our ColoTensor
|
|
||||||
from transformers.file_utils import ModelOutput
|
|
||||||
from dataclasses import fields
|
|
||||||
|
|
||||||
|
|
||||||
def _post_init_colotensor(self):
|
|
||||||
class_fields = fields(self)
|
|
||||||
# Safety and consistency checks
|
|
||||||
if len(class_fields) == 0:
|
|
||||||
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
|
||||||
if not all(field.default is None for field in class_fields[1:]):
|
|
||||||
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
|
|
||||||
|
|
||||||
first_field = getattr(self, class_fields[0].name)
|
|
||||||
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
|
||||||
|
|
||||||
def is_tensor_with_colo(x):
|
|
||||||
"""
|
|
||||||
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
|
|
||||||
"""
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return isinstance(x, ColoTensor)
|
|
||||||
|
|
||||||
if other_fields_are_none and not is_tensor_with_colo(first_field):
|
|
||||||
if isinstance(first_field, dict):
|
|
||||||
iterator = first_field.items()
|
|
||||||
first_field_iterator = True
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
iterator = iter(first_field)
|
|
||||||
first_field_iterator = True
|
|
||||||
except TypeError:
|
|
||||||
first_field_iterator = False
|
|
||||||
|
|
||||||
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
|
||||||
# set the associated fields
|
|
||||||
if first_field_iterator:
|
|
||||||
for element in iterator:
|
|
||||||
if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
|
|
||||||
break
|
|
||||||
setattr(self, element[0], element[1])
|
|
||||||
if element[1] is not None:
|
|
||||||
self[element[0]] = element[1]
|
|
||||||
elif first_field is not None:
|
|
||||||
self[class_fields[0].name] = first_field
|
|
||||||
else:
|
|
||||||
for field in class_fields:
|
|
||||||
v = getattr(self, field.name)
|
|
||||||
if v is not None:
|
|
||||||
self[field.name] = v
|
|
||||||
|
|
||||||
|
|
||||||
ModelOutput.__post_init__ = _post_init_colotensor
|
|
||||||
# complete the hack
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_linear(weight):
|
def init_1d_row_linear(weight):
|
||||||
|
|
Loading…
Reference in New Issue