mirror of https://github.com/hpcaitech/ColossalAI
fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
parent
c8cb9f9e34
commit
af88570f4b
|
@ -15,7 +15,7 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)
|
ZeroRedundancyOptimizer_Level_3)
|
||||||
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
|
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
|
||||||
from ._utils import convert_to_fp16
|
from ._utils import convert_to_fp16, convert_to_fp32
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
from ..amp import AMP_TYPE, GradScaler
|
from ..amp import AMP_TYPE, GradScaler
|
||||||
|
|
||||||
|
@ -43,12 +43,6 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
||||||
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
||||||
|
|
||||||
# LSG: check compatibility
|
|
||||||
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
|
|
||||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
|
|
||||||
ParallelMode.TENSOR) > 1:
|
|
||||||
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
|
|
||||||
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
|
|
||||||
self.use_zero_level_2_3 = False
|
self.use_zero_level_2_3 = False
|
||||||
|
|
||||||
if amp_type is not None:
|
if amp_type is not None:
|
||||||
|
@ -121,18 +115,6 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
data, label = self.load_batch()
|
data, label = self.load_batch()
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
# LSG: leave for debug, make sure dataloader is deterministic
|
|
||||||
# if forward_only:
|
|
||||||
# img = data[0]
|
|
||||||
# rank = gpc.get_local_rank(ParallelMode.DATA)
|
|
||||||
# world_size = gpc.get_world_size(ParallelMode.DATA)
|
|
||||||
# group = gpc.get_group(ParallelMode.DATA)
|
|
||||||
# input_list = [img.clone() for _ in range(world_size)]
|
|
||||||
# output_list = [torch.empty_like(img) for _ in range(world_size)]
|
|
||||||
# output_list[rank] = img.clone()
|
|
||||||
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
|
|
||||||
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
|
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||||
with torch_amp.autocast():
|
with torch_amp.autocast():
|
||||||
|
@ -146,6 +128,10 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
data = convert_to_fp16(data)
|
data = convert_to_fp16(data)
|
||||||
|
|
||||||
output = self.model(*data)
|
output = self.model(*data)
|
||||||
|
|
||||||
|
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||||
|
output = convert_to_fp32(output)
|
||||||
|
|
||||||
if not isinstance(output, (tuple, list)):
|
if not isinstance(output, (tuple, list)):
|
||||||
output = (output,)
|
output = (output,)
|
||||||
if return_loss:
|
if return_loss:
|
||||||
|
|
|
@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
|
||||||
|
if isinstance(data, Tensor):
|
||||||
|
ret = data.float()
|
||||||
|
elif isinstance(data, (list, tuple)):
|
||||||
|
ret = [val.float() for val in data]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from torch import Tensor
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
def matmul_2d(a,
|
def matmul_2d(a,
|
||||||
|
@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB`
|
"""Matrix multiplication for :math:`C = AB`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -120,10 +122,11 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_ABT_2D.forward(
|
with torch.no_grad():
|
||||||
None,
|
A_grad = Matmul_ABT_2D.apply(
|
||||||
output_grad, B,
|
output_grad, B,
|
||||||
ctx.summa_dim, ctx.A_shape,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -134,8 +137,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
B_grad = Matmul_ATB_2D.forward(
|
B_grad = Matmul_ATB_2D.apply(
|
||||||
None,
|
|
||||||
A, output_grad,
|
A, output_grad,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = AB^T`
|
"""Matrix multiplication for :math:`C = AB^T`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -214,10 +217,12 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_AB_2D.forward(
|
|
||||||
None,
|
with torch.no_grad():
|
||||||
|
A_grad = Matmul_AB_2D.apply(
|
||||||
output_grad, B,
|
output_grad, B,
|
||||||
ctx.summa_dim, ctx.A_shape,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -228,8 +233,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
B_grad = Matmul_ATB_2D.forward(
|
B_grad = Matmul_ATB_2D.apply(
|
||||||
None,
|
|
||||||
output_grad, A,
|
output_grad, A,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
"""Matrix multiplication for :math:`C = A^TB`
|
"""Matrix multiplication for :math:`C = A^TB`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
@ -308,10 +313,12 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
A, B = ctx.saved_tensors
|
A, B = ctx.saved_tensors
|
||||||
A_grad = Matmul_ABT_2D.forward(
|
|
||||||
None,
|
with torch.no_grad():
|
||||||
|
A_grad = Matmul_ABT_2D.apply(
|
||||||
B, output_grad,
|
B, output_grad,
|
||||||
ctx.summa_dim, ctx.A_shape,
|
ctx.summa_dim, ctx.A_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -322,8 +329,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
ctx.pipeline_parallel_size,
|
ctx.pipeline_parallel_size,
|
||||||
ctx.tensor_parallel_size
|
ctx.tensor_parallel_size
|
||||||
)
|
)
|
||||||
B_grad = Matmul_AB_2D.forward(
|
B_grad = Matmul_AB_2D.apply(
|
||||||
None,
|
|
||||||
A, output_grad,
|
A, output_grad,
|
||||||
ctx.summa_dim, ctx.B_shape,
|
ctx.summa_dim, ctx.B_shape,
|
||||||
ctx.row_rank, ctx.col_rank,
|
ctx.row_rank, ctx.col_rank,
|
||||||
|
@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
"""Matrix add bias: :math:`C = A + b`
|
"""Matrix add bias: :math:`C = A + b`
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
|
@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
row_rank = ctx.row_rank
|
row_rank = ctx.row_rank
|
||||||
col_rank = ctx.col_rank
|
col_rank = ctx.col_rank
|
||||||
|
@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||||
class _LayerNorm_2D(torch.autograd.Function):
|
class _LayerNorm_2D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
E_x: Tensor,
|
E_x: Tensor,
|
||||||
|
@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
row_parallel_mode = ctx.row_parallel_mode
|
row_parallel_mode = ctx.row_parallel_mode
|
||||||
col_parallel_mode = ctx.col_parallel_mode
|
col_parallel_mode = ctx.col_parallel_mode
|
||||||
|
@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||||
class _ViT_Split_Input_2D(torch.autograd.Function):
|
class _ViT_Split_Input_2D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any,
|
def forward(ctx: Any,
|
||||||
inputs: Tensor,
|
inputs: Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
# output_grad: [b/q, s, h/q]
|
# output_grad: [b/q, s, h/q]
|
||||||
# grads: [b, s, h/q]
|
# grads: [b, s, h/q]
|
||||||
|
|
|
@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1,
|
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
|
||||||
**kwargs):
|
|
||||||
base_scheduler = _CosineAnnealingLR(
|
base_scheduler = _CosineAnnealingLR(
|
||||||
optimizer, total_steps - warmup_steps, eta_min=eta_min)
|
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
|
||||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
super().__init__(optimizer, warmup_steps, base_scheduler)
|
||||||
|
|
||||||
|
|
||||||
@LR_SCHEDULERS.register_module
|
@LR_SCHEDULERS.register_module
|
||||||
|
|
|
@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler):
|
||||||
:param last_epoch: The index of last epoch, defaults to -1
|
:param last_epoch: The index of last epoch, defaults to -1
|
||||||
:type last_epoch: int, optional
|
:type last_epoch: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
||||||
if warmup_epochs < 0:
|
self.warmup_epochs = int(warmup_epochs)
|
||||||
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
|
|
||||||
self.warmup_epochs = warmup_epochs
|
|
||||||
self.after_scheduler = after_scheduler
|
self.after_scheduler = after_scheduler
|
||||||
self.finished = False
|
self.finished = False
|
||||||
super().__init__(optimizer, last_epoch)
|
super().__init__(optimizer, last_epoch)
|
||||||
|
@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler):
|
||||||
if self.last_epoch >= self.warmup_epochs:
|
if self.last_epoch >= self.warmup_epochs:
|
||||||
if not self.finished:
|
if not self.finished:
|
||||||
self.after_scheduler.base_lrs = self.base_lrs
|
self.after_scheduler.base_lrs = self.base_lrs
|
||||||
# reset lr to base_lr
|
|
||||||
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
|
||||||
group['lr'] = base_lr
|
|
||||||
self.finished = True
|
self.finished = True
|
||||||
with _enable_get_lr_call(self.after_scheduler):
|
|
||||||
return self.after_scheduler.get_lr()
|
return self.after_scheduler.get_lr()
|
||||||
|
|
||||||
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs]
|
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
|
||||||
|
|
||||||
def step(self, epoch=None):
|
def step(self, epoch=None):
|
||||||
if self.finished:
|
if self.finished:
|
||||||
|
|
|
@ -136,14 +136,14 @@ class Trainer:
|
||||||
self.call_hooks('before_train_epoch')
|
self.call_hooks('before_train_epoch')
|
||||||
self._timer.start('train-epoch')
|
self._timer.start('train-epoch')
|
||||||
for _ in progress:
|
for _ in progress:
|
||||||
self._cur_step += 1
|
|
||||||
|
|
||||||
self.call_hooks('before_train_iter')
|
self.call_hooks('before_train_iter')
|
||||||
self._timer.start('train-step')
|
self._timer.start('train-step')
|
||||||
logits, label, loss = self._engine.step()
|
logits, label, loss = self._engine.step()
|
||||||
self._timer.stop('train-step', keep_in_history=True)
|
self._timer.stop('train-step', keep_in_history=True)
|
||||||
self.call_hooks('after_train_iter', output=(logits, label, loss))
|
self.call_hooks('after_train_iter', output=(logits, label, loss))
|
||||||
|
|
||||||
|
self._cur_step += 1
|
||||||
|
|
||||||
if self.exceed_max_step():
|
if self.exceed_max_step():
|
||||||
# stop when max iter is reached
|
# stop when max iter is reached
|
||||||
break
|
break
|
||||||
|
@ -235,8 +235,6 @@ class Trainer:
|
||||||
last_epoch = self._cur_epoch
|
last_epoch = self._cur_epoch
|
||||||
|
|
||||||
for epoch in range(last_epoch, self._max_epochs):
|
for epoch in range(last_epoch, self._max_epochs):
|
||||||
self._cur_epoch += 1
|
|
||||||
|
|
||||||
# train for one epoch
|
# train for one epoch
|
||||||
self._train_epoch(epoch)
|
self._train_epoch(epoch)
|
||||||
|
|
||||||
|
@ -244,6 +242,8 @@ class Trainer:
|
||||||
if should_test and epoch % test_interval == 0:
|
if should_test and epoch % test_interval == 0:
|
||||||
self._eval(epoch, return_loss=True)
|
self._eval(epoch, return_loss=True)
|
||||||
|
|
||||||
|
self._cur_epoch += 1
|
||||||
|
|
||||||
# check for termination
|
# check for termination
|
||||||
if self.exceed_max_step():
|
if self.exceed_max_step():
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
|
|
|
@ -27,7 +27,7 @@ def sync_model_param_in_dp(model):
|
||||||
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2:
|
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
||||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
Loading…
Reference in New Issue