mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
241 lines
8.5 KiB
241 lines
8.5 KiB
import pytest
|
|
import colossalai
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
from transformers import GPT2Config, GPT2LMHeadModel
|
|
import torch.multiprocessing as mp
|
|
from colossalai.testing import rerun_if_address_is_in_use
|
|
from colossalai.utils.cuda import get_current_device
|
|
from colossalai.utils import free_port
|
|
from colossalai.utils import ColoInitContext
|
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager, distspec
|
|
from colossalai.core import global_context as gpc
|
|
from functools import partial
|
|
# Hack huggingface Bert ModelOutput
|
|
# Make it available to our ColoTensor
|
|
from transformers.file_utils import ModelOutput
|
|
from dataclasses import fields
|
|
from tests.test_tensor._utils import tensor_equal
|
|
|
|
|
|
def _post_init_colotensor(self):
|
|
class_fields = fields(self)
|
|
# Safety and consistency checks
|
|
if len(class_fields) == 0:
|
|
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
|
if not all(field.default is None for field in class_fields[1:]):
|
|
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
|
|
|
|
first_field = getattr(self, class_fields[0].name)
|
|
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
|
|
|
def is_tensor_with_colo(x):
|
|
"""
|
|
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
|
|
"""
|
|
if isinstance(x, torch.Tensor):
|
|
return True
|
|
|
|
return isinstance(x, ColoTensor)
|
|
|
|
if other_fields_are_none and not is_tensor_with_colo(first_field):
|
|
if isinstance(first_field, dict):
|
|
iterator = first_field.items()
|
|
first_field_iterator = True
|
|
else:
|
|
try:
|
|
iterator = iter(first_field)
|
|
first_field_iterator = True
|
|
except TypeError:
|
|
first_field_iterator = False
|
|
|
|
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
|
# set the associated fields
|
|
if first_field_iterator:
|
|
for element in iterator:
|
|
if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
|
|
break
|
|
setattr(self, element[0], element[1])
|
|
if element[1] is not None:
|
|
self[element[0]] = element[1]
|
|
elif first_field is not None:
|
|
self[class_fields[0].name] = first_field
|
|
else:
|
|
for field in class_fields:
|
|
v = getattr(self, field.name)
|
|
if v is not None:
|
|
self[field.name] = v
|
|
|
|
|
|
ModelOutput.__post_init__ = _post_init_colotensor
|
|
|
|
|
|
class GPTLMModel(nn.Module):
|
|
|
|
def __init__(self,
|
|
hidden_size=768,
|
|
num_layers=12,
|
|
num_attention_heads=12,
|
|
max_seq_len=1024,
|
|
vocab_size=50304,
|
|
checkpoint=False):
|
|
super().__init__()
|
|
self.checkpoint = checkpoint
|
|
self.model = GPT2LMHeadModel(
|
|
GPT2Config(n_embd=hidden_size,
|
|
n_layer=num_layers,
|
|
n_head=num_attention_heads,
|
|
n_positions=max_seq_len,
|
|
n_ctx=max_seq_len,
|
|
vocab_size=vocab_size,
|
|
resid_pdrop=0.0,
|
|
embd_pdrop=0.0,
|
|
attn_pdrop=0.0))
|
|
if checkpoint:
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
# Only return lm_logits
|
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
|
|
|
|
|
def gpt2_s(checkpoint=True):
|
|
return GPTLMModel(checkpoint=checkpoint)
|
|
|
|
|
|
def gpt2_m(checkpoint=True):
|
|
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
|
|
|
|
|
class GPTLMLoss(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.loss_fn = nn.CrossEntropyLoss()
|
|
|
|
def forward(self, logits, labels):
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
# Flatten the tokens
|
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
def get_data(batch_size, seq_len, vocab_size):
|
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
|
attention_mask = torch.ones_like(input_ids)
|
|
return input_ids, attention_mask
|
|
|
|
|
|
def init_1d_row_spec(model):
|
|
spec = TensorSpec(
|
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
|
with DistSpecManager.no_grad():
|
|
for n, p in model.named_parameters():
|
|
if 'weight' in n and 'ln' not in n:
|
|
p.set_spec(spec)
|
|
|
|
|
|
def init_1d_col_spec(model):
|
|
spec = TensorSpec(
|
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
|
with DistSpecManager.no_grad():
|
|
for n, p in model.named_parameters():
|
|
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
|
p.set_spec(spec)
|
|
|
|
|
|
def check_tensor_equal_1d(tensor: torch.Tensor, shard: ColoTensor):
|
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
assert len(shard.spec.dist_spec.dims) == 1
|
|
dim = shard.spec.dist_spec.dims[0]
|
|
assert torch.equal(tensor.chunk(world_size, dim)[rank], shard.torch_tensor())
|
|
|
|
|
|
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
|
assert tensor.ndim == shard.ndim
|
|
if tensor.shape == shard.shape:
|
|
return tensor_equal(tensor, shard)
|
|
else:
|
|
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
|
if dims_not_eq.numel() == 1:
|
|
# 1D shard
|
|
dim = dims_not_eq.item()
|
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def check_param_equal(model, torch_model):
|
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
|
assert tensor_shard_equal(torch_p, p)
|
|
|
|
|
|
def check_grad_equal(model, torch_model):
|
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
|
assert tensor_shard_equal(torch_p.grad, p.grad)
|
|
|
|
|
|
def run_gpt(init_spec_func):
|
|
BATCH_SIZE = 4
|
|
SEQ_LEN = 1024
|
|
VOCAB_SIZE = 50304
|
|
NUM_STEPS = 1
|
|
criterion = GPTLMLoss()
|
|
with ColoInitContext(device=get_current_device()):
|
|
model = gpt2_s()
|
|
model = model.cuda()
|
|
torch_model = gpt2_s().cuda()
|
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
|
torch_p.data.copy_(p)
|
|
init_spec_func(model)
|
|
check_param_equal(model, torch_model)
|
|
model.train()
|
|
torch_model.train()
|
|
for i in range(NUM_STEPS):
|
|
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
|
logits = model(input_ids, attn_mask)
|
|
torch_logits = torch_model(input_ids, attn_mask)
|
|
assert tensor_equal(torch_logits, logits)
|
|
loss = criterion(logits, input_ids)
|
|
torch_loss = criterion(torch_logits, input_ids)
|
|
loss.backward()
|
|
torch_loss.backward()
|
|
check_grad_equal(model, torch_model)
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
run_gpt(init_1d_row_spec)
|
|
run_gpt(init_1d_col_spec)
|
|
|
|
|
|
@pytest.mark.dist
|
|
@pytest.mark.parametrize('world_size', [1, 4])
|
|
@rerun_if_address_is_in_use()
|
|
def test_gpt(world_size):
|
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
mp.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_gpt(1)
|