2022-10-18 08:31:22 +00:00
|
|
|
from functools import partial
|
|
|
|
|
2022-06-02 04:13:15 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
2022-10-18 08:31:22 +00:00
|
|
|
|
|
|
|
import colossalai
|
2022-06-02 04:13:15 +00:00
|
|
|
from colossalai.amp import convert_to_apex_amp
|
2022-11-16 06:44:28 +00:00
|
|
|
from colossalai.gemini.chunk import search_chunk_configuration
|
|
|
|
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
|
|
|
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
|
2022-10-18 08:31:22 +00:00
|
|
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
|
|
|
from colossalai.utils import free_port
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
2022-11-14 08:05:09 +00:00
|
|
|
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
|
2022-07-25 03:18:08 +00:00
|
|
|
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
2022-06-02 04:13:15 +00:00
|
|
|
|
|
|
|
|
2022-10-09 01:18:51 +00:00
|
|
|
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
|
|
|
|
zero_dict = model.state_dict(only_rank_0=False)
|
|
|
|
torch_dict = torch_model.state_dict()
|
2022-06-02 04:13:15 +00:00
|
|
|
|
2022-10-09 01:18:51 +00:00
|
|
|
for key, value in torch_dict.items():
|
|
|
|
# key is 'module.model.PARAMETER', so we truncate it
|
|
|
|
key = key[7:]
|
|
|
|
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
|
|
|
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
|
|
|
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
|
|
|
assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \
|
|
|
|
"parameter '{}' has problem.".format(key)
|
2022-06-15 07:05:19 +00:00
|
|
|
|
|
|
|
|
2022-11-24 08:51:45 +00:00
|
|
|
def run_fwd_bwd(model, criterion, optimizer, input_ids):
|
2022-06-02 04:13:15 +00:00
|
|
|
optimizer.zero_grad()
|
2022-11-24 08:51:45 +00:00
|
|
|
logits = model(input_ids)
|
2022-06-02 04:13:15 +00:00
|
|
|
logits = logits.float()
|
|
|
|
loss = criterion(logits, input_ids)
|
|
|
|
optimizer.backward(loss)
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def init_1d_row_spec(model, pg: ProcessGroup):
|
2022-07-11 07:51:48 +00:00
|
|
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
2022-07-15 10:19:52 +00:00
|
|
|
for n, p in model.named_parameters():
|
|
|
|
p.set_process_group(pg)
|
|
|
|
if 'weight' in n and 'ln' not in n:
|
|
|
|
p.set_tensor_spec(*spec)
|
2022-06-17 08:12:05 +00:00
|
|
|
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def init_1d_col_spec(model, pg: ProcessGroup):
|
2022-07-11 07:51:48 +00:00
|
|
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
2022-07-15 10:19:52 +00:00
|
|
|
for n, p in model.named_parameters():
|
|
|
|
p.set_process_group(pg)
|
|
|
|
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
|
|
|
p.set_tensor_spec(*spec)
|
2022-06-17 08:12:05 +00:00
|
|
|
|
|
|
|
|
2022-07-18 06:14:52 +00:00
|
|
|
@parameterize('placement_policy', ['cuda', 'cpu'])
|
2022-10-09 01:18:51 +00:00
|
|
|
def run_gpt(placement_policy, tp_init_spec_func=None):
|
2022-06-02 04:13:15 +00:00
|
|
|
set_seed(42)
|
|
|
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
|
|
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
|
|
|
|
|
|
|
with ColoInitContext(device=get_current_device()):
|
|
|
|
model = model_builder()
|
2022-07-15 10:19:52 +00:00
|
|
|
model = model.cuda()
|
2022-06-02 04:13:15 +00:00
|
|
|
torch_model = model_builder().cuda()
|
2022-07-15 10:19:52 +00:00
|
|
|
|
2022-06-02 04:13:15 +00:00
|
|
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
2022-07-15 10:19:52 +00:00
|
|
|
torch_p.data.copy_(p.data)
|
2022-06-02 04:13:15 +00:00
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
|
|
|
|
# world size, dp = 2, tp =2, construct a hybrid parallelism.
|
|
|
|
if world_size == 4:
|
|
|
|
pg = ProcessGroup(tp_degree=2)
|
|
|
|
else:
|
|
|
|
pg = ProcessGroup(tp_degree=world_size)
|
|
|
|
|
2022-06-17 08:12:05 +00:00
|
|
|
if tp_init_spec_func:
|
2022-07-04 10:54:37 +00:00
|
|
|
tp_init_spec_func(model, pg)
|
2022-06-17 08:12:05 +00:00
|
|
|
|
2022-10-09 01:18:51 +00:00
|
|
|
dp_world_size = pg.dp_world_size()
|
2022-10-18 08:31:22 +00:00
|
|
|
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
2022-10-09 01:18:51 +00:00
|
|
|
config_dict[dp_world_size]['chunk_size'] = 5000
|
|
|
|
config_dict[dp_world_size]['keep_gathered'] = False
|
|
|
|
if placement_policy != 'cuda':
|
|
|
|
init_device = torch.device('cpu')
|
|
|
|
else:
|
|
|
|
init_device = None
|
|
|
|
|
2022-11-16 06:44:28 +00:00
|
|
|
model = GeminiDDP(model, init_device, placement_policy, True, False, 32)
|
|
|
|
# The same as the following 3 lines
|
|
|
|
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
|
|
|
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
|
|
|
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
|
|
|
|
|
|
|
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
|
|
|
# The same as the following 2 lines
|
|
|
|
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
|
|
|
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
2022-06-02 04:13:15 +00:00
|
|
|
|
2022-07-15 10:19:52 +00:00
|
|
|
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
2022-06-02 04:13:15 +00:00
|
|
|
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
|
|
|
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
2022-07-04 10:54:37 +00:00
|
|
|
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
2022-06-02 04:13:15 +00:00
|
|
|
|
2022-10-09 01:18:51 +00:00
|
|
|
check_param(model, torch_model, pg)
|
2022-07-15 10:19:52 +00:00
|
|
|
|
|
|
|
model.eval()
|
|
|
|
torch_model.eval()
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
set_seed(pg.dp_local_rank())
|
2022-11-24 08:51:45 +00:00
|
|
|
for i, (input_ids, label) in enumerate(train_dataloader):
|
2022-06-02 04:13:15 +00:00
|
|
|
if i > 2:
|
|
|
|
break
|
2022-07-15 10:19:52 +00:00
|
|
|
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
|
2022-11-24 08:51:45 +00:00
|
|
|
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo)
|
|
|
|
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
|
2022-10-09 01:18:51 +00:00
|
|
|
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
|
|
|
|
|
|
|
zero_optim.step()
|
2022-06-15 07:05:19 +00:00
|
|
|
torch_optim.step()
|
2022-10-09 01:18:51 +00:00
|
|
|
check_param(model, torch_model, pg)
|
2022-06-02 04:13:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
2022-06-17 08:12:05 +00:00
|
|
|
config = {}
|
|
|
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
|
|
if world_size == 4:
|
2022-07-25 03:18:08 +00:00
|
|
|
run_gpt(tp_init_spec_func=init_megatron_spec)
|
2022-06-17 08:12:05 +00:00
|
|
|
else:
|
2022-07-15 10:19:52 +00:00
|
|
|
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
2022-07-25 03:18:08 +00:00
|
|
|
run_gpt(tp_init_spec_func=init_1d_row_spec)
|
2022-06-02 04:13:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@pytest.mark.parametrize('world_size', [1, 4])
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_gpt(world_size):
|
|
|
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
|
|
mp.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_gpt(4)
|