moved env variables to global variables; (#215)

added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
pull/232/head
アマデウス 2022-02-14 11:15:02 +08:00 committed by Frank Lee
parent b82d60be02
commit 9ee197d0e9
63 changed files with 4304 additions and 1040 deletions

4
.gitignore vendored
View File

@ -137,8 +137,4 @@ dmypy.json
.DS_Store .DS_Store
#data/ #data/
# launcher setting
tests/launcher/log
tests/launcher/personal
docs/.build docs/.build

View File

@ -5,7 +5,7 @@ repos:
- id: yapf - id: yapf
args: ['--style=google', '--parallel', '--in-place'] args: ['--style=google', '--parallel', '--in-place']
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: '' rev: '4.0.1'
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format

View File

@ -4,8 +4,9 @@
import torch.nn as nn import torch.nn as nn
try: try:
import apex.amp as apex_amp import apex.amp as apex_amp
except: except ImportError:
pass raise ImportError('Cannot import apex.amp correctly.')
from torch import Tensor from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer

View File

@ -30,7 +30,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
""" """
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = [tensor] out = tensor
work = None work = None
else: else:
shape = list(tensor.shape) shape = list(tensor.shape)
@ -96,34 +96,40 @@ def all_reduce(tensor: Tensor,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor
work = None work = None
else: else:
work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op) out = tensor.contiguous()
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op: if async_op:
return tensor, work return out, work
else: else:
return tensor return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor
work = None work = None
else: else:
work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op) out = tensor.contiguous()
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op: if async_op:
return tensor, work return out, work
else: else:
return tensor return out
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor
work = None work = None
else: else:
work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) out = tensor.contiguous()
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op: if async_op:
return tensor, work return out, work
else: else:
return tensor return out

View File

@ -19,23 +19,12 @@ INITIALIZER_MAPPING = {
'moe': 'Initializer_Moe' 'moe': 'Initializer_Moe'
} }
# 1D parallel # 3D parallelism groups
PARALLEL_INPUT_1D = 'parallel_input_1d' INPUT_GROUP_3D = 'input_group_3d'
WEIGHT_GROUP_3D = 'weight_group_3d'
OUTPUT_GROUP_3D = 'output_group_3d'
# 2D paralllel # Attributes of tensor parallel parameters
SUMMA_DIM = 'SUMMA_DIM'
# 2.5D paralllel
TESSERACT_DIM = 'TESSERACT_DIM'
TESSERACT_DEP = 'TESSERACT_DEP'
# 3D parallel
DEPTH_3D = 'DEPTH_3D'
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
# Tensor parallel attributes
IS_TENSOR_PARALLEL = 'is_tensor_parallel' IS_TENSOR_PARALLEL = 'is_tensor_parallel'
NUM_PARTITIONS = 'num_partitions' NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]

View File

@ -8,14 +8,15 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode
from colossalai.global_variables import moe_env
class ParallelContext: class ParallelContext:
@ -307,7 +308,6 @@ class ParallelContext:
port: int port: int
): ):
"""Initializes the global distributed environment """Initializes the global distributed environment
:param rank: rank for the default process group :param rank: rank for the default process group
:type rank: int :type rank: int
:param world_size: world size of the default process group :param world_size: world size of the default process group
@ -389,7 +389,8 @@ class ParallelContext:
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode'] tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode) env.mode = tensor_parallel_mode
self.check_sanity() self.check_sanity()
pg_init = [] pg_init = []

View File

@ -1,22 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch.distributed as dist
from colossalai.context import Config import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module
class Initializer_1D(ProcessGroupInitializer): class Initializer_1D(ProcessGroupInitializer):
"""A ProcessGroupInitializer for 1d tensor parallelism. '''A ProcessGroupInitializer for 1d tensor parallelism.
'''
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -24,7 +20,7 @@ class Initializer_1D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu. """Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode) :return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple :rtype: Tuple
""" """
@ -33,7 +29,7 @@ class Initializer_1D(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_1D mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = '' env.parallel_input_1d = False
for i in range(self.num_group): for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]

View File

@ -1,34 +1,31 @@
import math import math
import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import SUMMA_DIM
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from colossalai.global_variables import tensor_parallel_env as env
def _check_summa_env_var(summa_dim): def _check_summa_env_var(summa_dim):
# check environment variable for SUMMA # check environment variable for SUMMA
env_summa_dim = os.environ.get(SUMMA_DIM, None) env_summa_dim = env.summa_dim
if env_summa_dim: if env_summa_dim:
assert int(env_summa_dim) == summa_dim, \ assert int(env_summa_dim) == summa_dim, \
'SUMMA_DIM has been set in the current environment and ' \ 'SUMMA_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized' 'does not match with the value passed to this initialized'
else: else:
os.environ[SUMMA_DIM] = str(summa_dim) env.summa_dim = summa_dim
class Initializer_2D_Row(ProcessGroupInitializer): class Initializer_2D_Row(ProcessGroupInitializer):
"""2d tensor parallel initialization among rows. """2d tensor parallel initialization among rows.
:param num_group: The number of all tensor groups :param num_group: The number of all tensor groups
:param summa_dim: The dimension of SUMMA :param summa_dim: The dimension of SUMMA
:param args: Args used to initialize base class :param args: Args used to initialize base class
:param kwargs: Kwargs used to initialize base class :param kwargs: Kwargs used to initialize base class
:type num_group: int :type num_group: int
:type summa_dim: int :type summa_dim: int
""" """
@ -132,7 +129,7 @@ class Initializer_2D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor parallelism's information :return: 2D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
""" """

View File

@ -2,22 +2,21 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
from colossalai.context import Config from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
def _check_tesseract_env_var(tesseract_dim: int, def _check_tesseract_env_var(tesseract_dim: int,
tesseract_dep: int): tesseract_dep: int):
# check environment variable for TESSERACT # check global variable for TESSERACT
env_tesseract_dim = os.environ.get(TESSERACT_DIM, None) env_tesseract_dim = env.tesseract_dim
env_tesseract_dep = os.environ.get(TESSERACT_DEP, None) env_tesseract_dep = env.tesseract_dep
if env_tesseract_dim and env_tesseract_dep: if env_tesseract_dim and env_tesseract_dep:
assert int(env_tesseract_dim) == tesseract_dim, \ assert int(env_tesseract_dim) == tesseract_dim, \
@ -27,8 +26,8 @@ def _check_tesseract_env_var(tesseract_dim: int,
'TESSERACT_DEP has been set in the current environment and ' \ 'TESSERACT_DEP has been set in the current environment and ' \
'does not match with the value passed to this initialized' 'does not match with the value passed to this initialized'
else: else:
os.environ[TESSERACT_DIM] = str(tesseract_dim) env.tesseract_dim = tesseract_dim
os.environ[TESSERACT_DEP] = str(tesseract_dep) env.tesseract_dep = tesseract_dep
# i row j col k dep # i row j col k dep
@ -245,7 +244,6 @@ class Initializer_2p5D(ProcessGroupInitializer):
:param pipeline_parallel_size: Size of pipeline parallel :param pipeline_parallel_size: Size of pipeline parallel
:param tensor_parallel_size: Size of tensor parallel :param tensor_parallel_size: Size of tensor parallel
:param depth: The depth of 2p5d parallel :param depth: The depth of 2p5d parallel
:type rank: int :type rank: int
:type world_size: int :type world_size: int
:type config: Config :type config: Config
@ -281,7 +279,7 @@ class Initializer_2p5D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: Whole 2p5D tensor parallelism's information :return: Whole 2p5D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
""" """

View File

@ -2,10 +2,9 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
@ -13,15 +12,15 @@ from .process_group_initializer import ProcessGroupInitializer
def _check_depth_env_var(depth): def _check_depth_env_var(depth):
# check environment variable for SUMMA # check global variable
env_depth = os.environ.get(DEPTH_3D, None) env_depth = env.depth_3d
if env_depth: if env_depth:
assert int(env_depth) == depth, \ assert int(env_depth) == depth, \
'DEPTH_3D has been set in the current environment and ' \ 'DEPTH_3D has been set in the current environment and ' \
'does not match with the value passed to this initialized' 'does not match with the value passed to this initialized'
else: else:
os.environ[DEPTH_3D] = str(depth) env.depth_3d = depth
class Initializer_3D_Input(ProcessGroupInitializer): class Initializer_3D_Input(ProcessGroupInitializer):
@ -34,6 +33,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
:type num_group: int :type num_group: int
:type depth: int :type depth: int
""" """
def __init__(self, num_group: int, depth: int, *args): def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args) super().__init__(*args)
self.num_group = num_group self.num_group = num_group
@ -50,15 +50,12 @@ class Initializer_3D_Input(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT mode = ParallelMode.PARALLEL_3D_INPUT
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D env.input_group_3d = mode
for h in range(self.num_group): for h in range(self.num_group):
for i in range(self.depth): for i in range(self.depth):
for k in range(self.depth): for k in range(self.depth):
ranks = [ ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)]
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for j in range(self.depth)
]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
@ -97,15 +94,12 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT mode = ParallelMode.PARALLEL_3D_WEIGHT
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D env.weight_group_3d = mode
for h in range(self.num_group): for h in range(self.num_group):
for k in range(self.depth): for k in range(self.depth):
for j in range(self.depth): for j in range(self.depth):
ranks = [ ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)]
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for i in range(self.depth)
]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
@ -118,7 +112,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
class Initializer_3D_Output(ProcessGroupInitializer): class Initializer_3D_Output(ProcessGroupInitializer):
"""3D tensor parallel initialization among weight. """3D tensor parallel initialization among output.
:param num_group: The number of all tensor groups :param num_group: The number of all tensor groups
:param depth: Depth of 3D parallelism :param depth: Depth of 3D parallelism
@ -144,15 +138,12 @@ class Initializer_3D_Output(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT mode = ParallelMode.PARALLEL_3D_OUTPUT
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D env.output_group_3d = mode
for h in range(self.num_group): for h in range(self.num_group):
for i in range(self.depth): for i in range(self.depth):
for j in range(self.depth): for j in range(self.depth):
ranks = [ ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)]
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for k in range(self.depth)
]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
@ -170,6 +161,7 @@ class Initializer_3D(ProcessGroupInitializer):
:param args: Args used to initialize ProcessGroupInitializer :param args: Args used to initialize ProcessGroupInitializer
""" """
def __init__(self, *args): def __init__(self, *args):
super().__init__(*args) super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
@ -178,16 +170,13 @@ class Initializer_3D(ProcessGroupInitializer):
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})'
_check_depth_env_var(self.depth) _check_depth_env_var(self.depth)
self.input_initializer = Initializer_3D_Input(self.num_group, self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args)
self.depth, *args) self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args)
self.weight_initializer = Initializer_3D_Weight( self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args)
self.num_group, self.depth, *args)
self.output_initializer = Initializer_3D_Output(
self.num_group, self.depth, *args)
def init_dist_group(self): def init_dist_group(self):
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. """Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information :return: 3D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
""" """

View File

@ -9,4 +9,4 @@ from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', __all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
'MoeGradientHandler', 'SequenceParallelGradientHandler'] 'MoeGradientHandler', 'SequenceParallelGradientHandler']

View File

@ -9,7 +9,6 @@ from typing import Iterable, Callable
from .._base_engine import Engine from .._base_engine import Engine
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.nn.layer import split_batch
class BaseSchedule(ABC): class BaseSchedule(ABC):
@ -69,7 +68,6 @@ class BaseSchedule(ABC):
self.batch_size = data.size(0) self.batch_size = data.size(0)
else: else:
self.batch_size = next(iter(data.values())).size(0) self.batch_size = next(iter(data.values())).size(0)
data, label = split_batch(data), split_batch(label)
if to_gpu: if to_gpu:
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
return data, label return data, label

View File

@ -1,3 +1,51 @@
from typing import Optional
class TensorParallelEnv(object):
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = object.__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self, *args, **kwargs):
self.load(*args, **kwargs)
def load(self,
mode: Optional[str] = None,
vocab_parallel: bool = False,
parallel_input_1d: bool = False,
summa_dim: int = None,
tesseract_dim: int = None,
tesseract_dep: int = None,
depth_3d: int = None,
input_group_3d=None,
weight_group_3d=None,
output_group_3d=None):
self.mode = mode
self.vocab_parallel = vocab_parallel
self.parallel_input_1d = parallel_input_1d
self.summa_dim = summa_dim
self.tesseract_dim = tesseract_dim
self.tesseract_dep = tesseract_dep
self.depth_3d = depth_3d
self.input_group_3d = input_group_3d
self.weight_group_3d = weight_group_3d
self.output_group_3d = output_group_3d
def save(self):
return dict(mode=self.mode,
vocab_parallel=self.vocab_parallel,
parallel_input_1d=self.parallel_input_1d,
summa_dim=self.summa_dim,
tesseract_dim=self.tesseract_dim,
tesseract_dep=self.tesseract_dep,
depth_3d=self.depth_3d,
input_group_3d=self.input_group_3d,
weight_group_3d=self.weight_group_3d,
output_group_3d=self.output_group_3d)
class MoeEnv: class MoeEnv:
@ -33,4 +81,6 @@ class MoeEnv:
return self.aux_loss return self.aux_loss
tensor_parallel_env = TensorParallelEnv()
moe_env = MoeEnv() moe_env = MoeEnv()

View File

@ -37,17 +37,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \ grad_input, grad_weight, grad_bias \
= colossal_layer_norm_cuda.backward_affine( = colossal_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape, input_, ctx.normalized_shape,
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global colossal_layer_norm_cuda global colossal_layer_norm_cuda
@ -61,8 +61,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape) self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape)) self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
self.bias = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):

View File

@ -1,7 +1,7 @@
from ._utils import split_batch from ._utils import partition_batch
from .dropout import Dropout from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear from .linear import Classifier, Linear
from .normalization import LayerNorm from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch'] __all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch']

View File

@ -2,13 +2,13 @@ from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_tensor_3d from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d} _parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
def split_batch(input_) -> Tensor: def partition_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode() tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch: if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, dict): if isinstance(input_, dict):

View File

@ -1,8 +1,5 @@
from contextlib import nullcontext
import torch.nn as nn import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.utils import conditional_context
from ..parallel_1d import * from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
@ -26,6 +23,8 @@ class Dropout(nn.Module):
self.drop = nn.Dropout(p, inplace) self.drop = nn.Dropout(p, inplace)
def forward(self, *args): def forward(self, *args):
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR) if self.tensor_parallel in [None, '1d']:
with cm:
return self.drop(*args) return self.drop(*args)
else:
with seed(ParallelMode.TENSOR):
return self.drop(*args)

View File

@ -1,5 +1,5 @@
import math import math
from typing import Callable, Optional from typing import Callable
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import dtype, nn from torch import dtype, nn
@ -12,10 +12,21 @@ from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
from ..vanilla import * from ..vanilla import *
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D} _parallel_embedding = {
'2d': Embedding2D,
'2.5d': Embedding2p5D,
'3d': Embedding3D,
}
_vocab_parallel_embedding = {
'1d': VocabParallelEmbedding1D,
'2d': VocabParallelEmbedding2D,
'2.5d': VocabParallelEmbedding2p5D,
'3d': VocabParallelEmbedding3D
}
_parallel_patchembedding = { _parallel_patchembedding = {
'None': VanillaPatchEmbedding, None: VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding, '1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D, '2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D, '2.5d': PatchEmbedding2p5D,
@ -40,26 +51,23 @@ class Embedding(nn.Module):
:param args: Args used in F.embedding :param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding :param kwargs: Kwargs used in F.embedding
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
vocab_parallel_limit: int = 2048,
*args, *args,
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None': if tensor_parallel is None or (tensor_parallel == '1d' and num_embeddings <= vocab_parallel_limit):
self.embed = nn.Embedding(num_embeddings, self.embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
embedding_dim, **kwargs).to(dtype).to(get_current_device())
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else: elif num_embeddings <= vocab_parallel_limit:
self.embed = _parallel_embedding[tensor_parallel]( self.embed = _parallel_embedding[tensor_parallel](
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
@ -69,6 +77,16 @@ class Embedding(nn.Module):
*args, *args,
**kwargs, **kwargs,
) )
else:
self.embed = _vocab_parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property @property
def weight(self): def weight(self):
@ -101,16 +119,19 @@ class PatchEmbedding(nn.Module):
:param position_embed_initializer: The intializer of position embedding, defaults to zero :param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional :type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self,
img_size: int, def __init__(
patch_size: int, self,
in_chans: int, img_size: int,
embed_size: int, patch_size: int,
dtype: dtype = None, in_chans: int,
flatten: bool = True, embed_size: int,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), dtype: dtype = None,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), flatten: bool = True,
position_embed_initializer: Callable = init.zeros_()) -> None: weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()
) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
self.embed = _parallel_patchembedding[tensor_parallel]( self.embed = _parallel_patchembedding[tensor_parallel](

View File

@ -1,7 +1,6 @@
import math import math
from typing import Callable, Optional from typing import Callable
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import dtype, nn from torch import dtype, nn
@ -16,13 +15,20 @@ from ..vanilla import *
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} _parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = { _parallel_classifier = {
'None': VanillaClassifier, None: VanillaClassifier,
'1d': Classifier1D, '1d': Classifier1D,
'2d': Classifier2D, '2d': Classifier2D,
'2.5d': Classifier2p5D, '2.5d': Classifier2p5D,
'3d': Classifier3D '3d': Classifier3D
} }
_vocab_parallel_classifier = {
'1d': VocabParallelClassifier1D,
'2d': VocabParallelClassifier2D,
'2.5d': VocabParallelClassifier2p5D,
'3d': VocabParallelClassifier3D
}
class Linear(nn.Module): class Linear(nn.Module):
""" """
@ -40,8 +46,9 @@ class Linear(nn.Module):
:type weight_initializer: typing.Callable, optional :type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
:param kwargs: Kwargs used for initialization :param kwargs: Kwargs used for particular parallelisms
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
@ -52,10 +59,10 @@ class Linear(nn.Module):
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None': if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype) self.layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features) weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias: if self.layer.bias is not None:
bias_initializer(self.layer.bias, fan_in=in_features) bias_initializer(self.layer.bias, fan_in=in_features)
else: else:
self.layer = _parallel_linear[tensor_parallel]( self.layer = _parallel_linear[tensor_parallel](
@ -97,26 +104,38 @@ class Classifier(nn.Module):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(
self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
weight: nn.Parameter = None, weight: nn.Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1) bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
) -> None: vocab_parallel_limit: int = 2048) -> None:
super().__init__() super().__init__()
self.layer = _parallel_classifier[get_tensor_parallel_mode()]( tensor_parallel = get_tensor_parallel_mode()
in_features, if num_classes <= vocab_parallel_limit or tensor_parallel is None:
num_classes, self.layer = _parallel_classifier[tensor_parallel](
weight=weight, in_features,
bias=bias, num_classes,
dtype=dtype, weight=weight,
weight_initializer=weight_initializer, bias=bias,
bias_initializer=bias_initializer, dtype=dtype,
) weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
else:
self.layer = _vocab_parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property @property
def weight(self): def weight(self):

View File

@ -1,7 +1,6 @@
from typing import Optional
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import nn from torch import nn
from colossalai import kernel
from ... import init as init from ... import init as init
from ..parallel_1d import * from ..parallel_1d import *
@ -11,7 +10,12 @@ from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
from ..vanilla import * from ..vanilla import *
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D} _parallel_layernorm = {
'1d': kernel.LayerNorm,
'2d': LayerNorm2D,
'2.5d': LayerNorm2p5D,
'3d': LayerNorm3D
}
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
@ -28,11 +32,12 @@ class LayerNorm(nn.Module):
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
""" """
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None: def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']: if tensor_parallel is None:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) self.norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
else: else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)

View File

@ -1,4 +1,7 @@
from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row from .layers import (Classifier1D, Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row,
from .layers import MixedFusedLayerNorm1D as LayerNorm1D VocabParallelClassifier1D, VocabParallelEmbedding1D)
__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D'] __all__ = [
'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
'VocabParallelEmbedding1D'
]

View File

@ -1,21 +1,20 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import PARALLEL_INPUT_1D
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from ..utils import divide from ..utils import divide
def set_parallel_input(input_parallel: bool): def set_parallel_input(input_parallel: bool):
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else '' env.parallel_input_1d = input_parallel
def get_parallel_input(): def get_parallel_input():
return bool(os.environ[PARALLEL_INPUT_1D]) return env.parallel_input_1d
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):

View File

@ -2,8 +2,6 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import numbers
from contextlib import nullcontext
from typing import Callable, Tuple from typing import Callable, Tuple
import torch import torch
@ -11,17 +9,17 @@ import torch.nn.functional as F
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._operation import FusedLayerNormAffineFunction1D from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad,
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, reduce_input, set_parallel_input, split_forward_gather_backward)
split_forward_gather_backward)
@LAYERS.register_module @LAYERS.register_module
@ -44,6 +42,7 @@ class Linear1D(torch.nn.Module):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
@ -106,12 +105,13 @@ class Classifier1D(ParallelLayer):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer):
self.reset_parameters(weight_initializer, bias_initializer) self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
set_parallel_input(False) set_parallel_input(False)
env.vocab_parallel = False
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes fan_in, fan_out = self.in_features, self.num_classes
@ -167,6 +168,84 @@ class Classifier1D(ParallelLayer):
return output return output
@LAYERS.register_module
class VocabParallelClassifier1D(ParallelLayer):
"""ColLinear with given weight
Classifier of 1D parallelism
:param in_features: size of input features
:type in_features: int
:param num_classes: number of classes in the dataset
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs))
self.has_weight = True
if bias:
self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
env.vocab_parallel = True
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
output = F.linear(input_parallel, self.weight, self.bias)
return output
@LAYERS.register_module @LAYERS.register_module
class Linear1D_Col(ParallelLayer): class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
@ -324,7 +403,7 @@ class Linear1D_Row(ParallelLayer):
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
@ -341,45 +420,13 @@ class Linear1D_Row(ParallelLayer):
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output + self.bias if self.bias is not None:
output = output + self.bias
return output return output
else: else:
return output, self.bias return output, self.bias
@LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module):
r"""
Layer Normalization for 1D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
"""
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm1D, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape, )
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
@LAYERS.register_module @LAYERS.register_module
class Embedding1D(ParallelLayer): class Embedding1D(ParallelLayer):
""" """
@ -398,11 +445,12 @@ class Embedding1D(ParallelLayer):
:param args: Args used in F.embedding :param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding :param kwargs: Kwargs used in F.embedding
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
@ -446,6 +494,84 @@ class Embedding1D(ParallelLayer):
return output return output
@LAYERS.register_module
class VocabParallelEmbedding1D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
return output
@LAYERS.register_module @LAYERS.register_module
class Dropout1D(ParallelLayer): class Dropout1D(ParallelLayer):
""" """
@ -456,6 +582,7 @@ class Dropout1D(ParallelLayer):
:param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False`` :param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False``
:type inplace: bool, optional :type inplace: bool, optional
""" """
def __init__(self, p: float = 0.5, inplace: bool = False): def __init__(self, p: float = 0.5, inplace: bool = False):
super().__init__() super().__init__()
self.parallel_input = get_parallel_input() self.parallel_input = get_parallel_input()
@ -463,7 +590,9 @@ class Dropout1D(ParallelLayer):
self.inplace = inplace self.inplace = inplace
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR) if self.parallel_input:
with cm: with seed(ParallelMode.TENSOR):
output = F.dropout(input_, self.p, self.training, self.inplace)
else:
output = F.dropout(input_, self.p, self.training, self.inplace) output = F.dropout(input_, self.p, self.training, self.inplace)
return output return output

View File

@ -1,6 +1,8 @@
from ._operation import reduce_by_batch_2d, split_tensor_2d from ._operation import reduce_by_batch_2d, split_tensor_2d
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D,
VocabParallelEmbedding2D)
__all__ = [ __all__ = [
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D' 'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D'
] ]

View File

@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.global_variables import tensor_parallel_env as env
def matmul_2d( def matmul_2d(
@ -22,6 +23,7 @@ def matmul_2d(
): ):
""" """
Matrix multiplication for 2D parallelism Matrix multiplication for 2D parallelism
:param a: matrix :math:`A` :param a: matrix :math:`A`
:type a: torch.tensor :type a: torch.tensor
:param b: matrix :math:`B` :param b: matrix :math:`B`
@ -56,37 +58,7 @@ def matmul_2d(
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
class classifier_2d(torch.autograd.Function): class _Classifier2D(torch.autograd.Function):
"""
Classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward( def forward(
@ -150,14 +122,54 @@ class classifier_2d(torch.autograd.Function):
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
B_grad = B_grad.reshape(ctx.B_shape) B_grad = B_grad.reshape(ctx.B_shape)
bias_grad = None
if ctx.use_bias: if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
else:
bias_grad = None
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape: Tuple[int, ...],
row_rank: int, col_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
"""
2D parallel classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
return _Classifier2D.apply(A, B, bias, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode,
col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size,
tensor_parallel_size)
class Matmul_AB_2D(torch.autograd.Function): class Matmul_AB_2D(torch.autograd.Function):
""" """
Matrix multiplication for :math:`C = AB` Matrix multiplication for :math:`C = AB`
@ -230,9 +242,9 @@ class Matmul_AB_2D(torch.autograd.Function):
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2 opa = [None] * 2
opb = [None] * 2 opb = [None] * 2
@ -361,9 +373,9 @@ class Matmul_ABT_2D(torch.autograd.Function):
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2 opb = [None] * 2
opr = [None] * 2 opr = [None] * 2
@ -501,9 +513,9 @@ class Matmul_ATB_2D(torch.autograd.Function):
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2 opa = [None] * 2
opr = [None] * 2 opr = [None] * 2
@ -572,35 +584,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class add_bias_2d(torch.autograd.Function): class _Add_Bias_2D(torch.autograd.Function):
"""
Matrix add bias: :math:`C = A + b`
:param input_: matrix :math:`A`
:type input_: torch.tensor
:param bias: matrix :math:`b`
:type bias: torch.tensor
:param output_size_per_partition: size of ouput per partition
:type output_size_per_partition: int
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward( def forward(
@ -651,31 +635,47 @@ class add_bias_2d(torch.autograd.Function):
return output_grad, grad, None, None, None, None, None, None, None, None, None, None return output_grad, grad, None, None, None, None, None, None, None, None, None, None
class layernorm_2d(torch.autograd.Function): def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, row_rank: int, col_rank: int,
row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, skip_bias_add: bool,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
""" """
Layernorm Matrix add bias: :math:`C = A + b`
:param input_: input maxtrix :param input_: matrix :math:`A`
:type input_: torch.tensor :type input_: torch.tensor
:param E_x: mean :param bias: matrix :math:`b`
:type E_x: torch.tensor :type bias: torch.tensor
:param Var_x: variance :param output_size_per_partition: size of ouput per partition
:type Var_x: torch.tensor :type output_size_per_partition: int
:param hidden_size: hidden size :param row_rank: the rank of row
:type hidden_size: int :type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode :param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode :param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
return _Add_Bias_2D.apply(input_, bias, output_size_per_partition, row_rank, col_rank, row_parallel_mode,
col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank,
pipeline_parallel_size, tensor_parallel_size)
class _Layernorm_2D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
input_: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor: col_parallel_mode: ParallelMode) -> Tensor:
input_ = input_ - E_x input_ = input_ - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
@ -709,76 +709,64 @@ class layernorm_2d(torch.autograd.Function):
return input_grad, None, None, None, None, None return input_grad, None, None, None, None, None
class all_gather_weight_2d(torch.autograd.Function): def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor:
""" """
all gather the weight of 2D parallelism Layernorm
:param inputs: input maxtrix :param input_: input maxtrix
:type inputs: torch.tensor :type input_: torch.tensor
:param dim: dimension of all gather :param E_x: mean
:type dim: int :type E_x: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism :param Var_x: variance
:type summa_dim: int :type Var_x: torch.tensor
:param hidden_size: hidden size
:type hidden_size: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode :param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
return _Layernorm_2D.apply(input_, E_x, Var_x, hidden_size, row_parallel_mode, col_parallel_mode)
class _AllGatherTensor2D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
ctx.dim = dim ctx.dim = dim
ctx.summa_dim = summa_dim ctx.parallel_mode = parallel_mode
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = all_gather(inputs, dim, col_parallel_mode) outputs = all_gather(inputs, dim, parallel_mode)
return outputs return outputs
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.summa_dim, dim=ctx.dim)[ctx.row_rank] grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode)
return grad.contiguous(), None, None, None return grad.contiguous(), None, None
class SplitFirst(torch.autograd.Function): def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
""" """
All gather the tensor of 2D parallelism
:param inputs: input maxtrix :param inputs: input maxtrix
:type inputs: torch.tensor :type inputs: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism :param dim: dimension to gather
:type summa_dim: int :type dim: int
:param col_parallel_mode: column parallel mode :param parallel_mode: parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
@staticmethod return _AllGatherTensor2D.apply(tensor, dim, parallel_mode)
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(list(grad.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode))
return grad, None, None
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
"""Splits 2D tensor in specified dimension across cols """Splits 2D tensor in specified dimension across cols
:param input_: Input tensor :param input_: Input tensor
:param dim: Specified dimension in which to split :param dim: Specified dimension in which to split
:type input_: torch.Tensor :type input_: torch.Tensor
:type dim: int, optional :type dim: int, optional
:return output: Splitted tensor :return output: Splitted tensor
:rtype output: torch.Tensor :rtype output: torch.Tensor
""" """
@ -788,9 +776,50 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
class reduce_by_batch_2d(torch.autograd.Function): class _ReduceTensor2D(torch.autograd.Function):
"""All-reduce the input from the model parallel region. @staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return output_grad, None
def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
""" """
All-reduce the input.
:param input_: input tensor
:param parallel_mode: parallel mode
"""
return _ReduceTensor2D.apply(input_, parallel_mode)
class _ReduceScatterTensor2D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
ctx.parallel_mode = parallel_mode
return reduce_scatter(input_, dim, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None
def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""
Reduce-scatter the input.
:param tensor: Input tensor
:param dim: Dimension to scatter
:param parallel_mode: Parallel mode
"""
return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode)
class _ReduceByBatch2D(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(graph, input_, reduce_mean: bool = False): def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
@ -802,12 +831,6 @@ class reduce_by_batch_2d(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_, reduce_mean: bool = False): def forward(ctx, input_, reduce_mean: bool = False):
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
ctx.reduce_mean = reduce_mean ctx.reduce_mean = reduce_mean
if reduce_mean: if reduce_mean:
@ -823,3 +846,14 @@ class reduce_by_batch_2d(torch.autograd.Function):
return output_grad / ctx.reduce_size, None return output_grad / ctx.reduce_size, None
else: else:
return output_grad, None return output_grad, None
def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor:
"""All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: bool, optional
"""
return _ReduceByBatch2D.apply(input_, reduce_mean)

View File

@ -1,14 +1,11 @@
import os
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.process_group_initializer.initializer_2d import SUMMA_DIM
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
def get_summa_dim_from_env() -> int: def get_summa_dim_from_env() -> int:
try: try:
summa_dim = os.environ[SUMMA_DIM] summa_dim = env.summa_dim
summa_dim = int(summa_dim)
assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' assert summa_dim > 0, 'SUMMA_DIM must be larger than zero'
return summa_dim return summa_dim

View File

@ -7,15 +7,16 @@ import torch.nn.functional as F
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import *
from ._utils import assert_summa_initialization, get_summa_dim_from_env from ._utils import assert_summa_initialization, get_summa_dim_from_env
@ -43,7 +44,7 @@ class Linear2D(ParallelLayer):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype=None, dtype: torch.dtype = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
@ -101,16 +102,16 @@ class Linear2D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
if self.skip_bias_add: if self.skip_bias_add:
bias = add_bias_2d.apply(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.pipeline_parallel_size, self.tensor_parallel_size) self.tensor_parallel_size)
return output, bias return output, bias
else: else:
output = add_bias_2d.apply(output, self.bias, self.hidden_size_per_partition, self.row_rank, output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank,
self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False,
False, self.data_parallel_rank, self.pipeline_parallel_rank, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.pipeline_parallel_size, self.tensor_parallel_size) self.tensor_parallel_size)
return output return output
else: else:
return output return output
@ -174,16 +175,14 @@ class LayerNorm2D(ParallelLayer):
# this time 1/sqrt(Var_x + epsilon) # this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = layernorm_2d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL) ParallelMode.PARALLEL_2D_COL)
bias = add_bias_2d.apply(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank, bias = add_bias_2d(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
self.tensor_parallel_size) scale = add_bias_2d(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank,
scale = add_bias_2d.apply(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output) output = torch.addcmul(bias, scale, output)
return output return output
@ -217,8 +216,8 @@ class PatchEmbedding2D(ParallelLayer):
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()): position_embed_initializer: Callable = init.zeros_()):
@ -268,19 +267,21 @@ class PatchEmbedding2D(ParallelLayer):
position_embed_initializer(self.pos_embed) position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2d(input_)
B, C, H, W = input_.shape B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL)
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL)
output = F.conv2d(input_, weight, bias, stride=self.patch_size) output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_weight_2d.apply(self.cls_token, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL)
pos_embed = all_gather_weight_2d.apply(self.pos_embed, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(output.shape[0], -1, -1) cls_token = cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1) output = torch.cat((cls_token, output), dim=1)
output = output + pos_embed output = output + pos_embed
@ -310,7 +311,7 @@ class Embedding2D(ParallelLayer):
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
@ -347,13 +348,90 @@ class Embedding2D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) input_ = split_tensor_2d(input_)
weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output return output
@LAYERS.register_module
class VocabParallelEmbedding2D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.num_embeddings_per_partition = divide(self.num_embeddings, self.summa_dim)
self.embed_dim_per_partition = divide(self.embed_dim, self.summa_dim)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
output_parallel[input_mask, :] = 0.
output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL)
return output
@LAYERS.register_module @LAYERS.register_module
class Classifier2D(ParallelLayer): class Classifier2D(ParallelLayer):
""" """
@ -379,7 +457,7 @@ class Classifier2D(ParallelLayer):
num_classes: int, num_classes: int,
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
@ -429,7 +507,101 @@ class Classifier2D(ParallelLayer):
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, ) out_shape = input_.shape[:-1] + (self.num_classes, )
return classifier_2d.apply(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank,
self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
self.tensor_parallel_size)
@LAYERS.register_module
class VocabParallelClassifier2D(ParallelLayer):
"""
Vocab parallel classifier layer for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
# parallel setting
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.summa_dim)
self.output_size_per_partition = divide(num_classes, self.summa_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(divide(self.num_classes, self.summa_dim**2), **factory_kwargs))
else:
self.bias = None
# initialize parameters
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.output_size_per_partition, )
output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
if self.bias is not None:
output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
return output

View File

@ -1,7 +1,8 @@
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D,
VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D)
__all__ = [ __all__ = [
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', 'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D' 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D'
] ]

View File

@ -22,42 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode) return gpc.get_local_rank(parallel_mode)
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: class _Classifier2p5D(torch.autograd.Function):
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
class classifier_2p5d(torch.autograd.Function):
"""
Classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward( def forward(
@ -122,12 +87,54 @@ class classifier_2p5d(torch.autograd.Function):
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
B_grad = B_grad.reshape(ctx.B_shape) B_grad = B_grad.reshape(ctx.B_shape)
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) if ctx.use_bias:
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
else:
bias_grad = None
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int,
...], row_rank: int, col_rank: int,
row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int,
pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor:
"""
Classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode,
col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size,
tensor_parallel_size)
class Matmul_AB_2p5D(torch.autograd.Function): class Matmul_AB_2p5D(torch.autograd.Function):
""" """
Matrix multiplication for :math:`C = AB` Matrix multiplication for :math:`C = AB`
@ -522,37 +529,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2p5D(torch.autograd.Function): class _Add_Bias_2p5D(torch.autograd.Function):
"""
Matrix add bias: :math:`C = A + b`
:param input: matrix :math:`A`
:type input: torch.tensor
:param bias: matrix :math:`b`
:type bias: torch.tensor
:param output_size_per_partition: output size in each partition
:type output_size_per_partition: int
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int,
@ -621,7 +598,46 @@ class Add_Bias_2p5D(torch.autograd.Function):
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class layernorm_2p5d(torch.autograd.Function): def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int,
col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
"""
Matrix add bias: :math:`C = A + b`
:param input: matrix :math:`A`
:type input: torch.tensor
:param bias: matrix :math:`b`
:type bias: torch.tensor
:param output_size_per_partition: output size in each partition
:type output_size_per_partition: int
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank,
col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank,
pipeline_parallel_size, tensor_parallel_size)
class _Layernorm2p5D(torch.autograd.Function):
""" """
Layernorm Layernorm
@ -671,7 +687,43 @@ class layernorm_2p5d(torch.autograd.Function):
return input_grad, None, None, None, None, None, None return input_grad, None, None, None, None, None, None
class all_gather_weight_2p5d(torch.autograd.Function): def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
row_parallel_mode: ParallelMode) -> Tensor:
"""
Layernorm
:param input: input maxtrix
:type input: torch.tensor
:param E_x: mean
:type E_x: torch.tensor
:param Var_x: variance
:type Var_x: torch.tensor
:param hidden_size: hidden size
:type hidden_size: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
return _Layernorm2p5D.apply(input, E_x, Var_x, hidden_size, row_parallel_mode)
class _AllGatherTensor2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.dim = dim
ctx.col_parallel_mode = col_parallel_mode
outputs = all_gather(inputs, dim, col_parallel_mode)
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = reduce_scatter(output_grad, ctx.dim, ctx.col_parallel_mode)
return grad.contiguous(), None, None
def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor:
""" """
all gather the weight of 2.5D parallelism all gather the weight of 2.5D parallelism
@ -684,21 +736,7 @@ class all_gather_weight_2p5d(torch.autograd.Function):
:param col_parallel_mode: column parallel mode :param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
@staticmethod return _AllGatherTensor2p5D.apply(inputs, dim, col_parallel_mode)
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.dim = dim
ctx.tesseract_dim = tesseract_dim
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = all_gather(inputs, dim, col_parallel_mode)
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.tesseract_dim, dim=ctx.dim)[ctx.row_rank]
return grad.contiguous(), None, None, None
class SplitFirst(torch.autograd.Function): class SplitFirst(torch.autograd.Function):
@ -737,10 +775,10 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
:param input_: Input tensor :param input_: Input tensor
:param dim: Specified dimension in which to split :param dim: Specified dimension in which to split
:type input_: torch.Tensor :type input_: torch.Tensor
:type dim: int, optional :type dim: int, optional
:return output: Splitted tensor :return output: Splitted tensor
:rtype output: torch.Tensor :rtype output: torch.Tensor
""" """
@ -750,9 +788,49 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
class reduce_by_batch_2p5d(torch.autograd.Function): class _ReduceTensor2p5D(torch.autograd.Function):
"""All-reduce the input from the model parallel region. @staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return output_grad, None
def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
""" """
All-reduce the input.
:param input_: input tensor
:param parallel_mode: parallel mode
"""
return _ReduceTensor2p5D.apply(input_, parallel_mode)
class _ReduceScatterTensor2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
ctx.parallel_mode = parallel_mode
return reduce_scatter(input_, dim, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None
def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""
Reduce-scatter the input.
:param input_: input tensor
:param parallel_mode: parallel mode
"""
return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode)
class _RreduceByBatch2p5D(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(graph, input_, reduce_mean: bool = False): def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
@ -764,12 +842,6 @@ class reduce_by_batch_2p5d(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_, reduce_mean: bool = False): def forward(ctx, input_, reduce_mean: bool = False):
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_mean = reduce_mean ctx.reduce_mean = reduce_mean
if reduce_mean: if reduce_mean:
@ -785,3 +857,15 @@ class reduce_by_batch_2p5d(torch.autograd.Function):
return output_grad / ctx.reduce_size, None return output_grad / ctx.reduce_size, None
else: else:
return output_grad, None return output_grad, None
def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor:
"""
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: bool, optional
"""
return _RreduceByBatch2p5D.apply(input_, reduce_mean)

View File

@ -1,13 +1,12 @@
import os
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
def get_tesseract_dim_dep_from_env(): def get_tesseract_dim_dep_from_env():
try: try:
tesseract_dim = int(os.environ['TESSERACT_DIM']) tesseract_dim = env.tesseract_dim
tesseract_dep = int(os.environ['TESSERACT_DEP']) tesseract_dep = env.tesseract_dep
assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero'
assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero'
return tesseract_dim, tesseract_dep return tesseract_dim, tesseract_dep

View File

@ -7,16 +7,18 @@ import torch.nn.functional as F
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d) from ._operation import (add_bias_2p5d, Matmul_AB_2p5D, Matmul_ABT_2p5D, all_gather_tensor_2p5d, classifier_2p5d,
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env) layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d)
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
@LAYERS.register_module @LAYERS.register_module
@ -41,7 +43,7 @@ class Linear2p5D(ParallelLayer):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
@ -112,17 +114,16 @@ class Linear2p5D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
if self.skip_bias_add: if self.skip_bias_add:
bias = Add_Bias_2p5D.apply(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank,
self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
True, self.data_parallel_rank, self.pipeline_parallel_rank, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.pipeline_parallel_size, self.tensor_parallel_size) self.tensor_parallel_size)
return output, bias return output, bias
else: else:
output = Add_Bias_2p5D.apply(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank, self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_COL, False, self.data_parallel_rank, False, self.data_parallel_rank, self.pipeline_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.pipeline_parallel_size, self.tensor_parallel_size)
self.tensor_parallel_size)
return output return output
else: else:
return output return output
@ -187,15 +188,15 @@ class LayerNorm2p5D(ParallelLayer):
# this time 1/sqrt(Var_x + epsilon) # this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = layernorm_2p5d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
bias = Add_Bias_2p5D.apply(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank, bias = add_bias_2p5d(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size) self.tensor_parallel_size)
scale = Add_Bias_2p5D.apply(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank, scale = add_bias_2p5d(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size) self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output) output = torch.addcmul(bias, scale, output)
return output return output
@ -229,8 +230,8 @@ class PatchEmbedding2p5D(ParallelLayer):
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()): position_embed_initializer: Callable = init.zeros_()):
@ -280,19 +281,21 @@ class PatchEmbedding2p5D(ParallelLayer):
position_embed_initializer(self.pos_embed) position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2p5d(input_, 0)
B, C, H, W = input_.shape B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL)
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL)
output = F.conv2d(input_, weight, bias, stride=self.patch_size) output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_weight_2p5d.apply(self.cls_token, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL)
pos_embed = all_gather_weight_2p5d.apply(self.pos_embed, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL)
cls_token = cls_token.expand(output.shape[0], -1, -1) cls_token = cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1) output = torch.cat((cls_token, output), dim=1)
output = output + pos_embed output = output + pos_embed
@ -322,7 +325,7 @@ class Embedding2p5D(ParallelLayer):
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
@ -359,13 +362,95 @@ class Embedding2p5D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) input_ = split_tensor_2p5d(input_, 0)
weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output return output
@LAYERS.register_module
class VocabParallelEmbedding2p5D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.num_embeddings_per_partition = divide(self.num_embeddings, self.tesseract_dim)
self.embed_dim_per_partition = divide(self.embed_dim, self.tesseract_dim)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL)
return output
@LAYERS.register_module @LAYERS.register_module
class Classifier2p5D(ParallelLayer): class Classifier2p5D(ParallelLayer):
""" """
@ -391,7 +476,7 @@ class Classifier2p5D(ParallelLayer):
num_classes: int, num_classes: int,
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
@ -442,7 +527,114 @@ class Classifier2p5D(ParallelLayer):
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, ) out_shape = input_.shape[:-1] + (self.num_classes, )
return classifier_2p5d.apply(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size) self.tensor_parallel_size)
@LAYERS.register_module
class VocabParallelClassifier2p5D(ParallelLayer):
"""
Vocab parallel classifier layer for 2.5D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
# parallel setting
assert_tesseract_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs))
else:
self.bias = None
# initialize parameters
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
output = Matmul_ABT_2p5D.apply(
x,
self.weight,
self.tesseract_dim,
out_shape,
self.row_rank,
self.col_rank,
self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
if self.bias is not None:
output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
return output

View File

@ -1,6 +1,8 @@
from ._operation import reduce_by_batch_3d, split_tensor_3d from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D,
VocabParallelEmbedding3D)
__all__ = [ __all__ = [
'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D' 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D',
'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D'
] ]

View File

@ -4,36 +4,20 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from colossalai.communication import all_gather, all_reduce, reduce_scatter, broadcast, reduce from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
from colossalai.context import parallel_mode
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.base_layer import ParallelLayer
class linear_3d(torch.autograd.Function): class _Linear3D(torch.autograd.Function):
"""
Linear layer for 3D parallelism
:param input_: matrix of input
:type input_: torch.tensor
:param weight: matrix of weight
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param input_dim: dimension of input, defaults to 0
:type input_dim: int, optional
:param weight_dim: dimension of weight, defaults to -1
:type weight_dim: int, optional
:param output_dim: dimension of output, defaults to 0
:type output_dim: int, optional
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, def forward(ctx,
@ -87,6 +71,8 @@ class linear_3d(torch.autograd.Function):
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op) async_ops.append(op)
else:
bias_grad = None
for op in async_ops: for op in async_ops:
if op is not None: if op is not None:
@ -95,9 +81,17 @@ class linear_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
class classifier_3d(torch.autograd.Function): def linear_3d(input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
""" """
Classifier Linear layer for 3D parallelism
:param input_: matrix of input :param input_: matrix of input
:type input_: torch.tensor :type input_: torch.tensor
@ -111,7 +105,19 @@ class classifier_3d(torch.autograd.Function):
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode :param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param input_dim: dimension of input, defaults to 0
:type input_dim: int, optional
:param weight_dim: dimension of weight, defaults to -1
:type weight_dim: int, optional
:param output_dim: dimension of output, defaults to 0
:type output_dim: int, optional
""" """
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode,
input_dim, weight_dim, output_dim)
class _Classifier3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
@ -156,6 +162,8 @@ class classifier_3d(torch.autograd.Function):
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op) async_ops.append(op)
else:
bias_grad = None
input_grad = torch.matmul(output_grad, weight) input_grad = torch.matmul(output_grad, weight)
@ -166,23 +174,17 @@ class classifier_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
class layernorm_3d(torch.autograd.Function): def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
""" """
Layernorm 3D parallel classifier
:param input_: input maxtrix :param input_: matrix of input
:type input_: torch.tensor :type input_: torch.tensor
:param weight: matrix of weight :param weight: matrix of weight
:type weight: torch.tensor :type weight: torch.tensor
:param bias: matrix of bias :param bias: matrix of bias
:type bias: torch.tensor :type bias: torch.tensor, optional
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability
:type eps: float
:param input_parallel_mode: input parallel mode :param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode :param weight_parallel_mode: weight parallel mode
@ -190,6 +192,11 @@ class layernorm_3d(torch.autograd.Function):
:param output_parallel_mode: output parallel mode :param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode)
class _Layernorm3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
@ -236,27 +243,78 @@ class layernorm_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None
def split_tensor_3d(input_: Tensor, def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
dim: int = 0, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, output_parallel_mode: ParallelMode) -> Tensor:
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: """
"""Splits 3D tensor in specified dimension 3D parallel Layernorm
:param input_: input maxtrix
:type input_: torch.tensor
:param weight: matrix of weight
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability
:type eps: float
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""Splits 3D parallel tensor in specified dimension
:param tensor: Input tensor
:param dim: Specified dimension in which to split
:param parallel_mode: Parallel mode
:param weight_parallel_mode: Weight parallel mode
:type tensor: torch.Tensor
:type dim: int
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode
:return output: Splitted tensor
:rtype output: torch.Tensor
"""
if tensor.size(dim) <= 1:
return tensor
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous()
return output
def split_batch_3d(input_: Tensor,
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
"""Splits 3D tensor in batch
:param input_: Input tensor :param input_: Input tensor
:param dim: Specified dimension in which to split :param dim: Specified dimension in which to split
:param input_parallel_mode: Input parallel mode :param input_parallel_mode: Input parallel mode
:param weight_parallel_mode: Weight parallel mode :param weight_parallel_mode: Weight parallel mode
:type input_: torch.Tensor :type input_: torch.Tensor
:type dim: int, optional :type dim: int, optional
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
:return output: Splitted tensor :return output: Splitted tensor
:rtype output: torch.Tensor :rtype output: torch.Tensor
""" """
if input_.size(dim) <= 1: if input_.size(dim) <= 1:
return input_ return input_
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode), output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode), output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
@ -264,9 +322,77 @@ def split_tensor_3d(input_: Tensor,
return output return output
class reduce_by_batch_3d(torch.autograd.Function): class _ReduceTensor3D(torch.autograd.Function):
"""All-reduce the input from the model parallel region.
@staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return output_grad, None
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
""" """
All-reduce the input.
:param tensor: Input tensor
:param parallel_mode: Parallel mode
"""
return _ReduceTensor3D.apply(tensor, parallel_mode)
class _ReduceGrad3D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, parallel_mode):
ctx.parallel_mode = parallel_mode
return input_
@staticmethod
def backward(ctx, output_grad):
input_grad = all_reduce(output_grad, ctx.parallel_mode)
return input_grad, None
def reduce_grad_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
"""
All-reduce the gradient in backward pass.
:param tensor: Input tensor
:param parallel_mode: Parallel mode
"""
return _ReduceGrad3D.apply(tensor, parallel_mode)
class _ReduceScatterTensor3D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
ctx.parallel_mode = parallel_mode
return reduce_scatter(input_, dim, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
return input_grad, None, None
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""
Reduce-scatter the input.
:param tensor: Input tensor
:param dim: Dimension to scatter
:param parallel_mode: Parallel mode
"""
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)
class _ReduceByBatch3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, def forward(ctx,
@ -274,16 +400,6 @@ class reduce_by_batch_3d(torch.autograd.Function):
input_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor: reduce_mean: bool = False) -> Tensor:
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, input_parallel_mode) output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode) output = all_reduce(output, weight_parallel_mode)
ctx.reduce_mean = reduce_mean ctx.reduce_mean = reduce_mean
@ -302,7 +418,26 @@ class reduce_by_batch_3d(torch.autograd.Function):
return output_grad, None, None, None return output_grad, None, None, None
class broadcast_weight_3d_from_diagonal(torch.autograd.Function): def reduce_by_batch_3d(tensor: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
"""
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:type reduce_mean: int, optional
"""
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
""" """
broadcast weight from diagonal broadcast weight from diagonal
@ -315,6 +450,7 @@ class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
:param weight_parallel_mode: output parallel mode :param weight_parallel_mode: output parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
@ -337,3 +473,9 @@ class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
else: else:
input_grad = None input_grad = None
return input_grad, None, None, None return input_grad, None, None, None
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)

View File

@ -1,31 +1,25 @@
#!/usr/bin/env python from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
# -*- encoding: utf-8 -*-
import os
from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from torch import Tensor from torch import Tensor
def get_depth_from_env() -> int: def get_depth_from_env() -> int:
try: try:
depth = os.environ[DEPTH_3D] depth = env.depth_3d
depth = int(depth)
assert depth > 0, 'DEPTH must be greater than zero' assert depth > 0, 'DEPTH must be greater than zero'
return depth return depth
except KeyError as e: except KeyError as e:
raise EnvironmentError( raise EnvironmentError('DEPTH is not found in the current environment, '
'DEPTH is not found in the current environment, ' 'please make sure that you have used the correct process group initializer')
'please make sure that you have used the correct process group initializer'
)
def get_parallel_mode_from_env(group): def get_parallel_mode_from_env(group):
return getattr(ParallelMode, os.environ[group]) assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \
f'{group} is not valid for 3D tensor parallelism.'
return getattr(env, group)
def get_last_group(a, b): def get_last_group(a, b):
@ -35,8 +29,7 @@ def get_last_group(a, b):
ParallelMode.PARALLEL_3D_OUTPUT: 'C', ParallelMode.PARALLEL_3D_OUTPUT: 'C',
} }
res = chr( res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
if res == 'A': if res == 'A':
return ParallelMode.PARALLEL_3D_INPUT return ParallelMode.PARALLEL_3D_INPUT
@ -47,8 +40,7 @@ def get_last_group(a, b):
def swap_in_out_group(): def swap_in_out_group():
os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \ env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d
os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D]
def dbg_check_shape(tensor: Tensor, shape: tuple): def dbg_check_shape(tensor: Tensor, shape: tuple):

View File

@ -1,5 +1,3 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math import math
from typing import Callable from typing import Callable
@ -10,11 +8,12 @@ from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
@ -37,7 +36,8 @@ class LayerNorm3D(ParallelLayer):
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
""" """
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None):
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@ -62,8 +62,8 @@ class LayerNorm3D(ParallelLayer):
init.ones_()(self.weight) init.ones_()(self.weight)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return layernorm_3d.apply(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
@LAYERS.register_module @LAYERS.register_module
@ -84,11 +84,12 @@ class Linear3D(ParallelLayer):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
@ -136,8 +137,8 @@ class Linear3D(ParallelLayer):
broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return linear_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode) self.output_parallel_mode)
@LAYERS.register_module @LAYERS.register_module
@ -160,12 +161,13 @@ class Classifier3D(ParallelLayer):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
@ -214,8 +216,94 @@ class Classifier3D(ParallelLayer):
broadcast(self.bias, input_src_rank, self.input_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return classifier_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode) self.output_parallel_mode)
@LAYERS.register_module
class VocabParallelClassifier3D(ParallelLayer):
"""
Vocab parallel classifier layer for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(num_classes, self.depth)
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.out_features_per_partition,
self.in_features_per_partition,
device=get_current_device(),
dtype=dtype))
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(),
dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
swap_in_out_group()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self) -> None:
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
@LAYERS.register_module @LAYERS.register_module
@ -242,13 +330,14 @@ class PatchEmbedding3D(ParallelLayer):
:param position_embed_initializer: The intializer of position embedding, defaults to zero :param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional :type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
img_size: int, img_size: int,
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()): position_embed_initializer: Callable = init.zeros_()):
@ -284,8 +373,8 @@ class PatchEmbedding3D(ParallelLayer):
set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth) set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
def _sync_grad_hook(self, grad) -> None: def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad, self.input_parallel_mode) grad = all_reduce(grad.clone(), self.input_parallel_mode)
grad = all_reduce(grad, self.weight_parallel_mode) grad = all_reduce(grad, self.weight_parallel_mode)
return grad return grad
@ -302,17 +391,19 @@ class PatchEmbedding3D(ParallelLayer):
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode) broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode) broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
self.bias.register_hook(self._sync_grad_hook) self.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook) self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
self.weight_parallel_mode, self.output_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
@ -341,11 +432,12 @@ class Embedding3D(ParallelLayer):
:param args: Args used in F.embedding :param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding :param kwargs: Kwargs used in F.embedding
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
@ -385,8 +477,95 @@ class Embedding3D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
self.weight_parallel_mode, self.output_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output return output
@LAYERS.register_module
class VocabParallelEmbedding3D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth)
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
weight = reduce_grad_3d(self.weight, self.weight_parallel_mode)
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output_parallel[input_mask, :] = 0.
output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode)
return output

View File

@ -2,12 +2,12 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import collections.abc import collections.abc
import os
from itertools import repeat from itertools import repeat
import numpy as np import numpy as np
import torch import torch
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE) from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from torch import Tensor, nn from torch import Tensor, nn
@ -38,7 +38,7 @@ class CheckpointModule(nn.Module):
def divide(numerator, denominator): def divide(numerator, denominator):
"""Only allow exact division """Only allow exact division
:param numerator: Numerator of the division :param numerator: Numerator of the division
:param denominator: Denominator of the division :param denominator: Denominator of the division
""" """
@ -65,7 +65,7 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
def get_tensor_parallel_mode(): def get_tensor_parallel_mode():
return os.environ[TENSOR_PARALLEL_MODE] return env.mode
# From PyTorch internals # From PyTorch internals

View File

@ -3,14 +3,14 @@ from typing import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.context import seed
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch import nn as nn from torch import nn as nn
from ..utils import to_2tuple from ..utils import to_2tuple
from colossalai.context import seed
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):
@ -36,6 +36,7 @@ class DropPath(nn.Module):
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
""" """
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
@ -47,6 +48,7 @@ class DropPath(nn.Module):
class WrappedDropout(nn.Module): class WrappedDropout(nn.Module):
"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. """Same as torch.nn.Dropout. But it is wrapped with the context of seed manager.
""" """
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
super().__init__() super().__init__()
if p < 0 or p > 1: if p < 0 or p > 1:
@ -75,6 +77,7 @@ class WrappedDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager. Here, it is wrapped with the context of seed manager.
""" """
def __init__(self, p: float = 0., mode=None): def __init__(self, p: float = 0., mode=None):
super().__init__() super().__init__()
self.p = p self.p = p
@ -120,13 +123,14 @@ class VanillaPatchEmbedding(nn.Module):
:param position_embed_initializer: The intializer of position embedding, defaults to zero :param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional :type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
img_size: int, img_size: int,
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()): position_embed_initializer: Callable = init.zeros_()):
@ -142,8 +146,9 @@ class VanillaPatchEmbedding(nn.Module):
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_size)) self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_size)) self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
@ -170,7 +175,7 @@ class VanillaPatchEmbedding(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class VanillaClassifier(nn.Module): class VanillaClassifier(nn.Module):
""" """
Classifier for ViT Dense linear classifier
:param in_features: size of each input sample :param in_features: size of each input sample
:type in_features: int :type in_features: int
@ -187,12 +192,13 @@ class VanillaClassifier(nn.Module):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
weight: nn.Parameter = None, weight: nn.Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()

View File

@ -1,25 +1,37 @@
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn.layer.utils import get_tensor_parallel_mode
from torch import nn from torch import nn
from torch.nn.modules.loss import * from torch.nn.modules.loss import *
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.nn.layer.utils import get_tensor_parallel_mode from .loss_1d import VocabParallelCrossEntropyLoss1D
from .loss_2d import CrossEntropyLoss2D from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
from .loss_2p5d import CrossEntropyLoss2p5D from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
from .loss_3d import CrossEntropyLoss3D from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
from .loss_moe import MoeCrossEntropyLoss, MoeLoss from .loss_moe import MoeCrossEntropyLoss, MoeLoss
_parallel_cross_entropy = { _parallel_cross_entropy = {
'2d': CrossEntropyLoss2D, '2d': CrossEntropyLoss2D,
'2.5d': CrossEntropyLoss2p5D, '2.5d': CrossEntropyLoss2p5D,
'3d': CrossEntropyLoss3D '3d': CrossEntropyLoss3D,
}
_vocab_parallel_cross_entropy = {
'1d': VocabParallelCrossEntropyLoss1D,
'2d': VocabParallelCrossEntropyLoss2D,
'2.5d': VocabParallelCrossEntropyLoss2p5D,
'3d': VocabParallelCrossEntropyLoss3D,
} }
class CrossEntropyLoss(_Loss): class CrossEntropyLoss(_Loss):
def __init__(self, reduction: bool = True, *args, **kwargs): def __init__(self, reduction: bool = True, *args, **kwargs):
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']: if tensor_parallel is not None and env.vocab_parallel:
self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
elif tensor_parallel is None or tensor_parallel == '1d':
reduction = 'mean' if reduction else 'none' reduction = 'mean' if reduction else 'none'
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
else: else:

View File

@ -0,0 +1,110 @@
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LOSSES
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.modules.loss import _Loss
class _VocabParallelCrossEntropy1D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, vocab_parallel_logits, targets):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
vocab_start_index = partition_vocab_size * rank
vocab_end_index = vocab_start_index + partition_vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
masked_target = targets.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(targets)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss1D(_Loss):
"""
Vocab parallel cross entropy loss for 1D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
loss = _VocabParallelCrossEntropy1D.apply(logits, targets)
if self.reduction_mean:
loss = loss.mean()
return loss

View File

@ -1,6 +1,12 @@
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d import torch
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@ -16,6 +22,7 @@ class CrossEntropyLoss2D(_Loss):
:type reduction: bool, optional :type reduction: bool, optional
""" """
def __init__(self, reduction=True, *args, **kwargs): def __init__(self, reduction=True, *args, **kwargs):
super().__init__() super().__init__()
assert_summa_initialization() assert_summa_initialization()
@ -29,8 +36,110 @@ class CrossEntropyLoss2D(_Loss):
:param logits: Output logits of model :param logits: Output logits of model
:param targets: True targets from data :param targets: True targets from data
""" """
targets = split_tensor_2d(targets)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.mean() loss = loss.mean()
loss = reduce_by_batch_2d.apply(loss, True) loss = reduce_by_batch_2d(loss, True)
return loss
class _VocabParallelCrossEntropy2D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets):
# logits: [b/q, h/q]
# labels: [b/q]
# loss: [b/q]
# vocab_parallel_logits: [b/q, s, v/q]
# target: [b/q, s]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
# Subtract the maximum value.
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size = logits.size(-1)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
vocab_start = rank * (vocab_size)
vocab_end = (rank + 1) * (vocab_size) - 1
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(
start=0,
end=logits.size()[0],
)
predicted_logits = logits[arange_1d, masked_target]
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(output_grad.unsqueeze(dim=-1))
return grad_input, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss2D(_Loss):
"""
Vocab parallel cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
targets = split_tensor_2d(targets)
loss = _VocabParallelCrossEntropy2D.apply(
logits,
targets,
)
if self.reduction_mean:
loss = loss.mean()
loss = reduce_by_batch_2d(loss, True)
return loss return loss

View File

@ -1,6 +1,12 @@
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d import torch
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@ -9,7 +15,7 @@ from torch.nn.modules.loss import _Loss
class CrossEntropyLoss2p5D(_Loss): class CrossEntropyLoss2p5D(_Loss):
""" """
Cross entropy loss for 2.5D parallelism Cross entropy loss for 2.5D parallelism
:param reduction: whether to average the loss, defaults to True :param reduction: whether to average the loss, defaults to True
:param args: Args for loss function :param args: Args for loss function
:param kwargs: Kwargs for loss function :param kwargs: Kwargs for loss function
@ -29,8 +35,104 @@ class CrossEntropyLoss2p5D(_Loss):
:param logits: Output logits of model :param logits: Output logits of model
:param targets: True targets from data :param targets: True targets from data
""" """
targets = split_tensor_2p5d(targets)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.mean() loss = loss.mean()
loss = reduce_by_batch_2p5d.apply(loss, True) loss = reduce_by_batch_2p5d(loss, True)
return loss
class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets):
# logits: [b/dq, h/q]
# loss: [b/dq]
# targets: [b/dq, h/q]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
# Subtract the maximum value.
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size = logits.size(-1)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
vocab_start = rank * (vocab_size)
vocab_end = (rank + 1) * (vocab_size) - 1
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(
start=0,
end=logits.size()[0],
)
predicted_logits = logits[arange_1d, masked_target]
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(output_grad.unsqueeze(dim=-1))
return grad_input, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss2p5D(_Loss):
"""
Vocab parallel cross entropy loss for 2.5D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
targets = split_tensor_2p5d(targets)
loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets)
if self.reduction_mean:
loss = loss.mean()
loss = reduce_by_batch_2p5d(loss, True)
return loss return loss

View File

@ -1,23 +1,28 @@
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D import torch
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d import torch.distributed as dist
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss3D(_Loss): class CrossEntropyLoss3D(_Loss):
""" """
Cross entropy loss for 3D parallelism Cross entropy loss for 3D parallelism
:param depth: depth for 3D parallelism
:type depth: int
:param reduction: whether to average the loss, defaults to True :param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
:param args: Args for loss function :param args: Args for loss function
:param kwargs: Kwargs for loss function :param kwargs: Kwargs for loss function
:type reduction: bool, optional
""" """
def __init__(self, reduction=True, *args, **kwargs): def __init__(self, reduction=True, *args, **kwargs):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@ -32,8 +37,103 @@ class CrossEntropyLoss3D(_Loss):
:param logits: Output logits of model :param logits: Output logits of model
:param targets: True targets from data :param targets: True targets from data
""" """
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.mean() loss = loss.mean()
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True) loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
return loss
class _VocabParallelCrossEntropy3D(torch.autograd.Function):
# Adapted from megatron.mpu.cross_entropy
# loss[i] = -logits[i][targets] + log(sum(exp(logits[i])))
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets, output_parallel_mode):
# logits: [b/q^2, c/q]
# labels: [b/q^2]
# loss: [b/q^2]
logits_max = torch.max(logits, dim=-1)[0]
dist.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(output_parallel_mode))
# Subtract the maximum value.
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size_per_partition = logits.size()[-1]
rank = gpc.get_local_rank(output_parallel_mode)
vocab_start = rank * vocab_size_per_partition
vocab_end = (rank + 1) * vocab_size_per_partition - 1
# loss[i] = 0 if targets[i] < vocab_start or targets[i] > vocab_end
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device())
predicted_logits = logits[arange_1d, masked_target]
predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode))
# Loss = log(sum(exp(logits))) - predicted-logit.
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(output_parallel_mode))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
input_grad = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = input_grad.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
input_grad.mul_(output_grad.unsqueeze(dim=-1))
return input_grad, None, None, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss3D(_Loss):
"""
Vocab parallel cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
loss = _VocabParallelCrossEntropy3D.apply(logits, targets, self.output_parallel_mode)
if self.reduction_mean:
loss = loss.mean()
loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
return loss return loss

View File

@ -17,7 +17,7 @@ class Accuracy(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']: if tensor_parallel not in _parallel_accuracy:
self.acc = calc_acc self.acc = calc_acc
else: else:
self.acc = _parallel_accuracy[tensor_parallel]() self.acc = _parallel_accuracy[tensor_parallel]()

View File

@ -1,5 +1,5 @@
import torch import torch
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
from torch import nn from torch import nn
from ._utils import calc_acc from ._utils import calc_acc
@ -18,6 +18,7 @@ class Accuracy2D(nn.Module):
:param targets: True labels from data :param targets: True labels from data
""" """
with torch.no_grad(): with torch.no_grad():
targets = split_tensor_2d(targets)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_2d.apply(correct) correct = reduce_by_batch_2d(correct)
return correct return correct

View File

@ -18,6 +18,7 @@ class Accuracy2p5D(nn.Module):
:param targets: True labels from data :param targets: True labels from data
""" """
with torch.no_grad(): with torch.no_grad():
targets = split_tensor_2p5d(targets)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_2p5d.apply(correct) correct = reduce_by_batch_2p5d(correct)
return correct return correct

View File

@ -1,6 +1,6 @@
import torch import torch
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from torch import nn from torch import nn
@ -22,6 +22,8 @@ class Accuracy3D(nn.Module):
:param targets: True labels from data :param targets: True labels from data
""" """
with torch.no_grad(): with torch.no_grad():
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode) correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode)
return correct return correct

View File

@ -224,7 +224,7 @@ class LogTimingByEpochHook(LogByEpochHook):
super().__init__(logger=logger, interval=interval, priority=priority) super().__init__(logger=logger, interval=interval, priority=priority)
self._timer = timer self._timer = timer
self._log_eval = log_eval self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
# extra handling to avoid the unstable readings of the first # extra handling to avoid the unstable readings of the first
# few training steps to affect the history mean time # few training steps to affect the history mean time
@ -256,7 +256,7 @@ class LogTimingByEpochHook(LogByEpochHook):
""" """
if self._is_epoch_to_log(trainer) and self._is_rank_to_log: if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
msg = self._get_message('Train') msg = self._get_message('Train')
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}, #steps/epoch = {trainer.steps_per_epoch}') self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}')
def after_test_epoch(self, trainer): def after_test_epoch(self, trainer):
"""Writes log after finishing a testing epoch. """Writes log after finishing a testing epoch.

View File

@ -317,24 +317,29 @@ class ThroughputMetric(Metric):
:param epoch_only: epoch only :param epoch_only: epoch only
:type epoch_only: bool :type epoch_only: bool
""" """
def __init__(self, epoch_only: bool): def __init__(self, epoch_only: bool, ignored_steps: int = 0):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps
self.cur_steps = 0
self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device()) self.accumulated_used_time = torch.zeros(1, device=get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device()) self.last_step_num_samples = torch.zeros(1, device=get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device()) self.last_step_used_time = torch.zeros(1, device=get_current_device())
def reset(self) -> None: def reset(self) -> None:
# self.cur_steps = 0
self.accumulated_num_samples.zero_() self.accumulated_num_samples.zero_()
self.accumulated_used_time.zero_() self.accumulated_used_time.zero_()
self.last_step_num_samples.zero_() self.last_step_num_samples.zero_()
self.last_step_used_time.zero_() self.last_step_used_time.zero_()
def update(self, num_samples, time) -> None: def update(self, num_samples, time) -> None:
self.cur_steps += 1
self.last_step_num_samples.fill_(num_samples) self.last_step_num_samples.fill_(num_samples)
self.last_step_used_time.fill_(time) self.last_step_used_time.fill_(time)
self.accumulated_num_samples += self.last_step_num_samples if self.cur_steps >= self.ignored_steps:
self.accumulated_used_time += self.last_step_used_time self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time
def get_last_step_value(self): def get_last_step_value(self):
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
@ -360,13 +365,14 @@ class ThroughputHook(MetricHook):
:param priority: priority of throughput hook, defaults to 10 :param priority: priority of throughput hook, defaults to 10
:type priority: int, optional :type priority: int, optional
""" """
def __init__(self, priority: int = 10): def __init__(self, ignored_steps: int = 0, priority: int = 10):
super().__init__(priority) super().__init__(priority)
self.ignored_steps = ignored_steps
def after_hook_is_attached(self, trainer): def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer) self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute: if self._is_stage_to_compute:
self.metric = ThroughputMetric(epoch_only=True) self.metric = ThroughputMetric(epoch_only=True, ignored_steps=self.ignored_steps)
# register the metric # register the metric
trainer.states['metrics']['train']['Throughput'] = self.metric trainer.states['metrics']['train']['Throughput'] = self.metric

View File

@ -1,8 +1,9 @@
from .activation_checkpoint import checkpoint from .activation_checkpoint import checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, is_using_ddp, is_using_pp, is_using_sequence, model_branch_context, multi_tensor_applier,
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param) param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient from .gradient_accumulation import accumulate_gradient
@ -11,9 +12,9 @@ from .timer import MultiTimer, Timer
__all__ = [ __all__ = [
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context', 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'model_branch_context',
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'conditional_context', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank' 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
] ]

View File

@ -6,8 +6,6 @@ import socket
import torch import torch
from torch._six import inf from torch._six import inf
import colossalai.context.parallel_mode
try: try:
import colossal_C import colossal_C
except: except:
@ -20,6 +18,7 @@ from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARA
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
@ -62,8 +61,7 @@ def sync_model_param(model, parallel_mode):
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters(): for param in model.parameters():
ranks = gpc.get_ranks_in_group(parallel_mode) ranks = gpc.get_ranks_in_group(parallel_mode)
dist.broadcast( dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
param, src=ranks[0], group=gpc.get_group(parallel_mode))
def is_dp_rank_0(): def is_dp_rank_0():
@ -99,6 +97,15 @@ def conditional_context(context_manager, enable=True):
yield yield
class model_branch_context(object):
def __enter__(self):
self.env_status = env.save()
def __exit__(self, *exc_info):
env.load(**self.env_status)
def is_model_parallel_parameter(p): def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
@ -124,9 +131,10 @@ def _calc_lp(grads, norm_type):
norm = 0.0 norm = 0.0
for grad in grads: for grad in grads:
grad_norm = torch.norm(grad, norm_type) grad_norm = torch.norm(grad, norm_type)
norm += grad_norm ** norm_type norm += grad_norm**norm_type
return norm return norm
# ======== Gradient Clipping ========= # ======== Gradient Clipping =========
@ -183,7 +191,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
moe_parallel_grads = [] # used to collect moe tensor parallel gradients moe_parallel_grads = [] # used to collect moe tensor parallel gradients
for p in params: for p in params:
if is_model_parallel_parameter(p): if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor) tensor_parallel_grads.append(p.grad.data / reductor)
elif is_moe_parallel_parameter(p): elif is_moe_parallel_parameter(p):
moe_parallel_grads.append(p.grad.data) moe_parallel_grads.append(p.grad.data)
@ -191,32 +199,24 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
no_tensor_parallel_grads.append(p.grad.data) no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0: if norm_type == 2.0:
tensor_parallel_norm = _calc_l2_norm( tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
tensor_parallel_grads) ** norm_type no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
no_tensor_parallel_norm = _calc_l2_norm( moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
no_tensor_parallel_grads) ** norm_type
moe_parallel_norm = _calc_l2_norm(
moe_parallel_grads) ** norm_type
else: else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp( no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
no_tensor_parallel_grads, norm_type)
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type) moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
dist.all_reduce(tensor_parallel_norm, dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR))
# Sum across all moe-tensor-parallel GPUs # Sum across all moe-tensor-parallel GPUs
if len(moe_parallel_grads) > 0: if len(moe_parallel_grads) > 0:
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL)) dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
no_tensor_parallel_norm += moe_parallel_norm no_tensor_parallel_norm += moe_parallel_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm, dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
op=dist.ReduceOp.SUM, total_norm = total_norm**(1.0 / norm_type)
group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm ** (1.0 / norm_type)
if type(total_norm) == 'torch.cuda.FloatTensor': if type(total_norm) == 'torch.cuda.FloatTensor':
total_norm = total_norm.item() total_norm = total_norm.item()
@ -225,10 +225,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if clip_coeff < 1.0: if clip_coeff < 1.0:
grads = [p.grad.detach() for p in params] grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
dummy_overflow_buf,
[grads, grads],
clip_coeff)
return total_norm return total_norm
@ -254,15 +251,14 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
ops = [] ops = []
ops.append(dist.all_reduce(total_num_zeros, ops.append(
op=dist.ReduceOp.SUM, dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True))
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True))
if gpc.is_initialized(ParallelMode.PIPELINE): if gpc.is_initialized(ParallelMode.PIPELINE):
ops.append(dist.all_reduce(total_num_zeros, ops.append(
op=dist.ReduceOp.SUM, dist.all_reduce(total_num_zeros,
group=gpc.get_group(ParallelMode.PIPELINE), op=dist.ReduceOp.SUM,
async_op=True)) group=gpc.get_group(ParallelMode.PIPELINE),
async_op=True))
for req in ops: for req in ops:
req.wait() req.wait()
@ -279,9 +275,8 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor):
def param_is_not_tensor_parallel_duplicate(param): def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, IS_TENSOR_PARALLEL) and return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank(
getattr(param, IS_TENSOR_PARALLEL)) or ( ParallelMode.TENSOR) == 0)
gpc.get_local_rank(ParallelMode.TENSOR) == 0)
@contextmanager @contextmanager

View File

@ -3,12 +3,20 @@ from typing import Callable
import torch import torch
from colossalai import nn as col_nn from colossalai import nn as col_nn
from colossalai.nn.layer.utils import CheckpointModule from colossalai.builder.pipeline import partition_uniform
from colossalai.registry import LAYERS, MODELS, LOSSES from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.layer.utils import CheckpointModule, divide
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.registry import LAYERS, LOSSES, MODELS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import dtype, nn from torch import dtype, nn
__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3'] __all__ = [
'GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt2_8B', 'gpt2_xl_pipeline',
'gpt2_8B_pipeline', 'gpt3', 'gpt3_pipeline'
]
@LAYERS.register_module @LAYERS.register_module
@ -18,7 +26,7 @@ class GPTEmbedding(nn.Module):
vocab_size: int, vocab_size: int,
max_position_embeddings: int, max_position_embeddings: int,
num_tokentypes: int = 0, num_tokentypes: int = 0,
padding_idx: int = 0, padding_idx: int = None,
dropout: float = 0., dropout: float = 0.,
dtype: dtype = None) -> None: dtype: dtype = None) -> None:
super().__init__() super().__init__()
@ -34,7 +42,7 @@ class GPTEmbedding(nn.Module):
def word_embedding_weight(self): def word_embedding_weight(self):
return self.word_embeddings.weight return self.word_embeddings.weight
def forward(self, input_ids, position_ids=None, tokentype_ids=None): def forward(self, input_ids, attention_mask=None, position_ids=None, tokentype_ids=None):
seq_length = input_ids.size(1) seq_length = input_ids.size(1)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
@ -42,7 +50,20 @@ class GPTEmbedding(nn.Module):
if self.tokentype_embeddings is not None and tokentype_ids is not None: if self.tokentype_embeddings is not None and tokentype_ids is not None:
x = x + self.tokentype_embeddings(tokentype_ids) x = x + self.tokentype_embeddings(tokentype_ids)
x = self.dropout(x) x = self.dropout(x)
return x
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
return x, attention_mask
@LAYERS.register_module @LAYERS.register_module
@ -53,20 +74,32 @@ class GPTSelfAttention(nn.Module):
attention_dropout: float, attention_dropout: float,
dropout: float, dropout: float,
bias: bool = True, bias: bool = True,
fuse_scale_mask_softmax: bool = False,
dtype: dtype = None) -> None: dtype: dtype = None) -> None:
super().__init__() super().__init__()
self.fuse_scale_mask_softmax = fuse_scale_mask_softmax
self.attention_head_size = dim // num_heads self.attention_head_size = divide(dim, num_heads)
self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias) self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
if fuse_scale_mask_softmax:
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True,
input_in_bf16=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
mask_func=None,
softmax_in_fp32=True,
scale=math.sqrt(self.attention_head_size))
else:
self.softmax = nn.Softmax(dim=-1)
self.attention_dropout = col_nn.Dropout(attention_dropout) self.attention_dropout = col_nn.Dropout(attention_dropout)
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True) self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True)
self.dropout = col_nn.Dropout(dropout) self.dropout = col_nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, attention_mask=None): def forward(self, x, attention_mask=None):
qkv = self.query_key_value(x) qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3 all_head_size = qkv.shape[-1] // 3
num_attention_heads = all_head_size // self.attention_head_size num_attention_heads = divide(all_head_size, self.attention_head_size)
new_qkv_shape = qkv.shape[:-1] + \ new_qkv_shape = qkv.shape[:-1] + \
(num_attention_heads, 3 * self.attention_head_size) (num_attention_heads, 3 * self.attention_head_size)
qkv = qkv.view(new_qkv_shape) qkv = qkv.view(new_qkv_shape)
@ -74,17 +107,20 @@ class GPTSelfAttention(nn.Module):
q, k, v = torch.chunk(qkv, 3, dim=-1) q, k, v = torch.chunk(qkv, 3, dim=-1)
x = torch.matmul(q, k.transpose(-1, -2)) x = torch.matmul(q, k.transpose(-1, -2))
x = x / math.sqrt(self.attention_head_size)
# causal mask if self.fuse_scale_mask_softmax:
q_len, k_len = q.size(-2), k.size(-2) x = self.softmax(x, attention_mask)
causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, else:
device=get_current_device())).view(1, 1, q_len, k_len).bool() x = x / math.sqrt(self.attention_head_size)
x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) # causal mask
q_len, k_len = q.size(-2), k.size(-2)
causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
device=get_current_device())).view(1, 1, q_len, k_len).bool()
x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))
if attention_mask is not None:
x = x + attention_mask
x = self.softmax(x)
if attention_mask is not None:
x = x + attention_mask
x = self.softmax(x)
x = self.attention_dropout(x) x = self.attention_dropout(x)
x = torch.matmul(x, v) x = torch.matmul(x, v)
@ -102,15 +138,16 @@ class GPTSelfAttention(nn.Module):
class GPTMLP(nn.Module): class GPTMLP(nn.Module):
def __init__(self, def __init__(self,
dim: int, dim: int,
mlp_ratio: int, mlp_ratio: float,
activation: Callable, activation: Callable,
dropout: float, dropout: float,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True): bias: bool = True):
super().__init__() super().__init__()
self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias) intermediate_dim = int(dim * mlp_ratio)
self.dense_1 = col_nn.Linear(dim, intermediate_dim, dtype=dtype, bias=bias)
self.activation = activation self.activation = activation
self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias) self.dense_2 = col_nn.Linear(intermediate_dim, dim, dtype=dtype, bias=bias)
self.dropout = col_nn.Dropout(dropout) self.dropout = col_nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
@ -126,27 +163,44 @@ class GPTBlock(CheckpointModule):
def __init__(self, def __init__(self,
dim: int, dim: int,
num_heads: int, num_heads: int,
mlp_ratio: int, mlp_ratio: float,
activation: Callable, activation: Callable,
attention_dropout: float = 0., attention_dropout: float = 0.,
dropout: float = 0., dropout: float = 0.,
layernorm_epsilon: float = 1e-5,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
apply_post_layernorm: bool = False,
fuse_scale_mask_softmax: bool = False,
checkpoint: bool = False): checkpoint: bool = False):
super().__init__(checkpoint=checkpoint) super().__init__(checkpoint)
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.apply_post_layernorm = apply_post_layernorm
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.attn = GPTSelfAttention(dim=dim, self.attn = GPTSelfAttention(dim=dim,
num_heads=num_heads, num_heads=num_heads,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
dropout=dropout, dropout=dropout,
bias=bias, bias=bias,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
dtype=dtype) dtype=dtype)
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias) self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias)
def _forward(self, x, attention_mask=None): def _forward(self, x, attention_mask=None):
x = x + self.attn(self.norm1(x), attention_mask) if not self.apply_post_layernorm:
x = x + self.mlp(self.norm2(x)) residual = x
x = self.norm1(x)
if self.apply_post_layernorm:
residual = x
x = residual + self.attn(x, attention_mask)
if not self.apply_post_layernorm:
residual = x
x = self.norm2(x)
if self.apply_post_layernorm:
residual = x
x = residual + self.mlp(x)
return x, attention_mask return x, attention_mask
@ -161,6 +215,10 @@ class GPTLMHead(nn.Module):
super().__init__() super().__init__()
self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype) self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype)
@property
def weight(self):
return self.dense.weight
def forward(self, x): def forward(self, x):
x = self.dense(x) x = self.dense(x)
return x return x
@ -187,18 +245,19 @@ class GPT(nn.Module):
dim: int = 768, dim: int = 768,
num_heads: int = 12, num_heads: int = 12,
depth: int = 12, depth: int = 12,
mlp_ratio: int = 4, mlp_ratio: float = 4.0,
dropout: float = 0.1, dropout: float = 0.1,
embedding_dropout: float = 0.1, embedding_dropout: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
layernorm_epsilon: float = 1e-5, layernorm_epsilon: float = 1e-5,
activation: Callable = nn.functional.gelu, activation: Callable = nn.functional.gelu,
checkpoint: bool = False, padding_idx: int = None,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
padding_idx: int = 0) -> None: apply_post_layernorm: bool = False,
fuse_scale_mask_softmax: bool = False,
checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.dtype = dtype
self.embed = GPTEmbedding(embedding_dim=dim, self.embed = GPTEmbedding(embedding_dim=dim,
vocab_size=vocab_size, vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
@ -213,8 +272,11 @@ class GPT(nn.Module):
activation=activation, activation=activation,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
dropout=dropout, dropout=dropout,
layernorm_epsilon=layernorm_epsilon,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
apply_post_layernorm=apply_post_layernorm,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
checkpoint=checkpoint, checkpoint=checkpoint,
) for _ in range(depth) ) for _ in range(depth)
]) ])
@ -224,22 +286,10 @@ class GPT(nn.Module):
self.head = GPTLMHead(dim=dim, self.head = GPTLMHead(dim=dim,
vocab_size=vocab_size, vocab_size=vocab_size,
word_embeeding_weight=self.embed.word_embedding_weight, word_embeeding_weight=self.embed.word_embedding_weight,
bias=bias,
dtype=dtype) dtype=dtype)
def forward(self, input_ids, attention_mask=None): def forward(self, input_ids, attention_mask=None):
# We create a 3D attention mask from a 2D tensor mask. x, attention_mask = self.embed(input_ids, attention_mask)
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
x = self.embed(input_ids)
for block in self.blocks: for block in self.blocks:
x, attention_mask = block(x, attention_mask) x, attention_mask = block(x, attention_mask)
@ -249,11 +299,103 @@ class GPT(nn.Module):
return x return x
class PipelineGPT(nn.Module):
def __init__(self,
vocab_size: int = 50304,
max_position_embeddings: int = 1024,
dim: int = 768,
num_heads: int = 12,
depth: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
embedding_dropout: float = 0.1,
attention_dropout: float = 0.1,
layernorm_epsilon: float = 1e-5,
activation: Callable = nn.functional.gelu,
padding_idx: int = None,
dtype: dtype = None,
bias: bool = True,
apply_post_layernorm: bool = False,
fuse_scale_mask_softmax: bool = False,
checkpoint: bool = False,
first: bool = False,
last: bool = False):
super().__init__()
self.checkpoint = checkpoint
self.first = first
self.last = last
if first:
self.embed = GPTEmbedding(embedding_dim=dim,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
padding_idx=padding_idx,
dropout=embedding_dropout,
dtype=dtype)
self.blocks = nn.ModuleList([
GPTBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
activation=activation,
attention_dropout=attention_dropout,
dropout=dropout,
layernorm_epsilon=layernorm_epsilon,
dtype=dtype,
bias=bias,
apply_post_layernorm=apply_post_layernorm,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
checkpoint=checkpoint,
) for _ in range(depth)
])
if self.last:
self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.head = GPTLMHead(dim=dim, vocab_size=vocab_size, dtype=dtype)
def forward(self, x=None, input_ids=None, attention_mask=None):
if self.first:
x, attention_mask = self.embed(input_ids, attention_mask)
for block in self.blocks:
x, attention_mask = block(x, attention_mask)
if self.last:
x = self.head(self.norm(x))
return x
def _create_gpt_model(**model_kwargs): def _create_gpt_model(**model_kwargs):
model = GPT(**model_kwargs) model = GPT(**model_kwargs)
return model return model
def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs):
logger = get_dist_logger()
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
rank = gpc.get_global_rank()
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
parts = partition_uniform(depth, pipeline_size,
num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions
models = []
for start, end in parts:
model_kwargs['first'] = start == 0
model_kwargs['last'] = end == depth
model_kwargs['depth'] = end - start
chunk = PipelineGPT(**model_kwargs).to(get_current_device())
if start == 0:
wrapper.register_parameter(chunk.embed.word_embedding_weight)
elif end == depth:
wrapper.register_parameter(chunk.head.weight)
models.append(chunk)
logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}')
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
return model
@MODELS.register_module @MODELS.register_module
def gpt2_small(**kwargs): def gpt2_small(**kwargs):
model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs)
@ -262,23 +404,47 @@ def gpt2_small(**kwargs):
@MODELS.register_module @MODELS.register_module
def gpt2_medium(**kwargs): def gpt2_medium(**kwargs):
model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(dim=1024, depth=24, num_heads=8, **kwargs)
return _create_gpt_model(**model_kwargs) return _create_gpt_model(**model_kwargs)
@MODELS.register_module @MODELS.register_module
def gpt2_large(**kwargs): def gpt2_large(**kwargs):
model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs) model_kwargs = dict(dim=1536, depth=36, num_heads=12, **kwargs)
return _create_gpt_model(**model_kwargs) return _create_gpt_model(**model_kwargs)
@MODELS.register_module @MODELS.register_module
def gpt2_xl(**kwargs): def gpt2_xl(**kwargs):
model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs) model_kwargs = dict(dim=1600, depth=48, num_heads=16, **kwargs)
return _create_gpt_model(**model_kwargs) return _create_gpt_model(**model_kwargs)
@MODELS.register_module @MODELS.register_module
def gpt3(**kwargs): def gpt2_8B(**kwargs):
model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs) model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs)
return _create_gpt_model(**model_kwargs) return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_xl_pipeline(**kwargs):
model_kwargs = dict(dim=1600, depth=48, num_heads=20, **kwargs)
return _create_gpt_pipeline_model(**model_kwargs)
@MODELS.register_module
def gpt2_8B_pipeline(**kwargs):
model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs)
return _create_gpt_pipeline_model(**model_kwargs)
@MODELS.register_module
def gpt3(**kwargs):
model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt3_pipeline(**kwargs):
model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
return _create_gpt_pipeline_model(**model_kwargs)

View File

@ -1,12 +1,14 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter
import time
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn import Linear1D_Col, Linear1D_Row from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import (Classifier1D, Embedding1D, Linear1D_Col, Linear1D_Row, VanillaClassifier,
VocabParallelClassifier1D, VocabParallelCrossEntropyLoss1D, VocabParallelEmbedding1D)
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE from torch.nn import Parameter
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear_col(): def check_linear_col():
@ -144,3 +146,351 @@ def check_linear_row():
check_equal(B_grad, layer.bias.grad) check_equal(B_grad, layer.bias.grad)
print_rank_0('linear_row backward: pass') print_rank_0('linear_row backward: pass')
def check_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('embed backward: pass')
def check_vocab_parallel_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('vocab parallel embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('vocab parallel embed backward: pass')
def check_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
env.parallel_input_1d = False
parallel_input_1d = env.parallel_input_1d
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True)
layer_master = layer_master.to(dtype).to(device)
W_master = layer_master.weight.data
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
layer.weight.data.copy_(W)
B_master = layer_master.bias.data
dist.broadcast(B_master, src=0)
B = B_master.clone()
layer.bias.data.copy_(B)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
if parallel_input_1d:
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
A = A.clone()
else:
A = A_master.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
if parallel_input_1d:
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier (no given weight) backward: pass')
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
W_master = layer_master.weight.data
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=0)[i]
layer.weight.data.copy_(W)
B_master = layer_master.bias.data
dist.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
layer.bias.data.copy_(B)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
check_equal(out, C)
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, layer.bias.grad)
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
def check_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
env.parallel_input_1d = False
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = C_master.clone()
check_equal(out, C)
print_rank_0('classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('classifier (given embed weight) backward: pass')
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
env.parallel_input_1d = False
layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
check_equal(out, C)
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
def check_vocab_parallel_loss():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
criterion = VocabParallelCrossEntropyLoss1D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=-1)[i]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('vocab parallel loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i]
check_equal(out_grad, out.grad)
print_rank_0('vocab parallel loss backward: pass')

View File

@ -9,6 +9,7 @@ SEQ_LENGTH = 8
IMG_SIZE = 16 IMG_SIZE = 16
HIDDEN_SIZE = 8 HIDDEN_SIZE = 8
NUM_CLASSES = 8 NUM_CLASSES = 8
VOCAB_SIZE = 16
def check_equal(A, B): def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True

View File

@ -7,6 +7,7 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port from colossalai.utils import free_port
@ -24,6 +25,7 @@ CONFIG = dict(
def check_layer(rank, world_size, port): def check_layer(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
@ -33,6 +35,13 @@ def check_layer(rank, world_size, port):
check_linear_col() check_linear_col()
check_linear_row() check_linear_row()
check_embed()
check_vocab_parallel_embed()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_vocab_parallel_loss()
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1,11 +1,12 @@
import torch import torch
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn import Linear2D, LayerNorm2D, Classifier2D from colossalai.nn import (Classifier2D, CrossEntropyLoss2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D,
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2D,
VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D)
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal, NUM_CLASSES
from .common import (BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal)
def check_linear(): def check_linear():
@ -57,7 +58,6 @@ def check_linear():
C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=-1)[j]
# print(f'Rank {gpc.get_global_rank()} A:\n{A}\nRank {gpc.get_global_rank()} W:\n{W}\nRank {gpc.get_global_rank()} b:\n{B}\nRank {gpc.get_global_rank()} C:\n{C}\nRank {gpc.get_global_rank()} out:\n{out}')
check_equal(out, C) check_equal(out, C)
print_rank_0('linear forward: pass') print_rank_0('linear forward: pass')
@ -90,84 +90,6 @@ def check_linear():
print_rank_0('linear backward: pass') print_rank_0('linear backward: pass')
def check_classifier():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = NUM_CLASSES
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[j]
W = torch.chunk(W, DEPTH, dim=-1)[i]
W = W.clone()
layer.weight.data.copy_(W)
# W.requires_grad = True
B_shape = (OUTPUT_SIZE,)
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, DEPTH, dim=0)[j]
B = B_master.clone()
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
# C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('classifier forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier backward: pass')
def check_layernorm(): def check_layernorm():
device = get_current_device() device = get_current_device()
dtype = torch.float32 dtype = torch.float32
@ -219,6 +141,497 @@ def check_layernorm():
print_rank_0('layer norm backward: pass') print_rank_0('layer norm backward: pass')
def check_embed():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('embed backward: pass')
def check_patch_embed():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = PatchEmbedding2D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer.cls_token)
torch.nn.init.ones_(layer.pos_embed)
layer = layer.to(device)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
proj_weight_master = layer_master.weight.data
torch.distributed.broadcast(proj_weight_master, src=0)
proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[j]
proj_weight = torch.chunk(proj_weight, DEPTH, dim=0)[i]
layer.weight.data.copy_(proj_weight)
proj_bias_master = layer_master.bias.data
torch.distributed.broadcast(proj_bias_master, src=0)
proj_bias = torch.chunk(proj_bias_master, DEPTH, dim=0)[j]
proj_bias = torch.chunk(proj_bias, DEPTH, dim=0)[i]
layer.bias.data.copy_(proj_bias)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(A)
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('patch embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[j]
cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[i]
check_equal(cls_grad, layer.cls_token.grad)
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[j]
pos_grad = torch.chunk(pos_grad, DEPTH, dim=-1)[i]
check_equal(pos_grad, layer.pos_embed.grad)
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, layer.weight.grad)
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
check_equal(bias_grad, layer.bias.grad)
print_rank_0('patch embed backward: pass')
def check_vocab_parallel_embed():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('vocab parallel embed backward: pass')
def check_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = NUM_CLASSES
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[j]
W = torch.chunk(W, DEPTH, dim=-1)[i]
W = W.clone()
layer.weight.data.copy_(W)
# W.requires_grad = True
B_shape = (OUTPUT_SIZE, )
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, DEPTH, dim=0)[j]
B = B_master.clone()
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
# C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier (no given weight) backward: pass')
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
bias = torch.chunk(bias, DEPTH)[i]
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
B_grad = torch.chunk(B_grad, DEPTH)[j]
B_grad = torch.chunk(B_grad, DEPTH)[i]
check_equal(B_grad, layer.bias.grad)
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
def check_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
layer = Classifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
check_equal(out, C)
print_rank_0('classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('classifier (given embed weight) backward: pass')
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
def check_loss():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
criterion = CrossEntropyLoss2D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('cross entropy loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
check_equal(out_grad, out.grad)
print_rank_0('cross entropy loss backward: pass')
def check_vocab_parallel_loss():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
criterion = VocabParallelCrossEntropyLoss2D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = torch.chunk(out, DEPTH, dim=-1)[j]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('vocab parallel cross entropy loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j]
check_equal(out_grad, out.grad)
print_rank_0('vocab parallel cross entropy loss backward: pass')
# def check_attention(): # def check_attention():
# device = get_current_device() # device = get_current_device()
# dtype = torch.float32 # dtype = torch.float32
@ -257,7 +670,6 @@ def check_layernorm():
# assert A.grad.shape == A.shape # assert A.grad.shape == A.shape
# print_rank_0('self attention backward: pass') # print_rank_0('self attention backward: pass')
# def check_mlp(): # def check_mlp():
# device = get_current_device() # device = get_current_device()
# dtype = torch.float32 # dtype = torch.float32
@ -291,7 +703,6 @@ def check_layernorm():
# assert A.grad.shape == A.shape # assert A.grad.shape == A.shape
# print_rank_0('mlp backward: pass') # print_rank_0('mlp backward: pass')
# def check_transformerlayer(): # def check_transformerlayer():
# device = get_current_device() # device = get_current_device()
# dtype = torch.float32 # dtype = torch.float32

View File

@ -8,6 +8,9 @@ BATCH_SIZE = 8
SEQ_LENGTH = 8 SEQ_LENGTH = 8
HIDDEN_SIZE = 8 HIDDEN_SIZE = 8
NUM_CLASSES = 8 NUM_CLASSES = 8
VOCAB_SIZE = 16
IMG_SIZE = 16
def check_equal(A, B): def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) == True assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)

View File

@ -8,20 +8,17 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from checks_2d.check_layer_2d import * from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
from checks_2d.check_operation_2d import * check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
check_vocab_parallel_loss)
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
CONFIG = dict( CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')), )
parallel=dict(
pipeline=dict(size=1),
tensor=dict(
size=4,
mode='2d'
)
),
)
def check_operations(): def check_operations():
@ -33,16 +30,24 @@ def check_operations():
def check_layer(): def check_layer():
check_linear() check_linear()
check_layernorm() check_layernorm()
check_classifier() check_embed()
check_patch_embed()
check_vocab_parallel_embed()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_loss()
check_vocab_parallel_loss()
def check_layer_and_operation(rank, world_size, port): def check_layer_and_operation(rank, world_size, port):
launch(config=CONFIG, disable_existing_loggers()
rank=rank, launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = True
# check_operations() # check_operations()
check_layer() check_layer()
gpc.destroy() gpc.destroy()

View File

@ -1,11 +1,12 @@
import torch import torch
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn import Linear2p5D, LayerNorm2p5D, Classifier2p5D from colossalai.nn import (Classifier2p5D, CrossEntropyLoss2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D,
from colossalai.utils import get_current_device PatchEmbedding2p5D, VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2p5D,
from colossalai.utils import print_rank_0 VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D)
from colossalai.utils import get_current_device, print_rank_0
from torch.nn import Parameter
from .common import * from .common import *
@ -19,11 +20,7 @@ def check_linear():
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = Linear2p5D( layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False)
INPUT_SIZE,
OUTPUT_SIZE,
dtype=dtype,
skip_bias_add=False)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device) A_master = torch.randn(A_shape, dtype=dtype, device=device)
@ -94,86 +91,6 @@ def check_linear():
print_rank_0('linear backward: pass') print_rank_0('linear backward: pass')
def check_classifier():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = NUM_CLASSES
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
# W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i]
W = W.clone()
layer.weight.data.copy_(W)
# W.requires_grad = True
B_shape = (OUTPUT_SIZE,)
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
B = B_master.clone()
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
# C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('classifier forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
# B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier backward: pass')
def check_layernorm(): def check_layernorm():
device = get_current_device() device = get_current_device()
dtype = torch.float32 dtype = torch.float32
@ -184,9 +101,7 @@ def check_layernorm():
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layernorm = LayerNorm2p5D( layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype)
INPUT_SIZE,
dtype=dtype)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device) A_master = torch.randn(A_shape, dtype=dtype, device=device)
@ -228,6 +143,500 @@ def check_layernorm():
print_rank_0('layer norm backward: pass') print_rank_0('layer norm backward: pass')
def check_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('embed backward: pass')
def check_patch_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer.cls_token)
torch.nn.init.ones_(layer.pos_embed)
layer = layer.to(device)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
proj_weight_master = layer_master.weight.data
torch.distributed.broadcast(proj_weight_master, src=0)
proj_weight = torch.chunk(proj_weight_master, TESSERACT_DIM, dim=0)[j]
proj_weight = torch.chunk(proj_weight, TESSERACT_DIM, dim=0)[i]
layer.weight.data.copy_(proj_weight)
proj_bias_master = layer_master.bias.data
torch.distributed.broadcast(proj_bias_master, src=0)
proj_bias = torch.chunk(proj_bias_master, TESSERACT_DIM, dim=0)[j]
proj_bias = torch.chunk(proj_bias, TESSERACT_DIM, dim=0)[i]
layer.bias.data.copy_(proj_bias)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(A)
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('patch embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, TESSERACT_DIM, dim=-1)[j]
cls_grad = torch.chunk(cls_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(cls_grad, layer.cls_token.grad)
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, TESSERACT_DIM, dim=-1)[j]
pos_grad = torch.chunk(pos_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(pos_grad, layer.pos_embed.grad)
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
check_equal(B_grad, layer.weight.grad)
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j]
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i]
check_equal(bias_grad, layer.bias.grad)
print_rank_0('patch embed backward: pass')
def check_vocab_parallel_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('vocab parallel embed backward: pass')
def check_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = NUM_CLASSES
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
# W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i]
W = W.clone()
layer.weight.data.copy_(W)
# W.requires_grad = True
B_shape = (OUTPUT_SIZE, )
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
B = B_master.clone()
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
# C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
# B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier (no given weight) backward: pass')
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=0)[i]
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[j]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, TESSERACT_DIM)[j]
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j]
if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
def check_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i]
embed.weight.data.copy_(weight)
layer = Classifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
check_equal(out, C)
print_rank_0('classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('classifier (given embed weight) backward: pass')
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i]
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
def check_loss():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
criterion = CrossEntropyLoss2p5D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('cross entropy loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
check_equal(out_grad, out.grad)
print_rank_0('cross entropy loss backward: pass')
def check_vocab_parallel_loss():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
criterion = VocabParallelCrossEntropyLoss2p5D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
out = torch.chunk(out, TESSERACT_DIM, dim=-1)[j]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('vocab parallel cross entropy loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(out_grad, out.grad)
print_rank_0('vocab parallel cross entropy loss backward: pass')
# def check_attention(): # def check_attention():
# device = get_current_device() # device = get_current_device()
# dtype = torch.float32 # dtype = torch.float32
@ -267,7 +676,6 @@ def check_layernorm():
# assert A.grad.shape == A.shape # assert A.grad.shape == A.shape
# print_rank_0('self attention backward: pass') # print_rank_0('self attention backward: pass')
# def check_mlp(): # def check_mlp():
# device = get_current_device() # device = get_current_device()
# dtype = torch.float32 # dtype = torch.float32
@ -304,7 +712,6 @@ def check_layernorm():
# assert A.grad.shape == A.shape # assert A.grad.shape == A.shape
# print_rank_0('mlp backward: pass') # print_rank_0('mlp backward: pass')
# def check_transformerlayer(): # def check_transformerlayer():
# device = get_current_device() # device = get_current_device()
# dtype = torch.float32 # dtype = torch.float32
@ -344,4 +751,4 @@ def check_layernorm():
# out.backward(grad) # out.backward(grad)
# assert A.grad.shape == A.shape # assert A.grad.shape == A.shape
# print_rank_0('transformerlayer backward: pass') # print_rank_0('transformerlayer backward: pass')

View File

@ -5,8 +5,10 @@ TESSERACT_DEP = 2
BATCH_SIZE = 8 BATCH_SIZE = 8
SEQ_LENGTH = 8 SEQ_LENGTH = 8
HIDDEN_SIZE = 8 HIDDEN_SIZE = 8
NUM_CLASSES = 3 NUM_CLASSES = 8
VOCAB_SIZE = 16
IMG_SIZE = 16
def check_equal(A, B): def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True assert torch.allclose(A, B, rtol=1e-5, atol=1e-2)

View File

@ -5,10 +5,10 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from checks_2p5d.check_layer_2p5d import (check_classifier, check_layernorm, from checks_2p5d.check_layer_2p5d import *
check_linear)
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
CONFIG = dict( CONFIG = dict(
@ -28,10 +28,19 @@ def check_operations():
def check_layer(): def check_layer():
check_linear() check_linear()
check_layernorm() check_layernorm()
check_classifier() check_embed()
check_patch_embed()
check_vocab_parallel_embed()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_loss()
check_vocab_parallel_loss()
def check_layer_and_operation(rank, world_size, port): def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, launch(config=CONFIG,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
@ -39,6 +48,9 @@ def check_layer_and_operation(rank, world_size, port):
port=port, port=port,
backend='nccl') backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = True
check_operations() check_operations()
check_layer() check_layer()
gpc.destroy() gpc.destroy()

View File

@ -3,16 +3,17 @@
import time import time
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D) import torch
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context from colossalai.core import global_context
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VanillaClassifier, from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D,
VanillaPatchEmbedding) VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D,
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D)
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
from .common import * from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
import torch
def check_linear(): def check_linear():
@ -27,9 +28,9 @@ def check_linear():
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode) j = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode) i = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode) k = global_context.get_local_rank(output_parallel_mode)
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True) layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True)
layer = layer.to(device) layer = layer.to(device)
@ -112,9 +113,9 @@ def check_layernorm():
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode) j = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode) i = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode) k = global_context.get_local_rank(output_parallel_mode)
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype) norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype)
norm = norm.to(device) norm = norm.to(device)
@ -186,7 +187,7 @@ def check_layernorm():
return fwd_end - fwd_start, bwd_end - bwd_start return fwd_end - fwd_start, bwd_end - bwd_start
def check_classifier(): def check_classifier_no_given_weight():
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
logger = get_dist_logger() logger = get_dist_logger()
device = get_current_device() device = get_current_device()
@ -197,9 +198,9 @@ def check_classifier():
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode) j = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode) i = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode) k = global_context.get_local_rank(output_parallel_mode)
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True) layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True)
layer = layer.to(device) layer = layer.to(device)
@ -229,14 +230,14 @@ def check_classifier():
torch.cuda.synchronize() torch.cuda.synchronize()
fwd_end = time.time() fwd_end = time.time()
print_rank_0( print_rank_0(
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), 'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
logger) tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone() A_master = A_master.clone()
A_master.requires_grad = True A_master.requires_grad = True
C_master = layer_master(A_master) C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=0)[j] C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} head forward: {}'.format(rank, check_equal(out, C))) logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
@ -249,7 +250,7 @@ def check_classifier():
out.backward(grad) out.backward(grad)
torch.cuda.synchronize() torch.cuda.synchronize()
bwd_end = time.time() bwd_end = time.time()
print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone() grad_master = grad_master.clone()
C_master.backward(grad_master) C_master.backward(grad_master)
@ -257,23 +258,275 @@ def check_classifier():
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} head backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
B_grad = layer_master.weight.grad B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k: if j == k:
logger.info('Rank {} head backward (weight_grad): {}'.format(rank, logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
check_equal(B_grad, layer.weight.grad))) rank, check_equal(B_grad, layer.weight.grad)))
else: else:
logger.info('Rank {} head backward (weight_grad): {}'.format(rank, layer.weight.grad is None)) logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
rank, layer.weight.grad is None))
bias_grad = layer_master.bias.grad bias_grad = layer_master.bias.grad
logger.info('Rank {} head backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad))) logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start return fwd_end - fwd_start, bwd_end - bwd_start
def check_embed(): def check_vocab_parallel_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
embed = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
embed.weight.data.copy_(weight)
layer = Classifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(embed(A))
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, embed.weight.grad)))
else:
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
rank, embed.weight.grad is None))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(embed(A))
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad,
embed.weight.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_patch_embed():
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
device = get_current_device() device = get_current_device()
logger = get_dist_logger() logger = get_dist_logger()
@ -283,9 +536,9 @@ def check_embed():
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode) j = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode) i = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode) k = global_context.get_local_rank(output_parallel_mode)
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer.cls_token) torch.nn.init.ones_(layer.cls_token)
@ -310,18 +563,99 @@ def check_embed():
A_master = torch.randn(A_shape, dtype=dtype, device=device) A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0) torch.distributed.broadcast(A_master, src=0)
A = A_master.clone() A = A_master.clone()
A.requires_grad = True
fwd_start = time.time() fwd_start = time.time()
out = layer(A) out = layer(A)
torch.cuda.synchronize() torch.cuda.synchronize()
fwd_end = time.time() fwd_end = time.time()
print_rank_0( print_rank_0(
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), 'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start), logger) fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank,
check_equal(pos_grad, layer.pos_embed.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank,
check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank,
check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(dtype).to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start),
ranks=[0])
A_master = A_master.clone() A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master) C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=-1)[k]
@ -329,7 +663,7 @@ def check_embed():
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0) torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k] grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@ -339,30 +673,88 @@ def check_embed():
out.backward(grad) out.backward(grad)
torch.cuda.synchronize() torch.cuda.synchronize()
bwd_end = time.time() bwd_end = time.time()
print_rank_0('embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
grad_master = grad_master.clone() grad_master = grad_master.clone()
C_master.backward(grad_master) C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad B_grad = layer_master.weight.grad
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad))) if j == k:
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
else:
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
pos_grad_master = layer_master.pos_embed.grad return fwd_end - fwd_start, bwd_end - bwd_start
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(rank, check_equal(pos_grad, layer.pos_embed.grad)))
def check_vocab_parallel_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(dtype).to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
ranks=[0])
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = layer_master.weight.grad B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
if j == k: B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, check_equal(B_grad, logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad,
layer.weight.grad))) layer.weight.grad)))
else:
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, layer.weight.grad is None))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} embed backward (proj_bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start return fwd_end - fwd_start, bwd_end - bwd_start
@ -375,11 +767,9 @@ def check_loss():
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode) j = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode) i = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
criterion = CrossEntropyLoss3D() criterion = CrossEntropyLoss3D()
criterion_master = torch.nn.CrossEntropyLoss() criterion_master = torch.nn.CrossEntropyLoss()
@ -397,24 +787,79 @@ def check_loss():
fwd_start = time.time() fwd_start = time.time()
loss = criterion(out, target_master) loss = criterion(out, target_master)
fwd_end = time.time() fwd_end = time.time()
print_rank_0( logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape),
'loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), fwd_end - fwd_start),
logger) ranks=[0])
out_master = out_master.clone() out_master = out_master.clone()
out_master.requires_grad = True out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master) loss_master = criterion_master(out_master, target_master)
logger.info('Rank {} CrossEntropyLoss forward: {}'.format(rank, check_equal(loss, loss_master))) logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
bwd_start = time.time() bwd_start = time.time()
loss.backward() loss.backward()
bwd_end = time.time() bwd_end = time.time()
print_rank_0('loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
loss_master.backward() loss_master.backward()
out_grad = out_master.grad out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
logger.info('Rank {} CrossEntropyLoss backward: {}'.format(rank, check_equal(out_grad, out.grad))) logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_loss():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
criterion = VocabParallelCrossEntropyLoss3D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = torch.chunk(out, DEPTH, dim=-1)[k]
out = torch.chunk(out, DEPTH, dim=0)[j]
out = out.clone()
out.requires_grad = True
fwd_start = time.time()
loss = criterion(out, target_master)
fwd_end = time.time()
logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
ranks=[0])
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
bwd_start = time.time()
loss.backward()
bwd_end = time.time()
logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k]
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start return fwd_end - fwd_start, bwd_end - bwd_start

View File

@ -10,6 +10,7 @@ HIDDEN_SIZE = 8
NUM_CLASSES = 8 NUM_CLASSES = 8
NUM_BLOCKS = 2 NUM_BLOCKS = 2
IMG_SIZE = 16 IMG_SIZE = 16
VOCAB_SIZE = 16
def check_equal(A, B): def check_equal(A, B):
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)

View File

@ -7,9 +7,14 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from checks_3d.check_layer_3d import * from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
check_vocab_parallel_loss)
CONFIG = dict( CONFIG = dict(
parallel=dict( parallel=dict(
@ -23,13 +28,23 @@ CONFIG = dict(
def check_layer(): def check_layer():
check_linear() check_linear()
check_layernorm() check_layernorm()
check_classifier() check_classifier_no_given_weight()
# check_embed() check_vocab_parallel_classifier_no_given_weight()
# check_loss() check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_embed()
check_patch_embed()
check_vocab_parallel_embed()
check_loss()
check_vocab_parallel_loss()
def check_layer_and_operation(rank, world_size, port): def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = True
check_layer() check_layer()
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()