fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

pull/21/head^2
Frank Lee 2021-11-10 10:47:58 +08:00
parent c8cb9f9e34
commit af88570f4b
7 changed files with 112 additions and 111 deletions

View File

@ -15,7 +15,7 @@ from colossalai.core import global_context as gpc
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
from ._utils import convert_to_fp16
from ._utils import convert_to_fp16, convert_to_fp32
from ._base_schedule import BaseSchedule
from ..amp import AMP_TYPE, GradScaler
@ -43,12 +43,6 @@ class NoPipelineSchedule(BaseSchedule):
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
# LSG: check compatibility
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
ParallelMode.TENSOR) > 1:
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
self.use_zero_level_2_3 = False
if amp_type is not None:
@ -121,18 +115,6 @@ class NoPipelineSchedule(BaseSchedule):
data, label = self.load_batch()
loss = None
# LSG: leave for debug, make sure dataloader is deterministic
# if forward_only:
# img = data[0]
# rank = gpc.get_local_rank(ParallelMode.DATA)
# world_size = gpc.get_world_size(ParallelMode.DATA)
# group = gpc.get_group(ParallelMode.DATA)
# input_list = [img.clone() for _ in range(world_size)]
# output_list = [torch.empty_like(img) for _ in range(world_size)]
# output_list[rank] = img.clone()
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
@ -146,6 +128,10 @@ class NoPipelineSchedule(BaseSchedule):
data = convert_to_fp16(data)
output = self.model(*data)
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
output = convert_to_fp32(output)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:

View File

@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.float()
elif isinstance(data, (list, tuple)):
ret = [val.float() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret

View File

@ -7,6 +7,7 @@ from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
def matmul_2d(a,
@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -120,32 +122,32 @@ class Matmul_AB_2D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None,
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.forward(
None,
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.apply(
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -214,32 +217,33 @@ class Matmul_ABT_2D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_AB_2D.forward(
None,
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.forward(
None,
output_grad, A,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_AB_2D.apply(
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.apply(
output_grad, A,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
@ -308,32 +313,33 @@ class Matmul_ATB_2D(torch.autograd.Function):
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None,
B, output_grad,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2D.forward(
None,
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
B, output_grad,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2D.apply(
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank
col_rank = ctx.col_rank
@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function):
class _LayerNorm_2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function):
class _ViT_Split_Input_2D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
batch_size: int,
@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q]
# grads: [b, s, h/q]

View File

@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1,
**kwargs):
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
base_scheduler = _CosineAnnealingLR(
optimizer, total_steps - warmup_steps, eta_min=eta_min)
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
super().__init__(optimizer, warmup_steps, base_scheduler)
@LR_SCHEDULERS.register_module

View File

@ -55,7 +55,7 @@ class DelayerScheduler(_LRScheduler):
class WarmupScheduler(_LRScheduler):
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
:param optimizer: Wrapped optimizer.
:type optimizer: torch.optim.Optimizer
@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler):
:param last_epoch: The index of last epoch, defaults to -1
:type last_epoch: int, optional
"""
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
if warmup_epochs < 0:
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
self.warmup_epochs = warmup_epochs
self.warmup_epochs = int(warmup_epochs)
self.after_scheduler = after_scheduler
self.finished = False
super().__init__(optimizer, last_epoch)
@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler):
if self.last_epoch >= self.warmup_epochs:
if not self.finished:
self.after_scheduler.base_lrs = self.base_lrs
# reset lr to base_lr
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
group['lr'] = base_lr
self.finished = True
with _enable_get_lr_call(self.after_scheduler):
return self.after_scheduler.get_lr()
return self.after_scheduler.get_lr()
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs]
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
def step(self, epoch=None):
if self.finished:

View File

@ -136,14 +136,14 @@ class Trainer:
self.call_hooks('before_train_epoch')
self._timer.start('train-epoch')
for _ in progress:
self._cur_step += 1
self.call_hooks('before_train_iter')
self._timer.start('train-step')
logits, label, loss = self._engine.step()
self._timer.stop('train-step', keep_in_history=True)
self.call_hooks('after_train_iter', output=(logits, label, loss))
self._cur_step += 1
if self.exceed_max_step():
# stop when max iter is reached
break
@ -235,8 +235,6 @@ class Trainer:
last_epoch = self._cur_epoch
for epoch in range(last_epoch, self._max_epochs):
self._cur_epoch += 1
# train for one epoch
self._train_epoch(epoch)
@ -244,6 +242,8 @@ class Trainer:
if should_test and epoch % test_interval == 0:
self._eval(epoch, return_loss=True)
self._cur_epoch += 1
# check for termination
if self.exceed_max_step():
self._logger.info(

View File

@ -27,7 +27,7 @@ def sync_model_param_in_dp(model):
:param model: A pyTorch nn.model on whose parameters you check the consistency
'''
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2:
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))