mirror of https://github.com/hpcaitech/ColossalAI
[unit test] refactor test tensor (#1005)
* polish test_gpt * update op unit tests * update test modelpull/1003/head
parent
ad536e308e
commit
8e3d0ad8f1
|
@ -1 +1 @@
|
|||
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net
|
||||
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net, gpt
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .registry import non_distributed_component_funcs
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
from .utils.dummy_data_generator import DummyDataGenerator
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
vocab_size = 50304
|
||||
batch_size = 4
|
||||
seq_len = 1024
|
||||
|
||||
def generate(self):
|
||||
input_ids = torch.randint(0,
|
||||
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
|
||||
device=get_current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50304,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0))
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
def gpt2_s(checkpoint=True):
|
||||
return GPTLMModel(checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_m(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='gpt2')
|
||||
def get_training_components():
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
|
||||
criterion = GPTLMLoss()
|
||||
return gpt2_s, trainloader, testloader, torch.optim.Adam, criterion
|
|
@ -1,5 +1,19 @@
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
|
@ -25,3 +39,19 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
|
|||
|
||||
def tensor_equal(A, B):
|
||||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
||||
|
||||
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
||||
assert tensor.ndim == shard.ndim
|
||||
if tensor.shape == shard.shape:
|
||||
return tensor_equal(tensor, shard)
|
||||
else:
|
||||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
||||
if dims_not_eq.numel() == 1:
|
||||
# 1D shard
|
||||
dim = dims_not_eq.item()
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from colossalai.core import global_context as gpc
|
||||
from _utils import tensor_shard_equal, tensor_equal
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
|
@ -45,13 +46,6 @@ def init_1d_row(weight, bias):
|
|||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def check_grad_1d_row(model: torch.nn.Module, weight, bias):
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
|
||||
assert torch.allclose(model.bias.grad, bias.grad)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
|
@ -61,14 +55,7 @@ def init_1d_col(weight, bias):
|
|||
bias.set_spec(spec)
|
||||
|
||||
|
||||
def check_grad_1d_col(model: torch.nn.Module, weight, bias):
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
|
||||
assert torch.allclose(model.bias.grad.chunk(size, -1)[rank], bias.grad)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, check_grad_func):
|
||||
def run_with_spec(spec_init_func):
|
||||
model = Conv1D(4, 16).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
||||
|
@ -76,18 +63,19 @@ def run_with_spec(spec_init_func, check_grad_func):
|
|||
x = torch.rand(2, 16).cuda()
|
||||
out = model(x)
|
||||
colo_out = torch.addmm(bias, x, weight)
|
||||
assert torch.allclose(out, colo_out)
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
check_grad_func(model, weight, bias)
|
||||
tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
tensor_shard_equal(model.bias.grad, bias.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row, check_grad_1d_row)
|
||||
run_with_spec(init_1d_col, check_grad_1d_col)
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -12,6 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight):
|
||||
|
@ -22,12 +23,6 @@ def init_1d_row(weight):
|
|||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def check_grad_1d_row(model: torch.nn.Module, weight):
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
|
||||
|
||||
|
||||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
|
@ -36,31 +31,25 @@ def init_1d_col(weight):
|
|||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def check_grad_1d_col(model: torch.nn.Module, weight):
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, check_grad_func):
|
||||
def run_with_spec(spec_init_func):
|
||||
model = torch.nn.Embedding(12, 32).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
spec_init_func(weight)
|
||||
x = torch.tensor((0, 3, 6, 9)).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.embedding(x, weight)
|
||||
assert torch.allclose(out, colo_out)
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
check_grad_func(model, weight)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row, check_grad_1d_row)
|
||||
run_with_spec(init_1d_col, check_grad_1d_col)
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -1,142 +1,16 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils import ColoInitContext
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager, distspec
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
|
||||
from colossalai.core import global_context as gpc
|
||||
from functools import partial
|
||||
# Hack huggingface Bert ModelOutput
|
||||
# Make it available to our ColoTensor
|
||||
from transformers.file_utils import ModelOutput
|
||||
from dataclasses import fields
|
||||
from tests.test_tensor._utils import tensor_equal
|
||||
|
||||
|
||||
def _post_init_colotensor(self):
|
||||
class_fields = fields(self)
|
||||
# Safety and consistency checks
|
||||
if len(class_fields) == 0:
|
||||
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
||||
if not all(field.default is None for field in class_fields[1:]):
|
||||
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
|
||||
|
||||
first_field = getattr(self, class_fields[0].name)
|
||||
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
||||
|
||||
def is_tensor_with_colo(x):
|
||||
"""
|
||||
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return True
|
||||
|
||||
return isinstance(x, ColoTensor)
|
||||
|
||||
if other_fields_are_none and not is_tensor_with_colo(first_field):
|
||||
if isinstance(first_field, dict):
|
||||
iterator = first_field.items()
|
||||
first_field_iterator = True
|
||||
else:
|
||||
try:
|
||||
iterator = iter(first_field)
|
||||
first_field_iterator = True
|
||||
except TypeError:
|
||||
first_field_iterator = False
|
||||
|
||||
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
||||
# set the associated fields
|
||||
if first_field_iterator:
|
||||
for element in iterator:
|
||||
if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
|
||||
break
|
||||
setattr(self, element[0], element[1])
|
||||
if element[1] is not None:
|
||||
self[element[0]] = element[1]
|
||||
elif first_field is not None:
|
||||
self[class_fields[0].name] = first_field
|
||||
else:
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
|
||||
|
||||
ModelOutput.__post_init__ = _post_init_colotensor
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50304,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0))
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
def gpt2_s(checkpoint=True):
|
||||
return GPTLMModel(checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_m(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_spec(model):
|
||||
|
@ -159,30 +33,6 @@ def init_1d_col_spec(model):
|
|||
p.set_spec(spec)
|
||||
|
||||
|
||||
def check_tensor_equal_1d(tensor: torch.Tensor, shard: ColoTensor):
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
assert len(shard.spec.dist_spec.dims) == 1
|
||||
dim = shard.spec.dist_spec.dims[0]
|
||||
assert torch.equal(tensor.chunk(world_size, dim)[rank], shard.torch_tensor())
|
||||
|
||||
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
||||
assert tensor.ndim == shard.ndim
|
||||
if tensor.shape == shard.shape:
|
||||
return tensor_equal(tensor, shard)
|
||||
else:
|
||||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
||||
if dims_not_eq.numel() == 1:
|
||||
# 1D shard
|
||||
dim = dims_not_eq.item()
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert tensor_shard_equal(torch_p, p)
|
||||
|
@ -194,23 +44,20 @@ def check_grad_equal(model, torch_model):
|
|||
|
||||
|
||||
def run_gpt(init_spec_func):
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50304
|
||||
NUM_STEPS = 1
|
||||
criterion = GPTLMLoss()
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = gpt2_s()
|
||||
model = model_builder()
|
||||
model = model.cuda()
|
||||
torch_model = gpt2_s().cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p)
|
||||
init_spec_func(model)
|
||||
check_param_equal(model, torch_model)
|
||||
model.train()
|
||||
torch_model.train()
|
||||
for i in range(NUM_STEPS):
|
||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
logits = model(input_ids, attn_mask)
|
||||
torch_logits = torch_model(input_ids, attn_mask)
|
||||
assert tensor_equal(torch_logits, logits)
|
||||
|
@ -219,6 +66,8 @@ def run_gpt(init_spec_func):
|
|||
loss.backward()
|
||||
torch_loss.backward()
|
||||
check_grad_equal(model, torch_model)
|
||||
if i > 0:
|
||||
break
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@ -237,4 +86,4 @@ def test_gpt(world_size):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(1)
|
||||
test_gpt(4)
|
||||
|
|
|
@ -13,6 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight, bias):
|
||||
|
@ -23,13 +24,6 @@ def init_1d_row(weight, bias):
|
|||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def check_grad_1d_row(model: torch.nn.Module, weight, bias):
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
|
||||
assert torch.allclose(model.bias.grad, bias.grad)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
|
@ -39,14 +33,7 @@ def init_1d_col(weight, bias):
|
|||
bias.set_spec(spec)
|
||||
|
||||
|
||||
def check_grad_1d_col(model: torch.nn.Module, weight, bias):
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
|
||||
assert torch.allclose(model.bias.grad.chunk(size, 0)[rank], bias.grad)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, check_grad_func):
|
||||
def run_with_spec(spec_init_func):
|
||||
model = torch.nn.Linear(4, 8).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
||||
|
@ -54,18 +41,19 @@ def run_with_spec(spec_init_func, check_grad_func):
|
|||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.linear(x, weight, bias)
|
||||
assert torch.allclose(out, colo_out)
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
check_grad_func(model, weight, bias)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
assert tensor_shard_equal(model.bias.grad, bias.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row, check_grad_1d_row)
|
||||
run_with_spec(init_1d_col, check_grad_1d_col)
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -13,78 +13,8 @@ from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec
|
|||
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from functools import partial
|
||||
import random
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# Hack huggingface Bert ModelOutput
|
||||
# Make it available to our ColoTensor
|
||||
from transformers.file_utils import ModelOutput
|
||||
from dataclasses import fields
|
||||
|
||||
|
||||
def _post_init_colotensor(self):
|
||||
class_fields = fields(self)
|
||||
# Safety and consistency checks
|
||||
if len(class_fields) == 0:
|
||||
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
||||
if not all(field.default is None for field in class_fields[1:]):
|
||||
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
|
||||
|
||||
first_field = getattr(self, class_fields[0].name)
|
||||
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
||||
|
||||
def is_tensor_with_colo(x):
|
||||
"""
|
||||
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return True
|
||||
|
||||
return isinstance(x, ColoTensor)
|
||||
|
||||
if other_fields_are_none and not is_tensor_with_colo(first_field):
|
||||
if isinstance(first_field, dict):
|
||||
iterator = first_field.items()
|
||||
first_field_iterator = True
|
||||
else:
|
||||
try:
|
||||
iterator = iter(first_field)
|
||||
first_field_iterator = True
|
||||
except TypeError:
|
||||
first_field_iterator = False
|
||||
|
||||
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
||||
# set the associated fields
|
||||
if first_field_iterator:
|
||||
for element in iterator:
|
||||
if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
|
||||
break
|
||||
setattr(self, element[0], element[1])
|
||||
if element[1] is not None:
|
||||
self[element[0]] = element[1]
|
||||
elif first_field is not None:
|
||||
self[class_fields[0].name] = first_field
|
||||
else:
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
|
||||
|
||||
ModelOutput.__post_init__ = _post_init_colotensor
|
||||
# complete the hack
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
from _utils import set_seed
|
||||
|
||||
|
||||
def init_1d_row_linear(weight):
|
||||
|
|
Loading…
Reference in New Issue