import copy from contextlib import nullcontext import pytest import torch from torch import distributed as dist from torch.optim import Adam import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.lazy.lazy_init import LazyInitContext from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.api import ( clear_layout_converter, is_customized_distributed_tensor, is_distributed_tensor, ) from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): use_lazy_init = False if 'use_lazy_init' in test_config: use_lazy_init = test_config.pop('use_lazy_init') if use_lazy_init: ctx = LazyInitContext() else: ctx = nullcontext() # prepare booster plugin = HybridParallelPlugin(**test_config) booster = Booster(plugin=plugin) stage_manager = plugin.stage_manager # prepare models and optimizers with ctx: org_model = model_fn().cuda() sharded_model = copy.deepcopy(org_model) if use_lazy_init: org_model = ctx.materialize(org_model) org_optimizer = Adam(org_model.parameters(), lr=1e-3) sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) criterion = loss_fn sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) def _criterion(outputs, inputs): outputs = output_transform_fn(outputs) loss = criterion(outputs) return loss # do forward and backward data = data_gen_fn() sharded_model.train() if stage_manager: data = { k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() } data_iter = iter([data]) sharded_output = booster.execute_pipeline(data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True) sharded_loss = sharded_output['loss'] else: data = {k: v.cuda() for k, v in data.items()} sharded_output = sharded_model(**data) sharded_loss = criterion(sharded_output) sharded_loss.backward() org_model.train() org_output = org_model(**data) org_loss = criterion(org_output) org_loss.backward() if stage_manager is None or stage_manager.is_last_stage(): # check last hidden state if org_model.__class__.__name__ == 'GPT2Model': org_hidden_state = org_output.last_hidden_state if stage_manager is None: sharded_hidden_state = sharded_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" # check loss assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \ f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" # unwrap model if org_model.__class__.__name__ == 'GPT2Model': gpt2 = org_model sharded_gpt2 = sharded_model.unwrap() else: gpt2 = org_model.transformer sharded_gpt2 = sharded_model.unwrap().transformer # check grad col_layer_for_check = ['h[0].mlp.c_fc'] row_layer_for_check = ['h[0].mlp.c_proj'] check_grad(gpt2, sharded_gpt2, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False) check_grad(gpt2, sharded_gpt2, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): org_weight = org_model.h[0].mlp.c_fc.weight shard_weight = sharded_model.h[0].mlp.c_fc.weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)] dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group) shard_weight = torch.cat(shard_weight_list, dim=1) assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \ f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}" torch.cuda.empty_cache() @parameterize('test_config', [{ 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, 'use_lazy_init': True }, { 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, 'enable_fused_normalization': False, 'use_lazy_init': False }, { 'tp_size': 4, 'pp_size': 1, 'enable_fused_normalization': True, 'use_lazy_init': False }]) @clear_cache_before_run() def run_gpt2_test(test_config): # TODO: add plugin_config for TP+DP after supporting & debugging it # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() torch.cuda.empty_cache() def check_gpt2(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_gpt2_test() @pytest.mark.skip('Have some bug caused by merge') @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_gpt2(): spawn(check_gpt2, 4) if __name__ == "__main__": test_gpt2()