2022-06-02 04:13:15 +00:00
|
|
|
import pytest
|
|
|
|
import colossalai
|
|
|
|
import torch
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
from colossalai.testing import rerun_if_address_is_in_use
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from colossalai.utils import free_port
|
2022-06-06 07:34:41 +00:00
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
2022-06-29 05:31:02 +00:00
|
|
|
from colossalai.gemini import ChunkManager
|
2022-06-02 04:13:15 +00:00
|
|
|
from functools import partial
|
2022-06-17 08:12:05 +00:00
|
|
|
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
2022-06-02 04:13:15 +00:00
|
|
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
2022-06-21 08:35:23 +00:00
|
|
|
from colossalai.nn.parallel import ZeroDDP
|
2022-06-02 04:13:15 +00:00
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
from colossalai.zero import ZeroOptimizer
|
|
|
|
from colossalai.testing import parameterize
|
|
|
|
from colossalai.amp import convert_to_apex_amp
|
2022-06-10 06:48:28 +00:00
|
|
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
2022-07-11 07:51:48 +00:00
|
|
|
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
2022-06-02 04:13:15 +00:00
|
|
|
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
2022-06-02 04:13:15 +00:00
|
|
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
|
|
|
if p.storage().size() > 0:
|
|
|
|
assert p.dtype == torch.half
|
2022-07-04 10:54:37 +00:00
|
|
|
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
|
|
|
|
pg.tp_world_size()), f'{torch_p} vs {p}'
|
2022-06-02 04:13:15 +00:00
|
|
|
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
2022-06-15 07:05:19 +00:00
|
|
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
|
|
|
if p.grad is not None:
|
2022-07-04 10:54:37 +00:00
|
|
|
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
|
|
|
|
pg.tp_local_rank(), pg.tp_world_size())
|
2022-06-15 07:05:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
2022-06-02 04:13:15 +00:00
|
|
|
optimizer.zero_grad()
|
|
|
|
logits = model(input_ids, attn_mask)
|
|
|
|
logits = logits.float()
|
|
|
|
loss = criterion(logits, input_ids)
|
|
|
|
optimizer.backward(loss)
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def init_1d_row_spec(model, pg: ProcessGroup):
|
2022-07-11 07:51:48 +00:00
|
|
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
2022-06-17 08:12:05 +00:00
|
|
|
with DistSpecManager.no_grad():
|
|
|
|
for n, p in model.named_parameters():
|
|
|
|
if 'weight' in n and 'ln' not in n:
|
2022-07-06 08:15:16 +00:00
|
|
|
p.set_tensor_spec(*spec)
|
2022-06-17 08:12:05 +00:00
|
|
|
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def init_1d_col_spec(model, pg: ProcessGroup):
|
2022-07-11 07:51:48 +00:00
|
|
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
2022-06-17 08:12:05 +00:00
|
|
|
with DistSpecManager.no_grad():
|
|
|
|
for n, p in model.named_parameters():
|
|
|
|
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
2022-07-06 08:15:16 +00:00
|
|
|
p.set_tensor_spec(*spec)
|
2022-06-17 08:12:05 +00:00
|
|
|
|
|
|
|
|
2022-06-02 04:13:15 +00:00
|
|
|
@parameterize('use_chunk', [False, True])
|
|
|
|
@parameterize('use_zero', [False, True])
|
2022-06-15 07:05:19 +00:00
|
|
|
@parameterize('placement_policy', ['cuda', 'cpu'])
|
2022-06-17 08:12:05 +00:00
|
|
|
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
2022-06-02 04:13:15 +00:00
|
|
|
set_seed(42)
|
|
|
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
|
|
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
|
|
|
|
|
|
|
with ColoInitContext(device=get_current_device()):
|
|
|
|
model = model_builder()
|
|
|
|
model = model.cuda().half()
|
|
|
|
torch_model = model_builder().cuda()
|
|
|
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
|
|
|
torch_p.data.copy_(p)
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
|
|
|
|
# world size, dp = 2, tp =2, construct a hybrid parallelism.
|
|
|
|
if world_size == 4:
|
|
|
|
pg = ProcessGroup(tp_degree=2)
|
|
|
|
else:
|
|
|
|
pg = ProcessGroup(tp_degree=world_size)
|
|
|
|
|
2022-06-17 08:12:05 +00:00
|
|
|
if tp_init_spec_func:
|
2022-07-04 10:54:37 +00:00
|
|
|
tp_init_spec_func(model, pg)
|
2022-06-17 08:12:05 +00:00
|
|
|
|
2022-06-15 07:05:19 +00:00
|
|
|
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
|
|
|
chunk_manager = ChunkManager(chunk_size,
|
|
|
|
enable_distributed_storage=use_zero,
|
|
|
|
init_device=GeminiManager.get_default_device(placement_policy))
|
|
|
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
2022-07-04 10:54:37 +00:00
|
|
|
model = ZeroDDP(model, gemini_manager, pg)
|
2022-06-02 04:13:15 +00:00
|
|
|
optim = HybridAdam(model.parameters(), lr=1e-3)
|
|
|
|
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
|
|
|
|
|
|
|
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
|
|
|
|
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
|
|
|
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
2022-07-04 10:54:37 +00:00
|
|
|
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
2022-06-02 04:13:15 +00:00
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
# print(chunk_manager)
|
|
|
|
check_param_equal(model, torch_model, pg)
|
2022-06-02 04:13:15 +00:00
|
|
|
model.train()
|
|
|
|
torch_model.train()
|
2022-07-04 10:54:37 +00:00
|
|
|
set_seed(pg.dp_local_rank())
|
2022-06-02 04:13:15 +00:00
|
|
|
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
|
|
|
if i > 2:
|
|
|
|
break
|
2022-07-04 10:54:37 +00:00
|
|
|
|
2022-06-15 07:05:19 +00:00
|
|
|
logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask)
|
|
|
|
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
2022-06-02 04:13:15 +00:00
|
|
|
assert tensor_equal(logits, torch_logits)
|
2022-07-04 10:54:37 +00:00
|
|
|
check_grad_equal(model, torch_model, pg)
|
2022-06-15 07:05:19 +00:00
|
|
|
optim.step()
|
|
|
|
torch_optim.step()
|
2022-07-04 10:54:37 +00:00
|
|
|
check_param_equal(model, torch_model, pg)
|
2022-06-02 04:13:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
2022-06-17 08:12:05 +00:00
|
|
|
config = {}
|
|
|
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
|
|
if world_size == 4:
|
|
|
|
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
|
|
|
run_gpt(tp_init_spec_func=init_1d_row_spec)
|
|
|
|
else:
|
|
|
|
run_gpt()
|
2022-06-02 04:13:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
2022-07-04 10:54:37 +00:00
|
|
|
@pytest.mark.skip("under development")
|
2022-06-02 04:13:15 +00:00
|
|
|
@pytest.mark.parametrize('world_size', [1, 4])
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_gpt(world_size):
|
|
|
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
|
|
mp.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_gpt(4)
|