diff --git a/.gitignore b/.gitignore index 162eb26a9..63ee85da4 100644 --- a/.gitignore +++ b/.gitignore @@ -137,8 +137,4 @@ dmypy.json .DS_Store #data/ -# launcher setting -tests/launcher/log -tests/launcher/personal - docs/.build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a4193fafe..208bdb7c5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: yapf args: ['--style=google', '--parallel', '--in-place'] - repo: https://github.com/pycqa/flake8 - rev: '' + rev: '4.0.1' hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-clang-format diff --git a/colossalai/amp/apex_amp/apex_amp.py b/colossalai/amp/apex_amp/apex_amp.py index 21390dc7d..b43228410 100644 --- a/colossalai/amp/apex_amp/apex_amp.py +++ b/colossalai/amp/apex_amp/apex_amp.py @@ -4,8 +4,9 @@ import torch.nn as nn try: import apex.amp as apex_amp -except: - pass +except ImportError: + raise ImportError('Cannot import apex.amp correctly.') + from torch import Tensor from colossalai.nn.optimizer import ColossalaiOptimizer diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index 9b948418f..5b4e5eeba 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -30,7 +30,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: """ depth = gpc.get_world_size(parallel_mode) if depth == 1: - out = [tensor] + out = tensor work = None else: shape = list(tensor.shape) @@ -96,34 +96,40 @@ def all_reduce(tensor: Tensor, async_op: bool = False) -> Tensor: depth = gpc.get_world_size(parallel_mode) if depth == 1: + out = tensor work = None else: - work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op) + out = tensor.contiguous() + work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) if async_op: - return tensor, work + return out, work else: - return tensor + return out def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): depth = gpc.get_world_size(parallel_mode) if depth == 1: + out = tensor work = None else: - work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op) + out = tensor.contiguous() + work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op) if async_op: - return tensor, work + return out, work else: - return tensor + return out def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): depth = gpc.get_world_size(parallel_mode) if depth == 1: + out = tensor work = None else: - work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) + out = tensor.contiguous() + work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) if async_op: - return tensor, work + return out, work else: - return tensor + return out diff --git a/colossalai/constants.py b/colossalai/constants.py index 0fb8ed77e..33babff96 100644 --- a/colossalai/constants.py +++ b/colossalai/constants.py @@ -19,23 +19,12 @@ INITIALIZER_MAPPING = { 'moe': 'Initializer_Moe' } -# 1D parallel -PARALLEL_INPUT_1D = 'parallel_input_1d' +# 3D parallelism groups +INPUT_GROUP_3D = 'input_group_3d' +WEIGHT_GROUP_3D = 'weight_group_3d' +OUTPUT_GROUP_3D = 'output_group_3d' -# 2D paralllel -SUMMA_DIM = 'SUMMA_DIM' - -# 2.5D paralllel -TESSERACT_DIM = 'TESSERACT_DIM' -TESSERACT_DEP = 'TESSERACT_DEP' - -# 3D parallel -DEPTH_3D = 'DEPTH_3D' -INPUT_GROUP_3D = 'PARALLEL_3D_INPUT' -WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT' -OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT' - -# Tensor parallel attributes +# Attributes of tensor parallel parameters IS_TENSOR_PARALLEL = 'is_tensor_parallel' NUM_PARTITIONS = 'num_partitions' TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index deee76f7f..b81c0b452 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -8,14 +8,15 @@ from typing import Union import numpy as np import torch import torch.distributed as dist -from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE +from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING from colossalai.context.config import Config +from colossalai.global_variables import moe_env +from colossalai.global_variables import tensor_parallel_env as env from colossalai.logging import get_dist_logger from colossalai.registry import DIST_GROUP_INITIALIZER from .parallel_mode import ParallelMode from .random import add_seed, get_seeds, set_mode -from colossalai.global_variables import moe_env class ParallelContext: @@ -307,7 +308,6 @@ class ParallelContext: port: int ): """Initializes the global distributed environment - :param rank: rank for the default process group :type rank: int :param world_size: world size of the default process group @@ -389,7 +389,8 @@ class ParallelContext: if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: tensor_parallel_mode = parallel_config['tensor']['mode'] assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" - os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode) + env.mode = tensor_parallel_mode + self.check_sanity() pg_init = [] diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py index e99068828..4d454f2a6 100644 --- a/colossalai/context/process_group_initializer/initializer_1d.py +++ b/colossalai/context/process_group_initializer/initializer_1d.py @@ -1,22 +1,18 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import os -import torch.distributed as dist -from colossalai.context import Config +import torch.distributed as dist +from colossalai.global_variables import tensor_parallel_env as env from colossalai.registry import DIST_GROUP_INITIALIZER -from .process_group_initializer import ProcessGroupInitializer + from ..parallel_mode import ParallelMode -from colossalai.constants import PARALLEL_INPUT_1D +from .process_group_initializer import ProcessGroupInitializer @DIST_GROUP_INITIALIZER.register_module class Initializer_1D(ProcessGroupInitializer): - """A ProcessGroupInitializer for 1d tensor parallelism. - - :param args: Args used to initialize ProcessGroupInitializer - :param kwargs: Kwargs used to initialize ProcessGroupInitializer - """ + '''A ProcessGroupInitializer for 1d tensor parallelism. + ''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -24,7 +20,7 @@ class Initializer_1D(ProcessGroupInitializer): def init_dist_group(self): """Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu. - + :return: (local_rank, group_world_size, process_group, ranks_in_group, mode) :rtype: Tuple """ @@ -33,7 +29,7 @@ class Initializer_1D(ProcessGroupInitializer): process_group = None group_world_size = None mode = ParallelMode.PARALLEL_1D - os.environ[PARALLEL_INPUT_1D] = '' + env.parallel_input_1d = False for i in range(self.num_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/context/process_group_initializer/initializer_2d.py index 7ce230fa2..b48ce60f9 100644 --- a/colossalai/context/process_group_initializer/initializer_2d.py +++ b/colossalai/context/process_group_initializer/initializer_2d.py @@ -1,34 +1,31 @@ import math -import os import torch.distributed as dist -from colossalai.constants import SUMMA_DIM from colossalai.registry import DIST_GROUP_INITIALIZER from .process_group_initializer import ProcessGroupInitializer from ..parallel_mode import ParallelMode +from colossalai.global_variables import tensor_parallel_env as env def _check_summa_env_var(summa_dim): # check environment variable for SUMMA - env_summa_dim = os.environ.get(SUMMA_DIM, None) + env_summa_dim = env.summa_dim if env_summa_dim: assert int(env_summa_dim) == summa_dim, \ 'SUMMA_DIM has been set in the current environment and ' \ 'does not match with the value passed to this initialized' else: - os.environ[SUMMA_DIM] = str(summa_dim) + env.summa_dim = summa_dim class Initializer_2D_Row(ProcessGroupInitializer): """2d tensor parallel initialization among rows. - :param num_group: The number of all tensor groups :param summa_dim: The dimension of SUMMA :param args: Args used to initialize base class :param kwargs: Kwargs used to initialize base class - :type num_group: int :type summa_dim: int """ @@ -132,7 +129,7 @@ class Initializer_2D(ProcessGroupInitializer): def init_dist_group(self): """Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu. - + :return: 2D tensor parallelism's information :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) """ diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py index f05d730b5..3c3e1b978 100644 --- a/colossalai/context/process_group_initializer/initializer_2p5d.py +++ b/colossalai/context/process_group_initializer/initializer_2p5d.py @@ -2,22 +2,21 @@ # -*- encoding: utf-8 -*- import math -import os import torch.distributed as dist - -from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP from colossalai.context import Config +from colossalai.global_variables import tensor_parallel_env as env from colossalai.registry import DIST_GROUP_INITIALIZER -from .process_group_initializer import ProcessGroupInitializer + from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): - # check environment variable for TESSERACT - env_tesseract_dim = os.environ.get(TESSERACT_DIM, None) - env_tesseract_dep = os.environ.get(TESSERACT_DEP, None) + # check global variable for TESSERACT + env_tesseract_dim = env.tesseract_dim + env_tesseract_dep = env.tesseract_dep if env_tesseract_dim and env_tesseract_dep: assert int(env_tesseract_dim) == tesseract_dim, \ @@ -27,8 +26,8 @@ def _check_tesseract_env_var(tesseract_dim: int, 'TESSERACT_DEP has been set in the current environment and ' \ 'does not match with the value passed to this initialized' else: - os.environ[TESSERACT_DIM] = str(tesseract_dim) - os.environ[TESSERACT_DEP] = str(tesseract_dep) + env.tesseract_dim = tesseract_dim + env.tesseract_dep = tesseract_dep # i row j col k dep @@ -245,7 +244,6 @@ class Initializer_2p5D(ProcessGroupInitializer): :param pipeline_parallel_size: Size of pipeline parallel :param tensor_parallel_size: Size of tensor parallel :param depth: The depth of 2p5d parallel - :type rank: int :type world_size: int :type config: Config @@ -281,7 +279,7 @@ class Initializer_2p5D(ProcessGroupInitializer): def init_dist_group(self): """Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu. - + :return: Whole 2p5D tensor parallelism's information :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) """ diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py index b17ed1962..edd8b4694 100644 --- a/colossalai/context/process_group_initializer/initializer_3d.py +++ b/colossalai/context/process_group_initializer/initializer_3d.py @@ -2,10 +2,9 @@ # -*- encoding: utf-8 -*- import math -import os import torch.distributed as dist -from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D +from colossalai.global_variables import tensor_parallel_env as env from colossalai.registry import DIST_GROUP_INITIALIZER from ..parallel_mode import ParallelMode @@ -13,15 +12,15 @@ from .process_group_initializer import ProcessGroupInitializer def _check_depth_env_var(depth): - # check environment variable for SUMMA - env_depth = os.environ.get(DEPTH_3D, None) + # check global variable + env_depth = env.depth_3d if env_depth: assert int(env_depth) == depth, \ 'DEPTH_3D has been set in the current environment and ' \ 'does not match with the value passed to this initialized' else: - os.environ[DEPTH_3D] = str(depth) + env.depth_3d = depth class Initializer_3D_Input(ProcessGroupInitializer): @@ -34,6 +33,7 @@ class Initializer_3D_Input(ProcessGroupInitializer): :type num_group: int :type depth: int """ + def __init__(self, num_group: int, depth: int, *args): super().__init__(*args) self.num_group = num_group @@ -50,15 +50,12 @@ class Initializer_3D_Input(ProcessGroupInitializer): process_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_INPUT - os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D + env.input_group_3d = mode for h in range(self.num_group): for i in range(self.depth): for k in range(self.depth): - ranks = [ - h * self.depth**3 + i + self.depth * - (j + self.depth * k) for j in range(self.depth) - ] + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)] group = dist.new_group(ranks) if self.rank in ranks: @@ -97,15 +94,12 @@ class Initializer_3D_Weight(ProcessGroupInitializer): process_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_WEIGHT - os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D + env.weight_group_3d = mode for h in range(self.num_group): for k in range(self.depth): for j in range(self.depth): - ranks = [ - h * self.depth**3 + i + self.depth * - (j + self.depth * k) for i in range(self.depth) - ] + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)] group = dist.new_group(ranks) if self.rank in ranks: @@ -118,7 +112,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer): class Initializer_3D_Output(ProcessGroupInitializer): - """3D tensor parallel initialization among weight. + """3D tensor parallel initialization among output. :param num_group: The number of all tensor groups :param depth: Depth of 3D parallelism @@ -144,15 +138,12 @@ class Initializer_3D_Output(ProcessGroupInitializer): process_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_OUTPUT - os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D + env.output_group_3d = mode for h in range(self.num_group): for i in range(self.depth): for j in range(self.depth): - ranks = [ - h * self.depth**3 + i + self.depth * - (j + self.depth * k) for k in range(self.depth) - ] + ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)] group = dist.new_group(ranks) if self.rank in ranks: @@ -170,6 +161,7 @@ class Initializer_3D(ProcessGroupInitializer): :param args: Args used to initialize ProcessGroupInitializer """ + def __init__(self, *args): super().__init__(*args) self.num_group = self.world_size // self.tensor_parallel_size @@ -178,16 +170,13 @@ class Initializer_3D(ProcessGroupInitializer): f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' _check_depth_env_var(self.depth) - self.input_initializer = Initializer_3D_Input(self.num_group, - self.depth, *args) - self.weight_initializer = Initializer_3D_Weight( - self.num_group, self.depth, *args) - self.output_initializer = Initializer_3D_Output( - self.num_group, self.depth, *args) + self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) + self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args) + self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args) def init_dist_group(self): """Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. - + :return: 3D tensor parallelism's information :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) """ diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py index 836f1f72b..b6503b778 100644 --- a/colossalai/engine/gradient_handler/__init__.py +++ b/colossalai/engine/gradient_handler/__init__.py @@ -9,4 +9,4 @@ from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler __all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', - 'MoeGradientHandler', 'SequenceParallelGradientHandler'] + 'MoeGradientHandler', 'SequenceParallelGradientHandler'] \ No newline at end of file diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index daaf1a1c0..d3c781b13 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -9,7 +9,6 @@ from typing import Iterable, Callable from .._base_engine import Engine from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device -from colossalai.nn.layer import split_batch class BaseSchedule(ABC): @@ -69,7 +68,6 @@ class BaseSchedule(ABC): self.batch_size = data.size(0) else: self.batch_size = next(iter(data.values())).size(0) - data, label = split_batch(data), split_batch(label) if to_gpu: return self._move_to_device(data), self._move_to_device(label) return data, label diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py index 48da054c0..04f6e891e 100644 --- a/colossalai/global_variables.py +++ b/colossalai/global_variables.py @@ -1,3 +1,51 @@ +from typing import Optional + + +class TensorParallelEnv(object): + + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = object.__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self, *args, **kwargs): + self.load(*args, **kwargs) + + def load(self, + mode: Optional[str] = None, + vocab_parallel: bool = False, + parallel_input_1d: bool = False, + summa_dim: int = None, + tesseract_dim: int = None, + tesseract_dep: int = None, + depth_3d: int = None, + input_group_3d=None, + weight_group_3d=None, + output_group_3d=None): + self.mode = mode + self.vocab_parallel = vocab_parallel + self.parallel_input_1d = parallel_input_1d + self.summa_dim = summa_dim + self.tesseract_dim = tesseract_dim + self.tesseract_dep = tesseract_dep + self.depth_3d = depth_3d + self.input_group_3d = input_group_3d + self.weight_group_3d = weight_group_3d + self.output_group_3d = output_group_3d + + def save(self): + return dict(mode=self.mode, + vocab_parallel=self.vocab_parallel, + parallel_input_1d=self.parallel_input_1d, + summa_dim=self.summa_dim, + tesseract_dim=self.tesseract_dim, + tesseract_dep=self.tesseract_dep, + depth_3d=self.depth_3d, + input_group_3d=self.input_group_3d, + weight_group_3d=self.weight_group_3d, + output_group_3d=self.output_group_3d) class MoeEnv: @@ -33,4 +81,6 @@ class MoeEnv: return self.aux_loss +tensor_parallel_env = TensorParallelEnv() + moe_env = MoeEnv() diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index a45a3e7ae..b2ecd9ff9 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -37,17 +37,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \ - = colossal_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + = colossal_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) return grad_input, grad_weight, grad_bias, None, None class MixedFusedLayerNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): super(MixedFusedLayerNorm, self).__init__() global colossal_layer_norm_cuda @@ -61,8 +61,8 @@ class MixedFusedLayerNorm(torch.nn.Module): normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) self.eps = eps - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype)) + self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype)) self.reset_parameters() def reset_parameters(self): diff --git a/colossalai/nn/layer/colossalai_layer/__init__.py b/colossalai/nn/layer/colossalai_layer/__init__.py index 54ed567eb..2ae1b07a7 100644 --- a/colossalai/nn/layer/colossalai_layer/__init__.py +++ b/colossalai/nn/layer/colossalai_layer/__init__.py @@ -1,7 +1,7 @@ -from ._utils import split_batch +from ._utils import partition_batch from .dropout import Dropout from .embedding import Embedding, PatchEmbedding from .linear import Classifier, Linear from .normalization import LayerNorm -__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch'] +__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch'] diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py index 0eb8e39e2..6271667cc 100644 --- a/colossalai/nn/layer/colossalai_layer/_utils.py +++ b/colossalai/nn/layer/colossalai_layer/_utils.py @@ -2,13 +2,13 @@ from torch import Tensor from ..parallel_2d._operation import split_tensor_2d from ..parallel_2p5d._operation import split_tensor_2p5d -from ..parallel_3d._operation import split_tensor_3d +from ..parallel_3d._operation import split_batch_3d from ..utils import get_tensor_parallel_mode -_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d} +_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d} -def split_batch(input_) -> Tensor: +def partition_batch(input_) -> Tensor: tensor_parallel_mode = get_tensor_parallel_mode() if tensor_parallel_mode in _parallel_split_batch: if isinstance(input_, dict): diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/nn/layer/colossalai_layer/dropout.py index f1dc297a1..8921b0884 100644 --- a/colossalai/nn/layer/colossalai_layer/dropout.py +++ b/colossalai/nn/layer/colossalai_layer/dropout.py @@ -1,8 +1,5 @@ -from contextlib import nullcontext - import torch.nn as nn from colossalai.context import ParallelMode, seed -from colossalai.utils import conditional_context from ..parallel_1d import * from ..utils import get_tensor_parallel_mode @@ -26,6 +23,8 @@ class Dropout(nn.Module): self.drop = nn.Dropout(p, inplace) def forward(self, *args): - cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR) - with cm: + if self.tensor_parallel in [None, '1d']: return self.drop(*args) + else: + with seed(ParallelMode.TENSOR): + return self.drop(*args) diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/nn/layer/colossalai_layer/embedding.py index b4c852c7e..daa74e8ae 100644 --- a/colossalai/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/nn/layer/colossalai_layer/embedding.py @@ -1,5 +1,5 @@ import math -from typing import Callable, Optional +from typing import Callable from colossalai.utils import get_current_device from torch import dtype, nn @@ -12,10 +12,21 @@ from ..parallel_3d import * from ..utils import get_tensor_parallel_mode from ..vanilla import * -_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D} +_parallel_embedding = { + '2d': Embedding2D, + '2.5d': Embedding2p5D, + '3d': Embedding3D, +} + +_vocab_parallel_embedding = { + '1d': VocabParallelEmbedding1D, + '2d': VocabParallelEmbedding2D, + '2.5d': VocabParallelEmbedding2p5D, + '3d': VocabParallelEmbedding3D +} _parallel_patchembedding = { - 'None': VanillaPatchEmbedding, + None: VanillaPatchEmbedding, '1d': VanillaPatchEmbedding, '2d': PatchEmbedding2D, '2.5d': PatchEmbedding2p5D, @@ -40,26 +51,23 @@ class Embedding(nn.Module): :param args: Args used in F.embedding :param kwargs: Kwargs used in F.embedding """ + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, dtype: dtype = None, weight_initializer: Callable = init.normal_(), + vocab_parallel_limit: int = 2048, *args, **kwargs) -> None: super().__init__() tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel == 'None': - self.embed = nn.Embedding(num_embeddings, - embedding_dim, - padding_idx=padding_idx, - device=get_current_device(), - dtype=dtype, - *args, - **kwargs) + if tensor_parallel is None or (tensor_parallel == '1d' and num_embeddings <= vocab_parallel_limit): + self.embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, + **kwargs).to(dtype).to(get_current_device()) weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) - else: + elif num_embeddings <= vocab_parallel_limit: self.embed = _parallel_embedding[tensor_parallel]( num_embeddings, embedding_dim, @@ -69,6 +77,16 @@ class Embedding(nn.Module): *args, **kwargs, ) + else: + self.embed = _vocab_parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) @property def weight(self): @@ -101,16 +119,19 @@ class PatchEmbedding(nn.Module): :param position_embed_initializer: The intializer of position embedding, defaults to zero :type position_embed_initializer: typing.Callable, optional """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()) -> None: + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_() + ) -> None: super().__init__() tensor_parallel = get_tensor_parallel_mode() self.embed = _parallel_patchembedding[tensor_parallel]( diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/nn/layer/colossalai_layer/linear.py index 69d458f09..baa2abf7c 100644 --- a/colossalai/nn/layer/colossalai_layer/linear.py +++ b/colossalai/nn/layer/colossalai_layer/linear.py @@ -1,7 +1,6 @@ import math -from typing import Callable, Optional +from typing import Callable -from colossalai.nn.layer.parallel_1d.layers import Classifier1D from colossalai.utils import get_current_device from torch import dtype, nn @@ -16,13 +15,20 @@ from ..vanilla import * _parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} _parallel_classifier = { - 'None': VanillaClassifier, + None: VanillaClassifier, '1d': Classifier1D, '2d': Classifier2D, '2.5d': Classifier2p5D, '3d': Classifier3D } +_vocab_parallel_classifier = { + '1d': VocabParallelClassifier1D, + '2d': VocabParallelClassifier2D, + '2.5d': VocabParallelClassifier2p5D, + '3d': VocabParallelClassifier3D +} + class Linear(nn.Module): """ @@ -40,8 +46,9 @@ class Linear(nn.Module): :type weight_initializer: typing.Callable, optional :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :type bias_initializer: typing.Callable, optional - :param kwargs: Kwargs used for initialization + :param kwargs: Kwargs used for particular parallelisms """ + def __init__(self, in_features: int, out_features: int, @@ -52,10 +59,10 @@ class Linear(nn.Module): **kwargs) -> None: super().__init__() tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel == 'None': - self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype) + if tensor_parallel is None: + self.layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device()) weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features) - if bias: + if self.layer.bias is not None: bias_initializer(self.layer.bias, fan_in=in_features) else: self.layer = _parallel_linear[tensor_parallel]( @@ -97,26 +104,38 @@ class Classifier(nn.Module): :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :type bias_initializer: typing.Callable, optional """ - def __init__( - self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1) - ) -> None: + + def __init__(self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + vocab_parallel_limit: int = 2048) -> None: super().__init__() - self.layer = _parallel_classifier[get_tensor_parallel_mode()]( - in_features, - num_classes, - weight=weight, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - ) + tensor_parallel = get_tensor_parallel_mode() + if num_classes <= vocab_parallel_limit or tensor_parallel is None: + self.layer = _parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) + else: + self.layer = _vocab_parallel_classifier[tensor_parallel]( + in_features, + num_classes, + weight=weight, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) @property def weight(self): diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/nn/layer/colossalai_layer/normalization.py index b29e1fbab..1f9277214 100644 --- a/colossalai/nn/layer/colossalai_layer/normalization.py +++ b/colossalai/nn/layer/colossalai_layer/normalization.py @@ -1,7 +1,6 @@ -from typing import Optional - from colossalai.utils import get_current_device from torch import nn +from colossalai import kernel from ... import init as init from ..parallel_1d import * @@ -11,7 +10,12 @@ from ..parallel_3d import * from ..utils import get_tensor_parallel_mode from ..vanilla import * -_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D} +_parallel_layernorm = { + '1d': kernel.LayerNorm, + '2d': LayerNorm2D, + '2.5d': LayerNorm2p5D, + '3d': LayerNorm3D +} class LayerNorm(nn.Module): @@ -28,11 +32,12 @@ class LayerNorm(nn.Module): :param dtype: The dtype of parameters, defaults to None :type dtype: torch.dtype, optional """ + def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None: super().__init__() tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel in ['None', '1d']: - self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) + if tensor_parallel is None: + self.norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) else: self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py index 6f2093a11..fddeedd7d 100644 --- a/colossalai/nn/layer/parallel_1d/__init__.py +++ b/colossalai/nn/layer/parallel_1d/__init__.py @@ -1,4 +1,7 @@ -from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row -from .layers import MixedFusedLayerNorm1D as LayerNorm1D +from .layers import (Classifier1D, Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row, + VocabParallelClassifier1D, VocabParallelEmbedding1D) -__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D'] +__all__ = [ + 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', + 'VocabParallelEmbedding1D' +] diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py index 602bd6c3f..cc1967f11 100644 --- a/colossalai/nn/layer/parallel_1d/_utils.py +++ b/colossalai/nn/layer/parallel_1d/_utils.py @@ -1,21 +1,20 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import os import torch import torch.distributed as dist -from colossalai.constants import PARALLEL_INPUT_1D from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env from ..utils import divide def set_parallel_input(input_parallel: bool): - os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else '' + env.parallel_input_1d = input_parallel def get_parallel_input(): - return bool(os.environ[PARALLEL_INPUT_1D]) + return env.parallel_input_1d def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 15e8fb834..daf54c126 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -2,8 +2,6 @@ # -*- encoding: utf-8 -*- import math -import numbers -from contextlib import nullcontext from typing import Callable, Tuple import torch @@ -11,17 +9,17 @@ import torch.nn.functional as F from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.registry import LAYERS -from colossalai.utils import get_current_device -from torch import Tensor, dtype +from colossalai.utils.cuda import get_current_device +from torch import Tensor from torch.nn.parameter import Parameter from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition -from ._operation import FusedLayerNormAffineFunction1D -from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, - split_forward_gather_backward) +from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, + reduce_input, set_parallel_input, split_forward_gather_backward) @LAYERS.register_module @@ -44,6 +42,7 @@ class Linear1D(torch.nn.Module): :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :type bias_initializer: typing.Callable, optional """ + def __init__(self, in_features: int, out_features: int, @@ -106,12 +105,13 @@ class Classifier1D(ParallelLayer): :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :type bias_initializer: typing.Callable, optional """ + def __init__(self, in_features: int, num_classes: int, weight: Parameter = None, bias: bool = True, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer): self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() set_parallel_input(False) + env.vocab_parallel = False def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.num_classes @@ -167,6 +168,84 @@ class Classifier1D(ParallelLayer): return output +@LAYERS.register_module +class VocabParallelClassifier1D(ParallelLayer): + """ColLinear with given weight + Classifier of 1D parallelism + + :param in_features: size of input features + :type in_features: int + :param num_classes: number of classes in the dataset + :type num_classes: int + :param weight: weight of the classifier, defaults to True + :type weight: torch.nn.Parameter, optional + :param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True`` + :type bias: bool, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + output = F.linear(input_parallel, self.weight, self.bias) + return output + + @LAYERS.register_module class Linear1D_Col(ParallelLayer): """Linear layer with column parallelism. @@ -324,7 +403,7 @@ class Linear1D_Row(ParallelLayer): weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) def _set_tensor_parallel_attributes(self): num_partition = gpc.get_world_size(ParallelMode.TENSOR) @@ -341,45 +420,13 @@ class Linear1D_Row(ParallelLayer): output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) if not self.skip_bias_add: - output = output + self.bias + if self.bias is not None: + output = output + self.bias return output else: return output, self.bias -@LAYERS.register_module -class MixedFusedLayerNorm1D(torch.nn.Module): - r""" - Layer Normalization for 1D parallelism - - :param normalized_shape: input shape from an expected input - of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - :type normalized_shape: int - :param eps: a value added to the denominator for numerical stability, defaults to 1e-05 - :type eps: float, optional - """ - - def __init__(self, normalized_shape, eps=1e-5): - super(MixedFusedLayerNorm1D, self).__init__() - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape, ) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, input): - return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) - - @LAYERS.register_module class Embedding1D(ParallelLayer): """ @@ -398,11 +445,12 @@ class Embedding1D(ParallelLayer): :param args: Args used in F.embedding :param kwargs: Kwargs used in F.embedding """ + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -446,6 +494,84 @@ class Embedding1D(ParallelLayer): return output +@LAYERS.register_module +class VocabParallelEmbedding1D(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + :param num_embeddings: number of embeddings + :type num_embeddings: int + :param embedding_dim: dimension of embedding + :type embedding_dim: int + :param padding_idx: index of padding, defaults to None + :type padding_idx: int, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param weight_initializer: The intializer of weight, defaults to normal initializer + :type weight_initializer: typing.Callable, optional + :param args: Args used in F.embedding + :param kwargs: Kwargs used in F.embedding + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + return output + + @LAYERS.register_module class Dropout1D(ParallelLayer): """ @@ -456,6 +582,7 @@ class Dropout1D(ParallelLayer): :param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False`` :type inplace: bool, optional """ + def __init__(self, p: float = 0.5, inplace: bool = False): super().__init__() self.parallel_input = get_parallel_input() @@ -463,7 +590,9 @@ class Dropout1D(ParallelLayer): self.inplace = inplace def forward(self, input_: Tensor) -> Tensor: - cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR) - with cm: + if self.parallel_input: + with seed(ParallelMode.TENSOR): + output = F.dropout(input_, self.p, self.training, self.inplace) + else: output = F.dropout(input_, self.p, self.training, self.inplace) return output diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/nn/layer/parallel_2d/__init__.py index 2122a1bfe..9bb62b456 100644 --- a/colossalai/nn/layer/parallel_2d/__init__.py +++ b/colossalai/nn/layer/parallel_2d/__init__.py @@ -1,6 +1,8 @@ from ._operation import reduce_by_batch_2d, split_tensor_2d -from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D +from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D, + VocabParallelEmbedding2D) __all__ = [ - 'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D' + 'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', + 'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D' ] diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index f186188a7..f5c16671a 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc from colossalai.utils import get_current_device from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.global_variables import tensor_parallel_env as env def matmul_2d( @@ -22,6 +23,7 @@ def matmul_2d( ): """ Matrix multiplication for 2D parallelism + :param a: matrix :math:`A` :type a: torch.tensor :param b: matrix :math:`B` @@ -56,37 +58,7 @@ def matmul_2d( data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) -class classifier_2d(torch.autograd.Function): - """ - Classifier - - :param a: matrix :math:`A` - :type a: torch.tensor - :param b: matrix :math:`B` - :type b: torch.tensor - :param bias: matrix of bias - :type bias: torch.tensor, optional - :param summa_dim: dimension of SUMMA fo 2D parallelism - :type summa_dim: int - :param out_shape: shape of output tensor - :type out_shape: tuple - :param row_rank: the rank of row - :type row_rank: int - :param col_rank: the rank of column - :type col_rank: int - :param row_parallel_mode: row parallel mode - :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param col_parallel_mode: column parallel mode - :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param data_parallel_rank: data parallel rank - :type data_parallel_rank: int - :param pipeline_parallel_rank: pipeline parallel rank - :type pipeline_parallel_rank: int - :param pipeline_parallel_size: pipeline parallel size - :type pipeline_parallel_size: int - :param tensor_parallel_size: tensor parallel size - :type tensor_parallel_size: int - """ +class _Classifier2D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -150,14 +122,54 @@ class classifier_2d(torch.autograd.Function): B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) B_grad = B_grad.reshape(ctx.B_shape) - bias_grad = None if ctx.use_bias: bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + else: + bias_grad = None return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None +def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape: Tuple[int, ...], + row_rank: int, col_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + """ + 2D parallel classifier + + :param a: matrix :math:`A` + :type a: torch.tensor + :param b: matrix :math:`B` + :type b: torch.tensor + :param bias: matrix of bias + :type bias: torch.tensor, optional + :param summa_dim: dimension of SUMMA fo 2D parallelism + :type summa_dim: int + :param out_shape: shape of output tensor + :type out_shape: tuple + :param row_rank: the rank of row + :type row_rank: int + :param col_rank: the rank of column + :type col_rank: int + :param row_parallel_mode: row parallel mode + :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param col_parallel_mode: column parallel mode + :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param data_parallel_rank: data parallel rank + :type data_parallel_rank: int + :param pipeline_parallel_rank: pipeline parallel rank + :type pipeline_parallel_rank: int + :param pipeline_parallel_size: pipeline parallel size + :type pipeline_parallel_size: int + :param tensor_parallel_size: tensor parallel size + :type tensor_parallel_size: int + """ + return _Classifier2D.apply(A, B, bias, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, + tensor_parallel_size) + + class Matmul_AB_2D(torch.autograd.Function): """ Matrix multiplication for :math:`C = AB` @@ -230,9 +242,9 @@ class Matmul_AB_2D(torch.autograd.Function): col_group = gpc.get_group(col_parallel_mode) src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size opa = [None] * 2 opb = [None] * 2 @@ -361,9 +373,9 @@ class Matmul_ABT_2D(torch.autograd.Function): col_group = gpc.get_group(col_parallel_mode) src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size opb = [None] * 2 opr = [None] * 2 @@ -501,9 +513,9 @@ class Matmul_ATB_2D(torch.autograd.Function): col_group = gpc.get_group(col_parallel_mode) src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size opa = [None] * 2 opr = [None] * 2 @@ -572,35 +584,7 @@ class Matmul_ATB_2D(torch.autograd.Function): return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None -class add_bias_2d(torch.autograd.Function): - """ - Matrix add bias: :math:`C = A + b` - - :param input_: matrix :math:`A` - :type input_: torch.tensor - :param bias: matrix :math:`b` - :type bias: torch.tensor - :param output_size_per_partition: size of ouput per partition - :type output_size_per_partition: int - :param row_rank: the rank of row - :type row_rank: int - :param col_rank: the rank of column - :type col_rank: int - :param row_parallel_mode: row parallel mode - :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param col_parallel_mode: column parallel mode - :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion - :type skip_bias_add: bool - :param data_parallel_rank: data parallel rank - :type data_parallel_rank: int - :param pipeline_parallel_rank: pipeline parallel rank - :type pipeline_parallel_rank: int - :param pipeline_parallel_size: pipeline parallel size - :type pipeline_parallel_size: int - :param tensor_parallel_size: tensor parallel size - :type tensor_parallel_size: int - """ +class _Add_Bias_2D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -651,31 +635,47 @@ class add_bias_2d(torch.autograd.Function): return output_grad, grad, None, None, None, None, None, None, None, None, None, None -class layernorm_2d(torch.autograd.Function): +def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, row_rank: int, col_rank: int, + row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, skip_bias_add: bool, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: """ - Layernorm + Matrix add bias: :math:`C = A + b` - :param input_: input maxtrix + :param input_: matrix :math:`A` :type input_: torch.tensor - :param E_x: mean - :type E_x: torch.tensor - :param Var_x: variance - :type Var_x: torch.tensor - :param hidden_size: hidden size - :type hidden_size: int + :param bias: matrix :math:`b` + :type bias: torch.tensor + :param output_size_per_partition: size of ouput per partition + :type output_size_per_partition: int + :param row_rank: the rank of row + :type row_rank: int + :param col_rank: the rank of column + :type col_rank: int :param row_parallel_mode: row parallel mode :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode :param col_parallel_mode: column parallel mode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion + :type skip_bias_add: bool + :param data_parallel_rank: data parallel rank + :type data_parallel_rank: int + :param pipeline_parallel_rank: pipeline parallel rank + :type pipeline_parallel_rank: int + :param pipeline_parallel_size: pipeline parallel size + :type pipeline_parallel_size: int + :param tensor_parallel_size: tensor parallel size + :type tensor_parallel_size: int """ + return _Add_Bias_2D.apply(input_, bias, output_size_per_partition, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, + pipeline_parallel_size, tensor_parallel_size) + + +class _Layernorm_2D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx: Any, - input_: Tensor, - E_x: Tensor, - Var_x: Tensor, - hidden_size: int, - row_parallel_mode: ParallelMode, + def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode) -> Tensor: input_ = input_ - E_x # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) @@ -709,76 +709,64 @@ class layernorm_2d(torch.autograd.Function): return input_grad, None, None, None, None, None -class all_gather_weight_2d(torch.autograd.Function): +def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode) -> Tensor: """ - all gather the weight of 2D parallelism + Layernorm - :param inputs: input maxtrix - :type inputs: torch.tensor - :param dim: dimension of all gather - :type dim: int - :param summa_dim: dimension of SUMMA fo 2D parallelism - :type summa_dim: int + :param input_: input maxtrix + :type input_: torch.tensor + :param E_x: mean + :type E_x: torch.tensor + :param Var_x: variance + :type Var_x: torch.tensor + :param hidden_size: hidden size + :type hidden_size: int + :param row_parallel_mode: row parallel mode + :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode :param col_parallel_mode: column parallel mode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode """ + return _Layernorm_2D.apply(input_, E_x, Var_x, hidden_size, row_parallel_mode, col_parallel_mode) + + +class _AllGatherTensor2D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor: + def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: ctx.dim = dim - ctx.summa_dim = summa_dim - ctx.row_rank = gpc.get_local_rank(col_parallel_mode) + ctx.parallel_mode = parallel_mode - outputs = all_gather(inputs, dim, col_parallel_mode) + outputs = all_gather(inputs, dim, parallel_mode) return outputs @staticmethod @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - grad = output_grad.chunk(ctx.summa_dim, dim=ctx.dim)[ctx.row_rank] - return grad.contiguous(), None, None, None + grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) + return grad.contiguous(), None, None -class SplitFirst(torch.autograd.Function): +def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: """ + All gather the tensor of 2D parallelism + :param inputs: input maxtrix :type inputs: torch.tensor - :param summa_dim: dimension of SUMMA fo 2D parallelism - :type summa_dim: int - :param col_parallel_mode: column parallel mode - :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param dim: dimension to gather + :type dim: int + :param parallel_mode: parallel mode + :type parallel_mode: colossalai.context.parallel_mode.ParallelMode """ - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor: - ctx.summa_dim = summa_dim - ctx.batch_size = inputs.size(0) - ctx.para_mode = col_parallel_mode - row_rank = gpc.get_local_rank(col_parallel_mode) - - outputs = inputs.chunk(summa_dim, dim=0)[row_rank] - return outputs - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - grad_shape = (ctx.batch_size, ) + output_grad.shape[1:] - grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) - dist.all_gather(list(grad.chunk(ctx.summa_dim, dim=0)), - output_grad.contiguous(), - group=gpc.get_group(ctx.para_mode)) - return grad, None, None + return _AllGatherTensor2D.apply(tensor, dim, parallel_mode) def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: """Splits 2D tensor in specified dimension across cols - :param input_: Input tensor :param dim: Specified dimension in which to split - :type input_: torch.Tensor :type dim: int, optional - :return output: Splitted tensor :rtype output: torch.Tensor """ @@ -788,9 +776,50 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() -class reduce_by_batch_2d(torch.autograd.Function): - """All-reduce the input from the model parallel region. +class _ReduceTensor2D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: """ + All-reduce the input. + + :param input_: input tensor + :param parallel_mode: parallel mode + """ + return _ReduceTensor2D.apply(input_, parallel_mode) + + +class _ReduceScatterTensor2D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None + + +def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + """ + Reduce-scatter the input. + + :param tensor: Input tensor + :param dim: Dimension to scatter + :param parallel_mode: Parallel mode + """ + return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) + + +class _ReduceByBatch2D(torch.autograd.Function): @staticmethod def symbolic(graph, input_, reduce_mean: bool = False): output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) @@ -802,12 +831,6 @@ class reduce_by_batch_2d(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input_, reduce_mean: bool = False): - """ - :param input_: input maxtrix - :type input_: torch.tensor - :param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False - :type reduce_mean: int, optional - """ output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) ctx.reduce_mean = reduce_mean if reduce_mean: @@ -823,3 +846,14 @@ class reduce_by_batch_2d(torch.autograd.Function): return output_grad / ctx.reduce_size, None else: return output_grad, None + + +def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor: + """All-reduce the input from the model parallel region. + + :param input_: input maxtrix + :type input_: torch.tensor + :param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False + :type reduce_mean: bool, optional + """ + return _ReduceByBatch2D.apply(input_, reduce_mean) \ No newline at end of file diff --git a/colossalai/nn/layer/parallel_2d/_utils.py b/colossalai/nn/layer/parallel_2d/_utils.py index 65d3af2b0..012fec41c 100644 --- a/colossalai/nn/layer/parallel_2d/_utils.py +++ b/colossalai/nn/layer/parallel_2d/_utils.py @@ -1,14 +1,11 @@ -import os - from colossalai.context.parallel_mode import ParallelMode -from colossalai.context.process_group_initializer.initializer_2d import SUMMA_DIM from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env def get_summa_dim_from_env() -> int: try: - summa_dim = os.environ[SUMMA_DIM] - summa_dim = int(summa_dim) + summa_dim = env.summa_dim assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' return summa_dim diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index 51c62c75c..b6adbcecd 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -7,15 +7,16 @@ import torch.nn.functional as F from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.registry import LAYERS -from colossalai.utils import get_current_device -from torch import Tensor, dtype +from colossalai.utils.cuda import get_current_device +from torch import Tensor from torch.nn import Parameter -from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..base_layer import ParallelLayer -from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import * from ._utils import assert_summa_initialization, get_summa_dim_from_env @@ -43,7 +44,7 @@ class Linear2D(ParallelLayer): in_features: int, out_features: int, bias: bool = True, - dtype=None, + dtype: torch.dtype = None, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): @@ -101,16 +102,16 @@ class Linear2D(ParallelLayer): if self.bias is not None: if self.skip_bias_add: - bias = add_bias_2d.apply(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, - self.pipeline_parallel_size, self.tensor_parallel_size) + bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) return output, bias else: - output = add_bias_2d.apply(output, self.bias, self.hidden_size_per_partition, self.row_rank, - self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, - False, self.data_parallel_rank, self.pipeline_parallel_rank, - self.pipeline_parallel_size, self.tensor_parallel_size) + output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) return output else: return output @@ -174,16 +175,14 @@ class LayerNorm2D(ParallelLayer): # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) - output = layernorm_2d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, - ParallelMode.PARALLEL_2D_COL) - bias = add_bias_2d.apply(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - scale = add_bias_2d.apply(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL) + bias = add_bias_2d(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + scale = add_bias_2d(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) output = torch.addcmul(bias, scale, output) return output @@ -217,8 +216,8 @@ class PatchEmbedding2D(ParallelLayer): patch_size: int, in_chans: int, embed_size: int, - dtype: dtype = None, flatten: bool = True, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), position_embed_initializer: Callable = init.zeros_()): @@ -268,19 +267,21 @@ class PatchEmbedding2D(ParallelLayer): position_embed_initializer(self.pos_embed) def forward(self, input_: Tensor) -> Tensor: + input_ = split_tensor_2d(input_) + B, C, H, W = input_.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) - bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) + weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL) + bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL) output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: output = output.flatten(2).transpose(1, 2) # BCHW -> BNC - cls_token = all_gather_weight_2d.apply(self.cls_token, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) - pos_embed = all_gather_weight_2d.apply(self.pos_embed, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) + cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) + pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) cls_token = cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) output = output + pos_embed @@ -310,7 +311,7 @@ class Embedding2D(ParallelLayer): num_embeddings: int, embedding_dim: int, padding_idx: int = None, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -347,13 +348,90 @@ class Embedding2D(ParallelLayer): self.weight[self.padding_idx].fill_(0) def forward(self, input_: Tensor) -> Tensor: - weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) + input_ = split_tensor_2d(input_) + weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output +@LAYERS.register_module +class VocabParallelEmbedding2D(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + :param num_embeddings: number of embeddings + :type num_embeddings: int + :param embedding_dim: dimension of embedding + :type embedding_dim: int + :param padding_idx: index of padding, defaults to None + :type padding_idx: int, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param weight_initializer: The intializer of weight, defaults to normal initializer + :type weight_initializer: typing.Callable, optional + :param args: Args used in F.embedding + :param kwargs: Kwargs used in F.embedding + """ + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + assert_summa_initialization() + self.summa_dim = get_summa_dim_from_env() + self.num_embeddings_per_partition = divide(self.num_embeddings, self.summa_dim) + self.embed_dim_per_partition = divide(self.embed_dim, self.summa_dim) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + output_parallel[input_mask, :] = 0. + output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL) + return output + + @LAYERS.register_module class Classifier2D(ParallelLayer): """ @@ -379,7 +457,7 @@ class Classifier2D(ParallelLayer): num_classes: int, weight: Parameter = None, bias: bool = True, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -429,7 +507,101 @@ class Classifier2D(ParallelLayer): def forward(self, input_: Tensor) -> Tensor: out_shape = input_.shape[:-1] + (self.num_classes, ) - return classifier_2d.apply(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, - self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, + self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + + +@LAYERS.register_module +class VocabParallelClassifier2D(ParallelLayer): + """ + Vocab parallel classifier layer for 2D parallelism + + :param in_features: size of each input sample + :type in_features: int + :param num_classes: number of classes + :type num_classes: int + :param weight: weight of the classifier, defaults to True + :type weight: torch.nn.Parameter, optional + :param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True`` + :type bias: bool, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + """ + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.num_classes = num_classes + + # parallel setting + assert_summa_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + self.summa_dim = get_summa_dim_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.summa_dim) + self.output_size_per_partition = divide(num_classes, self.summa_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(divide(self.num_classes, self.summa_dim**2), **factory_kwargs)) + else: + self.bias = None + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/q, n/q, k/q] + # output: [m/q, n/q, h/q] + out_shape = x.shape[:-1] + (self.output_size_per_partition, ) + + output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + + if self.bias is not None: + output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/nn/layer/parallel_2p5d/__init__.py index 202c948c5..5ca351605 100644 --- a/colossalai/nn/layer/parallel_2p5d/__init__.py +++ b/colossalai/nn/layer/parallel_2p5d/__init__.py @@ -1,7 +1,8 @@ from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d -from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D +from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D, + VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D) __all__ = [ 'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', - 'Embedding2p5D' + 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D' ] diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py index 177919cf7..8974ff377 100644 --- a/colossalai/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -22,42 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode): return gpc.get_local_rank(parallel_mode) -def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: - return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), - dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() - - -class classifier_2p5d(torch.autograd.Function): - """ - Classifier - - :param a: matrix :math:`A` - :type a: torch.tensor - :param b: matrix :math:`B` - :type b: torch.tensor - :param bias: matrix of bias - :type bias: torch.tensor, optional - :param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism - :type tesseract_dim: int - :param out_shape: shape of output tensor - :type out_shape: tuple - :param row_rank: the rank of row - :type row_rank: int - :param col_rank: the rank of column - :type col_rank: int - :param row_parallel_mode: row parallel mode - :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param col_parallel_mode: column parallel mode - :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param data_parallel_rank: data parallel rank - :type data_parallel_rank: int - :param pipeline_parallel_rank: pipeline parallel rank - :type pipeline_parallel_rank: int - :param pipeline_parallel_size: pipeline parallel size - :type pipeline_parallel_size: int - :param tensor_parallel_size: tensor parallel size - :type tensor_parallel_size: int - """ +class _Classifier2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -122,12 +87,54 @@ class classifier_2p5d(torch.autograd.Function): B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) B_grad = B_grad.reshape(ctx.B_shape) - bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) - bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) + bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + else: + bias_grad = None return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None +def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int, + ...], row_rank: int, col_rank: int, + row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int, + pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor: + """ + Classifier + + :param a: matrix :math:`A` + :type a: torch.tensor + :param b: matrix :math:`B` + :type b: torch.tensor + :param bias: matrix of bias + :type bias: torch.tensor, optional + :param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism + :type tesseract_dim: int + :param out_shape: shape of output tensor + :type out_shape: tuple + :param row_rank: the rank of row + :type row_rank: int + :param col_rank: the rank of column + :type col_rank: int + :param row_parallel_mode: row parallel mode + :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param col_parallel_mode: column parallel mode + :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param data_parallel_rank: data parallel rank + :type data_parallel_rank: int + :param pipeline_parallel_rank: pipeline parallel rank + :type pipeline_parallel_rank: int + :param pipeline_parallel_size: pipeline parallel size + :type pipeline_parallel_size: int + :param tensor_parallel_size: tensor parallel size + :type tensor_parallel_size: int + """ + return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode, + col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, + tensor_parallel_size) + + class Matmul_AB_2p5D(torch.autograd.Function): """ Matrix multiplication for :math:`C = AB` @@ -522,37 +529,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function): return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None -class Add_Bias_2p5D(torch.autograd.Function): - """ - Matrix add bias: :math:`C = A + b` - - :param input: matrix :math:`A` - :type input: torch.tensor - :param bias: matrix :math:`b` - :type bias: torch.tensor - :param output_size_per_partition: output size in each partition - :type output_size_per_partition: int - :param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism - :type tesseract_dim: int - :param row_rank: the rank of row - :type row_rank: int - :param col_rank: the rank of column - :type col_rank: int - :param row_parallel_mode: row parallel mode - :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param col_parallel_mode: column parallel mode - :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion - :type skip_bias_add: bool - :param data_parallel_rank: data parallel rank - :type data_parallel_rank: int - :param pipeline_parallel_rank: pipeline parallel rank - :type pipeline_parallel_rank: int - :param pipeline_parallel_size: pipeline parallel size - :type pipeline_parallel_size: int - :param tensor_parallel_size: tensor parallel size - :type tensor_parallel_size: int - """ +class _Add_Bias_2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, @@ -621,7 +598,46 @@ class Add_Bias_2p5D(torch.autograd.Function): return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None -class layernorm_2p5d(torch.autograd.Function): +def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int, + col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, + data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, + tensor_parallel_size: int) -> Tensor: + """ + Matrix add bias: :math:`C = A + b` + + :param input: matrix :math:`A` + :type input: torch.tensor + :param bias: matrix :math:`b` + :type bias: torch.tensor + :param output_size_per_partition: output size in each partition + :type output_size_per_partition: int + :param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism + :type tesseract_dim: int + :param row_rank: the rank of row + :type row_rank: int + :param col_rank: the rank of column + :type col_rank: int + :param row_parallel_mode: row parallel mode + :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param col_parallel_mode: column parallel mode + :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion + :type skip_bias_add: bool + :param data_parallel_rank: data parallel rank + :type data_parallel_rank: int + :param pipeline_parallel_rank: pipeline parallel rank + :type pipeline_parallel_rank: int + :param pipeline_parallel_size: pipeline parallel size + :type pipeline_parallel_size: int + :param tensor_parallel_size: tensor parallel size + :type tensor_parallel_size: int + """ + return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank, + col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, + pipeline_parallel_size, tensor_parallel_size) + + +class _Layernorm2p5D(torch.autograd.Function): """ Layernorm @@ -671,7 +687,43 @@ class layernorm_2p5d(torch.autograd.Function): return input_grad, None, None, None, None, None, None -class all_gather_weight_2p5d(torch.autograd.Function): +def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, + row_parallel_mode: ParallelMode) -> Tensor: + """ + Layernorm + + :param input: input maxtrix + :type input: torch.tensor + :param E_x: mean + :type E_x: torch.tensor + :param Var_x: variance + :type Var_x: torch.tensor + :param hidden_size: hidden size + :type hidden_size: int + :param row_parallel_mode: row parallel mode + :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode + """ + return _Layernorm2p5D.apply(input, E_x, Var_x, hidden_size, row_parallel_mode) + + +class _AllGatherTensor2p5D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: + ctx.dim = dim + ctx.col_parallel_mode = col_parallel_mode + + outputs = all_gather(inputs, dim, col_parallel_mode) + return outputs + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + grad = reduce_scatter(output_grad, ctx.dim, ctx.col_parallel_mode) + return grad.contiguous(), None, None + + +def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: """ all gather the weight of 2.5D parallelism @@ -684,21 +736,7 @@ class all_gather_weight_2p5d(torch.autograd.Function): :param col_parallel_mode: column parallel mode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode """ - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: - ctx.dim = dim - ctx.tesseract_dim = tesseract_dim - ctx.row_rank = gpc.get_local_rank(col_parallel_mode) - - outputs = all_gather(inputs, dim, col_parallel_mode) - return outputs - - @staticmethod - @custom_bwd - def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - grad = output_grad.chunk(ctx.tesseract_dim, dim=ctx.dim)[ctx.row_rank] - return grad.contiguous(), None, None, None + return _AllGatherTensor2p5D.apply(inputs, dim, col_parallel_mode) class SplitFirst(torch.autograd.Function): @@ -737,10 +775,10 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: :param input_: Input tensor :param dim: Specified dimension in which to split - + :type input_: torch.Tensor :type dim: int, optional - + :return output: Splitted tensor :rtype output: torch.Tensor """ @@ -750,9 +788,49 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() -class reduce_by_batch_2p5d(torch.autograd.Function): - """All-reduce the input from the model parallel region. +class _ReduceTensor2p5D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: """ + All-reduce the input. + + :param input_: input tensor + :param parallel_mode: parallel mode + """ + return _ReduceTensor2p5D.apply(input_, parallel_mode) + + +class _ReduceScatterTensor2p5D(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None + + +def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + """ + Reduce-scatter the input. + + :param input_: input tensor + :param parallel_mode: parallel mode + """ + return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode) + + +class _RreduceByBatch2p5D(torch.autograd.Function): @staticmethod def symbolic(graph, input_, reduce_mean: bool = False): output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) @@ -764,12 +842,6 @@ class reduce_by_batch_2p5d(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input_, reduce_mean: bool = False): - """ - :param input_: input maxtrix - :type input_: torch.tensor - :param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False - :type reduce_mean: int, optional - """ output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) ctx.reduce_mean = reduce_mean if reduce_mean: @@ -785,3 +857,15 @@ class reduce_by_batch_2p5d(torch.autograd.Function): return output_grad / ctx.reduce_size, None else: return output_grad, None + + +def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor: + """ + All-reduce the input from the model parallel region. + + :param input_: input maxtrix + :type input_: torch.tensor + :param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False + :type reduce_mean: bool, optional + """ + return _RreduceByBatch2p5D.apply(input_, reduce_mean) \ No newline at end of file diff --git a/colossalai/nn/layer/parallel_2p5d/_utils.py b/colossalai/nn/layer/parallel_2p5d/_utils.py index c9c6b194f..bcab619ca 100644 --- a/colossalai/nn/layer/parallel_2p5d/_utils.py +++ b/colossalai/nn/layer/parallel_2p5d/_utils.py @@ -1,13 +1,12 @@ -import os - from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env def get_tesseract_dim_dep_from_env(): try: - tesseract_dim = int(os.environ['TESSERACT_DIM']) - tesseract_dep = int(os.environ['TESSERACT_DEP']) + tesseract_dim = env.tesseract_dim + tesseract_dep = env.tesseract_dep assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' return tesseract_dim, tesseract_dep diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index 170cc2a34..7dd17f21b 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -7,16 +7,18 @@ import torch.nn.functional as F from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.registry import LAYERS -from colossalai.utils import get_current_device -from torch import Tensor, dtype +from colossalai.utils.cuda import get_current_device +from torch import Tensor from torch.nn import Parameter from ..base_layer import ParallelLayer -from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) -from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d) -from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env) +from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple +from ._operation import (add_bias_2p5d, Matmul_AB_2p5D, Matmul_ABT_2p5D, all_gather_tensor_2p5d, classifier_2p5d, + layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d) +from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env @LAYERS.register_module @@ -41,7 +43,7 @@ class Linear2p5D(ParallelLayer): in_features: int, out_features: int, bias: bool = True, - dtype: dtype = None, + dtype: torch.dtype = None, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): @@ -112,17 +114,16 @@ class Linear2p5D(ParallelLayer): if self.bias is not None: if self.skip_bias_add: - bias = Add_Bias_2p5D.apply(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, - self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, - True, self.data_parallel_rank, self.pipeline_parallel_rank, - self.pipeline_parallel_size, self.tensor_parallel_size) + bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) return output, bias else: - output = Add_Bias_2p5D.apply(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, - self.row_rank, self.col_rank, self.dep_rank, - ParallelMode.PARALLEL_2P5D_COL, False, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, + self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, + False, self.data_parallel_rank, self.pipeline_parallel_rank, + self.pipeline_parallel_size, self.tensor_parallel_size) return output else: return output @@ -187,15 +188,15 @@ class LayerNorm2p5D(ParallelLayer): # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) - output = layernorm_2p5d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) - bias = Add_Bias_2p5D.apply(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) - scale = Add_Bias_2p5D.apply(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) + bias = add_bias_2p5d(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + scale = add_bias_2p5d(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) output = torch.addcmul(bias, scale, output) return output @@ -229,8 +230,8 @@ class PatchEmbedding2p5D(ParallelLayer): patch_size: int, in_chans: int, embed_size: int, - dtype: dtype = None, flatten: bool = True, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), position_embed_initializer: Callable = init.zeros_()): @@ -280,19 +281,21 @@ class PatchEmbedding2p5D(ParallelLayer): position_embed_initializer(self.pos_embed) def forward(self, input_: Tensor) -> Tensor: + input_ = split_tensor_2p5d(input_, 0) + B, C, H, W = input_.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) - bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) + weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL) + bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL) output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: output = output.flatten(2).transpose(1, 2) # BCHW -> BNC - cls_token = all_gather_weight_2p5d.apply(self.cls_token, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) - pos_embed = all_gather_weight_2p5d.apply(self.pos_embed, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) + cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) + pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) cls_token = cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) output = output + pos_embed @@ -322,7 +325,7 @@ class Embedding2p5D(ParallelLayer): num_embeddings: int, embedding_dim: int, padding_idx: int = None, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -359,13 +362,95 @@ class Embedding2p5D(ParallelLayer): self.weight[self.padding_idx].fill_(0) def forward(self, input_: Tensor) -> Tensor: - weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) + input_ = split_tensor_2p5d(input_, 0) + + weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output +@LAYERS.register_module +class VocabParallelEmbedding2p5D(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + :param num_embeddings: number of embeddings + :type num_embeddings: int + :param embedding_dim: dimension of embedding + :type embedding_dim: int + :param padding_idx: index of padding, defaults to None + :type padding_idx: int, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param weight_initializer: The intializer of weight, defaults to normal initializer + :type weight_initializer: typing.Callable, optional + :param args: Args used in F.embedding + :param kwargs: Kwargs used in F.embedding + """ + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + assert_tesseract_initialization() + self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() + self.num_embeddings_per_partition = divide(self.num_embeddings, self.tesseract_dim) + self.embed_dim_per_partition = divide(self.embed_dim, self.tesseract_dim) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL) + return output + + @LAYERS.register_module class Classifier2p5D(ParallelLayer): """ @@ -391,7 +476,7 @@ class Classifier2p5D(ParallelLayer): num_classes: int, weight: Parameter = None, bias: bool = True, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -442,7 +527,114 @@ class Classifier2p5D(ParallelLayer): def forward(self, input_: Tensor) -> Tensor: out_shape = input_.shape[:-1] + (self.num_classes, ) - return classifier_2p5d.apply(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, - self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, + self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + + +@LAYERS.register_module +class VocabParallelClassifier2p5D(ParallelLayer): + """ + Vocab parallel classifier layer for 2.5D parallelism + + :param in_features: size of each input sample + :type in_features: int + :param num_classes: number of classes + :type num_classes: int + :param weight: weight of the classifier, defaults to True + :type weight: torch.nn.Parameter, optional + :param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True`` + :type bias: bool, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + """ + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + self.in_features = in_features + self.num_classes = num_classes + + # parallel setting + assert_tesseract_initialization() + self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() + + # partitioning dimension + self.input_size_per_partition = divide(in_features, self.tesseract_dim) + self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) + + # create weight, shape: [k/q, h/q] + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + # create bias, shape: [h/q] + if bias: + self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) + else: + self.bias = None + + # initialize parameters + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, x: Tensor) -> Tensor: + # input: [m/dq, n/q, k/q] + # output: [m/dq, n/q, h/q] + out_shape = x.shape[:-1] + (self.hidden_size_per_partition, ) + + output = Matmul_ABT_2p5D.apply( + x, + self.weight, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) + + if self.bias is not None: + output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, + self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False, + self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, + self.tensor_parallel_size) + return output diff --git a/colossalai/nn/layer/parallel_3d/__init__.py b/colossalai/nn/layer/parallel_3d/__init__.py index 46eeacda1..9ae255b44 100644 --- a/colossalai/nn/layer/parallel_3d/__init__.py +++ b/colossalai/nn/layer/parallel_3d/__init__.py @@ -1,6 +1,8 @@ -from ._operation import reduce_by_batch_3d, split_tensor_3d -from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D +from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d +from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D, + VocabParallelEmbedding3D) __all__ = [ - 'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D' + 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', + 'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D' ] diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index 5c93dcb45..26e30d8cf 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -4,36 +4,20 @@ from typing import Optional, Tuple import torch -from colossalai.communication import all_gather, all_reduce, reduce_scatter, broadcast, reduce +from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter) +from colossalai.context import parallel_mode from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from ._utils import get_parallel_mode_from_env +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D + +from colossalai.nn.layer.base_layer import ParallelLayer -class linear_3d(torch.autograd.Function): - """ - Linear layer for 3D parallelism +class _Linear3D(torch.autograd.Function): - :param input_: matrix of input - :type input_: torch.tensor - :param weight: matrix of weight - :type weight: torch.tensor - :param bias: matrix of bias - :type bias: torch.tensor, optional - :param input_parallel_mode: input parallel mode - :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param weight_parallel_mode: weight parallel mode - :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param output_parallel_mode: output parallel mode - :type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param input_dim: dimension of input, defaults to 0 - :type input_dim: int, optional - :param weight_dim: dimension of weight, defaults to -1 - :type weight_dim: int, optional - :param output_dim: dimension of output, defaults to 0 - :type output_dim: int, optional - """ @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx, @@ -87,6 +71,8 @@ class linear_3d(torch.autograd.Function): bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) async_ops.append(op) + else: + bias_grad = None for op in async_ops: if op is not None: @@ -95,9 +81,17 @@ class linear_3d(torch.autograd.Function): return input_grad, weight_grad, bias_grad, None, None, None, None, None, None -class classifier_3d(torch.autograd.Function): +def linear_3d(input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + input_dim: int = 0, + weight_dim: int = -1, + output_dim: int = 0) -> Tensor: """ - Classifier + Linear layer for 3D parallelism :param input_: matrix of input :type input_: torch.tensor @@ -111,7 +105,19 @@ class classifier_3d(torch.autograd.Function): :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode :param output_parallel_mode: output parallel mode :type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param input_dim: dimension of input, defaults to 0 + :type input_dim: int, optional + :param weight_dim: dimension of weight, defaults to -1 + :type weight_dim: int, optional + :param output_dim: dimension of output, defaults to 0 + :type output_dim: int, optional """ + return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode, + input_dim, weight_dim, output_dim) + + +class _Classifier3D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, @@ -156,6 +162,8 @@ class classifier_3d(torch.autograd.Function): bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) async_ops.append(op) + else: + bias_grad = None input_grad = torch.matmul(output_grad, weight) @@ -166,23 +174,17 @@ class classifier_3d(torch.autograd.Function): return input_grad, weight_grad, bias_grad, None, None, None, None, None, None -class layernorm_3d(torch.autograd.Function): +def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: """ - Layernorm + 3D parallel classifier - :param input_: input maxtrix + :param input_: matrix of input :type input_: torch.tensor :param weight: matrix of weight :type weight: torch.tensor :param bias: matrix of bias - :type bias: torch.tensor - :param normalized_shape: input shape from an expected input - of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - :type normalized_shape: int - :param eps: a value added to the denominator for numerical stability - :type eps: float + :type bias: torch.tensor, optional :param input_parallel_mode: input parallel mode :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode :param weight_parallel_mode: weight parallel mode @@ -190,6 +192,11 @@ class layernorm_3d(torch.autograd.Function): :param output_parallel_mode: output parallel mode :type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode """ + return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode) + + +class _Layernorm3D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, @@ -236,27 +243,78 @@ class layernorm_3d(torch.autograd.Function): return input_grad, weight_grad, bias_grad, None, None, None, None, None -def split_tensor_3d(input_: Tensor, - dim: int = 0, - input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, - weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: - """Splits 3D tensor in specified dimension +def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, + input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode) -> Tensor: + """ + 3D parallel Layernorm + :param input_: input maxtrix + :type input_: torch.tensor + :param weight: matrix of weight + :type weight: torch.tensor + :param bias: matrix of bias + :type bias: torch.tensor + :param normalized_shape: input shape from an expected input + of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + :type normalized_shape: int + :param eps: a value added to the denominator for numerical stability + :type eps: float + :param input_parallel_mode: input parallel mode + :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param weight_parallel_mode: weight parallel mode + :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param output_parallel_mode: output parallel mode + :type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode + """ + return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode, + output_parallel_mode) + + +def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + """Splits 3D parallel tensor in specified dimension + + :param tensor: Input tensor + :param dim: Specified dimension in which to split + :param parallel_mode: Parallel mode + :param weight_parallel_mode: Weight parallel mode + + :type tensor: torch.Tensor + :type dim: int + :type parallel_mode: colossalai.context.parallel_mode.ParallelMode + + :return output: Splitted tensor + :rtype output: torch.Tensor + """ + if tensor.size(dim) <= 1: + return tensor + output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), + dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous() + return output + + +def split_batch_3d(input_: Tensor, + dim: int = 0, + input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, + weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: + """Splits 3D tensor in batch :param input_: Input tensor :param dim: Specified dimension in which to split :param input_parallel_mode: Input parallel mode :param weight_parallel_mode: Weight parallel mode - :type input_: torch.Tensor :type dim: int, optional :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional - :return output: Splitted tensor :rtype output: torch.Tensor """ if input_.size(dim) <= 1: return input_ + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode), dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() output = torch.chunk(output, gpc.get_world_size(input_parallel_mode), @@ -264,9 +322,77 @@ def split_tensor_3d(input_: Tensor, return output -class reduce_by_batch_3d(torch.autograd.Function): - """All-reduce the input from the model parallel region. +class _ReduceTensor3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, parallel_mode): + return all_reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + return output_grad, None + + +def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: """ + All-reduce the input. + + :param tensor: Input tensor + :param parallel_mode: Parallel mode + """ + return _ReduceTensor3D.apply(tensor, parallel_mode) + + +class _ReduceGrad3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, parallel_mode): + ctx.parallel_mode = parallel_mode + return input_ + + @staticmethod + def backward(ctx, output_grad): + input_grad = all_reduce(output_grad, ctx.parallel_mode) + return input_grad, None + + +def reduce_grad_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: + """ + All-reduce the gradient in backward pass. + + :param tensor: Input tensor + :param parallel_mode: Parallel mode + """ + return _ReduceGrad3D.apply(tensor, parallel_mode) + + +class _ReduceScatterTensor3D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim + ctx.parallel_mode = parallel_mode + return reduce_scatter(input_, dim, parallel_mode) + + @staticmethod + def backward(ctx, output_grad): + input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode) + return input_grad, None, None + + +def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: + """ + Reduce-scatter the input. + + :param tensor: Input tensor + :param dim: Dimension to scatter + :param parallel_mode: Parallel mode + """ + return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode) + + +class _ReduceByBatch3D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, @@ -274,16 +400,6 @@ class reduce_by_batch_3d(torch.autograd.Function): input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, reduce_mean: bool = False) -> Tensor: - """ - :param input_: input maxtrix - :type input_: torch.tensor - :param input_parallel_mode: input parallel mode - :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param weight_parallel_mode: weight parallel mode - :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False - :type reduce_mean: int, optional - """ output = all_reduce(input_, input_parallel_mode) output = all_reduce(output, weight_parallel_mode) ctx.reduce_mean = reduce_mean @@ -302,7 +418,26 @@ class reduce_by_batch_3d(torch.autograd.Function): return output_grad, None, None, None -class broadcast_weight_3d_from_diagonal(torch.autograd.Function): +def reduce_by_batch_3d(tensor: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False) -> Tensor: + """ + All-reduce the input from the model parallel region. + + :param input_: input maxtrix + :type input_: torch.tensor + :param input_parallel_mode: input parallel mode + :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param weight_parallel_mode: weight parallel mode + :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode + :param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False + :type reduce_mean: int, optional + """ + return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) + + +class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function): """ broadcast weight from diagonal @@ -315,6 +450,7 @@ class broadcast_weight_3d_from_diagonal(torch.autograd.Function): :param weight_parallel_mode: output parallel mode :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode """ + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, @@ -337,3 +473,9 @@ class broadcast_weight_3d_from_diagonal(torch.autograd.Function): else: input_grad = None return input_grad, None, None, None + + +def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: + return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode, + output_parallel_mode) diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py index ca3b405ea..0622164cd 100644 --- a/colossalai/nn/layer/parallel_3d/_utils.py +++ b/colossalai/nn/layer/parallel_3d/_utils.py @@ -1,31 +1,25 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import os - -from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D, - WEIGHT_GROUP_3D) +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env from torch import Tensor def get_depth_from_env() -> int: try: - depth = os.environ[DEPTH_3D] - depth = int(depth) + depth = env.depth_3d assert depth > 0, 'DEPTH must be greater than zero' return depth except KeyError as e: - raise EnvironmentError( - 'DEPTH is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer' - ) + raise EnvironmentError('DEPTH is not found in the current environment, ' + 'please make sure that you have used the correct process group initializer') def get_parallel_mode_from_env(group): - return getattr(ParallelMode, os.environ[group]) + assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \ + f'{group} is not valid for 3D tensor parallelism.' + return getattr(env, group) def get_last_group(a, b): @@ -35,8 +29,7 @@ def get_last_group(a, b): ParallelMode.PARALLEL_3D_OUTPUT: 'C', } - res = chr( - ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b])) + res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b])) if res == 'A': return ParallelMode.PARALLEL_3D_INPUT @@ -47,8 +40,7 @@ def get_last_group(a, b): def swap_in_out_group(): - os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \ - os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D] + env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d def dbg_check_shape(tensor: Tensor, shape: tuple): diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 048b158fa..da8a50995 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- import math from typing import Callable @@ -10,11 +8,12 @@ from colossalai.communication import all_reduce, broadcast from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.registry import LAYERS -from colossalai.utils import get_current_device -from torch import Tensor, dtype +from colossalai.utils.cuda import get_current_device +from torch import Tensor from torch.nn import Parameter from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -37,7 +36,8 @@ class LayerNorm3D(ParallelLayer): :param dtype: The dtype of parameters, defaults to None :type dtype: torch.dtype, optional """ - def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None): + + def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None): super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -62,8 +62,8 @@ class LayerNorm3D(ParallelLayer): init.ones_()(self.weight) def forward(self, input_: Tensor) -> Tensor: - return layernorm_3d.apply(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, - self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) + return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, + self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @LAYERS.register_module @@ -84,11 +84,12 @@ class Linear3D(ParallelLayer): :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :type bias_initializer: typing.Callable, optional """ + def __init__(self, in_features: int, out_features: int, bias: bool = True, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -136,8 +137,8 @@ class Linear3D(ParallelLayer): broadcast(self.bias, output_src_rank, self.output_parallel_mode) def forward(self, input_: Tensor) -> Tensor: - return linear_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, - self.output_parallel_mode) + return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, + self.output_parallel_mode) @LAYERS.register_module @@ -160,12 +161,13 @@ class Classifier3D(ParallelLayer): :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :type bias_initializer: typing.Callable, optional """ + def __init__(self, in_features: int, num_classes: int, weight: Parameter = None, bias: bool = True, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -214,8 +216,94 @@ class Classifier3D(ParallelLayer): broadcast(self.bias, input_src_rank, self.input_parallel_mode) def forward(self, input_: Tensor) -> Tensor: - return classifier_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, - self.output_parallel_mode) + return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, + self.output_parallel_mode) + + +@LAYERS.register_module +class VocabParallelClassifier3D(ParallelLayer): + """ + Vocab parallel classifier layer for 2D parallelism + + :param in_features: size of each input sample + :type in_features: int + :param num_classes: number of classes + :type num_classes: int + :param weight: weight of the classifier, defaults to True + :type weight: torch.nn.Parameter, optional + :param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True`` + :type bias: bool, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.depth = get_depth_from_env() + self.in_features_per_partition = divide(in_features, self.depth) + self.out_features_per_partition = divide(num_classes, self.depth) + + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter( + torch.empty(self.out_features_per_partition, + self.in_features_per_partition, + device=get_current_device(), + dtype=dtype)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(), + dtype=dtype)) + else: + self.bias = None + + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + swap_in_out_group() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self) -> None: + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.in_features, self.num_classes + weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] + output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] + + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) + broadcast(self.bias, output_src_rank, self.output_parallel_mode) + + def forward(self, input_: Tensor) -> Tensor: + return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode, + self.weight_parallel_mode, self.output_parallel_mode) @LAYERS.register_module @@ -242,13 +330,14 @@ class PatchEmbedding3D(ParallelLayer): :param position_embed_initializer: The intializer of position embedding, defaults to zero :type position_embed_initializer: typing.Callable, optional """ + def __init__(self, img_size: int, patch_size: int, in_chans: int, embed_size: int, - dtype: dtype = None, flatten: bool = True, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), position_embed_initializer: Callable = init.zeros_()): @@ -284,8 +373,8 @@ class PatchEmbedding3D(ParallelLayer): set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth) set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) - def _sync_grad_hook(self, grad) -> None: - grad = all_reduce(grad, self.input_parallel_mode) + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_parallel_mode) grad = all_reduce(grad, self.weight_parallel_mode) return grad @@ -302,17 +391,19 @@ class PatchEmbedding3D(ParallelLayer): broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode) + broadcast(self.weight, input_src_rank, self.input_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode) broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode) + self.weight.register_hook(self._sync_grad_hook) self.bias.register_hook(self._sync_grad_hook) self.cls_token.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook) def forward(self, input_: Tensor) -> Tensor: - weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, - self.weight_parallel_mode, self.output_parallel_mode) - output = F.conv2d(input_, weight, self.bias, stride=self.patch_size) + input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) + input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) + output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: output = output.flatten(2).transpose(1, 2) # BCHW -> BNC @@ -341,11 +432,12 @@ class Embedding3D(ParallelLayer): :param args: Args used in F.embedding :param kwargs: Kwargs used in F.embedding """ + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -385,8 +477,95 @@ class Embedding3D(ParallelLayer): self.weight[self.padding_idx].fill_(0) def forward(self, input_: Tensor) -> Tensor: - weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, - self.weight_parallel_mode, self.output_parallel_mode) + input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) + input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) + weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode, + self.output_parallel_mode) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output + + +@LAYERS.register_module +class VocabParallelEmbedding3D(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + :param num_embeddings: number of embeddings + :type num_embeddings: int + :param embedding_dim: dimension of embedding + :type embedding_dim: int + :param padding_idx: index of padding, defaults to None + :type padding_idx: int, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param weight_initializer: The intializer of weight, defaults to normal initializer + :type weight_initializer: typing.Callable, optional + :param args: Args used in F.embedding + :param kwargs: Kwargs used in F.embedding + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.depth = get_depth_from_env() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth) + self.embed_dim_per_partition = divide(self.embed_dim, self.depth) + vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) + self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] + broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) + + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + weight = reduce_grad_3d(self.weight, self.weight_parallel_mode) + + output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output_parallel[input_mask, :] = 0. + output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode) + + return output diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py index 3f1626a0e..c1d88d2fc 100644 --- a/colossalai/nn/layer/utils/common.py +++ b/colossalai/nn/layer/utils/common.py @@ -2,12 +2,12 @@ # -*- encoding: utf-8 -*- import collections.abc -import os from itertools import repeat import numpy as np import torch -from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE) +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS +from colossalai.global_variables import tensor_parallel_env as env from colossalai.utils import checkpoint from torch import Tensor, nn @@ -38,7 +38,7 @@ class CheckpointModule(nn.Module): def divide(numerator, denominator): """Only allow exact division - + :param numerator: Numerator of the division :param denominator: Denominator of the division """ @@ -65,7 +65,7 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions): def get_tensor_parallel_mode(): - return os.environ[TENSOR_PARALLEL_MODE] + return env.mode # From PyTorch internals diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py index e707162fb..e5c9fd074 100644 --- a/colossalai/nn/layer/vanilla/layers.py +++ b/colossalai/nn/layer/vanilla/layers.py @@ -3,14 +3,14 @@ from typing import Callable import torch import torch.nn.functional as F +from colossalai.context import seed from colossalai.nn import init as init from colossalai.registry import LAYERS -from colossalai.utils import get_current_device -from torch import Tensor, dtype +from colossalai.utils.cuda import get_current_device +from torch import Tensor from torch import nn as nn from ..utils import to_2tuple -from colossalai.context import seed def drop_path(x, drop_prob: float = 0., training: bool = False): @@ -36,6 +36,7 @@ class DropPath(nn.Module): Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py """ + def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob @@ -47,6 +48,7 @@ class DropPath(nn.Module): class WrappedDropout(nn.Module): """Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. """ + def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): super().__init__() if p < 0 or p > 1: @@ -75,6 +77,7 @@ class WrappedDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Here, it is wrapped with the context of seed manager. """ + def __init__(self, p: float = 0., mode=None): super().__init__() self.p = p @@ -120,13 +123,14 @@ class VanillaPatchEmbedding(nn.Module): :param position_embed_initializer: The intializer of position embedding, defaults to zero :type position_embed_initializer: typing.Callable, optional """ + def __init__(self, img_size: int, patch_size: int, in_chans: int, embed_size: int, - dtype: dtype = None, flatten: bool = True, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), position_embed_initializer: Callable = init.zeros_()): @@ -142,8 +146,9 @@ class VanillaPatchEmbedding(nn.Module): self.weight = nn.Parameter( torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_size)) - self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_size)) + self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) + self.pos_embed = nn.Parameter( + torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -170,7 +175,7 @@ class VanillaPatchEmbedding(nn.Module): @LAYERS.register_module class VanillaClassifier(nn.Module): """ - Classifier for ViT + Dense linear classifier :param in_features: size of each input sample :type in_features: int @@ -187,12 +192,13 @@ class VanillaClassifier(nn.Module): :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :type bias_initializer: typing.Callable, optional """ + def __init__(self, in_features: int, num_classes: int, weight: nn.Parameter = None, bias: bool = True, - dtype: dtype = None, + dtype: torch.dtype = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 87f43bb84..373e4ec94 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1,25 +1,37 @@ +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn.layer.utils import get_tensor_parallel_mode from torch import nn from torch.nn.modules.loss import * from torch.nn.modules.loss import _Loss -from colossalai.nn.layer.utils import get_tensor_parallel_mode -from .loss_2d import CrossEntropyLoss2D -from .loss_2p5d import CrossEntropyLoss2p5D -from .loss_3d import CrossEntropyLoss3D +from .loss_1d import VocabParallelCrossEntropyLoss1D +from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D +from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D +from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D from .loss_moe import MoeCrossEntropyLoss, MoeLoss _parallel_cross_entropy = { '2d': CrossEntropyLoss2D, '2.5d': CrossEntropyLoss2p5D, - '3d': CrossEntropyLoss3D + '3d': CrossEntropyLoss3D, +} + +_vocab_parallel_cross_entropy = { + '1d': VocabParallelCrossEntropyLoss1D, + '2d': VocabParallelCrossEntropyLoss2D, + '2.5d': VocabParallelCrossEntropyLoss2p5D, + '3d': VocabParallelCrossEntropyLoss3D, } class CrossEntropyLoss(_Loss): + def __init__(self, reduction: bool = True, *args, **kwargs): super().__init__() tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel in ['None', '1d']: + if tensor_parallel is not None and env.vocab_parallel: + self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) + elif tensor_parallel is None or tensor_parallel == '1d': reduction = 'mean' if reduction else 'none' self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) else: diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/nn/loss/loss_1d.py new file mode 100644 index 000000000..d0e1ec2a4 --- /dev/null +++ b/colossalai/nn/loss/loss_1d.py @@ -0,0 +1,110 @@ +import torch +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import LOSSES +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.modules.loss import _Loss + + +class _VocabParallelCrossEntropy1D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, vocab_parallel_logits, targets): + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indecies + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + vocab_start_index = partition_vocab_size * rank + vocab_end_index = vocab_start_index + partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) + masked_target = targets.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(targets) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(predicted_logits, + op=torch.distributed.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + return loss + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss1D(_Loss): + """ + Vocab parallel cross entropy loss for 1D parallelism + + :param reduction: whether to average the loss, defaults to True + + :type reduction: bool, optional + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets + + :param logits: Output logits of model + :param targets: True targets from data + """ + loss = _VocabParallelCrossEntropy1D.apply(logits, targets) + if self.reduction_mean: + loss = loss.mean() + return loss diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/nn/loss/loss_2d.py index b5438b887..a2ad8f435 100644 --- a/colossalai/nn/loss/loss_2d.py +++ b/colossalai/nn/loss/loss_2d.py @@ -1,6 +1,12 @@ -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d +import torch +import torch.distributed as dist +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss @@ -16,6 +22,7 @@ class CrossEntropyLoss2D(_Loss): :type reduction: bool, optional """ + def __init__(self, reduction=True, *args, **kwargs): super().__init__() assert_summa_initialization() @@ -29,8 +36,110 @@ class CrossEntropyLoss2D(_Loss): :param logits: Output logits of model :param targets: True targets from data """ + targets = split_tensor_2d(targets) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() - loss = reduce_by_batch_2d.apply(loss, True) + loss = reduce_by_batch_2d(loss, True) + return loss + + +class _VocabParallelCrossEntropy2D(torch.autograd.Function): + ### Modified based on megatron.mpu.cross_entropy ### + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets): + # logits: [b/q, h/q] + # labels: [b/q] + # loss: [b/q] + # vocab_parallel_logits: [b/q, s, v/q] + # target: [b/q, s] + logits_max = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + # Subtract the maximum value. + # vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size = logits.size(-1) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + vocab_start = rank * (vocab_size) + vocab_end = (rank + 1) * (vocab_size) - 1 + + target_mask = (targets < vocab_start) | (targets > vocab_end) + + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange( + start=0, + end=logits.size()[0], + ) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(output_grad.unsqueeze(dim=-1)) + + return grad_input, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss2D(_Loss): + """ + Vocab parallel cross entropy loss for 2D parallelism + + :param reduction: whether to average the loss, defaults to True + + :type reduction: bool, optional + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets + + :param logits: Output logits of model + :param targets: True targets from data + """ + targets = split_tensor_2d(targets) + loss = _VocabParallelCrossEntropy2D.apply( + logits, + targets, + ) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2d(loss, True) return loss diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/nn/loss/loss_2p5d.py index f66f98c3e..b5379776b 100644 --- a/colossalai/nn/loss/loss_2p5d.py +++ b/colossalai/nn/loss/loss_2p5d.py @@ -1,6 +1,12 @@ -from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d +import torch +import torch.distributed as dist +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss @@ -9,7 +15,7 @@ from torch.nn.modules.loss import _Loss class CrossEntropyLoss2p5D(_Loss): """ Cross entropy loss for 2.5D parallelism - + :param reduction: whether to average the loss, defaults to True :param args: Args for loss function :param kwargs: Kwargs for loss function @@ -29,8 +35,104 @@ class CrossEntropyLoss2p5D(_Loss): :param logits: Output logits of model :param targets: True targets from data """ + targets = split_tensor_2p5d(targets) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() - loss = reduce_by_batch_2p5d.apply(loss, True) + loss = reduce_by_batch_2p5d(loss, True) + return loss + + +class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): + ### Modified based on megatron.mpu.cross_entropy ### + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets): + # logits: [b/dq, h/q] + # loss: [b/dq] + # targets: [b/dq, h/q] + logits_max = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + # Subtract the maximum value. + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size = logits.size(-1) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + vocab_start = rank * (vocab_size) + vocab_end = (rank + 1) * (vocab_size) - 1 + + target_mask = (targets < vocab_start) | (targets > vocab_end) + + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange( + start=0, + end=logits.size()[0], + ) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as their gradient. + grad_input = softmax + + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(output_grad.unsqueeze(dim=-1)) + + return grad_input, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss2p5D(_Loss): + """ + Vocab parallel cross entropy loss for 2.5D parallelism + + :param reduction: whether to average the loss, defaults to True + + :type reduction: bool, optional + """ + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets + + :param logits: Output logits of model + :param targets: True targets from data + """ + targets = split_tensor_2p5d(targets) + loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_2p5d(loss, True) + return loss diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/nn/loss/loss_3d.py index 02ac06b37..0835d2770 100644 --- a/colossalai/nn/loss/loss_3d.py +++ b/colossalai/nn/loss/loss_3d.py @@ -1,23 +1,28 @@ -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d +import torch +import torch.distributed as dist +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D +from colossalai.core import global_context as gpc +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.registry import LOSSES +from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss + @LOSSES.register_module class CrossEntropyLoss3D(_Loss): """ Cross entropy loss for 3D parallelism - :param depth: depth for 3D parallelism - :type depth: int :param reduction: whether to average the loss, defaults to True - :type reduction: bool, optional - :param args: Args for loss function :param kwargs: Kwargs for loss function + + :type reduction: bool, optional """ + def __init__(self, reduction=True, *args, **kwargs): super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -32,8 +37,103 @@ class CrossEntropyLoss3D(_Loss): :param logits: Output logits of model :param targets: True targets from data """ + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() - loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True) + loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) + return loss + + +class _VocabParallelCrossEntropy3D(torch.autograd.Function): + # Adapted from megatron.mpu.cross_entropy + # loss[i] = -logits[i][targets] + log(sum(exp(logits[i]))) + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, logits, targets, output_parallel_mode): + # logits: [b/q^2, c/q] + # labels: [b/q^2] + # loss: [b/q^2] + logits_max = torch.max(logits, dim=-1)[0] + dist.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(output_parallel_mode)) + # Subtract the maximum value. + logits = logits - logits_max.unsqueeze(dim=-1) + + vocab_size_per_partition = logits.size()[-1] + rank = gpc.get_local_rank(output_parallel_mode) + vocab_start = rank * vocab_size_per_partition + vocab_end = (rank + 1) * vocab_size_per_partition - 1 + + # loss[i] = 0 if targets[i] < vocab_start or targets[i] > vocab_end + target_mask = (targets < vocab_start) | (targets > vocab_end) + masked_target = targets.clone() - vocab_start + masked_target[target_mask] = 0 + arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) + predicted_logits = logits[arange_1d, masked_target] + predicted_logits = predicted_logits.clone().contiguous().view_as(targets) + predicted_logits[target_mask] = 0. + dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode)) + + # Loss = log(sum(exp(logits))) - predicted-logit. + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=-1) + dist.all_reduce(sum_exp_logits, group=gpc.get_group(output_parallel_mode)) + loss = torch.log(sum_exp_logits) - predicted_logits + + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target) + + return loss + + @staticmethod + @custom_bwd + def backward(ctx, output_grad): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + input_grad = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = input_grad.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + input_grad.mul_(output_grad.unsqueeze(dim=-1)) + + return input_grad, None, None, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss3D(_Loss): + """ + Vocab parallel cross entropy loss for 2D parallelism + + :param reduction: whether to average the loss, defaults to True + + :type reduction: bool, optional + """ + + def __init__(self, reduction=True): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.reduction_mean = reduction + + def forward(self, logits, targets): + """Calculate loss between logits and targets + + :param logits: Output logits of model + :param targets: True targets from data + """ + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + loss = _VocabParallelCrossEntropy3D.apply(logits, targets, self.output_parallel_mode) + if self.reduction_mean: + loss = loss.mean() + loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) return loss diff --git a/colossalai/nn/metric/__init__.py b/colossalai/nn/metric/__init__.py index 7ce17b08b..00833b611 100644 --- a/colossalai/nn/metric/__init__.py +++ b/colossalai/nn/metric/__init__.py @@ -17,7 +17,7 @@ class Accuracy(nn.Module): def __init__(self): super().__init__() tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel in ['None', '1d']: + if tensor_parallel not in _parallel_accuracy: self.acc = calc_acc else: self.acc = _parallel_accuracy[tensor_parallel]() diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/nn/metric/accuracy_2d.py index 4706e1700..4a3eb6f7a 100644 --- a/colossalai/nn/metric/accuracy_2d.py +++ b/colossalai/nn/metric/accuracy_2d.py @@ -1,5 +1,5 @@ import torch -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d from torch import nn from ._utils import calc_acc @@ -18,6 +18,7 @@ class Accuracy2D(nn.Module): :param targets: True labels from data """ with torch.no_grad(): + targets = split_tensor_2d(targets) correct = calc_acc(logits, targets) - correct = reduce_by_batch_2d.apply(correct) + correct = reduce_by_batch_2d(correct) return correct diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/nn/metric/accuracy_2p5d.py index 1bf34ae22..0eeedd46f 100644 --- a/colossalai/nn/metric/accuracy_2p5d.py +++ b/colossalai/nn/metric/accuracy_2p5d.py @@ -18,6 +18,7 @@ class Accuracy2p5D(nn.Module): :param targets: True labels from data """ with torch.no_grad(): + targets = split_tensor_2p5d(targets) correct = calc_acc(logits, targets) - correct = reduce_by_batch_2p5d.apply(correct) + correct = reduce_by_batch_2p5d(correct) return correct diff --git a/colossalai/nn/metric/accuracy_3d.py b/colossalai/nn/metric/accuracy_3d.py index e7612dde2..e24219e64 100644 --- a/colossalai/nn/metric/accuracy_3d.py +++ b/colossalai/nn/metric/accuracy_3d.py @@ -1,6 +1,6 @@ import torch from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from torch import nn @@ -22,6 +22,8 @@ class Accuracy3D(nn.Module): :param targets: True labels from data """ with torch.no_grad(): + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) correct = calc_acc(logits, targets) - correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode) + correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode) return correct diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/trainer/hooks/_log_hook.py index 29ef4efa3..2a081e088 100644 --- a/colossalai/trainer/hooks/_log_hook.py +++ b/colossalai/trainer/hooks/_log_hook.py @@ -224,7 +224,7 @@ class LogTimingByEpochHook(LogByEpochHook): super().__init__(logger=logger, interval=interval, priority=priority) self._timer = timer self._log_eval = log_eval - self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() + self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() # extra handling to avoid the unstable readings of the first # few training steps to affect the history mean time @@ -256,7 +256,7 @@ class LogTimingByEpochHook(LogByEpochHook): """ if self._is_epoch_to_log(trainer) and self._is_rank_to_log: msg = self._get_message('Train') - self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}, #steps/epoch = {trainer.steps_per_epoch}') + self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}') def after_test_epoch(self, trainer): """Writes log after finishing a testing epoch. diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 69aa23eb9..3faf8b438 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -317,24 +317,29 @@ class ThroughputMetric(Metric): :param epoch_only: epoch only :type epoch_only: bool """ - def __init__(self, epoch_only: bool): + def __init__(self, epoch_only: bool, ignored_steps: int = 0): super().__init__(epoch_only=epoch_only) + self.ignored_steps = ignored_steps + self.cur_steps = 0 self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) self.accumulated_used_time = torch.zeros(1, device=get_current_device()) self.last_step_num_samples = torch.zeros(1, device=get_current_device()) self.last_step_used_time = torch.zeros(1, device=get_current_device()) def reset(self) -> None: + # self.cur_steps = 0 self.accumulated_num_samples.zero_() self.accumulated_used_time.zero_() self.last_step_num_samples.zero_() self.last_step_used_time.zero_() def update(self, num_samples, time) -> None: + self.cur_steps += 1 self.last_step_num_samples.fill_(num_samples) self.last_step_used_time.fill_(time) - self.accumulated_num_samples += self.last_step_num_samples - self.accumulated_used_time += self.last_step_used_time + if self.cur_steps >= self.ignored_steps: + self.accumulated_num_samples += self.last_step_num_samples + self.accumulated_used_time += self.last_step_used_time def get_last_step_value(self): self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ @@ -360,13 +365,14 @@ class ThroughputHook(MetricHook): :param priority: priority of throughput hook, defaults to 10 :type priority: int, optional """ - def __init__(self, priority: int = 10): + def __init__(self, ignored_steps: int = 0, priority: int = 10): super().__init__(priority) + self.ignored_steps = ignored_steps def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) if self._is_stage_to_compute: - self.metric = ThroughputMetric(epoch_only=True) + self.metric = ThroughputMetric(epoch_only=True, ignored_steps=self.ignored_steps) # register the metric trainer.states['metrics']['train']['Throughput'] = self.metric diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 2ce181954..c769022a5 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,8 +1,9 @@ from .activation_checkpoint import checkpoint from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0, - is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, - print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param) + is_using_ddp, is_using_pp, is_using_sequence, model_branch_context, multi_tensor_applier, + param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, + sync_model_param) from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .data_sampler import DataParallelSampler, get_dataloader from .gradient_accumulation import accumulate_gradient @@ -11,9 +12,9 @@ from .timer import MultiTimer, Timer __all__ = [ 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', - 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context', - 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', - 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', - 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', - 'get_dataloader', 'switch_virtual_pipeline_parallel_rank' + 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'model_branch_context', + 'conditional_context', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', + 'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', + 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', + 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank' ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 2ae8b754e..942801018 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -6,8 +6,6 @@ import socket import torch from torch._six import inf -import colossalai.context.parallel_mode - try: import colossal_C except: @@ -20,6 +18,7 @@ from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARA from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import moe_env +from colossalai.global_variables import tensor_parallel_env as env from .multi_tensor_apply import multi_tensor_applier @@ -62,8 +61,7 @@ def sync_model_param(model, parallel_mode): if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: for param in model.parameters(): ranks = gpc.get_ranks_in_group(parallel_mode) - dist.broadcast( - param, src=ranks[0], group=gpc.get_group(parallel_mode)) + dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) def is_dp_rank_0(): @@ -99,6 +97,15 @@ def conditional_context(context_manager, enable=True): yield +class model_branch_context(object): + + def __enter__(self): + self.env_status = env.save() + + def __exit__(self, *exc_info): + env.load(**self.env_status) + + def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) @@ -124,9 +131,10 @@ def _calc_lp(grads, norm_type): norm = 0.0 for grad in grads: grad_norm = torch.norm(grad, norm_type) - norm += grad_norm ** norm_type + norm += grad_norm**norm_type return norm + # ======== Gradient Clipping ========= @@ -183,7 +191,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): moe_parallel_grads = [] # used to collect moe tensor parallel gradients for p in params: if is_model_parallel_parameter(p): - reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) + reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) tensor_parallel_grads.append(p.grad.data / reductor) elif is_moe_parallel_parameter(p): moe_parallel_grads.append(p.grad.data) @@ -191,32 +199,24 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): no_tensor_parallel_grads.append(p.grad.data) if norm_type == 2.0: - tensor_parallel_norm = _calc_l2_norm( - tensor_parallel_grads) ** norm_type - no_tensor_parallel_norm = _calc_l2_norm( - no_tensor_parallel_grads) ** norm_type - moe_parallel_norm = _calc_l2_norm( - moe_parallel_grads) ** norm_type + tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type + no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type + moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type else: tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) - no_tensor_parallel_norm = _calc_lp( - no_tensor_parallel_grads, norm_type) + no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type) # Sum across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: - dist.all_reduce(tensor_parallel_norm, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.TENSOR)) + dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) # Sum across all moe-tensor-parallel GPUs if len(moe_parallel_grads) > 0: dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL)) no_tensor_parallel_norm += moe_parallel_norm total_norm = tensor_parallel_norm + no_tensor_parallel_norm if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: - dist.all_reduce(total_norm, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PIPELINE)) - total_norm = total_norm ** (1.0 / norm_type) + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) + total_norm = total_norm**(1.0 / norm_type) if type(total_norm) == 'torch.cuda.FloatTensor': total_norm = total_norm.item() @@ -225,10 +225,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if clip_coeff < 1.0: grads = [p.grad.detach() for p in params] dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(colossal_C.multi_tensor_scale, - dummy_overflow_buf, - [grads, grads], - clip_coeff) + multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) return total_norm @@ -254,15 +251,14 @@ def count_zeros_fp32(parameters): # Sum across all model-parallel GPUs. ops = [] - ops.append(dist.all_reduce(total_num_zeros, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.TENSOR), - async_op=True)) + ops.append( + dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) if gpc.is_initialized(ParallelMode.PIPELINE): - ops.append(dist.all_reduce(total_num_zeros, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PIPELINE), - async_op=True)) + ops.append( + dist.all_reduce(total_num_zeros, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PIPELINE), + async_op=True)) for req in ops: req.wait() @@ -279,9 +275,8 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor): def param_is_not_tensor_parallel_duplicate(param): - return (hasattr(param, IS_TENSOR_PARALLEL) and - getattr(param, IS_TENSOR_PARALLEL)) or ( - gpc.get_local_rank(ParallelMode.TENSOR) == 0) + return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( + ParallelMode.TENSOR) == 0) @contextmanager diff --git a/model_zoo/gpt/gpt.py b/model_zoo/gpt/gpt.py index bfa85813f..b5413f6b8 100644 --- a/model_zoo/gpt/gpt.py +++ b/model_zoo/gpt/gpt.py @@ -3,12 +3,20 @@ from typing import Callable import torch from colossalai import nn as col_nn -from colossalai.nn.layer.utils import CheckpointModule -from colossalai.registry import LAYERS, MODELS, LOSSES +from colossalai.builder.pipeline import partition_uniform +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.layer.utils import CheckpointModule, divide +from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper +from colossalai.registry import LAYERS, LOSSES, MODELS from colossalai.utils import get_current_device from torch import dtype, nn -__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3'] +__all__ = [ + 'GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt2_8B', 'gpt2_xl_pipeline', + 'gpt2_8B_pipeline', 'gpt3', 'gpt3_pipeline' +] @LAYERS.register_module @@ -18,7 +26,7 @@ class GPTEmbedding(nn.Module): vocab_size: int, max_position_embeddings: int, num_tokentypes: int = 0, - padding_idx: int = 0, + padding_idx: int = None, dropout: float = 0., dtype: dtype = None) -> None: super().__init__() @@ -34,7 +42,7 @@ class GPTEmbedding(nn.Module): def word_embedding_weight(self): return self.word_embeddings.weight - def forward(self, input_ids, position_ids=None, tokentype_ids=None): + def forward(self, input_ids, attention_mask=None, position_ids=None, tokentype_ids=None): seq_length = input_ids.size(1) if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) @@ -42,7 +50,20 @@ class GPTEmbedding(nn.Module): if self.tokentype_embeddings is not None and tokentype_ids is not None: x = x + self.tokentype_embeddings(tokentype_ids) x = self.dropout(x) - return x + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # Adapted from huggingface + if attention_mask is not None: + batch_size = input_ids.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = col_nn.partition_batch(attention_mask) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + return x, attention_mask @LAYERS.register_module @@ -53,20 +74,32 @@ class GPTSelfAttention(nn.Module): attention_dropout: float, dropout: float, bias: bool = True, + fuse_scale_mask_softmax: bool = False, dtype: dtype = None) -> None: super().__init__() - - self.attention_head_size = dim // num_heads + self.fuse_scale_mask_softmax = fuse_scale_mask_softmax + self.attention_head_size = divide(dim, num_heads) self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias) + if fuse_scale_mask_softmax: + from colossalai.kernel import FusedScaleMaskSoftmax + from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType + self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True, + input_in_bf16=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + mask_func=None, + softmax_in_fp32=True, + scale=math.sqrt(self.attention_head_size)) + else: + self.softmax = nn.Softmax(dim=-1) self.attention_dropout = col_nn.Dropout(attention_dropout) self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True) self.dropout = col_nn.Dropout(dropout) - self.softmax = nn.Softmax(dim=-1) def forward(self, x, attention_mask=None): qkv = self.query_key_value(x) all_head_size = qkv.shape[-1] // 3 - num_attention_heads = all_head_size // self.attention_head_size + num_attention_heads = divide(all_head_size, self.attention_head_size) new_qkv_shape = qkv.shape[:-1] + \ (num_attention_heads, 3 * self.attention_head_size) qkv = qkv.view(new_qkv_shape) @@ -74,17 +107,20 @@ class GPTSelfAttention(nn.Module): q, k, v = torch.chunk(qkv, 3, dim=-1) x = torch.matmul(q, k.transpose(-1, -2)) - x = x / math.sqrt(self.attention_head_size) - # causal mask - q_len, k_len = q.size(-2), k.size(-2) - causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, - device=get_current_device())).view(1, 1, q_len, k_len).bool() - x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) + if self.fuse_scale_mask_softmax: + x = self.softmax(x, attention_mask) + else: + x = x / math.sqrt(self.attention_head_size) + # causal mask + q_len, k_len = q.size(-2), k.size(-2) + causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, + device=get_current_device())).view(1, 1, q_len, k_len).bool() + x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) + if attention_mask is not None: + x = x + attention_mask + x = self.softmax(x) - if attention_mask is not None: - x = x + attention_mask - x = self.softmax(x) x = self.attention_dropout(x) x = torch.matmul(x, v) @@ -102,15 +138,16 @@ class GPTSelfAttention(nn.Module): class GPTMLP(nn.Module): def __init__(self, dim: int, - mlp_ratio: int, + mlp_ratio: float, activation: Callable, dropout: float, dtype: dtype = None, bias: bool = True): super().__init__() - self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias) + intermediate_dim = int(dim * mlp_ratio) + self.dense_1 = col_nn.Linear(dim, intermediate_dim, dtype=dtype, bias=bias) self.activation = activation - self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias) + self.dense_2 = col_nn.Linear(intermediate_dim, dim, dtype=dtype, bias=bias) self.dropout = col_nn.Dropout(dropout) def forward(self, x): @@ -126,27 +163,44 @@ class GPTBlock(CheckpointModule): def __init__(self, dim: int, num_heads: int, - mlp_ratio: int, + mlp_ratio: float, activation: Callable, attention_dropout: float = 0., dropout: float = 0., + layernorm_epsilon: float = 1e-5, dtype: dtype = None, bias: bool = True, + apply_post_layernorm: bool = False, + fuse_scale_mask_softmax: bool = False, checkpoint: bool = False): - super().__init__(checkpoint=checkpoint) - self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) + super().__init__(checkpoint) + self.apply_post_layernorm = apply_post_layernorm + self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) self.attn = GPTSelfAttention(dim=dim, num_heads=num_heads, attention_dropout=attention_dropout, dropout=dropout, bias=bias, + fuse_scale_mask_softmax=fuse_scale_mask_softmax, dtype=dtype) - self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) + self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias) def _forward(self, x, attention_mask=None): - x = x + self.attn(self.norm1(x), attention_mask) - x = x + self.mlp(self.norm2(x)) + if not self.apply_post_layernorm: + residual = x + x = self.norm1(x) + if self.apply_post_layernorm: + residual = x + x = residual + self.attn(x, attention_mask) + + if not self.apply_post_layernorm: + residual = x + x = self.norm2(x) + if self.apply_post_layernorm: + residual = x + x = residual + self.mlp(x) + return x, attention_mask @@ -161,6 +215,10 @@ class GPTLMHead(nn.Module): super().__init__() self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype) + @property + def weight(self): + return self.dense.weight + def forward(self, x): x = self.dense(x) return x @@ -187,18 +245,19 @@ class GPT(nn.Module): dim: int = 768, num_heads: int = 12, depth: int = 12, - mlp_ratio: int = 4, + mlp_ratio: float = 4.0, dropout: float = 0.1, embedding_dropout: float = 0.1, attention_dropout: float = 0.1, layernorm_epsilon: float = 1e-5, activation: Callable = nn.functional.gelu, - checkpoint: bool = False, + padding_idx: int = None, dtype: dtype = None, bias: bool = True, - padding_idx: int = 0) -> None: + apply_post_layernorm: bool = False, + fuse_scale_mask_softmax: bool = False, + checkpoint: bool = False) -> None: super().__init__() - self.dtype = dtype self.embed = GPTEmbedding(embedding_dim=dim, vocab_size=vocab_size, max_position_embeddings=max_position_embeddings, @@ -213,8 +272,11 @@ class GPT(nn.Module): activation=activation, attention_dropout=attention_dropout, dropout=dropout, + layernorm_epsilon=layernorm_epsilon, dtype=dtype, bias=bias, + apply_post_layernorm=apply_post_layernorm, + fuse_scale_mask_softmax=fuse_scale_mask_softmax, checkpoint=checkpoint, ) for _ in range(depth) ]) @@ -224,22 +286,10 @@ class GPT(nn.Module): self.head = GPTLMHead(dim=dim, vocab_size=vocab_size, word_embeeding_weight=self.embed.word_embedding_weight, - bias=bias, dtype=dtype) def forward(self, input_ids, attention_mask=None): - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # Adapted from huggingface - if attention_mask is not None: - batch_size = input_ids.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 - - x = self.embed(input_ids) + x, attention_mask = self.embed(input_ids, attention_mask) for block in self.blocks: x, attention_mask = block(x, attention_mask) @@ -249,11 +299,103 @@ class GPT(nn.Module): return x +class PipelineGPT(nn.Module): + def __init__(self, + vocab_size: int = 50304, + max_position_embeddings: int = 1024, + dim: int = 768, + num_heads: int = 12, + depth: int = 12, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + embedding_dropout: float = 0.1, + attention_dropout: float = 0.1, + layernorm_epsilon: float = 1e-5, + activation: Callable = nn.functional.gelu, + padding_idx: int = None, + dtype: dtype = None, + bias: bool = True, + apply_post_layernorm: bool = False, + fuse_scale_mask_softmax: bool = False, + checkpoint: bool = False, + first: bool = False, + last: bool = False): + super().__init__() + self.checkpoint = checkpoint + self.first = first + self.last = last + if first: + self.embed = GPTEmbedding(embedding_dim=dim, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + padding_idx=padding_idx, + dropout=embedding_dropout, + dtype=dtype) + self.blocks = nn.ModuleList([ + GPTBlock( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + activation=activation, + attention_dropout=attention_dropout, + dropout=dropout, + layernorm_epsilon=layernorm_epsilon, + dtype=dtype, + bias=bias, + apply_post_layernorm=apply_post_layernorm, + fuse_scale_mask_softmax=fuse_scale_mask_softmax, + checkpoint=checkpoint, + ) for _ in range(depth) + ]) + if self.last: + self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + self.head = GPTLMHead(dim=dim, vocab_size=vocab_size, dtype=dtype) + + def forward(self, x=None, input_ids=None, attention_mask=None): + if self.first: + x, attention_mask = self.embed(input_ids, attention_mask) + + for block in self.blocks: + x, attention_mask = block(x, attention_mask) + + if self.last: + x = self.head(self.norm(x)) + + return x + + def _create_gpt_model(**model_kwargs): model = GPT(**model_kwargs) return model +def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs): + logger = get_dist_logger() + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + rank = gpc.get_global_rank() + wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) + parts = partition_uniform(depth, pipeline_size, + num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions + models = [] + for start, end in parts: + model_kwargs['first'] = start == 0 + model_kwargs['last'] = end == depth + model_kwargs['depth'] = end - start + chunk = PipelineGPT(**model_kwargs).to(get_current_device()) + if start == 0: + wrapper.register_parameter(chunk.embed.word_embedding_weight) + elif end == depth: + wrapper.register_parameter(chunk.head.weight) + models.append(chunk) + logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}') + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + return model + + @MODELS.register_module def gpt2_small(**kwargs): model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs) @@ -262,23 +404,47 @@ def gpt2_small(**kwargs): @MODELS.register_module def gpt2_medium(**kwargs): - model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs) + model_kwargs = dict(dim=1024, depth=24, num_heads=8, **kwargs) return _create_gpt_model(**model_kwargs) @MODELS.register_module def gpt2_large(**kwargs): - model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs) + model_kwargs = dict(dim=1536, depth=36, num_heads=12, **kwargs) return _create_gpt_model(**model_kwargs) @MODELS.register_module def gpt2_xl(**kwargs): - model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs) + model_kwargs = dict(dim=1600, depth=48, num_heads=16, **kwargs) return _create_gpt_model(**model_kwargs) @MODELS.register_module -def gpt3(**kwargs): - model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs) +def gpt2_8B(**kwargs): + model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs) return _create_gpt_model(**model_kwargs) + + +@MODELS.register_module +def gpt2_xl_pipeline(**kwargs): + model_kwargs = dict(dim=1600, depth=48, num_heads=20, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) + + +@MODELS.register_module +def gpt2_8B_pipeline(**kwargs): + model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) + + +@MODELS.register_module +def gpt3(**kwargs): + model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs) + return _create_gpt_model(**model_kwargs) + + +@MODELS.register_module +def gpt3_pipeline(**kwargs): + model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs) + return _create_gpt_pipeline_model(**model_kwargs) diff --git a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py index ec4ceb2c1..5e1681da9 100644 --- a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -1,12 +1,14 @@ import torch import torch.distributed as dist -from torch.nn import Parameter -import time from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn import Linear1D_Col, Linear1D_Row +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import (Classifier1D, Embedding1D, Linear1D_Col, Linear1D_Row, VanillaClassifier, + VocabParallelClassifier1D, VocabParallelCrossEntropyLoss1D, VocabParallelEmbedding1D) from colossalai.utils import get_current_device, print_rank_0 -from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE +from torch.nn import Parameter + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear_col(): @@ -144,3 +146,351 @@ def check_linear_row(): check_equal(B_grad, layer.bias.grad) print_rank_0('linear_row backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + env.parallel_input_1d = False + parallel_input_1d = env.parallel_input_1d + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = B_master.clone() + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + if parallel_input_1d: + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + else: + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + if parallel_input_1d: + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = C_master.clone() + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + criterion = VocabParallelCrossEntropyLoss1D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=-1)[i] + out = out.clone() + out.requires_grad = True + + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel loss backward: pass') diff --git a/tests/test_layers/test_1d/checks_1d/common.py b/tests/test_layers/test_1d/checks_1d/common.py index 4489d8233..8b7b28613 100644 --- a/tests/test_layers/test_1d/checks_1d/common.py +++ b/tests/test_layers/test_1d/checks_1d/common.py @@ -9,6 +9,7 @@ SEQ_LENGTH = 8 IMG_SIZE = 16 HIDDEN_SIZE = 8 NUM_CLASSES = 8 +VOCAB_SIZE = 16 def check_equal(A, B): assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py index f4120dc53..58b914b90 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_layers/test_1d/test_1d.py @@ -7,6 +7,7 @@ import pytest import torch import torch.multiprocessing as mp from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers from colossalai.initialize import launch from colossalai.utils import free_port @@ -24,6 +25,7 @@ CONFIG = dict( def check_layer(rank, world_size, port): + disable_existing_loggers() launch(config=CONFIG, rank=rank, world_size=world_size, @@ -33,6 +35,13 @@ def check_layer(rank, world_size, port): check_linear_col() check_linear_row() + check_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_vocab_parallel_loss() gpc.destroy() torch.cuda.empty_cache() diff --git a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_layers/test_2d/checks_2d/check_layer_2d.py index a300a196c..e030e473a 100644 --- a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -1,11 +1,12 @@ import torch -from torch.nn import Parameter - from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn import Linear2D, LayerNorm2D, Classifier2D +from colossalai.nn import (Classifier2D, CrossEntropyLoss2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, + VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2D, + VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D) from colossalai.utils import get_current_device, print_rank_0 -from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal, NUM_CLASSES + +from .common import (BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal) def check_linear(): @@ -57,7 +58,6 @@ def check_linear(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] - # print(f'Rank {gpc.get_global_rank()} A:\n{A}\nRank {gpc.get_global_rank()} W:\n{W}\nRank {gpc.get_global_rank()} b:\n{B}\nRank {gpc.get_global_rank()} C:\n{C}\nRank {gpc.get_global_rank()} out:\n{out}') check_equal(out, C) print_rank_0('linear forward: pass') @@ -90,84 +90,6 @@ def check_linear(): print_rank_0('linear backward: pass') -def check_classifier(): - device = get_current_device() - dtype = torch.float32 - INPUT_SIZE = HIDDEN_SIZE - OUTPUT_SIZE = NUM_CLASSES - - j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) - i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) - - layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randint(5, A_shape, dtype=dtype, device=device) - torch.distributed.broadcast(A_master, src=0) - A = torch.chunk(A_master, DEPTH, dim=0)[i] - A = torch.chunk(A, DEPTH, dim=-1)[j] - A = A.clone() - A.requires_grad = True - - W_shape = (OUTPUT_SIZE, INPUT_SIZE) - W_master = torch.randint(5, W_shape, dtype=dtype, device=device) - torch.distributed.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=-1)[j] - W = torch.chunk(W, DEPTH, dim=-1)[i] - W = W.clone() - layer.weight.data.copy_(W) - # W.requires_grad = True - - B_shape = (OUTPUT_SIZE,) - B_master = torch.randint(5, B_shape, dtype=dtype, device=device) - torch.distributed.broadcast(B_master, src=0) - # B = torch.chunk(B_master, DEPTH, dim=0)[j] - B = B_master.clone() - layer.bias.data.copy_(B) - - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - W_master = W_master.clone() - W_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master - C = torch.chunk(C_master, DEPTH, dim=0)[i] - # C = torch.chunk(C, DEPTH, dim=-1)[j] - - check_equal(out, C) - print_rank_0('classifier forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) - torch.distributed.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=0)[i] - # grad = torch.chunk(grad, DEPTH, dim=-1)[j] - grad = grad.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] - A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] - check_equal(A_grad, A.grad) - - W_grad = W_master.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] - W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = B_master.grad - # B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] - # if i == 0: - check_equal(B_grad, layer.bias.grad) - - print_rank_0('classifier backward: pass') - - def check_layernorm(): device = get_current_device() dtype = torch.float32 @@ -219,6 +141,497 @@ def check_layernorm(): print_rank_0('layer norm backward: pass') +def check_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_patch_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = PatchEmbedding2D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer.cls_token) + torch.nn.init.ones_(layer.pos_embed) + layer = layer.to(device) + + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + proj_weight_master = layer_master.weight.data + torch.distributed.broadcast(proj_weight_master, src=0) + proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[j] + proj_weight = torch.chunk(proj_weight, DEPTH, dim=0)[i] + layer.weight.data.copy_(proj_weight) + proj_bias_master = layer_master.bias.data + torch.distributed.broadcast(proj_bias_master, src=0) + proj_bias = torch.chunk(proj_bias_master, DEPTH, dim=0)[j] + proj_bias = torch.chunk(proj_bias, DEPTH, dim=0)[i] + layer.bias.data.copy_(proj_bias) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(A) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('patch embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[j] + cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[i] + check_equal(cls_grad, layer.cls_token.grad) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[j] + pos_grad = torch.chunk(pos_grad, DEPTH, dim=-1)[i] + check_equal(pos_grad, layer.pos_embed.grad) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.weight.grad) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + bias_grad = torch.chunk(bias_grad, DEPTH)[i] + check_equal(bias_grad, layer.bias.grad) + print_rank_0('patch embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = NUM_CLASSES + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randint(5, A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randint(5, W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[j] + W = torch.chunk(W, DEPTH, dim=-1)[i] + W = W.clone() + layer.weight.data.copy_(W) + # W.requires_grad = True + + B_shape = (OUTPUT_SIZE, ) + B_master = torch.randint(5, B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + # B = torch.chunk(B_master, DEPTH, dim=0)[j] + B = B_master.clone() + layer.bias.data.copy_(B) + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, DEPTH, dim=0)[i] + # C = torch.chunk(C, DEPTH, dim=-1)[j] + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + # grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + # B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + # if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + weight = torch.chunk(weight, DEPTH, dim=-1)[j] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + bias = torch.chunk(bias, DEPTH)[i] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[j] + A = A.clone() + A.requires_grad = True + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, DEPTH)[j] + B_grad = torch.chunk(B_grad, DEPTH)[i] + check_equal(B_grad, layer.bias.grad) + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + layer = Classifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_loss(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + criterion = CrossEntropyLoss2D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + check_equal(out_grad, out.grad) + print_rank_0('cross entropy loss backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + + criterion = VocabParallelCrossEntropyLoss2D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = torch.chunk(out, DEPTH, dim=-1)[j] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel cross entropy loss backward: pass') + + # def check_attention(): # device = get_current_device() # dtype = torch.float32 @@ -257,7 +670,6 @@ def check_layernorm(): # assert A.grad.shape == A.shape # print_rank_0('self attention backward: pass') - # def check_mlp(): # device = get_current_device() # dtype = torch.float32 @@ -291,7 +703,6 @@ def check_layernorm(): # assert A.grad.shape == A.shape # print_rank_0('mlp backward: pass') - # def check_transformerlayer(): # device = get_current_device() # dtype = torch.float32 diff --git a/tests/test_layers/test_2d/checks_2d/common.py b/tests/test_layers/test_2d/checks_2d/common.py index 9eb7f7454..8c855c18b 100644 --- a/tests/test_layers/test_2d/checks_2d/common.py +++ b/tests/test_layers/test_2d/checks_2d/common.py @@ -8,6 +8,9 @@ BATCH_SIZE = 8 SEQ_LENGTH = 8 HIDDEN_SIZE = 8 NUM_CLASSES = 8 +VOCAB_SIZE = 16 +IMG_SIZE = 16 + def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) == True + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py index 83dc80a95..540151010 100644 --- a/tests/test_layers/test_2d/test_2d.py +++ b/tests/test_layers/test_2d/test_2d.py @@ -8,20 +8,17 @@ import torch import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port -from checks_2d.check_layer_2d import * -from checks_2d.check_operation_2d import * +from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, + check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, + check_vocab_parallel_loss) +from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB -CONFIG = dict( - parallel=dict( - pipeline=dict(size=1), - tensor=dict( - size=4, - mode='2d' - ) - ), -) +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')), ) def check_operations(): @@ -33,16 +30,24 @@ def check_operations(): def check_layer(): check_linear() check_layernorm() - check_classifier() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_loss() + check_vocab_parallel_loss() + def check_layer_and_operation(rank, world_size, port): - launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True # check_operations() check_layer() gpc.destroy() diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index 256d8dc59..a8f551093 100644 --- a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -1,11 +1,12 @@ import torch -from torch.nn import Parameter - from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn import Linear2p5D, LayerNorm2p5D, Classifier2p5D -from colossalai.utils import get_current_device -from colossalai.utils import print_rank_0 +from colossalai.nn import (Classifier2p5D, CrossEntropyLoss2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, + PatchEmbedding2p5D, VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2p5D, + VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D) +from colossalai.utils import get_current_device, print_rank_0 +from torch.nn import Parameter + from .common import * @@ -19,11 +20,7 @@ def check_linear(): j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - layer = Linear2p5D( - INPUT_SIZE, - OUTPUT_SIZE, - dtype=dtype, - skip_bias_add=False) + layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) A_master = torch.randn(A_shape, dtype=dtype, device=device) @@ -94,86 +91,6 @@ def check_linear(): print_rank_0('linear backward: pass') -def check_classifier(): - device = get_current_device() - dtype = torch.float32 - INPUT_SIZE = HIDDEN_SIZE - OUTPUT_SIZE = NUM_CLASSES - - j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - - layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randint(5, A_shape, dtype=dtype, device=device) - torch.distributed.broadcast(A_master, src=0) - A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] - A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] - A = A.clone() - A.requires_grad = True - - W_shape = (OUTPUT_SIZE, INPUT_SIZE) - W_master = torch.randint(5, W_shape, dtype=dtype, device=device) - torch.distributed.broadcast(W_master, src=0) - # W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j] - W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j] - W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i] - W = W.clone() - layer.weight.data.copy_(W) - # W.requires_grad = True - - B_shape = (OUTPUT_SIZE,) - B_master = torch.randint(5, B_shape, dtype=dtype, device=device) - torch.distributed.broadcast(B_master, src=0) - # B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] - B = B_master.clone() - layer.bias.data.copy_(B) - - - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - W_master = W_master.clone() - W_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master - C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] - # C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] - - check_equal(out, C) - print_rank_0('classifier forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) - torch.distributed.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] - # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] - grad = grad.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] - A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] - check_equal(A_grad, A.grad) - - W_grad = W_master.grad - W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] - W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = B_master.grad - # B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] - # if i == 0: - check_equal(B_grad, layer.bias.grad) - - print_rank_0('classifier backward: pass') - - def check_layernorm(): device = get_current_device() dtype = torch.float32 @@ -184,9 +101,7 @@ def check_layernorm(): j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - layernorm = LayerNorm2p5D( - INPUT_SIZE, - dtype=dtype) + layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) A_master = torch.randn(A_shape, dtype=dtype, device=device) @@ -228,6 +143,500 @@ def check_layernorm(): print_rank_0('layer norm backward: pass') +def check_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_patch_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer.cls_token) + torch.nn.init.ones_(layer.pos_embed) + layer = layer.to(device) + + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + proj_weight_master = layer_master.weight.data + torch.distributed.broadcast(proj_weight_master, src=0) + proj_weight = torch.chunk(proj_weight_master, TESSERACT_DIM, dim=0)[j] + proj_weight = torch.chunk(proj_weight, TESSERACT_DIM, dim=0)[i] + layer.weight.data.copy_(proj_weight) + proj_bias_master = layer_master.bias.data + torch.distributed.broadcast(proj_bias_master, src=0) + proj_bias = torch.chunk(proj_bias_master, TESSERACT_DIM, dim=0)[j] + proj_bias = torch.chunk(proj_bias, TESSERACT_DIM, dim=0)[i] + layer.bias.data.copy_(proj_bias) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(A) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('patch embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, TESSERACT_DIM, dim=-1)[j] + cls_grad = torch.chunk(cls_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(cls_grad, layer.cls_token.grad) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, TESSERACT_DIM, dim=-1)[j] + pos_grad = torch.chunk(pos_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(pos_grad, layer.pos_embed.grad) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + check_equal(B_grad, layer.weight.grad) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j] + bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i] + check_equal(bias_grad, layer.bias.grad) + print_rank_0('patch embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] + B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = NUM_CLASSES + + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + + layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randint(5, A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randint(5, W_shape, dtype=dtype, device=device) + torch.distributed.broadcast(W_master, src=0) + # W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j] + W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j] + W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i] + W = W.clone() + layer.weight.data.copy_(W) + # W.requires_grad = True + + B_shape = (OUTPUT_SIZE, ) + B_master = torch.randint(5, B_shape, dtype=dtype, device=device) + torch.distributed.broadcast(B_master, src=0) + # B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] + B = B_master.clone() + layer.bias.data.copy_(B) + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + # C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + # B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j] + # if i == 0: + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=0)[i] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[j] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, TESSERACT_DIM)[j] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] + A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] + A = A.clone() + A.requires_grad = True + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] + A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j] + if i == 0: + check_equal(B_grad, layer.bias.grad) + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i] + embed.weight.data.copy_(weight) + + layer = Classifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j] + weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] + C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] + grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] + W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_loss(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + criterion = CrossEntropyLoss2p5D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] + check_equal(out_grad, out.grad) + print_rank_0('cross entropy loss backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + + criterion = VocabParallelCrossEntropyLoss2p5D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] + out = torch.chunk(out, TESSERACT_DIM, dim=-1)[j] + out = out.clone() + out.requires_grad = True + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel cross entropy loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] + out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel cross entropy loss backward: pass') + + # def check_attention(): # device = get_current_device() # dtype = torch.float32 @@ -267,7 +676,6 @@ def check_layernorm(): # assert A.grad.shape == A.shape # print_rank_0('self attention backward: pass') - # def check_mlp(): # device = get_current_device() # dtype = torch.float32 @@ -304,7 +712,6 @@ def check_layernorm(): # assert A.grad.shape == A.shape # print_rank_0('mlp backward: pass') - # def check_transformerlayer(): # device = get_current_device() # dtype = torch.float32 @@ -344,4 +751,4 @@ def check_layernorm(): # out.backward(grad) # assert A.grad.shape == A.shape -# print_rank_0('transformerlayer backward: pass') \ No newline at end of file +# print_rank_0('transformerlayer backward: pass') diff --git a/tests/test_layers/test_2p5d/checks_2p5d/common.py b/tests/test_layers/test_2p5d/checks_2p5d/common.py index 23ff24b7c..aff85f109 100644 --- a/tests/test_layers/test_2p5d/checks_2p5d/common.py +++ b/tests/test_layers/test_2p5d/checks_2p5d/common.py @@ -5,8 +5,10 @@ TESSERACT_DEP = 2 BATCH_SIZE = 8 SEQ_LENGTH = 8 HIDDEN_SIZE = 8 -NUM_CLASSES = 3 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 +IMG_SIZE = 16 def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True \ No newline at end of file + assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) \ No newline at end of file diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py index 4de4015bf..da0848d06 100644 --- a/tests/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_layers/test_2p5d/test_2p5d.py @@ -5,10 +5,10 @@ import torch import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port -from checks_2p5d.check_layer_2p5d import (check_classifier, check_layernorm, - check_linear) +from checks_2p5d.check_layer_2p5d import * from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB CONFIG = dict( @@ -28,10 +28,19 @@ def check_operations(): def check_layer(): check_linear() check_layernorm() - check_classifier() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_loss() + check_vocab_parallel_loss() def check_layer_and_operation(rank, world_size, port): + disable_existing_loggers() launch(config=CONFIG, rank=rank, world_size=world_size, @@ -39,6 +48,9 @@ def check_layer_and_operation(rank, world_size, port): port=port, backend='nccl') + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True check_operations() check_layer() gpc.destroy() diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py index c05960acc..087bb0781 100644 --- a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -3,16 +3,17 @@ import time -from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D) +import torch +from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.core import global_context from colossalai.logging import get_dist_logger -from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VanillaClassifier, - VanillaPatchEmbedding) +from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, + VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D, + VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D) from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.utils import get_current_device, print_rank_0 -from .common import * -import torch +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear(): @@ -27,9 +28,9 @@ def check_linear(): weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - j = A_rank = global_context.get_local_rank(input_parallel_mode) - i = B_rank = global_context.get_local_rank(weight_parallel_mode) - k = C_rank = global_context.get_local_rank(output_parallel_mode) + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True) layer = layer.to(device) @@ -112,9 +113,9 @@ def check_layernorm(): weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - j = A_rank = global_context.get_local_rank(input_parallel_mode) - i = B_rank = global_context.get_local_rank(weight_parallel_mode) - k = C_rank = global_context.get_local_rank(output_parallel_mode) + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype) norm = norm.to(device) @@ -186,7 +187,7 @@ def check_layernorm(): return fwd_end - fwd_start, bwd_end - bwd_start -def check_classifier(): +def check_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() @@ -197,9 +198,9 @@ def check_classifier(): weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - j = A_rank = global_context.get_local_rank(input_parallel_mode) - i = B_rank = global_context.get_local_rank(weight_parallel_mode) - k = C_rank = global_context.get_local_rank(output_parallel_mode) + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True) layer = layer.to(device) @@ -229,14 +230,14 @@ def check_classifier(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'head forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), - logger) + 'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} head forward: {}'.format(rank, check_equal(out, C))) + logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -249,7 +250,7 @@ def check_classifier(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -257,23 +258,275 @@ def check_classifier(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} head backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format( + rank, check_equal(A_grad, A.grad))) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] if j == k: - logger.info('Rank {} head backward (weight_grad): {}'.format(rank, - check_equal(B_grad, layer.weight.grad))) + logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, layer.weight.grad))) else: - logger.info('Rank {} head backward (weight_grad): {}'.format(rank, layer.weight.grad is None)) + logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( + rank, layer.weight.grad is None)) bias_grad = layer_master.bias.grad - logger.info('Rank {} head backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad))) + logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, layer.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start -def check_embed(): +def check_vocab_parallel_classifier_no_given_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + bias_master = layer_master.bias.data + torch.distributed.broadcast(bias_master, src=0) + bias = torch.chunk(bias_master, DEPTH)[j] + layer.bias.data.copy_(bias) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=0)[i] + A = torch.chunk(A, DEPTH, dim=-1)[k] + A = torch.chunk(A, DEPTH, dim=0)[j] + A = A.clone() + A.requires_grad = True + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = torch.chunk(C, DEPTH, dim=0)[k] + logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = torch.chunk(grad, DEPTH, dim=0)[k] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), + logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] + A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format( + rank, check_equal(A_grad, A.grad))) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, layer.weight.grad))) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[j] + logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format( + rank, check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_classifier_given_embed_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + embed = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[k] + embed.weight.data.copy_(weight) + + layer = Classifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(embed(A)) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + if j == k: + logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( + rank, check_equal(B_grad, embed.weight.grad))) + else: + logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( + rank, embed.weight.grad is None)) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_classifier_given_embed_weight(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + embed.weight.data.copy_(weight) + + layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) + layer = layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(embed(A)) + torch.cuda.synchronize() + fwd_end = time.time() + print_rank_0( + 'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[j] + C = torch.chunk(C, DEPTH, dim=0)[k] + logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[j] + grad = torch.chunk(grad, DEPTH, dim=0)[k] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), + logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, + check_equal(B_grad, + embed.weight.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_patch_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() @@ -283,9 +536,9 @@ def check_embed(): weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - j = A_rank = global_context.get_local_rank(input_parallel_mode) - i = B_rank = global_context.get_local_rank(weight_parallel_mode) - k = C_rank = global_context.get_local_rank(output_parallel_mode) + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) torch.nn.init.ones_(layer.cls_token) @@ -310,18 +563,99 @@ def check_embed(): A_master = torch.randn(A_shape, dtype=dtype, device=device) torch.distributed.broadcast(A_master, src=0) A = A_master.clone() - A.requires_grad = True fwd_start = time.time() out = layer(A) torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), logger) + 'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), + fwd_end - fwd_start), logger) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + cls_grad_master = layer_master.cls_token.grad + cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k] + logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad))) + + pos_grad_master = layer_master.pos_embed.grad + pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k] + logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank, + check_equal(pos_grad, layer.pos_embed.grad))) + + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] + logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank, + check_equal(B_grad, layer.weight.grad))) + + bias_grad = layer_master.bias.grad + bias_grad = torch.chunk(bias_grad, DEPTH)[k] + logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank, + check_equal(bias_grad, layer.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_embed(): + rank = torch.distributed.get_rank() + device = get_current_device() + logger = get_dist_logger() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE) + layer = layer.to(dtype).to(device) + layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), + fwd_end - fwd_start), + ranks=[0]) A_master = A_master.clone() - A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] @@ -329,7 +663,7 @@ def check_embed(): logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -339,30 +673,88 @@ def check_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) grad_master = grad_master.clone() C_master.backward(grad_master) - cls_grad_master = layer_master.cls_token.grad - cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k] - logger.info('Rank {} embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad))) + B_grad = layer_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + if j == k: + logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + else: + logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None)) - pos_grad_master = layer_master.pos_embed.grad - pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k] - logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(rank, check_equal(pos_grad, layer.pos_embed.grad))) + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_embed(): + rank = torch.distributed.get_rank() + device = get_current_device() + logger = get_dist_logger() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) + layer = layer.to(dtype).to(device) + layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + layer_master = layer_master.to(dtype).to(device) + + weight_master = layer_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[k] + layer.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + + fwd_start = time.time() + out = layer(A) + torch.cuda.synchronize() + fwd_end = time.time() + logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), + ranks=[0]) + + A_master = A_master.clone() + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=0)[i] + C = torch.chunk(C, DEPTH, dim=-1)[k] + C = torch.chunk(C, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + grad = torch.chunk(grad, DEPTH, dim=-1)[k] + grad = torch.chunk(grad, DEPTH, dim=0)[j] + grad = grad.clone() + bwd_start = time.time() + out.backward(grad) + torch.cuda.synchronize() + bwd_end = time.time() + logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + grad_master = grad_master.clone() + C_master.backward(grad_master) B_grad = layer_master.weight.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] - if j == k: - logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, check_equal(B_grad, + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] + logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, + check_equal(B_grad, layer.weight.grad))) - else: - logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, layer.weight.grad is None)) - - bias_grad = layer_master.bias.grad - bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} embed backward (proj_bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -375,11 +767,9 @@ def check_loss(): input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) - j = A_rank = global_context.get_local_rank(input_parallel_mode) - i = B_rank = global_context.get_local_rank(weight_parallel_mode) - k = C_rank = global_context.get_local_rank(output_parallel_mode) + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) criterion = CrossEntropyLoss3D() criterion_master = torch.nn.CrossEntropyLoss() @@ -397,24 +787,79 @@ def check_loss(): fwd_start = time.time() loss = criterion(out, target_master) fwd_end = time.time() - print_rank_0( - 'loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), - logger) + logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), + fwd_end - fwd_start), + ranks=[0]) out_master = out_master.clone() out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) - logger.info('Rank {} CrossEntropyLoss forward: {}'.format(rank, check_equal(loss, loss_master))) + logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) bwd_start = time.time() loss.backward() bwd_end = time.time() - print_rank_0('loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) loss_master.backward() out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] - logger.info('Rank {} CrossEntropyLoss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + + +def check_vocab_parallel_loss(): + rank = torch.distributed.get_rank() + logger = get_dist_logger() + device = get_current_device() + dtype = torch.float32 + + input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + + j = global_context.get_local_rank(input_parallel_mode) + i = global_context.get_local_rank(weight_parallel_mode) + k = global_context.get_local_rank(output_parallel_mode) + + criterion = VocabParallelCrossEntropyLoss3D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=0)[i] + out = torch.chunk(out, DEPTH, dim=-1)[k] + out = torch.chunk(out, DEPTH, dim=0)[j] + out = out.clone() + out.requires_grad = True + + fwd_start = time.time() + loss = criterion(out, target_master) + fwd_end = time.time() + logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), + ranks=[0]) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + + bwd_start = time.time() + loss.backward() + bwd_end = time.time() + logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + + loss_master.backward() + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k] + out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] + logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) return fwd_end - fwd_start, bwd_end - bwd_start diff --git a/tests/test_layers/test_3d/checks_3d/common.py b/tests/test_layers/test_3d/checks_3d/common.py index a7c6b8678..43a04f649 100644 --- a/tests/test_layers/test_3d/checks_3d/common.py +++ b/tests/test_layers/test_3d/checks_3d/common.py @@ -10,6 +10,7 @@ HIDDEN_SIZE = 8 NUM_CLASSES = 8 NUM_BLOCKS = 2 IMG_SIZE = 16 +VOCAB_SIZE = 16 def check_equal(A, B): eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index 73bdbb5bd..c6803ab2b 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -7,9 +7,14 @@ import torch import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port -from checks_3d.check_layer_3d import * +from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, + check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, + check_vocab_parallel_loss) CONFIG = dict( parallel=dict( @@ -23,13 +28,23 @@ CONFIG = dict( def check_layer(): check_linear() check_layernorm() - check_classifier() - # check_embed() - # check_loss() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_embed() + check_patch_embed() + check_vocab_parallel_embed() + check_loss() + check_vocab_parallel_loss() def check_layer_and_operation(rank, world_size, port): + disable_existing_loggers() launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True check_layer() gpc.destroy() torch.cuda.empty_cache()