diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8ee1e97c6..e1593cf6b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): Returns: None """ - if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. @@ -487,7 +486,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): Returns: None """ - # Call the superclass backward method to compute gradients. super().backward(loss, *args, **kwargs) @@ -513,7 +511,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): Returns: None """ - # Call the superclass backward method to compute gradients. super().backward_by_grad(tensor, grad) @@ -674,7 +671,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): Returns: None """ - # Call the superclass `_sync_grad` method to synchronize gradients. super()._sync_grad() @@ -1081,7 +1077,7 @@ class HybridParallelPlugin(PipelinePluginBase): return True def support_no_sync(self) -> bool: - return False + return True def control_checkpoint_io(self) -> bool: return True @@ -1175,9 +1171,14 @@ class HybridParallelPlugin(PipelinePluginBase): model, data_iter, criterion, optimizer, return_loss, return_outputs ) + # run with gradients accumulation + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + ): + return outputs + # Synchronize the grads of shared parameters of the model. model.sync_shared_params() - # Synchronize sequence parallelism gradients of the model. model.sync_sp_grads() @@ -1241,5 +1242,8 @@ class HybridParallelPlugin(PipelinePluginBase): def get_checkpoint_io(self) -> CheckpointIO: return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) - def no_sync(self, model: Module) -> Iterator[None]: - raise NotImplementedError + def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert ( + self.zero_stage != 2 + ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." + return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index e724d7359..6f2fc104f 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -1,8 +1,11 @@ +import copy from contextlib import nullcontext from typing import Optional import torch import torch.distributed as dist +from torch.testing import assert_close +from torch.utils.data import Dataset import colossalai from colossalai.booster import Booster @@ -11,9 +14,33 @@ from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device, set_seed from tests.kit.model_zoo import model_zoo +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + set_seed(42) + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + @clear_cache_before_run() def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: @@ -85,10 +112,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) +@parameterize( + "test_args", + [ + { + "batch_size": 8, + "num_steps": 4, + "tp": 2, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 0, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 8, + "num_steps": 4, + "tp": 1, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 1, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 1, + "num_steps": 4, + "tp": 2, + "pp": 1, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 1, + "zero": 2, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + { + "batch_size": 1, + "num_steps": 4, + "tp": 2, + "pp": 1, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 1, + "zero": 0, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, + ], +) +def run_grad_acc_test(test_args): + model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())) + model = model_fn() + optimizer = HybridAdam(model.parameters()) + origin_model = copy.deepcopy(model).cuda() + origin_optimizer = HybridAdam(origin_model.parameters()) + + plugin = HybridParallelPlugin( + tp_size=test_args["tp"], + pp_size=test_args["pp"], + pp_style=test_args["pp_style"], + zero_stage=test_args["zero"], + num_model_chunks=test_args["num_model_chunks"], + enable_fused_normalization=True, + num_microbatches=test_args["num_microbatches"], + precision=test_args["precision"], + ) + booster = Booster(plugin=plugin) + + dataset = RandomDataset( + num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size, + max_length=test_args["max_length"], + vocab_size=model.config.vocab_size, + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True) + + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + + grad_accu_step = test_args["gradient_accumulation_step"] + for step, batch in enumerate(dataloader): + batch = move_to_cuda(batch) + # train origin model + origin_output = origin_model(**batch) + origin_loss = origin_output[0] / grad_accu_step + origin_loss.backward() + + if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2: + ctx = booster.no_sync(model, optimizer) + else: + ctx = nullcontext() + + with ctx: + if plugin.stage_manager is not None: + batch = iter([batch]) + booster.execute_pipeline( + batch, + model, + criterion=lambda outputs, inputs: outputs[0] / grad_accu_step, + optimizer=optimizer, + return_loss=False, + ) + else: + outputs = model(**batch) + loss = outputs[0] / grad_accu_step + booster.backward(loss, optimizer) + + if (step + 1) % grad_accu_step == 0: + # update origin model weight + origin_optimizer.step() + origin_optimizer.zero_grad() + + # update sharded model + optimizer.step() + optimizer.zero_grad() + + # tricky code here, shard the origin model inorder to check the parameters in the same stage. + origin_model, origin_optimizer, _, dataloader, _ = booster.boost( + origin_model, origin_optimizer, dataloader=dataloader + ) + for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + + def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_3d_plugin(early_stop=early_stop) + run_grad_acc_test() @rerun_if_address_is_in_use()