mirror of https://github.com/hpcaitech/ColossalAI
fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
parent
c8cb9f9e34
commit
af88570f4b
|
@ -15,7 +15,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
|
||||
from ._utils import convert_to_fp16
|
||||
from ._utils import convert_to_fp16, convert_to_fp32
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ..amp import AMP_TYPE, GradScaler
|
||||
|
||||
|
@ -43,12 +43,6 @@ class NoPipelineSchedule(BaseSchedule):
|
|||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
||||
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
||||
|
||||
# LSG: check compatibility
|
||||
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
|
||||
ParallelMode.TENSOR) > 1:
|
||||
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
|
||||
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
|
||||
self.use_zero_level_2_3 = False
|
||||
|
||||
if amp_type is not None:
|
||||
|
@ -121,18 +115,6 @@ class NoPipelineSchedule(BaseSchedule):
|
|||
data, label = self.load_batch()
|
||||
loss = None
|
||||
|
||||
# LSG: leave for debug, make sure dataloader is deterministic
|
||||
# if forward_only:
|
||||
# img = data[0]
|
||||
# rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
# world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
# group = gpc.get_group(ParallelMode.DATA)
|
||||
# input_list = [img.clone() for _ in range(world_size)]
|
||||
# output_list = [torch.empty_like(img) for _ in range(world_size)]
|
||||
# output_list[rank] = img.clone()
|
||||
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
|
||||
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
|
||||
|
||||
# forward
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
with torch_amp.autocast():
|
||||
|
@ -146,6 +128,10 @@ class NoPipelineSchedule(BaseSchedule):
|
|||
data = convert_to_fp16(data)
|
||||
|
||||
output = self.model(*data)
|
||||
|
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||
output = convert_to_fp32(output)
|
||||
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
|
|
|
@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
|
|||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
||||
|
||||
|
||||
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
|
||||
if isinstance(data, Tensor):
|
||||
ret = data.float()
|
||||
elif isinstance(data, (list, tuple)):
|
||||
ret = [val.float() for val in data]
|
||||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from torch import Tensor
|
|||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
def matmul_2d(a,
|
||||
|
@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
|||
"""Matrix multiplication for :math:`C = AB`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
|
@ -120,32 +122,32 @@ class Matmul_AB_2D(torch.autograd.Function):
|
|||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
A_grad = Matmul_ABT_2D.forward(
|
||||
None,
|
||||
output_grad, B,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2D.forward(
|
||||
None,
|
||||
A, output_grad,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_2D.apply(
|
||||
output_grad, B,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2D.apply(
|
||||
A, output_grad,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
|||
"""Matrix multiplication for :math:`C = AB^T`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
|
@ -214,32 +217,33 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
|||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
A_grad = Matmul_AB_2D.forward(
|
||||
None,
|
||||
output_grad, B,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2D.forward(
|
||||
None,
|
||||
output_grad, A,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_AB_2D.apply(
|
||||
output_grad, B,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2D.apply(
|
||||
output_grad, A,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|||
"""Matrix multiplication for :math:`C = A^TB`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
|
@ -308,32 +313,33 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
A_grad = Matmul_ABT_2D.forward(
|
||||
None,
|
||||
B, output_grad,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_AB_2D.forward(
|
||||
None,
|
||||
A, output_grad,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_2D.apply(
|
||||
B, output_grad,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_AB_2D.apply(
|
||||
A, output_grad,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
|||
"""Matrix add bias: :math:`C = A + b`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
input: Tensor,
|
||||
bias: Tensor,
|
||||
|
@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
row_rank = ctx.row_rank
|
||||
col_rank = ctx.col_rank
|
||||
|
@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
|||
class _LayerNorm_2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx: Any,
|
||||
input: Tensor,
|
||||
E_x: Tensor,
|
||||
|
@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
row_parallel_mode = ctx.row_parallel_mode
|
||||
col_parallel_mode = ctx.col_parallel_mode
|
||||
|
@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function):
|
|||
class _ViT_Split_Input_2D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
batch_size: int,
|
||||
|
@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# output_grad: [b/q, s, h/q]
|
||||
# grads: [b, s, h/q]
|
||||
|
|
|
@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
|
|||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1,
|
||||
**kwargs):
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
|
||||
base_scheduler = _CosineAnnealingLR(
|
||||
optimizer, total_steps - warmup_steps, eta_min=eta_min)
|
||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
||||
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
|
||||
super().__init__(optimizer, warmup_steps, base_scheduler)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
|
|
|
@ -55,7 +55,7 @@ class DelayerScheduler(_LRScheduler):
|
|||
|
||||
|
||||
class WarmupScheduler(_LRScheduler):
|
||||
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
|
||||
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
|
||||
|
||||
:param optimizer: Wrapped optimizer.
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
|
@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler):
|
|||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
||||
if warmup_epochs < 0:
|
||||
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.warmup_epochs = int(warmup_epochs)
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler):
|
|||
if self.last_epoch >= self.warmup_epochs:
|
||||
if not self.finished:
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
# reset lr to base_lr
|
||||
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
||||
group['lr'] = base_lr
|
||||
self.finished = True
|
||||
with _enable_get_lr_call(self.after_scheduler):
|
||||
return self.after_scheduler.get_lr()
|
||||
return self.after_scheduler.get_lr()
|
||||
|
||||
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs]
|
||||
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if self.finished:
|
||||
|
|
|
@ -136,14 +136,14 @@ class Trainer:
|
|||
self.call_hooks('before_train_epoch')
|
||||
self._timer.start('train-epoch')
|
||||
for _ in progress:
|
||||
self._cur_step += 1
|
||||
|
||||
self.call_hooks('before_train_iter')
|
||||
self._timer.start('train-step')
|
||||
logits, label, loss = self._engine.step()
|
||||
self._timer.stop('train-step', keep_in_history=True)
|
||||
self.call_hooks('after_train_iter', output=(logits, label, loss))
|
||||
|
||||
self._cur_step += 1
|
||||
|
||||
if self.exceed_max_step():
|
||||
# stop when max iter is reached
|
||||
break
|
||||
|
@ -235,8 +235,6 @@ class Trainer:
|
|||
last_epoch = self._cur_epoch
|
||||
|
||||
for epoch in range(last_epoch, self._max_epochs):
|
||||
self._cur_epoch += 1
|
||||
|
||||
# train for one epoch
|
||||
self._train_epoch(epoch)
|
||||
|
||||
|
@ -244,6 +242,8 @@ class Trainer:
|
|||
if should_test and epoch % test_interval == 0:
|
||||
self._eval(epoch, return_loss=True)
|
||||
|
||||
self._cur_epoch += 1
|
||||
|
||||
# check for termination
|
||||
if self.exceed_max_step():
|
||||
self._logger.info(
|
||||
|
|
|
@ -27,7 +27,7 @@ def sync_model_param_in_dp(model):
|
|||
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
||||
'''
|
||||
|
||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2:
|
||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||
for param in model.parameters():
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
||||
|
|
Loading…
Reference in New Issue