diff --git a/colossalai/engine/schedule/_no_pipeline.py b/colossalai/engine/schedule/_no_pipeline.py index bbb8a7589..899f03ab4 100644 --- a/colossalai/engine/schedule/_no_pipeline.py +++ b/colossalai/engine/schedule/_no_pipeline.py @@ -15,7 +15,7 @@ from colossalai.core import global_context as gpc from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) from colossalai.nn.optimizer._utils import clip_grad_norm_fp32 -from ._utils import convert_to_fp16 +from ._utils import convert_to_fp16, convert_to_fp32 from ._base_schedule import BaseSchedule from ..amp import AMP_TYPE, GradScaler @@ -43,12 +43,6 @@ class NoPipelineSchedule(BaseSchedule): assert amp_type is None or isinstance(amp_type, AMP_TYPE), \ 'unrecognised value for argument fp16, it can only be None, torch or apex' - # LSG: check compatibility - # LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel - if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size( - ParallelMode.TENSOR) > 1: - assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \ - 'You can only AMP_TYPE.PARALLEL for tensor parallel training' self.use_zero_level_2_3 = False if amp_type is not None: @@ -121,18 +115,6 @@ class NoPipelineSchedule(BaseSchedule): data, label = self.load_batch() loss = None - # LSG: leave for debug, make sure dataloader is deterministic - # if forward_only: - # img = data[0] - # rank = gpc.get_local_rank(ParallelMode.DATA) - # world_size = gpc.get_world_size(ParallelMode.DATA) - # group = gpc.get_group(ParallelMode.DATA) - # input_list = [img.clone() for _ in range(world_size)] - # output_list = [torch.empty_like(img) for _ in range(world_size)] - # output_list[rank] = img.clone() - # dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group) - # assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2]) - # forward if self.fp16 and self.amp_type == AMP_TYPE.TORCH: with torch_amp.autocast(): @@ -146,6 +128,10 @@ class NoPipelineSchedule(BaseSchedule): data = convert_to_fp16(data) output = self.model(*data) + + if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: + output = convert_to_fp32(output) + if not isinstance(output, (tuple, list)): output = (output,) if return_loss: diff --git a/colossalai/engine/schedule/_utils.py b/colossalai/engine/schedule/_utils.py index 9c4a2a19b..cdfd0246c 100644 --- a/colossalai/engine/schedule/_utils.py +++ b/colossalai/engine/schedule/_utils.py @@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]): else: raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}") return ret + + +def convert_to_fp32(data: Union[Tensor, List[Tensor]]): + if isinstance(data, Tensor): + ret = data.float() + elif isinstance(data, (list, tuple)): + ret = [val.float() for val in data] + else: + raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}") + return ret + diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index 2c7eb8ac6..d9ecf2fad 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -7,6 +7,7 @@ from torch import Tensor from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd def matmul_2d(a, @@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function): """Matrix multiplication for :math:`C = AB` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -120,32 +122,32 @@ class Matmul_AB_2D(torch.autograd.Function): return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors - A_grad = Matmul_ABT_2D.forward( - None, - output_grad, B, - ctx.summa_dim, ctx.A_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) - B_grad = Matmul_ATB_2D.forward( - None, - A, output_grad, - ctx.summa_dim, ctx.B_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply( + output_grad, B, + ctx.summa_dim, ctx.A_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) + B_grad = Matmul_ATB_2D.apply( + A, output_grad, + ctx.summa_dim, ctx.B_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function): """Matrix multiplication for :math:`C = AB^T` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -214,32 +217,33 @@ class Matmul_ABT_2D(torch.autograd.Function): return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors - A_grad = Matmul_AB_2D.forward( - None, - output_grad, B, - ctx.summa_dim, ctx.A_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) - B_grad = Matmul_ATB_2D.forward( - None, - output_grad, A, - ctx.summa_dim, ctx.B_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) + + with torch.no_grad(): + A_grad = Matmul_AB_2D.apply( + output_grad, B, + ctx.summa_dim, ctx.A_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) + B_grad = Matmul_ATB_2D.apply( + output_grad, A, + ctx.summa_dim, ctx.B_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function): """Matrix multiplication for :math:`C = A^TB` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, @@ -308,32 +313,33 @@ class Matmul_ATB_2D(torch.autograd.Function): return out @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors - A_grad = Matmul_ABT_2D.forward( - None, - B, output_grad, - ctx.summa_dim, ctx.A_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) - B_grad = Matmul_AB_2D.forward( - None, - A, output_grad, - ctx.summa_dim, ctx.B_shape, - ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, - ctx.col_parallel_mode, - ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, - ctx.tensor_parallel_size - ) + + with torch.no_grad(): + A_grad = Matmul_ABT_2D.apply( + B, output_grad, + ctx.summa_dim, ctx.A_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) + B_grad = Matmul_AB_2D.apply( + A, output_grad, + ctx.summa_dim, ctx.B_shape, + ctx.row_rank, ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function): """Matrix add bias: :math:`C = A + b` """ @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, input: Tensor, bias: Tensor, @@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function): return output @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: row_rank = ctx.row_rank col_rank = ctx.col_rank @@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function): class _LayerNorm_2D(torch.autograd.Function): @staticmethod + @custom_fwd(cast_inputs=torch.float32) def forward(ctx: Any, input: Tensor, E_x: Tensor, @@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function): return output @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: row_parallel_mode = ctx.row_parallel_mode col_parallel_mode = ctx.col_parallel_mode @@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function): class _ViT_Split_Input_2D(torch.autograd.Function): @staticmethod + @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, batch_size: int, @@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function): return output @staticmethod + @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: # output_grad: [b/q, s, h/q] # grads: [b, s, h/q] diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index 067636a3d..0df30baab 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler): :type last_epoch: int, optional """ - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1, - **kwargs): + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1): base_scheduler = _CosineAnnealingLR( - optimizer, total_steps - warmup_steps, eta_min=eta_min) - super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) + optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch) + super().__init__(optimizer, warmup_steps, base_scheduler) @LR_SCHEDULERS.register_module diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py index c8972c922..173d2f52c 100644 --- a/colossalai/nn/lr_scheduler/delayed.py +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -55,7 +55,7 @@ class DelayerScheduler(_LRScheduler): class WarmupScheduler(_LRScheduler): - """ Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler + """ Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler :param optimizer: Wrapped optimizer. :type optimizer: torch.optim.Optimizer @@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler): :param last_epoch: The index of last epoch, defaults to -1 :type last_epoch: int, optional """ - def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): - if warmup_epochs < 0: - raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}') - self.warmup_epochs = warmup_epochs + self.warmup_epochs = int(warmup_epochs) self.after_scheduler = after_scheduler self.finished = False super().__init__(optimizer, last_epoch) @@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler): if self.last_epoch >= self.warmup_epochs: if not self.finished: self.after_scheduler.base_lrs = self.base_lrs - # reset lr to base_lr - for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): - group['lr'] = base_lr self.finished = True - with _enable_get_lr_call(self.after_scheduler): - return self.after_scheduler.get_lr() + return self.after_scheduler.get_lr() - return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs] + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] def step(self, epoch=None): if self.finished: diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py index 5f06cbeee..a8706ceb1 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/trainer/_trainer.py @@ -136,14 +136,14 @@ class Trainer: self.call_hooks('before_train_epoch') self._timer.start('train-epoch') for _ in progress: - self._cur_step += 1 - self.call_hooks('before_train_iter') self._timer.start('train-step') logits, label, loss = self._engine.step() self._timer.stop('train-step', keep_in_history=True) self.call_hooks('after_train_iter', output=(logits, label, loss)) + self._cur_step += 1 + if self.exceed_max_step(): # stop when max iter is reached break @@ -235,8 +235,6 @@ class Trainer: last_epoch = self._cur_epoch for epoch in range(last_epoch, self._max_epochs): - self._cur_epoch += 1 - # train for one epoch self._train_epoch(epoch) @@ -244,6 +242,8 @@ class Trainer: if should_test and epoch % test_interval == 0: self._eval(epoch, return_loss=True) + self._cur_epoch += 1 + # check for termination if self.exceed_max_step(): self._logger.info( diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 1496e77ac..d8c6663ba 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -27,7 +27,7 @@ def sync_model_param_in_dp(model): :param model: A pyTorch nn.model on whose parameters you check the consistency ''' - if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2: + if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: for param in model.parameters(): ranks = gpc.get_ranks_in_group(ParallelMode.DATA) dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))