[unit test] refactor test tensor (#1005)

* polish test_gpt

* update op unit tests

* update test model
pull/1003/head
ver217 2022-05-19 18:57:56 +08:00 committed by GitHub
parent ad536e308e
commit 8e3d0ad8f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 143 additions and 290 deletions

View File

@ -1 +1 @@
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net, gpt

View File

@ -0,0 +1,79 @@
import torch
import torch.nn as nn
from .registry import non_distributed_component_funcs
from transformers import GPT2Config, GPT2LMHeadModel
from .utils.dummy_data_generator import DummyDataGenerator
from colossalai.utils.cuda import get_current_device
class DummyDataLoader(DummyDataGenerator):
vocab_size = 50304
batch_size = 4
seq_len = 1024
def generate(self):
input_ids = torch.randint(0,
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
device=get_current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50304,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
def gpt2_s(checkpoint=True):
return GPTLMModel(checkpoint=checkpoint)
def gpt2_m(checkpoint=True):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
@non_distributed_component_funcs.register(name='gpt2')
def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = GPTLMLoss()
return gpt2_s, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,5 +1,19 @@
import os
import random
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def check_equal(A, B): def check_equal(A, B):
@ -25,3 +39,19 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
def tensor_equal(A, B): def tensor_equal(A, B):
return torch.allclose(A, B, rtol=1e-3, atol=1e-1) return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return tensor_equal(tensor, shard)
else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1:
# 1D shard
dim = dims_not_eq.item()
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
else:
raise NotImplementedError

View File

@ -11,6 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from functools import partial from functools import partial
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from _utils import tensor_shard_equal, tensor_equal
class Conv1D(nn.Module): class Conv1D(nn.Module):
@ -45,13 +46,6 @@ def init_1d_row(weight, bias):
weight.set_spec(spec) weight.set_spec(spec)
def check_grad_1d_row(model: torch.nn.Module, weight, bias):
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
assert torch.allclose(model.bias.grad, bias.grad)
def init_1d_col(weight, bias): def init_1d_col(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
@ -61,14 +55,7 @@ def init_1d_col(weight, bias):
bias.set_spec(spec) bias.set_spec(spec)
def check_grad_1d_col(model: torch.nn.Module, weight, bias): def run_with_spec(spec_init_func):
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
assert torch.allclose(model.bias.grad.chunk(size, -1)[rank], bias.grad)
def run_with_spec(spec_init_func, check_grad_func):
model = Conv1D(4, 16).cuda() model = Conv1D(4, 16).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach())) bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
@ -76,18 +63,19 @@ def run_with_spec(spec_init_func, check_grad_func):
x = torch.rand(2, 16).cuda() x = torch.rand(2, 16).cuda()
out = model(x) out = model(x)
colo_out = torch.addmm(bias, x, weight) colo_out = torch.addmm(bias, x, weight)
assert torch.allclose(out, colo_out) assert tensor_equal(out, colo_out)
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
check_grad_func(model, weight, bias) tensor_shard_equal(model.weight.grad, weight.grad)
tensor_shard_equal(model.bias.grad, bias.grad)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row, check_grad_1d_row) run_with_spec(init_1d_row)
run_with_spec(init_1d_col, check_grad_1d_col) run_with_spec(init_1d_col)
@pytest.mark.dist @pytest.mark.dist

View File

@ -12,6 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight): def init_1d_row(weight):
@ -22,12 +23,6 @@ def init_1d_row(weight):
weight.set_spec(spec) weight.set_spec(spec)
def check_grad_1d_row(model: torch.nn.Module, weight):
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
def init_1d_col(weight): def init_1d_col(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
@ -36,31 +31,25 @@ def init_1d_col(weight):
weight.set_spec(spec) weight.set_spec(spec)
def check_grad_1d_col(model: torch.nn.Module, weight): def run_with_spec(spec_init_func):
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
def run_with_spec(spec_init_func, check_grad_func):
model = torch.nn.Embedding(12, 32).cuda() model = torch.nn.Embedding(12, 32).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
spec_init_func(weight) spec_init_func(weight)
x = torch.tensor((0, 3, 6, 9)).cuda() x = torch.tensor((0, 3, 6, 9)).cuda()
out = model(x) out = model(x)
colo_out = F.embedding(x, weight) colo_out = F.embedding(x, weight)
assert torch.allclose(out, colo_out) assert tensor_equal(out, colo_out)
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
check_grad_func(model, weight) assert tensor_shard_equal(model.weight.grad, weight.grad)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row, check_grad_1d_row) run_with_spec(init_1d_row)
run_with_spec(init_1d_col, check_grad_1d_col) run_with_spec(init_1d_col)
@pytest.mark.dist @pytest.mark.dist

View File

@ -1,142 +1,16 @@
import pytest import pytest
import colossalai import colossalai
import os
import random
import numpy as np
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from transformers import GPT2Config, GPT2LMHeadModel
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils import ColoInitContext from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager, distspec from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial from functools import partial
# Hack huggingface Bert ModelOutput from _utils import tensor_equal, tensor_shard_equal
# Make it available to our ColoTensor from tests.components_to_test.registry import non_distributed_component_funcs
from transformers.file_utils import ModelOutput
from dataclasses import fields
from tests.test_tensor._utils import tensor_equal
def _post_init_colotensor(self):
class_fields = fields(self)
# Safety and consistency checks
if len(class_fields) == 0:
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
def is_tensor_with_colo(x):
"""
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
"""
if isinstance(x, torch.Tensor):
return True
return isinstance(x, ColoTensor)
if other_fields_are_none and not is_tensor_with_colo(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
ModelOutput.__post_init__ = _post_init_colotensor
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50304,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
def gpt2_s(checkpoint=True):
return GPTLMModel(checkpoint=checkpoint)
def gpt2_m(checkpoint=True):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def init_1d_row_spec(model): def init_1d_row_spec(model):
@ -159,30 +33,6 @@ def init_1d_col_spec(model):
p.set_spec(spec) p.set_spec(spec)
def check_tensor_equal_1d(tensor: torch.Tensor, shard: ColoTensor):
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
assert len(shard.spec.dist_spec.dims) == 1
dim = shard.spec.dist_spec.dims[0]
assert torch.equal(tensor.chunk(world_size, dim)[rank], shard.torch_tensor())
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return tensor_equal(tensor, shard)
else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1:
# 1D shard
dim = dims_not_eq.item()
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
else:
raise NotImplementedError
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p, p) assert tensor_shard_equal(torch_p, p)
@ -194,23 +44,20 @@ def check_grad_equal(model, torch_model):
def run_gpt(init_spec_func): def run_gpt(init_spec_func):
BATCH_SIZE = 4 get_components_func = non_distributed_component_funcs.get_callable('gpt2')
SEQ_LEN = 1024 model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
VOCAB_SIZE = 50304
NUM_STEPS = 1
criterion = GPTLMLoss()
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = gpt2_s() model = model_builder()
model = model.cuda() model = model.cuda()
torch_model = gpt2_s().cuda() torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p) torch_p.data.copy_(p)
init_spec_func(model) init_spec_func(model)
check_param_equal(model, torch_model) check_param_equal(model, torch_model)
model.train() model.train()
torch_model.train() torch_model.train()
for i in range(NUM_STEPS): for i, (input_ids, attn_mask) in enumerate(train_dataloader):
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
logits = model(input_ids, attn_mask) logits = model(input_ids, attn_mask)
torch_logits = torch_model(input_ids, attn_mask) torch_logits = torch_model(input_ids, attn_mask)
assert tensor_equal(torch_logits, logits) assert tensor_equal(torch_logits, logits)
@ -219,6 +66,8 @@ def run_gpt(init_spec_func):
loss.backward() loss.backward()
torch_loss.backward() torch_loss.backward()
check_grad_equal(model, torch_model) check_grad_equal(model, torch_model)
if i > 0:
break
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
@ -237,4 +86,4 @@ def test_gpt(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_gpt(1) test_gpt(4)

View File

@ -13,6 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias): def init_1d_row(weight, bias):
@ -23,13 +24,6 @@ def init_1d_row(weight, bias):
weight.set_spec(spec) weight.set_spec(spec)
def check_grad_1d_row(model: torch.nn.Module, weight, bias):
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad)
assert torch.allclose(model.bias.grad, bias.grad)
def init_1d_col(weight, bias): def init_1d_col(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
@ -39,14 +33,7 @@ def init_1d_col(weight, bias):
bias.set_spec(spec) bias.set_spec(spec)
def check_grad_1d_col(model: torch.nn.Module, weight, bias): def run_with_spec(spec_init_func):
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad)
assert torch.allclose(model.bias.grad.chunk(size, 0)[rank], bias.grad)
def run_with_spec(spec_init_func, check_grad_func):
model = torch.nn.Linear(4, 8).cuda() model = torch.nn.Linear(4, 8).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach())) bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
@ -54,18 +41,19 @@ def run_with_spec(spec_init_func, check_grad_func):
x = torch.rand(2, 4).cuda() x = torch.rand(2, 4).cuda()
out = model(x) out = model(x)
colo_out = F.linear(x, weight, bias) colo_out = F.linear(x, weight, bias)
assert torch.allclose(out, colo_out) assert tensor_equal(out, colo_out)
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
check_grad_func(model, weight, bias) assert tensor_shard_equal(model.weight.grad, weight.grad)
assert tensor_shard_equal(model.bias.grad, bias.grad)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row, check_grad_1d_row) run_with_spec(init_1d_row)
run_with_spec(init_1d_col, check_grad_1d_col) run_with_spec(init_1d_col)
@pytest.mark.dist @pytest.mark.dist

View File

@ -13,78 +13,8 @@ from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial from functools import partial
import random from _utils import set_seed
import os
import numpy as np
# Hack huggingface Bert ModelOutput
# Make it available to our ColoTensor
from transformers.file_utils import ModelOutput
from dataclasses import fields
def _post_init_colotensor(self):
class_fields = fields(self)
# Safety and consistency checks
if len(class_fields) == 0:
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
def is_tensor_with_colo(x):
"""
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
"""
if isinstance(x, torch.Tensor):
return True
return isinstance(x, ColoTensor)
if other_fields_are_none and not is_tensor_with_colo(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
ModelOutput.__post_init__ = _post_init_colotensor
# complete the hack
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def init_1d_row_linear(weight): def init_1d_row_linear(weight):