mirror of https://github.com/hpcaitech/ColossalAI
Hongxin Liu
1 year ago
committed by
GitHub
113 changed files with 627 additions and 631 deletions
@ -1,4 +1,4 @@
|
||||
from .model import ModelWrapper |
||||
from .model import AMPModelMixin, ModelWrapper |
||||
from .optimizer import OptimizerWrapper |
||||
|
||||
__all__ = ['OptimizerWrapper', 'ModelWrapper'] |
||||
__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin'] |
||||
|
@ -1,7 +1,7 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import GRADIENT_HANDLER |
||||
from colossalai.legacy.registry import GRADIENT_HANDLER |
||||
|
||||
from ...context.parallel_mode import ParallelMode |
||||
from ._base_gradient_handler import BaseGradientHandler |
||||
from .utils import bucket_allreduce |
||||
|
@ -1,9 +1,9 @@
|
||||
from colossalai.context.moe_context import MOE_CONTEXT |
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import GRADIENT_HANDLER |
||||
from colossalai.legacy.registry import GRADIENT_HANDLER |
||||
from colossalai.utils.moe import get_moe_epsize_param_dict |
||||
|
||||
from ...context.parallel_mode import ParallelMode |
||||
from ._base_gradient_handler import BaseGradientHandler |
||||
from .utils import bucket_allreduce |
||||
|
@ -1,7 +1,7 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import GRADIENT_HANDLER |
||||
from colossalai.legacy.registry import GRADIENT_HANDLER |
||||
|
||||
from ...context.parallel_mode import ParallelMode |
||||
from ._base_gradient_handler import BaseGradientHandler |
||||
from .utils import bucket_allreduce |
||||
|
@ -1,4 +1,4 @@
|
||||
from colossalai.registry import GRADIENT_HANDLER |
||||
from colossalai.legacy.registry import GRADIENT_HANDLER |
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler |
||||
|
@ -1,14 +1,13 @@
|
||||
from typing import Union, List, Any |
||||
from typing import Any, List, Union |
||||
|
||||
import torch |
||||
from torch.utils.data import DataLoader |
||||
from tqdm import tqdm |
||||
|
||||
from colossalai.engine import Engine |
||||
from colossalai.legacy.engine import Engine |
||||
from colossalai.legacy.trainer.hooks import BaseHook |
||||
from colossalai.logging import DistributedLogger |
||||
from colossalai.utils import MultiTimer |
||||
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage |
||||
from colossalai.trainer.hooks import BaseHook |
||||
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0 |
||||
|
||||
|
||||
class Trainer: |
@ -1,7 +1,12 @@
|
||||
from ._base_hook import BaseHook |
||||
from ._checkpoint_hook import SaveCheckpointHook |
||||
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook, |
||||
TensorboardHook) |
||||
from ._log_hook import ( |
||||
LogMemoryByEpochHook, |
||||
LogMetricByEpochHook, |
||||
LogMetricByStepHook, |
||||
LogTimingByEpochHook, |
||||
TensorboardHook, |
||||
) |
||||
from ._lr_scheduler_hook import LRSchedulerHook |
||||
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook |
||||
|
@ -1,11 +1,12 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
import torch |
||||
from colossalai.logging import get_dist_logger |
||||
|
||||
from colossalai.registry import HOOKS |
||||
from colossalai.trainer.hooks import BaseHook |
||||
from colossalai.legacy.registry import HOOKS |
||||
from colossalai.legacy.trainer.hooks import BaseHook |
||||
from colossalai.logging import get_dist_logger |
||||
from colossalai.utils.checkpointing import save_checkpoint |
||||
|
||||
from ._lr_scheduler_hook import LRSchedulerHook |
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from colossalai.registry import HOOKS |
||||
from torch import Tensor |
||||
|
||||
from colossalai.legacy.registry import HOOKS |
||||
|
||||
from ._metric_hook import LearningRateMetric, MetricHook |
||||
|
||||
|
@ -1,105 +1,106 @@
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from colossalai.context import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.registry import LOSSES |
||||
from torch.cuda.amp import custom_bwd, custom_fwd |
||||
from torch.nn.modules.loss import _Loss |
||||
|
||||
|
||||
class _VocabParallelCrossEntropy1D(torch.autograd.Function): |
||||
|
||||
@staticmethod |
||||
@custom_fwd(cast_inputs=torch.float32) |
||||
def forward(ctx, vocab_parallel_logits, targets, process_group): |
||||
if process_group is None: |
||||
process_group = gpc.get_group(ParallelMode.PARALLEL_1D) |
||||
|
||||
# Maximum value along vocab dimension across all GPUs. |
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] |
||||
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) |
||||
# Subtract the maximum value. |
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) |
||||
|
||||
# Get the partition's vocab indices |
||||
partition_vocab_size = vocab_parallel_logits.size()[-1] |
||||
rank = dist.get_rank(process_group) |
||||
vocab_start_index = partition_vocab_size * rank |
||||
vocab_end_index = vocab_start_index + partition_vocab_size |
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked). |
||||
target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) |
||||
masked_target = targets.clone() - vocab_start_index |
||||
masked_target[target_mask] = 0 |
||||
|
||||
# Get predicted-logits = logits[target]. |
||||
# For Simplicity, we convert logits to a 2-D tensor with size |
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*]. |
||||
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) |
||||
masked_target_1d = masked_target.view(-1) |
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) |
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] |
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous() |
||||
predicted_logits = predicted_logits_1d.view_as(targets) |
||||
predicted_logits[target_mask] = 0.0 |
||||
# All reduce is needed to get the chunks from other GPUs. |
||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) |
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs. |
||||
exp_logits = torch.exp(vocab_parallel_logits) |
||||
sum_exp_logits = exp_logits.sum(dim=-1) |
||||
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) |
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit. |
||||
loss = torch.log(sum_exp_logits) - predicted_logits |
||||
# Store softmax, target-mask and masked-target for backward pass. |
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) |
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) |
||||
return loss |
||||
|
||||
@staticmethod |
||||
@custom_bwd |
||||
def backward(ctx, grad_output): |
||||
|
||||
# Retrieve tensors from the forward path. |
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors |
||||
|
||||
# All the inputs have softmax as their gradient. |
||||
grad_input = softmax |
||||
# For simplicity, work with the 2D gradient. |
||||
partition_vocab_size = softmax.size()[-1] |
||||
grad_2d = grad_input.view(-1, partition_vocab_size) |
||||
|
||||
# Add the gradient from matching classes. |
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) |
||||
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) |
||||
|
||||
# Finally elementwise multiplication with the output gradients. |
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1)) |
||||
|
||||
return grad_input, None, None |
||||
|
||||
|
||||
@LOSSES.register_module |
||||
class VocabParallelCrossEntropyLoss1D(_Loss): |
||||
"""Vocab parallel cross entropy loss for 1D parallelism. |
||||
|
||||
Args: |
||||
reduction (bool, optional): whether to average the loss, defaults to True. |
||||
""" |
||||
|
||||
def __init__(self, reduction=True): |
||||
super().__init__() |
||||
self.reduction_mean = reduction |
||||
|
||||
def forward(self, logits, targets, process_group=None): |
||||
"""Calculate loss between logits and targets. |
||||
|
||||
Args: |
||||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). |
||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. |
||||
""" |
||||
loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) |
||||
if self.reduction_mean: |
||||
loss = loss.mean() |
||||
return loss |
||||
import torch |
||||
import torch.distributed as dist |
||||
from torch.cuda.amp import custom_bwd, custom_fwd |
||||
from torch.nn.modules.loss import _Loss |
||||
|
||||
from colossalai.context import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.legacy.registry import LOSSES |
||||
|
||||
|
||||
class _VocabParallelCrossEntropy1D(torch.autograd.Function): |
||||
|
||||
@staticmethod |
||||
@custom_fwd(cast_inputs=torch.float32) |
||||
def forward(ctx, vocab_parallel_logits, targets, process_group): |
||||
if process_group is None: |
||||
process_group = gpc.get_group(ParallelMode.PARALLEL_1D) |
||||
|
||||
# Maximum value along vocab dimension across all GPUs. |
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] |
||||
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) |
||||
# Subtract the maximum value. |
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) |
||||
|
||||
# Get the partition's vocab indices |
||||
partition_vocab_size = vocab_parallel_logits.size()[-1] |
||||
rank = dist.get_rank(process_group) |
||||
vocab_start_index = partition_vocab_size * rank |
||||
vocab_end_index = vocab_start_index + partition_vocab_size |
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked). |
||||
target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) |
||||
masked_target = targets.clone() - vocab_start_index |
||||
masked_target[target_mask] = 0 |
||||
|
||||
# Get predicted-logits = logits[target]. |
||||
# For Simplicity, we convert logits to a 2-D tensor with size |
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*]. |
||||
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) |
||||
masked_target_1d = masked_target.view(-1) |
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) |
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] |
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous() |
||||
predicted_logits = predicted_logits_1d.view_as(targets) |
||||
predicted_logits[target_mask] = 0.0 |
||||
# All reduce is needed to get the chunks from other GPUs. |
||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) |
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs. |
||||
exp_logits = torch.exp(vocab_parallel_logits) |
||||
sum_exp_logits = exp_logits.sum(dim=-1) |
||||
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) |
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit. |
||||
loss = torch.log(sum_exp_logits) - predicted_logits |
||||
# Store softmax, target-mask and masked-target for backward pass. |
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) |
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) |
||||
return loss |
||||
|
||||
@staticmethod |
||||
@custom_bwd |
||||
def backward(ctx, grad_output): |
||||
|
||||
# Retrieve tensors from the forward path. |
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors |
||||
|
||||
# All the inputs have softmax as their gradient. |
||||
grad_input = softmax |
||||
# For simplicity, work with the 2D gradient. |
||||
partition_vocab_size = softmax.size()[-1] |
||||
grad_2d = grad_input.view(-1, partition_vocab_size) |
||||
|
||||
# Add the gradient from matching classes. |
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) |
||||
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) |
||||
|
||||
# Finally elementwise multiplication with the output gradients. |
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1)) |
||||
|
||||
return grad_input, None, None |
||||
|
||||
|
||||
@LOSSES.register_module |
||||
class VocabParallelCrossEntropyLoss1D(_Loss): |
||||
"""Vocab parallel cross entropy loss for 1D parallelism. |
||||
|
||||
Args: |
||||
reduction (bool, optional): whether to average the loss, defaults to True. |
||||
""" |
||||
|
||||
def __init__(self, reduction=True): |
||||
super().__init__() |
||||
self.reduction_mean = reduction |
||||
|
||||
def forward(self, logits, targets, process_group=None): |
||||
"""Calculate loss between logits and targets. |
||||
|
||||
Args: |
||||
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). |
||||
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. |
||||
""" |
||||
loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) |
||||
if self.reduction_mean: |
||||
loss = loss.mean() |
||||
return loss |
||||
|
@ -1,80 +1,81 @@
|
||||
import torch.nn as nn |
||||
from colossalai.registry import LOSSES |
||||
from torch.nn.modules.loss import _Loss |
||||
from colossalai.context.moe_context import MOE_CONTEXT |
||||
|
||||
|
||||
@LOSSES.register_module |
||||
class MoeCrossEntropyLoss(_Loss): |
||||
r"""torch.nn.CrossEntropyLoss added with auxiliary loss. |
||||
|
||||
Args: |
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). |
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. |
||||
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. |
||||
|
||||
The ``args`` and ``kwargs`` should include parameters below: |
||||
:: |
||||
|
||||
weight (Tensor, optional) |
||||
size_average (bool, optional) |
||||
ignore_index (int, optional) |
||||
reduce (bool, optional) |
||||
reduction (str, optional) |
||||
label_smoothing (float, optional) |
||||
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in |
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. |
||||
""" |
||||
|
||||
def __init__(self, aux_weight: float = 0.01, *args, **kwargs): |
||||
super().__init__() |
||||
self.loss = nn.CrossEntropyLoss(*args, **kwargs) |
||||
self.aux_weight = aux_weight |
||||
|
||||
def forward(self, *args): |
||||
""" |
||||
The ``args`` should at least include parameters below: |
||||
:: |
||||
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). |
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. |
||||
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in |
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. |
||||
""" |
||||
main_loss = self.loss(*args) |
||||
aux_loss = MOE_CONTEXT.get_loss() |
||||
return main_loss + self.aux_weight * aux_loss |
||||
|
||||
|
||||
@LOSSES.register_module |
||||
class MoeLoss(_Loss): |
||||
"""A wrapper class for any loss module to add with auxiliary loss. |
||||
|
||||
Args: |
||||
aux_weight (float): Weight of auxiliary loss in total loss. |
||||
loss_fn (``Callable``): Loss function. |
||||
args (list): Args in loss function. |
||||
kwargs (dict): Kwargs in loss function |
||||
""" |
||||
|
||||
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): |
||||
super().__init__() |
||||
self.loss_fn = loss_fn(*args, **kwargs) |
||||
self.aux_weight = aux_weight |
||||
|
||||
def forward(self, *args, **kwargs): |
||||
""" |
||||
The ``args`` and ``kwargs`` should at least include parameters below: |
||||
:: |
||||
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). |
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. |
||||
|
||||
Note: |
||||
The ``args`` and ``kwargs`` may include different parameters varying with different loss function. |
||||
""" |
||||
main_loss = self.loss_fn(*args, **kwargs) |
||||
aux_loss = MOE_CONTEXT.get_loss() |
||||
return main_loss + self.aux_weight * aux_loss |
||||
import torch.nn as nn |
||||
from torch.nn.modules.loss import _Loss |
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT |
||||
from colossalai.legacy.registry import LOSSES |
||||
|
||||
|
||||
@LOSSES.register_module |
||||
class MoeCrossEntropyLoss(_Loss): |
||||
r"""torch.nn.CrossEntropyLoss added with auxiliary loss. |
||||
|
||||
Args: |
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). |
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. |
||||
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. |
||||
|
||||
The ``args`` and ``kwargs`` should include parameters below: |
||||
:: |
||||
|
||||
weight (Tensor, optional) |
||||
size_average (bool, optional) |
||||
ignore_index (int, optional) |
||||
reduce (bool, optional) |
||||
reduction (str, optional) |
||||
label_smoothing (float, optional) |
||||
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in |
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. |
||||
""" |
||||
|
||||
def __init__(self, aux_weight: float = 0.01, *args, **kwargs): |
||||
super().__init__() |
||||
self.loss = nn.CrossEntropyLoss(*args, **kwargs) |
||||
self.aux_weight = aux_weight |
||||
|
||||
def forward(self, *args): |
||||
""" |
||||
The ``args`` should at least include parameters below: |
||||
:: |
||||
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). |
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. |
||||
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in |
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. |
||||
""" |
||||
main_loss = self.loss(*args) |
||||
aux_loss = MOE_CONTEXT.get_loss() |
||||
return main_loss + self.aux_weight * aux_loss |
||||
|
||||
|
||||
@LOSSES.register_module |
||||
class MoeLoss(_Loss): |
||||
"""A wrapper class for any loss module to add with auxiliary loss. |
||||
|
||||
Args: |
||||
aux_weight (float): Weight of auxiliary loss in total loss. |
||||
loss_fn (``Callable``): Loss function. |
||||
args (list): Args in loss function. |
||||
kwargs (dict): Kwargs in loss function |
||||
""" |
||||
|
||||
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): |
||||
super().__init__() |
||||
self.loss_fn = loss_fn(*args, **kwargs) |
||||
self.aux_weight = aux_weight |
||||
|
||||
def forward(self, *args, **kwargs): |
||||
""" |
||||
The ``args`` and ``kwargs`` should at least include parameters below: |
||||
:: |
||||
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). |
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. |
||||
|
||||
Note: |
||||
The ``args`` and ``kwargs`` may include different parameters varying with different loss function. |
||||
""" |
||||
main_loss = self.loss_fn(*args, **kwargs) |
||||
aux_loss = MOE_CONTEXT.get_loss() |
||||
return main_loss + self.aux_weight * aux_loss |
||||
|
@ -1,2 +1,3 @@
|
||||
colossalai |
||||
torch |
||||
six |
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue