Merge branch 'main' into feature/shardformer

pull/4612/head
Hongxin Liu 2023-09-05 21:54:08 +08:00 committed by GitHub
commit fae6c92ead
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
113 changed files with 629 additions and 633 deletions

View File

@ -1,5 +1,5 @@
class Registry: class Registry:
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here # TODO: refactor the registry classes used in colossalai.legacy.registry, colossalai.fx and here
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name

View File

@ -3,6 +3,7 @@ import os
import warnings import warnings
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
unwrap_optimizer, unwrap_optimizer,
) )
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO from .torch_ddp_plugin import TorchDDPCheckpointIO
@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
sharded_optimizer_loading_epilogue(optimizer) sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
class LowLevelZeroModel(ModelWrapper): def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
use_safetensors)
def __init__(self, module: nn.Module, stage: int, precision: str) -> None: def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
super().__init__(module) assert isinstance(model, LowLevelZeroModel)
self.dtype = None super().load_unsharded_model(model.module, checkpoint, strict)
if precision == 'fp16': model.update_master_params()
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
module = zero_model_wrapper(module, zero_stage=stage)
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs): def load_sharded_model(self,
if self.convert_fn is not None: model: LowLevelZeroModel,
args = tree_map(self.convert_fn, args) checkpoint_index_file: Path,
kwargs = tree_map(self.convert_fn, kwargs) strict: bool = False,
return super().forward(*args, **kwargs) use_safetensors: bool = False,
load_sub_module: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):
@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
super().__init__() super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage self.stage = stage
self.precision = precision self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, self.zero_optim_kwargs = dict(
communication_dtype=communication_dtype, initial_scale=initial_scale,
overlap_communication=overlap_communication, growth_factor=growth_factor,
cpu_offload=cpu_offload) backoff_factor=backoff_factor,
self.optim_kwargs = dict(initial_scale=initial_scale, growth_interval=growth_interval,
growth_factor=growth_factor, hysteresis=hysteresis,
backoff_factor=backoff_factor, min_scale=min_scale,
growth_interval=growth_interval, max_scale=max_scale,
hysteresis=hysteresis, clip_grad_norm=max_norm,
min_scale=min_scale, reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
max_scale=max_scale, communication_dtype=communication_dtype,
max_norm=max_norm, overlap_communication=overlap_communication,
norm_type=norm_type) cpu_offload=cpu_offload,
partition_grad=(stage == 2),
)
self.verbose = verbose self.verbose = verbose
# set class name with stage, for better error message # set class name with stage, for better error message
@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision) model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \ if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper): not isinstance(optimizer, OptimizerWrapper):
optimizer = zero_optim_wrapper(model.unwrap(), optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
optimizer, **self.zero_optim_kwargs,
optim_config=self.zero_optim_config, verbose=self.verbose)
**self.optim_kwargs, # inject update_master_params
verbose=self.verbose) model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -15,8 +15,8 @@ from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode

View File

@ -2,8 +2,9 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer

View File

@ -3,7 +3,7 @@ import math
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer

View File

@ -4,9 +4,10 @@
import math import math
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import Config from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer

View File

@ -6,7 +6,7 @@ import math
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer

View File

@ -3,7 +3,7 @@
from torch import distributed as dist from torch import distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer

View File

@ -2,9 +2,11 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch.distributed as dist import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module

View File

@ -3,7 +3,7 @@
from torch import distributed as dist from torch import distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer

View File

@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch.distributed as dist import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .initializer_tensor import Initializer_Tensor from .initializer_tensor import Initializer_Tensor

View File

@ -3,9 +3,10 @@
import torch.distributed as dist import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module

View File

@ -17,13 +17,13 @@ from torch.utils.data import DataLoader
from colossalai.amp import AMP_TYPE, convert_to_amp from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine from colossalai.legacy.builder.builder import build_gradient_handler
from colossalai.engine.gradient_accumulation import accumulate_gradient from colossalai.legacy.engine import Engine
from colossalai.engine.schedule import ( from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
from colossalai.legacy.engine.schedule import (
InterleavedPipelineSchedule, InterleavedPipelineSchedule,
NonPipelineSchedule, NonPipelineSchedule,
PipelineSchedule, PipelineSchedule,

View File

@ -1,4 +1,4 @@
from .model import ModelWrapper from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper', 'ModelWrapper'] __all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']

View File

@ -23,3 +23,14 @@ class ModelWrapper(nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
class AMPModelMixin:
"""This mixin class defines the interface for AMP training.
"""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
pass

View File

View File

@ -3,7 +3,7 @@
import inspect import inspect
from colossalai.registry import * from colossalai.legacy.registry import *
def build_from_config(module, config: dict): def build_from_config(module, config: dict):
@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer):
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
Returns: Returns:
An object of :class:`colossalai.engine.BaseGradientHandler` An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
""" """
config_ = config.copy() config_ = config.copy()
config_['model'] = model config_['model'] = model

View File

@ -8,11 +8,17 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule from colossalai.legacy.engine.schedule import (
BaseSchedule,
InterleavedPipelineSchedule,
NonPipelineSchedule,
PipelineSchedule,
)
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
class Engine: class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method """Basic engine class for training and evaluation. It runs a specific process method

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from colossalai.engine import BaseGradientHandler from colossalai.legacy.engine import BaseGradientHandler
from ._gradient_accumulation import ( from ._gradient_accumulation import (
GradAccumDataloader, GradAccumDataloader,
@ -33,7 +33,7 @@ def accumulate_gradient(model: nn.Module,
dataloader (:class:`torch.utils.data.DataLoader` or iterable objects): dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
your dataloader object, would be called like iter(dataloader) your dataloader object, would be called like iter(dataloader)
accumulate_size (int): the number of steps to accumulate gradients accumulate_size (int): the number of steps to accumulate gradients
gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]): gradient_handlers (List[:class:`colossalai.legacy.engine.BaseGradientHandler`]):
list of gradient handler objects. Default is None. list of gradient handler objects. Default is None.
lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`): lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
your ``lr_scheduler`` object for gradient accumulation. Defaults to None. your ``lr_scheduler`` object for gradient accumulation. Defaults to None.

View File

@ -10,7 +10,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.engine import BaseGradientHandler from colossalai.legacy.engine import BaseGradientHandler
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import conditional_context from colossalai.utils import conditional_context
@ -262,7 +262,7 @@ class GradAccumGradientHandler:
before accumulation size is reached. before accumulation size is reached.
Args: Args:
grad_handler (:class:`colossalai.engine.BaseGradientHandler`): grad_handler (:class:`colossalai.legacy.engine.BaseGradientHandler`):
Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`. Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`.
accumulate_size (int): The number of steps to accumulate gradients. accumulate_size (int): The number of steps to accumulate gradients.

View File

@ -1,7 +1,7 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce from .utils import bucket_allreduce

View File

@ -1,9 +1,9 @@
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict from colossalai.utils.moe import get_moe_epsize_param_dict
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce from .utils import bucket_allreduce

View File

@ -7,7 +7,7 @@ import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler

View File

@ -1,7 +1,7 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce from .utils import bucket_allreduce

View File

@ -1,4 +1,4 @@
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler

View File

@ -95,7 +95,7 @@ class BaseSchedule(ABC):
"""The process function over a batch of dataset for training or evaluation. """The process function over a batch of dataset for training or evaluation.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
forward_only (bool): If True, the process won't include backward. forward_only (bool): If True, the process won't include backward.
return_loss (bool, optional): If False, the loss won't be returned. return_loss (bool, optional): If False, the loss won't be returned.

View File

@ -54,7 +54,7 @@ class NonPipelineSchedule(BaseSchedule):
The returned labels and loss will None if :attr:`return_loss` is False. The returned labels and loss will None if :attr:`return_loss` is False.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional): forward_only (bool, optional):
If True, the model is run for the forward pass, else back propagation will be executed. If True, the model is run for the forward pass, else back propagation will be executed.

View File

@ -236,7 +236,7 @@ class PipelineSchedule(BaseSchedule):
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels. return_output_label (bool, optional): Whether returns output labels.
@ -274,7 +274,7 @@ class PipelineSchedule(BaseSchedule):
This is a helper function and can be ignored by users. This is a helper function and can be ignored by users.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage. input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage. output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage. output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
@ -314,7 +314,7 @@ class PipelineSchedule(BaseSchedule):
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional): forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run. Whether run forward step only. Default is false. If true, no backward will be run.
@ -518,7 +518,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
model_chunk_id (int): The id of model chunks. model_chunk_id (int): The id of model chunks.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
@ -555,7 +555,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
communication between pipeline stages as needed. communication between pipeline stages as needed.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional): forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run. Whether run forward step only. Default is false. If true, no backward will be run.

View File

@ -69,7 +69,7 @@ class PipelineScheduleV2(PipelineSchedule):
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional): forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run. Whether run forward step only. Default is false. If true, no backward will be run.

View File

@ -6,7 +6,7 @@ from typing import List
class Registry: class Registry:
"""This is a registry class used to register classes and modules so that a universal """This is a registry class used to register classes and modules so that a universal
object builder can be enabled. object builder can be enabled.
Args: Args:
@ -42,7 +42,7 @@ class Registry:
return module_class return module_class
def get_module(self, module_name: str): def get_module(self, module_name: str):
"""Retrieves a module with name `module_name` and returns the module if it has """Retrieves a module with name `module_name` and returns the module if it has
already been registered before. already been registered before.
Args: Args:

View File

@ -1,14 +1,13 @@
from typing import Union, List, Any from typing import Any, List, Union
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from colossalai.engine import Engine from colossalai.legacy.engine import Engine
from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from colossalai.trainer.hooks import BaseHook
class Trainer: class Trainer:

View File

@ -1,7 +1,12 @@
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook from ._checkpoint_hook import SaveCheckpointHook
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook, from ._log_hook import (
TensorboardHook) LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook,
TensorboardHook,
)
from ._lr_scheduler_hook import LRSchedulerHook from ._lr_scheduler_hook import LRSchedulerHook
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook

View File

@ -1,11 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch import torch
from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS from colossalai.legacy.registry import HOOKS
from colossalai.trainer.hooks import BaseHook from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import get_dist_logger
from colossalai.utils.checkpointing import save_checkpoint from colossalai.utils.checkpointing import save_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook from ._lr_scheduler_hook import LRSchedulerHook

View File

@ -3,17 +3,17 @@
import os import os
import os.path as osp import os.path as osp
from typing import List from typing import List
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS from colossalai.legacy.registry import HOOKS
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
from colossalai.utils import report_memory_usage, is_dp_rank_0, \ from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._commons_ import _format_number from ._commons_ import _format_number
from colossalai.trainer.hooks._metric_hook import ThroughputMetric
class LogByEpochHook(BaseHook): class LogByEpochHook(BaseHook):

View File

@ -1,6 +1,7 @@
from colossalai.registry import HOOKS
from torch import Tensor from torch import Tensor
from colossalai.legacy.registry import HOOKS
from ._metric_hook import LearningRateMetric, MetricHook from ._metric_hook import LearningRateMetric, MetricHook

View File

@ -6,10 +6,11 @@ from typing import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.communication import all_reduce from colossalai.communication import all_reduce
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS from colossalai.legacy.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook from ._base_hook import BaseHook
@ -19,8 +20,8 @@ from ._commons_ import _format_number
class Metric(ABC): class Metric(ABC):
"""A basic class of metric collectors. It collects a specific """A basic class of metric collectors. It collects a specific
metric during training or evaluation and would always be used with metric during training or evaluation and would always be used with
:class:`MetricHook` to help it update its states and show the :class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric metric. So please use corresponding hook class to make the metric
collector works. collector works.
Args: Args:
@ -220,9 +221,9 @@ class AccuracyMetric(Metric):
class MetricHook(BaseHook): class MetricHook(BaseHook):
"""Specialized hook classes for :class:`Metric`. """Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and Some help metric collectors initialize, reset and
update their states. Others are used to display and update their states. Others are used to display and
record the metric. record the metric.
Args: Args:
@ -355,7 +356,7 @@ class ThroughputMetric(Metric):
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else: else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA) gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
@ -366,7 +367,7 @@ class ThroughputMetric(Metric):
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else: else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA) gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())

View File

@ -15,8 +15,8 @@ from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm from colossalai.kernel import LayerNorm
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import ( from colossalai.utils.checkpointing import (
broadcast_state_dict, broadcast_state_dict,
gather_tensor_parallel_state_dict, gather_tensor_parallel_state_dict,

View File

@ -5,21 +5,30 @@ from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, from ._operation import (
reduce_scatter_tensor_2d, split_batch_2d) Matmul_AB_2D,
Matmul_ABT_2D,
add_bias_2d,
all_gather_tensor_2d,
classifier_2d,
layernorm_2d,
reduce_scatter_tensor_2d,
split_batch_2d,
)
from ._utils import assert_summa_initialization, get_summa_dim_from_env from ._utils import assert_summa_initialization, get_summa_dim_from_env

View File

@ -5,22 +5,34 @@ from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.utils.checkpointing import (
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, broadcast_state_dict,
partition_tensor_parallel_state_dict) gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d, from ._operation import (
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d) Matmul_AB_2p5D,
Matmul_ABT_2p5D,
add_bias_2p5d,
all_gather_tensor_2p5d,
classifier_2p5d,
layernorm_2p5d,
reduce_scatter_tensor_2p5d,
split_batch_2p5d,
)
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env

View File

@ -13,9 +13,9 @@ from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import ( from colossalai.utils.checkpointing import (
broadcast_state_dict, broadcast_state_dict,
gather_tensor_parallel_state_dict, gather_tensor_parallel_state_dict,

View File

@ -2,20 +2,20 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import colossalai
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter from torch.nn import Parameter
import colossalai
from colossalai.context import seed
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
from colossalai.registry import LAYERS
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.kernel import FusedScaleMaskSoftmax from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.context import seed from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.legacy.registry import LAYERS
from colossalai.nn.layer.parallel_sequence._operation import RingAV, RingQK
@LAYERS.register_module @LAYERS.register_module

View File

@ -8,8 +8,8 @@ from torch import nn as nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.context import seed from colossalai.context import seed
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ..utils import to_2tuple from ..utils import to_2tuple

View File

@ -1,105 +1,106 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import ParallelMode from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.core import global_context as gpc from torch.nn.modules.loss import _Loss
from colossalai.registry import LOSSES
from torch.cuda.amp import custom_bwd, custom_fwd from colossalai.context import ParallelMode
from torch.nn.modules.loss import _Loss from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
class _VocabParallelCrossEntropy1D(torch.autograd.Function):
class _VocabParallelCrossEntropy1D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) @staticmethod
def forward(ctx, vocab_parallel_logits, targets, process_group): @custom_fwd(cast_inputs=torch.float32)
if process_group is None: def forward(ctx, vocab_parallel_logits, targets, process_group):
process_group = gpc.get_group(ParallelMode.PARALLEL_1D) if process_group is None:
process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] # Maximum value along vocab dimension across all GPUs.
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
# Subtract the maximum value. torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) # Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indices
partition_vocab_size = vocab_parallel_logits.size()[-1] # Get the partition's vocab indices
rank = dist.get_rank(process_group) partition_vocab_size = vocab_parallel_logits.size()[-1]
vocab_start_index = partition_vocab_size * rank rank = dist.get_rank(process_group)
vocab_end_index = vocab_start_index + partition_vocab_size vocab_start_index = partition_vocab_size * rank
vocab_end_index = vocab_start_index + partition_vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) # Create a mask of valid vocab ids (1 means it needs to be masked).
masked_target = targets.clone() - vocab_start_index target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
masked_target[target_mask] = 0 masked_target = targets.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size # Get predicted-logits = logits[target].
# [*, partition-vocab-size] and target to a 1-D tensor of size [*]. # For Simplicity, we convert logits to a 2-D tensor with size
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
masked_target_1d = masked_target.view(-1) logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) masked_target_1d = masked_target.view(-1)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits = predicted_logits_1d.view_as(targets) predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits[target_mask] = 0.0 predicted_logits = predicted_logits_1d.view_as(targets)
# All reduce is needed to get the chunks from other GPUs. predicted_logits[target_mask] = 0.0
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) # All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = torch.exp(vocab_parallel_logits) # Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits = exp_logits.sum(dim=-1) exp_logits = torch.exp(vocab_parallel_logits)
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits # Loss = log(sum(exp(logits))) - predicted-logit.
# Store softmax, target-mask and masked-target for backward pass. loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) # Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
return loss ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
@custom_bwd @staticmethod
def backward(ctx, grad_output): @custom_bwd
def backward(ctx, grad_output):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors # Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax # All the inputs have softmax as their gradient.
# For simplicity, work with the 2D gradient. grad_input = softmax
partition_vocab_size = softmax.size()[-1] # For simplicity, work with the 2D gradient.
grad_2d = grad_input.view(-1, partition_vocab_size) partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) # Add the gradient from matching classes.
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1)) # Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None, None
return grad_input, None, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss1D(_Loss): @LOSSES.register_module
"""Vocab parallel cross entropy loss for 1D parallelism. class VocabParallelCrossEntropyLoss1D(_Loss):
"""Vocab parallel cross entropy loss for 1D parallelism.
Args:
reduction (bool, optional): whether to average the loss, defaults to True. Args:
""" reduction (bool, optional): whether to average the loss, defaults to True.
"""
def __init__(self, reduction=True):
super().__init__() def __init__(self, reduction=True):
self.reduction_mean = reduction super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets, process_group=None):
"""Calculate loss between logits and targets. def forward(self, logits, targets, process_group=None):
"""Calculate loss between logits and targets.
Args:
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). Args:
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
""" targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) """
if self.reduction_mean: loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
loss = loss.mean() if self.reduction_mean:
return loss loss = loss.mean()
return loss

View File

@ -1,15 +1,16 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss2D(_Loss): class CrossEntropyLoss2D(_Loss):

View File

@ -1,15 +1,16 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss2p5D(_Loss): class CrossEntropyLoss2p5D(_Loss):

View File

@ -1,15 +1,16 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss3D(_Loss): class CrossEntropyLoss3D(_Loss):

View File

@ -1,80 +1,81 @@
import torch.nn as nn import torch.nn as nn
from colossalai.registry import LOSSES from torch.nn.modules.loss import _Loss
from torch.nn.modules.loss import _Loss
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.registry import LOSSES
@LOSSES.register_module
class MoeCrossEntropyLoss(_Loss): @LOSSES.register_module
r"""torch.nn.CrossEntropyLoss added with auxiliary loss. class MoeCrossEntropyLoss(_Loss):
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
Args:
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). Args:
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
The ``args`` and ``kwargs`` should include parameters below:
:: The ``args`` and ``kwargs`` should include parameters below:
::
weight (Tensor, optional)
size_average (bool, optional) weight (Tensor, optional)
ignore_index (int, optional) size_average (bool, optional)
reduce (bool, optional) ignore_index (int, optional)
reduction (str, optional) reduce (bool, optional)
label_smoothing (float, optional) reduction (str, optional)
label_smoothing (float, optional)
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
""" `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
super().__init__() def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
self.loss = nn.CrossEntropyLoss(*args, **kwargs) super().__init__()
self.aux_weight = aux_weight self.loss = nn.CrossEntropyLoss(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args):
""" def forward(self, *args):
The ``args`` should at least include parameters below: """
:: The ``args`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
""" `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
main_loss = self.loss(*args) """
aux_loss = MOE_CONTEXT.get_loss() main_loss = self.loss(*args)
return main_loss + self.aux_weight * aux_loss aux_loss = MOE_CONTEXT.get_loss()
return main_loss + self.aux_weight * aux_loss
@LOSSES.register_module
class MoeLoss(_Loss): @LOSSES.register_module
"""A wrapper class for any loss module to add with auxiliary loss. class MoeLoss(_Loss):
"""A wrapper class for any loss module to add with auxiliary loss.
Args:
aux_weight (float): Weight of auxiliary loss in total loss. Args:
loss_fn (``Callable``): Loss function. aux_weight (float): Weight of auxiliary loss in total loss.
args (list): Args in loss function. loss_fn (``Callable``): Loss function.
kwargs (dict): Kwargs in loss function args (list): Args in loss function.
""" kwargs (dict): Kwargs in loss function
"""
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
super().__init__() def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
self.loss_fn = loss_fn(*args, **kwargs) super().__init__()
self.aux_weight = aux_weight self.loss_fn = loss_fn(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args, **kwargs):
""" def forward(self, *args, **kwargs):
The ``args`` and ``kwargs`` should at least include parameters below: """
:: The ``args`` and ``kwargs`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
Note:
The ``args`` and ``kwargs`` may include different parameters varying with different loss function. Note:
""" The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
main_loss = self.loss_fn(*args, **kwargs) """
aux_loss = MOE_CONTEXT.get_loss() main_loss = self.loss_fn(*args, **kwargs)
return main_loss + self.aux_weight * aux_loss aux_loss = MOE_CONTEXT.get_loss()
return main_loss + self.aux_weight * aux_loss

View File

@ -1,6 +1,7 @@
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler

View File

@ -1,6 +1,6 @@
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module

View File

@ -2,7 +2,8 @@ from typing import List
from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
from .delayed import WarmupScheduler from .delayed import WarmupScheduler

View File

@ -1,6 +1,6 @@
from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module

View File

@ -1,6 +1,7 @@
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
from .delayed import WarmupScheduler from .delayed import WarmupScheduler

View File

@ -1,9 +1,9 @@
from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
from torch.optim.lr_scheduler import StepLR as _StepLR from torch.optim.lr_scheduler import StepLR as _StepLR
from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module

View File

@ -4,7 +4,7 @@ from typing import Optional
import torch import torch
from colossalai.kernel.op_builder import CPUAdamBuilder from colossalai.kernel.op_builder import CPUAdamBuilder
from colossalai.registry import OPTIMIZERS from colossalai.legacy.registry import OPTIMIZERS
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer

View File

@ -8,7 +8,7 @@ Licensed under the MIT License.
''' '''
import torch import torch
from colossalai.registry import OPTIMIZERS from colossalai.legacy.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier from colossalai.utils import multi_tensor_applier

View File

@ -1,7 +1,7 @@
# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py # modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py
import torch import torch
from colossalai.registry import OPTIMIZERS from colossalai.legacy.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier from colossalai.utils import multi_tensor_applier

View File

@ -2,7 +2,7 @@
import torch import torch
from torch.optim.optimizer import Optimizer, required from torch.optim.optimizer import Optimizer, required
from colossalai.registry import OPTIMIZERS from colossalai.legacy.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier from colossalai.utils import multi_tensor_applier

View File

@ -4,7 +4,7 @@ import torch
from torch.optim import Adam from torch.optim import Adam
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.registry import OPTIMIZERS from colossalai.legacy.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier from colossalai.utils import multi_tensor_applier
from .cpu_adam import CPUAdam from .cpu_adam import CPUAdam

View File

@ -5,7 +5,7 @@ Adapted from the pytorch-lamb library at https://github.com/cybertronai/pytorch-
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.registry import OPTIMIZERS from colossalai.legacy.registry import OPTIMIZERS
@OPTIMIZERS.register_module @OPTIMIZERS.register_module

View File

@ -5,7 +5,7 @@ from typing import Iterable
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.registry import OPTIMIZERS from colossalai.legacy.registry import OPTIMIZERS
@OPTIMIZERS.register_module @OPTIMIZERS.register_module
@ -22,28 +22,24 @@ class Lars(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
""" """
def __init__( def __init__(self,
self, params: Iterable[torch.nn.Parameter],
params: Iterable[torch.nn.Parameter], lr=1e-3,
lr=1e-3, momentum=0,
momentum=0, eeta=1e-3,
eeta=1e-3, weight_decay=0,
weight_decay=0, epsilon=0.0) -> None:
epsilon=0.0
) -> None:
if not isinstance(lr, float) or lr < 0.0: if not isinstance(lr, float) or lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0: if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum)) raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError( raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
"Invalid weight_decay value: {}".format(weight_decay))
if eeta <= 0 or eeta > 1: if eeta <= 0 or eeta > 1:
raise ValueError("Invalid eeta value: {}".format(eeta)) raise ValueError("Invalid eeta value: {}".format(eeta))
if epsilon < 0: if epsilon < 0:
raise ValueError("Invalid epsilon value: {}".format(epsilon)) raise ValueError("Invalid epsilon value: {}".format(epsilon))
defaults = dict(lr=lr, momentum=momentum, defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
super().__init__(params, defaults) super().__init__(params, defaults)
@ -76,11 +72,9 @@ class Lars(Optimizer):
if lars: if lars:
w_norm = torch.norm(p) w_norm = torch.norm(p)
g_norm = torch.norm(p.grad) g_norm = torch.norm(p.grad)
trust_ratio = torch.where( trust_ratio = torch.where(w_norm > 0 and g_norm > 0,
w_norm > 0 and g_norm > 0, eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
eeta * w_norm / (g_norm + weight_decay * w_norm + eps), torch.ones_like(w_norm))
torch.ones_like(w_norm)
)
trust_ratio.clamp_(0.0, 50) trust_ratio.clamp_(0.0, 50)
scaled_lr *= trust_ratio.item() scaled_lr *= trust_ratio.item()
if weight_decay != 0: if weight_decay != 0:
@ -90,8 +84,7 @@ class Lars(Optimizer):
if momentum != 0: if momentum != 0:
param_state = self.state[p] param_state = self.state[p]
if 'momentum_buffer' not in param_state: if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone( buf = param_state['momentum_buffer'] = torch.clone(decayed_grad).detach()
decayed_grad).detach()
else: else:
buf = param_state['momentum_buffer'] buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(decayed_grad) buf.mul_(momentum).add_(decayed_grad)

View File

@ -4,15 +4,15 @@
import math import math
import random import random
import numpy as np from typing import Iterator, TypeVar
from typing import TypeVar, Iterator
import numpy as np
import torch import torch
from torch.utils.data import Sampler, Dataset, DataLoader from torch.utils.data import DataLoader, Dataset, Sampler
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import DATA_SAMPLERS from colossalai.legacy.registry import DATA_SAMPLERS
T_co = TypeVar('T_co', covariant=True) T_co = TypeVar('T_co', covariant=True)
@ -30,11 +30,7 @@ class DataParallelSampler(Sampler):
the batch size, then the last batch will be smaller, defaults to False. the batch size, then the last batch will be smaller, defaults to False.
""" """
def __init__(self, def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None:
dataset: Dataset,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False) -> None:
self.dataset = dataset self.dataset = dataset
self.num_replicas = gpc.get_world_size(ParallelMode.DATA) self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
self.rank = gpc.get_local_rank(ParallelMode.DATA) self.rank = gpc.get_local_rank(ParallelMode.DATA)
@ -54,8 +50,7 @@ class DataParallelSampler(Sampler):
self.num_replicas # type: ignore[arg-type] self.num_replicas # type: ignore[arg-type]
) )
else: else:
self.num_samples = math.ceil( self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle self.shuffle = shuffle
self.seed = seed self.seed = seed
@ -72,7 +67,7 @@ class DataParallelSampler(Sampler):
# set_epoch manually # set_epoch manually
self.epoch += 1 self.epoch += 1
else: else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type] indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last: if not self.drop_last:
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
@ -80,8 +75,7 @@ class DataParallelSampler(Sampler):
if padding_size <= len(indices): if padding_size <= len(indices):
indices += indices[:padding_size] indices += indices[:padding_size]
else: else:
indices += (indices * math.ceil(padding_size / indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
len(indices)))[:padding_size]
else: else:
# remove tail of data to make it evenly divisible. # remove tail of data to make it evenly divisible.
indices = indices[:self.total_size] indices = indices[:self.total_size]
@ -109,8 +103,8 @@ class DataParallelSampler(Sampler):
def get_dataloader(dataset, def get_dataloader(dataset,
shuffle=False, shuffle=False,
seed=1024, seed=1024,
add_sampler=True, add_sampler=True,
drop_last=False, drop_last=False,
pin_memory=False, pin_memory=False,
num_workers=0, num_workers=0,

View File

@ -1,17 +1,17 @@
import os import gzip
from typing import List
from colossalai.engine import Engine
from torch.profiler import profile as torch_profile
from torch.profiler.profiler import ProfilerAction
from typing import Any, Callable, Iterable, Optional
from torch.autograd import ProfilerActivity
import json import json
import os import os
import tempfile import tempfile
import gzip from typing import Any, Callable, Iterable, List, Optional
from torch.autograd import ProfilerActivity
from torch.profiler import profile as torch_profile
from torch.profiler.profiler import ProfilerAction
from colossalai.legacy.engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils.profiler.extention import ProfilerExtension from colossalai.utils.profiler.extention import ProfilerExtension
from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
from colossalai.logging import get_dist_logger
class profile(torch_profile): class profile(torch_profile):

View File

@ -1,12 +1,14 @@
import os import os
import threading import threading
import time import time
import torch
from enum import Enum from enum import Enum
from typing import List from typing import List
from colossalai.gemini.stateful_tensor import StatefulTensor
import torch
from colossalai.gemini.ophooks import BaseOpHook from colossalai.gemini.ophooks import BaseOpHook
from colossalai.engine import Engine from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.legacy.engine import Engine
from colossalai.utils.profiler.extention import ProfilerExtension from colossalai.utils.profiler.extention import ProfilerExtension

View File

@ -1,6 +1,6 @@
import torch import torch
from colossalai.registry import OPHOOKS from colossalai.legacy.registry import OPHOOKS
from . import BaseOpHook from . import BaseOpHook

View File

@ -1,6 +1,6 @@
import torch import torch
from colossalai.registry import OPHOOKS from colossalai.legacy.registry import OPHOOKS
from . import BaseOpHook from . import BaseOpHook

View File

@ -3,8 +3,8 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.legacy.registry import OPHOOKS
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook from colossalai.zero.legacy.gemini.ophooks import BaseOpHook

View File

@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
@ -617,3 +618,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block_size += current_block_size ret_block_size += current_block_size
yield ret_block, ret_block_size yield ret_block, ret_block_size
def update_master_params(self, model: nn.Module) -> None:
"""Update master params from working params
Args:
model (nn.Module): The model to update master params
"""
for p in model.parameters():
p_id = id(p)
if p_id in self._param_store.working_to_master_param:
master_param = self._param_store.working_to_master_param[p_id]
padding_size = self._param_store.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])

View File

@ -92,14 +92,14 @@ follow the steps below to create a new distributed initialization.
Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce
strategies may be executed for different kinds of parallelism, users can strategies may be executed for different kinds of parallelism, users can
inherit `colossalai.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library inherit `colossalai.legacy.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library
uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data
parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own
gradient handler like below: gradient handler like below:
```python ```python
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.engine import BaseGradientHandler from colossalai.legacy.engine import BaseGradientHandler
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module
class YourGradientHandler(BaseGradientHandler): class YourGradientHandler(BaseGradientHandler):
@ -121,4 +121,5 @@ gradient_handlers = [
Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
schedules. If you want to modify how the forward and backward passes are executed, you can schedules. If you want to modify how the forward and backward passes are executed, you can
inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function. inherit `colossalai.legacy.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
<!-- doc-test-command: echo -->

View File

@ -36,14 +36,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from colossalai import nn as col_nn from colossalai import nn as col_nn
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE
from colossalai.builder.pipeline import partition_uniform from colossalai.legacy.builder.pipeline import partition_uniform
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.schedule import (InterleavedPipelineSchedule, from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule) PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils.timer import MultiTimer from colossalai.utils.timer import MultiTimer
from model_zoo.gpt import GPTLMLoss from model_zoo.gpt import GPTLMLoss
from torch.nn import functional as F from torch.nn import functional as F
@ -268,3 +268,4 @@ def train():
return_output_label=False, return_output_label=False,
) )
``` ```
<!-- doc-test-command: echo -->

View File

@ -34,11 +34,11 @@ import colossalai
import colossalai.nn as col_nn import colossalai.nn as col_nn
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.builder import build_pipeline_model from colossalai.legacy.builder import build_pipeline_model
from colossalai.engine.schedule import (InterleavedPipelineSchedule, from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule) PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, get_dataloader
from timm.models import vision_transformer as vit from timm.models import vision_transformer as vit
from torchvision import transforms from torchvision import transforms
@ -51,17 +51,17 @@ from torchvision.datasets import CIFAR10
Generally, we provide 3 ways to build a pipelined model: Generally, we provide 3 ways to build a pipelined model:
1. `colossalai.builder.build_pipeline_model_from_cfg` 1. `colossalai.legacy.builder.build_pipeline_model_from_cfg`
2. `colossalai.builder.build_pipeline_model` 2. `colossalai.legacy.builder.build_pipeline_model`
3. Split the model by stages by yourself 3. Split the model by stages by yourself
When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU. When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU.
`colossalai.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size). `colossalai.legacy.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size).
If you are familiar with `PyTorch`, you can use `colossalai.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly. If you are familiar with `PyTorch`, you can use `colossalai.legacy.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly.
In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.builder.build_pipeline_model()` to build the pipelined model. In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.legacy.builder.build_pipeline_model()` to build the pipelined model.
When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`. When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`.
@ -245,3 +245,4 @@ def train():
hooks=hook_list, hooks=hook_list,
display_progress=True) display_progress=True)
``` ```
<!-- doc-test-command: echo -->

View File

@ -79,7 +79,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.nn.metric import Accuracy from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
``` ```
- Other modules - Other modules
@ -273,8 +273,8 @@ SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
### Build pipeline model (`/hybrid_parallel/model/vit.py`) ### Build pipeline model (`/hybrid_parallel/model/vit.py`)
Colossal-AI provides two methods to build a pipeline model from the existing model. Colossal-AI provides two methods to build a pipeline model from the existing model.
- `colossalai.builder.build_pipeline_model_from_cfg` - `colossalai.legacy.builder.build_pipeline_model_from_cfg`
- `colossalai.builder.build_pipeline_model` - `colossalai.legacy.builder.build_pipeline_model`
Besides, you can also build a pipeline model from scratch with Colossal-AI. Besides, you can also build a pipeline model from scratch with Colossal-AI.
```python ```python
@ -284,11 +284,11 @@ from typing import Callable
import inspect import inspect
import torch import torch
from colossalai import nn as col_nn from colossalai import nn as col_nn
from colossalai.registry import LAYERS, MODELS from colossalai.legacy.registry import LAYERS, MODELS
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.builder.pipeline import partition_uniform from colossalai.legacy.builder.pipeline import partition_uniform
from torch import dtype, nn from torch import dtype, nn
from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
@ -415,7 +415,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw
#### Import modules #### Import modules
```python ```python
from colossalai.engine.schedule import (InterleavedPipelineSchedule, from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule) PipelineSchedule)
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
import os import os
@ -644,3 +644,4 @@ torchrun --standalone --nproc_per_node <NUM_GPUs> train_hybrid.py --config ./co
# If your torch >= 1.9.0 # If your torch >= 1.9.0
# python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py # python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
``` ```
<!-- doc-test-command: echo -->

View File

@ -64,7 +64,7 @@ Trainer is a more high-level wrapper for the user to execute training with fewer
```python ```python
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
# build components and initialize with colossalai.initialize # build components and initialize with colossalai.initialize
... ...
@ -107,7 +107,7 @@ If you want to customize your own hook class, you can inherit `hooks.BaseHook` a
```python ```python
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.trainer import hooks from colossalai.legacy.trainer import hooks
class LogMessageHook(hooks.BaseHook): class LogMessageHook(hooks.BaseHook):
@ -345,7 +345,7 @@ If you wish to train with a trainer object, you can follow the code snippet belo
```python ```python
from colossalai.nn.metric import Accuracy from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
# create a trainer object # create a trainer object
@ -387,3 +387,4 @@ python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr loc
# with trainer # with trainer
python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
``` ```
<!-- doc-test-command: echo -->

View File

@ -41,7 +41,7 @@ for epoch in range(num_epochs):
#### Save when using trainer #### Save when using trainer
```python ```python
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
model = ... model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...) engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...) trainer = Trainer(engine, ...)
@ -61,3 +61,4 @@ model = ...
load_checkpoint('xxx.pt', model) load_checkpoint('xxx.pt', model)
... # train or test ... # train or test
``` ```
<!-- doc-test-command: echo -->

View File

@ -28,8 +28,8 @@ To implement a customized gradient handler, you need to follow these steps.
3. implement `handle_gradient` method. 3. implement `handle_gradient` method.
```python ```python
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module
@ -61,3 +61,4 @@ to demonstrate the use of gradient handler. In this example, we used `DataParall
```shell ```shell
python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
``` ```
<!-- doc-test-command: echo -->

View File

@ -267,7 +267,7 @@ from pathlib import Path
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_dataloader from colossalai.utils import get_dataloader
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.lr_scheduler import LinearWarmupLR
from timm.models import vit_base_patch16_224 from timm.models import vit_base_patch16_224
from torchvision import datasets, transforms from torchvision import datasets, transforms

View File

@ -79,7 +79,7 @@ import colossalai.nn as col_nn
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
@ -157,3 +157,4 @@ trainer.fit(train_dataloader=train_dataloader,
``` ```
We use `2` pipeline stages and the batch will be split into `4` micro batches. We use `2` pipeline stages and the batch will be split into `4` micro batches.
<!-- doc-test-command: echo -->

View File

@ -81,14 +81,14 @@ Colossal-AI 为用户提供了一个全局 context使他们能够轻松地管
## 梯度 Handler ## 梯度 Handler
梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承 梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承
`colossalai.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。 `colossalai.legacy.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。
如果数据并行被检测到,梯度 handler 会被自动添加进 engine。 如果数据并行被检测到,梯度 handler 会被自动添加进 engine。
你可以添加你自己的梯度 handler如下所示 你可以添加你自己的梯度 handler如下所示
```python ```python
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.engine import BaseGradientHandler from colossalai.legacy.engine import BaseGradientHandler
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module
class YourGradientHandler(BaseGradientHandler): class YourGradientHandler(BaseGradientHandler):
@ -109,4 +109,5 @@ gradient_handlers = [
## Schedule ## Schedule
Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。 Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。
如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。 如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.legacy.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。
<!-- doc-test-command: echo -->

View File

@ -36,14 +36,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from colossalai import nn as col_nn from colossalai import nn as col_nn
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE
from colossalai.builder.pipeline import partition_uniform from colossalai.legacy.builder.pipeline import partition_uniform
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.schedule import (InterleavedPipelineSchedule, from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule) PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils.timer import MultiTimer from colossalai.utils.timer import MultiTimer
from model_zoo.gpt import GPTLMLoss from model_zoo.gpt import GPTLMLoss
from torch.nn import functional as F from torch.nn import functional as F
@ -273,3 +273,4 @@ def train():
return_output_label=False, return_output_label=False,
) )
``` ```
<!-- doc-test-command: echo -->

View File

@ -32,11 +32,11 @@ import colossalai
import colossalai.nn as col_nn import colossalai.nn as col_nn
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.builder import build_pipeline_model from colossalai.legacy.builder import build_pipeline_model
from colossalai.engine.schedule import (InterleavedPipelineSchedule, from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule) PipelineSchedule)
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, get_dataloader
from timm.models import vision_transformer as vit from timm.models import vision_transformer as vit
from torchvision import transforms from torchvision import transforms
@ -48,17 +48,17 @@ from torchvision.datasets import CIFAR10
总的来说, 我们提供3种方法来建立一个流水并行的模型: 总的来说, 我们提供3种方法来建立一个流水并行的模型:
1. `colossalai.builder.build_pipeline_model_from_cfg` 1. `colossalai.legacy.builder.build_pipeline_model_from_cfg`
2. `colossalai.builder.build_pipeline_model` 2. `colossalai.legacy.builder.build_pipeline_model`
3. 自己按阶段拆分模型 3. 自己按阶段拆分模型
当你的内存能够容纳模型时,你可以使用前两种方法来建立你的模型,否则你必须自己分割模型。前两种方法首先在 CPU 上建立整个模型,然后分割模型,最后你可以直接把模型的相应部分移到 GPU 上。 当你的内存能够容纳模型时,你可以使用前两种方法来建立你的模型,否则你必须自己分割模型。前两种方法首先在 CPU 上建立整个模型,然后分割模型,最后你可以直接把模型的相应部分移到 GPU 上。
`colossalai.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。 `colossalai.legacy.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。
如果你熟悉 `PyTorch`, 你可以使用 `colossalai.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。 如果你熟悉 `PyTorch`, 你可以使用 `colossalai.legacy.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。
在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.builder.build_pipeline_model()` 来建立流水线模型。 在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.legacy.builder.build_pipeline_model()` 来建立流水线模型。
当数据是 **一个** `Tensor`, 你可以使用你的模型 `forward()` 中的位置参数来获得数据张量。对于流水线的第一阶段,`forward()` 的第一个位置参数是从数据加载器加载的数据张量。对于其他阶段,`forward()` 的第一个位置参数是上一阶段的输出张量。注意,如果该阶段不是最后一个阶段,则 `forward()` 的返回必须是一个 `Tensor` 当数据是 **一个** `Tensor`, 你可以使用你的模型 `forward()` 中的位置参数来获得数据张量。对于流水线的第一阶段,`forward()` 的第一个位置参数是从数据加载器加载的数据张量。对于其他阶段,`forward()` 的第一个位置参数是上一阶段的输出张量。注意,如果该阶段不是最后一个阶段,则 `forward()` 的返回必须是一个 `Tensor`
@ -244,3 +244,4 @@ def train():
hooks=hook_list, hooks=hook_list,
display_progress=True) display_progress=True)
``` ```
<!-- doc-test-command: echo -->

View File

@ -74,7 +74,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.nn.metric import Accuracy from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
``` ```
- 其他模块 - 其他模块
@ -256,8 +256,8 @@ SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
### 构建流水线模型 (`/hybrid_parallel/model/vit.py`) ### 构建流水线模型 (`/hybrid_parallel/model/vit.py`)
Colossal-AI 提供了两种从现有模型构建流水线模型的方法。 Colossal-AI 提供了两种从现有模型构建流水线模型的方法。
- `colossalai.builder.build_pipeline_model_from_cfg` - `colossalai.legacy.builder.build_pipeline_model_from_cfg`
- `colossalai.builder.build_pipeline_model` - `colossalai.legacy.builder.build_pipeline_model`
此外,您还可以使用 Colossal-AI 从头开始构建流水线模型。 此外,您还可以使用 Colossal-AI 从头开始构建流水线模型。
```python ```python
@ -266,11 +266,11 @@ from typing import Callable
import inspect import inspect
import torch import torch
from colossalai import nn as col_nn from colossalai import nn as col_nn
from colossalai.registry import LAYERS, MODELS from colossalai.legacy.registry import LAYERS, MODELS
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.builder.pipeline import partition_uniform from colossalai.legacy.builder.pipeline import partition_uniform
from torch import dtype, nn from torch import dtype, nn
from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
@MODELS.register_module @MODELS.register_module
@ -380,7 +380,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw
#### 导入模块 #### 导入模块
```python ```python
from colossalai.engine.schedule import (InterleavedPipelineSchedule, from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
PipelineSchedule) PipelineSchedule)
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
import os import os
@ -589,3 +589,4 @@ torchrun --standalone --nproc_per_node <NUM_GPUs> train_hybrid.py --config ./co
# If your torch >= 1.9.0 # If your torch >= 1.9.0
# python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py # python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
``` ```
<!-- doc-test-command: echo -->

View File

@ -61,7 +61,7 @@ Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除
```python ```python
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
# build components and initialize with colossalai.initialize # build components and initialize with colossalai.initialize
... ...
@ -104,7 +104,7 @@ trainer.fit(
```python ```python
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.trainer import hooks from colossalai.legacy.trainer import hooks
class LogMessageHook(hooks.BaseHook): class LogMessageHook(hooks.BaseHook):
@ -341,7 +341,7 @@ for epoch in range(gpc.config.NUM_EPOCHS):
```python ```python
from colossalai.nn.metric import Accuracy from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
# create a trainer object # create a trainer object
@ -384,3 +384,4 @@ python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr loc
# with trainer # with trainer
python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
``` ```
<!-- doc-test-command: echo -->

View File

@ -41,7 +41,7 @@ for epoch in range(num_epochs):
#### 用 trainer 保存 #### 用 trainer 保存
```python ```python
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
model = ... model = ...
engine, _, _, _ = colossalai.initialize(model=model, ...) engine, _, _, _ = colossalai.initialize(model=model, ...)
trainer = Trainer(engine, ...) trainer = Trainer(engine, ...)
@ -61,3 +61,4 @@ model = ...
load_checkpoint('xxx.pt', model) load_checkpoint('xxx.pt', model)
... # train or test ... # train or test
``` ```
<!-- doc-test-command: echo -->

View File

@ -25,8 +25,8 @@
3. 实现 `handle_gradient` 3. 实现 `handle_gradient`
```python ```python
from colossalai.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module
@ -57,3 +57,4 @@ gradient_handler = [dict(type='MyGradientHandler')]
```shell ```shell
python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
``` ```
<!-- doc-test-command: echo -->

View File

@ -245,7 +245,7 @@ from pathlib import Path
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_dataloader from colossalai.utils import get_dataloader
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.lr_scheduler import LinearWarmupLR
from timm.models import vit_base_patch16_224 from timm.models import vit_base_patch16_224
from torchvision import datasets, transforms from torchvision import datasets, transforms

View File

@ -78,7 +78,7 @@ import colossalai.nn as col_nn
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
@ -156,3 +156,4 @@ trainer.fit(train_dataloader=train_dataloader,
``` ```
我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。 我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。
<!-- doc-test-command: echo -->

View File

@ -6,7 +6,7 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
from colossalai.registry import DATASETS from colossalai.legacy.registry import DATASETS
@DATASETS.register_module @DATASETS.register_module

View File

@ -8,11 +8,11 @@ from torch.nn.parameter import Parameter
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LAYERS, LOSSES, MODELS
from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.registry import LAYERS, LOSSES, MODELS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device

View File

@ -10,9 +10,9 @@ import colossalai
import colossalai.utils as utils import colossalai.utils as utils
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR from colossalai.nn import LinearWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import colo_set_process_memory_fraction, is_using_pp from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
from colossalai.utils.timer import MultiTimer from colossalai.utils.timer import MultiTimer
from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.init_ctx import ZeroInitContext

View File

@ -3,17 +3,16 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# copied from fairseq/fairseq/data/indexed_dataset.py # copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary # Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies # other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible. # Added document index to index file and made it accessible.
# An empty sentence no longer separates documents. # An empty sentence no longer separates documents.
from functools import lru_cache
import os import os
import shutil import shutil
import struct import struct
from functools import lru_cache
from itertools import accumulate from itertools import accumulate
import numpy as np import numpy as np
@ -88,16 +87,7 @@ def write_longs(f, a):
f.write(np.array(a, dtype=np.int64)) f.write(np.array(a, dtype=np.int64))
dtypes = { dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: float, 7: np.double, 8: np.uint16}
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
8: np.uint16
}
def code(dtype): def code(dtype):
@ -136,10 +126,8 @@ class IndexedDataset(torch.utils.data.Dataset):
def read_index(self, path): def read_index(self, path):
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
magic = f.read(8) magic = f.read(8)
assert magic == self._HDR_MAGIC, ( assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.')
'Make sure that --dataset-impl is configured properly.'
)
version = f.read(8) version = f.read(8)
assert struct.unpack('<Q', version) == (1,) assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16)) code, self.element_size = struct.unpack('<QQ', f.read(16))
@ -198,13 +186,11 @@ class IndexedDataset(torch.utils.data.Dataset):
@staticmethod @staticmethod
def exists(path): def exists(path):
return ( return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return False # avoid prefetching to save memory return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset): class IndexedCachedDataset(IndexedDataset):
@ -233,7 +219,7 @@ class IndexedCachedDataset(IndexedDataset):
for i in indices: for i in indices:
self.cache_index[i] = ptx self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i] size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx: ptx + size] a = self.cache[ptx:ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
ptx += size ptx += size
@ -250,7 +236,7 @@ class IndexedCachedDataset(IndexedDataset):
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i] ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx: ptx + a.size]) np.copyto(a, self.cache[ptx:ptx + a.size])
return a return a
elif isinstance(idx, slice): elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary # Hack just to make this work, can optimizer later if necessary
@ -261,15 +247,7 @@ class IndexedCachedDataset(IndexedDataset):
class IndexedDatasetBuilder(object): class IndexedDatasetBuilder(object):
element_sizes = { element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8}
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8
}
def __init__(self, out_file, dtype=np.int32): def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb') self.out_file = open(out_file, 'wb')
@ -332,12 +310,15 @@ def _warmup_mmap_file(path):
class MMapIndexedDataset(torch.utils.data.Dataset): class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object): class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00' _HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod @classmethod
def writer(cls, path, dtype): def writer(cls, path, dtype):
class _Writer(object): class _Writer(object):
def __enter__(self): def __enter__(self):
self._file = open(path, 'wb') self._file = open(path, 'wb')
@ -384,10 +365,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __init__(self, path, skip_warmup=False): def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream: with open(path, 'rb') as stream:
magic_test = stream.read(9) magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, ( assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.')
'Make sure that --dataset-impl is configured properly.'
)
version = struct.unpack('<Q', stream.read(8)) version = struct.unpack('<Q', stream.read(8))
assert (1,) == version assert (1,) == version
@ -406,16 +385,16 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print(" reading sizes...") print(" reading sizes...")
self._sizes = np.frombuffer( self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
print(" reading pointers...") print(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, self._pointers = np.frombuffer(self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes) offset=offset + self._sizes.nbytes)
print(" reading document index...") print(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, self._doc_idx = np.frombuffer(self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes) offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self): def __del__(self):
@ -480,8 +459,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, int): if isinstance(idx, int):
ptr, size = self._index[idx] ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
count=size, offset=ptr)
return np_array return np_array
elif isinstance(idx, slice): elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self)) start, stop, step = idx.indices(len(self))
@ -491,8 +469,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
sizes = self._index._sizes[idx] sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes)) offsets = list(accumulate(sizes))
total_size = sum(sizes) total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1]) sents = np.split(np_array, offsets[:-1])
return sents return sents
@ -506,8 +483,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
if length is None: if length is None:
length = size - offset length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
count=length, offset=ptr)
return np_array return np_array
@property @property
@ -530,12 +506,11 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
@staticmethod @staticmethod
def exists(path): def exists(path):
return ( return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
class MMapIndexedDatasetBuilder(object): class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64): def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb') self._data_file = open(out_file, 'wb')
self._dtype = dtype self._dtype = dtype

View File

@ -1,2 +1,3 @@
colossalai colossalai
torch torch
six

Some files were not shown because too many files have changed in this diff Show More