mirror of https://github.com/hpcaitech/ColossalAI
[legacy] move engine to legacy (#4560)
* [legacy] move engine to legacy * [example] fix seq parallel example * [example] fix seq parallel example * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [example] update seq parallel requirementspull/4612/head^2
parent
89fe027787
commit
8accecd55b
|
@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer):
|
||||||
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
|
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An object of :class:`colossalai.engine.BaseGradientHandler`
|
An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
|
||||||
"""
|
"""
|
||||||
config_ = config.copy()
|
config_ = config.copy()
|
||||||
config_['model'] = model
|
config_['model'] = model
|
||||||
|
|
|
@ -21,9 +21,9 @@ from colossalai.builder.builder import build_gradient_handler
|
||||||
from colossalai.context import Config, ConfigException, ParallelMode
|
from colossalai.context import Config, ConfigException, ParallelMode
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine import Engine
|
from colossalai.legacy.engine import Engine
|
||||||
from colossalai.engine.gradient_accumulation import accumulate_gradient
|
from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
|
||||||
from colossalai.engine.schedule import (
|
from colossalai.legacy.engine.schedule import (
|
||||||
InterleavedPipelineSchedule,
|
InterleavedPipelineSchedule,
|
||||||
NonPipelineSchedule,
|
NonPipelineSchedule,
|
||||||
PipelineSchedule,
|
PipelineSchedule,
|
||||||
|
|
|
@ -8,11 +8,17 @@ from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
|
||||||
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
|
from colossalai.legacy.engine.schedule import (
|
||||||
|
BaseSchedule,
|
||||||
|
InterleavedPipelineSchedule,
|
||||||
|
NonPipelineSchedule,
|
||||||
|
PipelineSchedule,
|
||||||
|
)
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
|
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
|
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
"""Basic engine class for training and evaluation. It runs a specific process method
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
|
||||||
from colossalai.engine import BaseGradientHandler
|
from colossalai.legacy.engine import BaseGradientHandler
|
||||||
|
|
||||||
from ._gradient_accumulation import (
|
from ._gradient_accumulation import (
|
||||||
GradAccumDataloader,
|
GradAccumDataloader,
|
||||||
|
@ -33,7 +33,7 @@ def accumulate_gradient(model: nn.Module,
|
||||||
dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
|
dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
|
||||||
your dataloader object, would be called like iter(dataloader)
|
your dataloader object, would be called like iter(dataloader)
|
||||||
accumulate_size (int): the number of steps to accumulate gradients
|
accumulate_size (int): the number of steps to accumulate gradients
|
||||||
gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]):
|
gradient_handlers (List[:class:`colossalai.legacy.engine.BaseGradientHandler`]):
|
||||||
list of gradient handler objects. Default is None.
|
list of gradient handler objects. Default is None.
|
||||||
lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
|
lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
|
||||||
your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
|
your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
|
|
@ -10,7 +10,7 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.engine import BaseGradientHandler
|
from colossalai.legacy.engine import BaseGradientHandler
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.utils import conditional_context
|
from colossalai.utils import conditional_context
|
||||||
|
|
||||||
|
@ -262,7 +262,7 @@ class GradAccumGradientHandler:
|
||||||
before accumulation size is reached.
|
before accumulation size is reached.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grad_handler (:class:`colossalai.engine.BaseGradientHandler`):
|
grad_handler (:class:`colossalai.legacy.engine.BaseGradientHandler`):
|
||||||
Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`.
|
Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`.
|
||||||
accumulate_size (int): The number of steps to accumulate gradients.
|
accumulate_size (int): The number of steps to accumulate gradients.
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
|
|
||||||
from ...context.parallel_mode import ParallelMode
|
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
from .utils import bucket_allreduce
|
from .utils import bucket_allreduce
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from colossalai.utils.moe import get_moe_epsize_param_dict
|
from colossalai.utils.moe import get_moe_epsize_param_dict
|
||||||
|
|
||||||
from ...context.parallel_mode import ParallelMode
|
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
from .utils import bucket_allreduce
|
from .utils import bucket_allreduce
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
|
|
||||||
from ...context.parallel_mode import ParallelMode
|
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
from .utils import bucket_allreduce
|
from .utils import bucket_allreduce
|
||||||
|
|
|
@ -95,7 +95,7 @@ class BaseSchedule(ABC):
|
||||||
"""The process function over a batch of dataset for training or evaluation.
|
"""The process function over a batch of dataset for training or evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
||||||
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||||
forward_only (bool): If True, the process won't include backward.
|
forward_only (bool): If True, the process won't include backward.
|
||||||
return_loss (bool, optional): If False, the loss won't be returned.
|
return_loss (bool, optional): If False, the loss won't be returned.
|
|
@ -54,7 +54,7 @@ class NonPipelineSchedule(BaseSchedule):
|
||||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
||||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||||
forward_only (bool, optional):
|
forward_only (bool, optional):
|
||||||
If True, the model is run for the forward pass, else back propagation will be executed.
|
If True, the model is run for the forward pass, else back propagation will be executed.
|
|
@ -236,7 +236,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
||||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
||||||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||||
return_output_label (bool, optional): Whether returns output labels.
|
return_output_label (bool, optional): Whether returns output labels.
|
||||||
|
@ -274,7 +274,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
This is a helper function and can be ignored by users.
|
This is a helper function and can be ignored by users.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
||||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
|
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
|
||||||
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
|
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
|
||||||
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
|
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
|
||||||
|
@ -314,7 +314,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
||||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||||
forward_only (bool, optional):
|
forward_only (bool, optional):
|
||||||
Whether run forward step only. Default is false. If true, no backward will be run.
|
Whether run forward step only. Default is false. If true, no backward will be run.
|
||||||
|
@ -518,7 +518,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
||||||
model_chunk_id (int): The id of model chunks.
|
model_chunk_id (int): The id of model chunks.
|
||||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
||||||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||||
|
@ -555,7 +555,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
communication between pipeline stages as needed.
|
communication between pipeline stages as needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
||||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||||
forward_only (bool, optional):
|
forward_only (bool, optional):
|
||||||
Whether run forward step only. Default is false. If true, no backward will be run.
|
Whether run forward step only. Default is false. If true, no backward will be run.
|
|
@ -69,7 +69,7 @@ class PipelineScheduleV2(PipelineSchedule):
|
||||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
||||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||||
forward_only (bool, optional):
|
forward_only (bool, optional):
|
||||||
Whether run forward step only. Default is false. If true, no backward will be run.
|
Whether run forward step only. Default is false. If true, no backward will be run.
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai.engine import Engine
|
from colossalai.legacy.engine import Engine
|
||||||
from colossalai.legacy.trainer.hooks import BaseHook
|
from colossalai.legacy.trainer.hooks import BaseHook
|
||||||
from colossalai.logging import DistributedLogger
|
from colossalai.logging import DistributedLogger
|
||||||
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
|
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
import os
|
import gzip
|
||||||
from typing import List
|
|
||||||
from colossalai.engine import Engine
|
|
||||||
from torch.profiler import profile as torch_profile
|
|
||||||
from torch.profiler.profiler import ProfilerAction
|
|
||||||
from typing import Any, Callable, Iterable, Optional
|
|
||||||
from torch.autograd import ProfilerActivity
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import gzip
|
from typing import Any, Callable, Iterable, List, Optional
|
||||||
|
|
||||||
|
from torch.autograd import ProfilerActivity
|
||||||
|
from torch.profiler import profile as torch_profile
|
||||||
|
from torch.profiler.profiler import ProfilerAction
|
||||||
|
|
||||||
|
from colossalai.legacy.engine import Engine
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils.profiler.extention import ProfilerExtension
|
from colossalai.utils.profiler.extention import ProfilerExtension
|
||||||
from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
|
from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
|
|
||||||
class profile(torch_profile):
|
class profile(torch_profile):
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import torch
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.gemini.ophooks import BaseOpHook
|
from colossalai.gemini.ophooks import BaseOpHook
|
||||||
from colossalai.engine import Engine
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
|
from colossalai.legacy.engine import Engine
|
||||||
from colossalai.utils.profiler.extention import ProfilerExtension
|
from colossalai.utils.profiler.extention import ProfilerExtension
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -92,14 +92,14 @@ follow the steps below to create a new distributed initialization.
|
||||||
|
|
||||||
Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce
|
Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce
|
||||||
strategies may be executed for different kinds of parallelism, users can
|
strategies may be executed for different kinds of parallelism, users can
|
||||||
inherit `colossalai.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library
|
inherit `colossalai.legacy.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library
|
||||||
uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data
|
uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data
|
||||||
parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own
|
parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own
|
||||||
gradient handler like below:
|
gradient handler like below:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from colossalai.engine import BaseGradientHandler
|
from colossalai.legacy.engine import BaseGradientHandler
|
||||||
|
|
||||||
@GRADIENT_HANDLER.register_module
|
@GRADIENT_HANDLER.register_module
|
||||||
class YourGradientHandler(BaseGradientHandler):
|
class YourGradientHandler(BaseGradientHandler):
|
||||||
|
@ -121,4 +121,5 @@ gradient_handlers = [
|
||||||
|
|
||||||
Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
|
Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
|
||||||
schedules. If you want to modify how the forward and backward passes are executed, you can
|
schedules. If you want to modify how the forward and backward passes are executed, you can
|
||||||
inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
|
inherit `colossalai.legacy.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
|
||||||
|
<!-- doc-test-command: echo -->
|
||||||
|
|
|
@ -39,7 +39,7 @@ from colossalai.amp import AMP_TYPE
|
||||||
from colossalai.builder.pipeline import partition_uniform
|
from colossalai.builder.pipeline import partition_uniform
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
|
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||||
PipelineSchedule)
|
PipelineSchedule)
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||||
|
|
|
@ -35,7 +35,7 @@ import colossalai.nn as col_nn
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.builder import build_pipeline_model
|
from colossalai.builder import build_pipeline_model
|
||||||
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
|
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||||
PipelineSchedule)
|
PipelineSchedule)
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.legacy.trainer import Trainer, hooks
|
from colossalai.legacy.trainer import Trainer, hooks
|
||||||
|
|
|
@ -415,7 +415,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw
|
||||||
|
|
||||||
#### Import modules
|
#### Import modules
|
||||||
```python
|
```python
|
||||||
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
|
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||||
PipelineSchedule)
|
PipelineSchedule)
|
||||||
from colossalai.utils import MultiTimer
|
from colossalai.utils import MultiTimer
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -29,7 +29,7 @@ To implement a customized gradient handler, you need to follow these steps.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
|
||||||
|
|
||||||
|
|
||||||
@GRADIENT_HANDLER.register_module
|
@GRADIENT_HANDLER.register_module
|
||||||
|
@ -61,3 +61,4 @@ to demonstrate the use of gradient handler. In this example, we used `DataParall
|
||||||
```shell
|
```shell
|
||||||
python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
|
python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
|
||||||
```
|
```
|
||||||
|
<!-- doc-test-command: echo -->
|
||||||
|
|
|
@ -81,14 +81,14 @@ Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管
|
||||||
## 梯度 Handler
|
## 梯度 Handler
|
||||||
|
|
||||||
梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承
|
梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承
|
||||||
`colossalai.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。
|
`colossalai.legacy.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。
|
||||||
如果数据并行被检测到,梯度 handler 会被自动添加进 engine。
|
如果数据并行被检测到,梯度 handler 会被自动添加进 engine。
|
||||||
|
|
||||||
你可以添加你自己的梯度 handler,如下所示:
|
你可以添加你自己的梯度 handler,如下所示:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from colossalai.engine import BaseGradientHandler
|
from colossalai.legacy.engine import BaseGradientHandler
|
||||||
|
|
||||||
@GRADIENT_HANDLER.register_module
|
@GRADIENT_HANDLER.register_module
|
||||||
class YourGradientHandler(BaseGradientHandler):
|
class YourGradientHandler(BaseGradientHandler):
|
||||||
|
@ -109,4 +109,5 @@ gradient_handlers = [
|
||||||
## Schedule
|
## Schedule
|
||||||
|
|
||||||
Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。
|
Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。
|
||||||
如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。
|
如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.legacy.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。
|
||||||
|
<!-- doc-test-command: echo -->
|
||||||
|
|
|
@ -39,7 +39,7 @@ from colossalai.amp import AMP_TYPE
|
||||||
from colossalai.builder.pipeline import partition_uniform
|
from colossalai.builder.pipeline import partition_uniform
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
|
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||||
PipelineSchedule)
|
PipelineSchedule)
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||||
|
|
|
@ -33,7 +33,7 @@ import colossalai.nn as col_nn
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.builder import build_pipeline_model
|
from colossalai.builder import build_pipeline_model
|
||||||
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
|
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||||
PipelineSchedule)
|
PipelineSchedule)
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.legacy.trainer import Trainer, hooks
|
from colossalai.legacy.trainer import Trainer, hooks
|
||||||
|
|
|
@ -380,7 +380,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw
|
||||||
|
|
||||||
#### 导入模块
|
#### 导入模块
|
||||||
```python
|
```python
|
||||||
from colossalai.engine.schedule import (InterleavedPipelineSchedule,
|
from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
|
||||||
PipelineSchedule)
|
PipelineSchedule)
|
||||||
from colossalai.utils import MultiTimer
|
from colossalai.utils import MultiTimer
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
|
||||||
|
|
||||||
|
|
||||||
@GRADIENT_HANDLER.register_module
|
@GRADIENT_HANDLER.register_module
|
||||||
|
@ -57,3 +57,4 @@ gradient_handler = [dict(type='MyGradientHandler')]
|
||||||
```shell
|
```shell
|
||||||
python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
|
python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
|
||||||
```
|
```
|
||||||
|
<!-- doc-test-command: echo -->
|
||||||
|
|
|
@ -3,17 +3,16 @@
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
# copied from fairseq/fairseq/data/indexed_dataset.py
|
# copied from fairseq/fairseq/data/indexed_dataset.py
|
||||||
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
||||||
# other slight modifications to remove fairseq dependencies
|
# other slight modifications to remove fairseq dependencies
|
||||||
# Added document index to index file and made it accessible.
|
# Added document index to index file and made it accessible.
|
||||||
# An empty sentence no longer separates documents.
|
# An empty sentence no longer separates documents.
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import struct
|
import struct
|
||||||
|
from functools import lru_cache
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -88,16 +87,7 @@ def write_longs(f, a):
|
||||||
f.write(np.array(a, dtype=np.int64))
|
f.write(np.array(a, dtype=np.int64))
|
||||||
|
|
||||||
|
|
||||||
dtypes = {
|
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: float, 7: np.double, 8: np.uint16}
|
||||||
1: np.uint8,
|
|
||||||
2: np.int8,
|
|
||||||
3: np.int16,
|
|
||||||
4: np.int32,
|
|
||||||
5: np.int64,
|
|
||||||
6: np.float,
|
|
||||||
7: np.double,
|
|
||||||
8: np.uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def code(dtype):
|
def code(dtype):
|
||||||
|
@ -136,10 +126,8 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||||
def read_index(self, path):
|
def read_index(self, path):
|
||||||
with open(index_file_path(path), 'rb') as f:
|
with open(index_file_path(path), 'rb') as f:
|
||||||
magic = f.read(8)
|
magic = f.read(8)
|
||||||
assert magic == self._HDR_MAGIC, (
|
assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
|
||||||
'Index file doesn\'t match expected format. '
|
'Make sure that --dataset-impl is configured properly.')
|
||||||
'Make sure that --dataset-impl is configured properly.'
|
|
||||||
)
|
|
||||||
version = f.read(8)
|
version = f.read(8)
|
||||||
assert struct.unpack('<Q', version) == (1,)
|
assert struct.unpack('<Q', version) == (1,)
|
||||||
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
||||||
|
@ -198,9 +186,7 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exists(path):
|
def exists(path):
|
||||||
return (
|
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
|
||||||
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_prefetch(self):
|
def supports_prefetch(self):
|
||||||
|
@ -233,7 +219,7 @@ class IndexedCachedDataset(IndexedDataset):
|
||||||
for i in indices:
|
for i in indices:
|
||||||
self.cache_index[i] = ptx
|
self.cache_index[i] = ptx
|
||||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||||
a = self.cache[ptx: ptx + size]
|
a = self.cache[ptx:ptx + size]
|
||||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||||
self.data_file.readinto(a)
|
self.data_file.readinto(a)
|
||||||
ptx += size
|
ptx += size
|
||||||
|
@ -250,7 +236,7 @@ class IndexedCachedDataset(IndexedDataset):
|
||||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||||
a = np.empty(tensor_size, dtype=self.dtype)
|
a = np.empty(tensor_size, dtype=self.dtype)
|
||||||
ptx = self.cache_index[i]
|
ptx = self.cache_index[i]
|
||||||
np.copyto(a, self.cache[ptx: ptx + a.size])
|
np.copyto(a, self.cache[ptx:ptx + a.size])
|
||||||
return a
|
return a
|
||||||
elif isinstance(idx, slice):
|
elif isinstance(idx, slice):
|
||||||
# Hack just to make this work, can optimizer later if necessary
|
# Hack just to make this work, can optimizer later if necessary
|
||||||
|
@ -261,15 +247,7 @@ class IndexedCachedDataset(IndexedDataset):
|
||||||
|
|
||||||
|
|
||||||
class IndexedDatasetBuilder(object):
|
class IndexedDatasetBuilder(object):
|
||||||
element_sizes = {
|
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8}
|
||||||
np.uint8: 1,
|
|
||||||
np.int8: 1,
|
|
||||||
np.int16: 2,
|
|
||||||
np.int32: 4,
|
|
||||||
np.int64: 8,
|
|
||||||
np.float: 4,
|
|
||||||
np.double: 8
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, out_file, dtype=np.int32):
|
def __init__(self, out_file, dtype=np.int32):
|
||||||
self.out_file = open(out_file, 'wb')
|
self.out_file = open(out_file, 'wb')
|
||||||
|
@ -332,12 +310,15 @@ def _warmup_mmap_file(path):
|
||||||
|
|
||||||
|
|
||||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
class Index(object):
|
class Index(object):
|
||||||
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def writer(cls, path, dtype):
|
def writer(cls, path, dtype):
|
||||||
|
|
||||||
class _Writer(object):
|
class _Writer(object):
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self._file = open(path, 'wb')
|
self._file = open(path, 'wb')
|
||||||
|
|
||||||
|
@ -384,10 +365,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, path, skip_warmup=False):
|
def __init__(self, path, skip_warmup=False):
|
||||||
with open(path, 'rb') as stream:
|
with open(path, 'rb') as stream:
|
||||||
magic_test = stream.read(9)
|
magic_test = stream.read(9)
|
||||||
assert self._HDR_MAGIC == magic_test, (
|
assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
|
||||||
'Index file doesn\'t match expected format. '
|
'Make sure that --dataset-impl is configured properly.')
|
||||||
'Make sure that --dataset-impl is configured properly.'
|
|
||||||
)
|
|
||||||
version = struct.unpack('<Q', stream.read(8))
|
version = struct.unpack('<Q', stream.read(8))
|
||||||
assert (1,) == version
|
assert (1,) == version
|
||||||
|
|
||||||
|
@ -406,16 +385,16 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
||||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||||
print(" reading sizes...")
|
print(" reading sizes...")
|
||||||
self._sizes = np.frombuffer(
|
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
|
||||||
self._bin_buffer,
|
|
||||||
dtype=np.int32,
|
|
||||||
count=self._len,
|
|
||||||
offset=offset)
|
|
||||||
print(" reading pointers...")
|
print(" reading pointers...")
|
||||||
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
|
self._pointers = np.frombuffer(self._bin_buffer,
|
||||||
|
dtype=np.int64,
|
||||||
|
count=self._len,
|
||||||
offset=offset + self._sizes.nbytes)
|
offset=offset + self._sizes.nbytes)
|
||||||
print(" reading document index...")
|
print(" reading document index...")
|
||||||
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
|
self._doc_idx = np.frombuffer(self._bin_buffer,
|
||||||
|
dtype=np.int64,
|
||||||
|
count=self._doc_count,
|
||||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
|
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
@ -480,8 +459,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if isinstance(idx, int):
|
if isinstance(idx, int):
|
||||||
ptr, size = self._index[idx]
|
ptr, size = self._index[idx]
|
||||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
|
||||||
count=size, offset=ptr)
|
|
||||||
return np_array
|
return np_array
|
||||||
elif isinstance(idx, slice):
|
elif isinstance(idx, slice):
|
||||||
start, stop, step = idx.indices(len(self))
|
start, stop, step = idx.indices(len(self))
|
||||||
|
@ -491,8 +469,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
sizes = self._index._sizes[idx]
|
sizes = self._index._sizes[idx]
|
||||||
offsets = list(accumulate(sizes))
|
offsets = list(accumulate(sizes))
|
||||||
total_size = sum(sizes)
|
total_size = sum(sizes)
|
||||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
|
||||||
count=total_size, offset=ptr)
|
|
||||||
sents = np.split(np_array, offsets[:-1])
|
sents = np.split(np_array, offsets[:-1])
|
||||||
return sents
|
return sents
|
||||||
|
|
||||||
|
@ -506,8 +483,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
if length is None:
|
if length is None:
|
||||||
length = size - offset
|
length = size - offset
|
||||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||||
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
|
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
|
||||||
count=length, offset=ptr)
|
|
||||||
return np_array
|
return np_array
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -530,12 +506,11 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exists(path):
|
def exists(path):
|
||||||
return (
|
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
|
||||||
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MMapIndexedDatasetBuilder(object):
|
class MMapIndexedDatasetBuilder(object):
|
||||||
|
|
||||||
def __init__(self, out_file, dtype=np.int64):
|
def __init__(self, out_file, dtype=np.int64):
|
||||||
self._data_file = open(out_file, 'wb')
|
self._data_file = open(out_file, 'wb')
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
colossalai
|
colossalai
|
||||||
torch
|
torch
|
||||||
|
six
|
||||||
|
|
|
@ -11,8 +11,8 @@ import colossalai
|
||||||
from colossalai.amp import AMP_TYPE
|
from colossalai.amp import AMP_TYPE
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine.schedule import PipelineSchedule
|
|
||||||
from colossalai.kernel import LayerNorm
|
from colossalai.kernel import LayerNorm
|
||||||
|
from colossalai.legacy.engine.schedule import PipelineSchedule
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import FusedAdam
|
from colossalai.nn.optimizer import FusedAdam
|
||||||
from colossalai.utils import MultiTimer, is_using_pp
|
from colossalai.utils import MultiTimer, is_using_pp
|
||||||
|
|
|
@ -98,7 +98,7 @@ def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
if err is None:
|
if err is None:
|
||||||
passed_models.append(name)
|
passed_models.append(name)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
from colossalai.engine.gradient_handler import MoeGradientHandler
|
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
|
||||||
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
|
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
|
||||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context import MOE_CONTEXT
|
from colossalai.context import MOE_CONTEXT
|
||||||
from colossalai.engine.gradient_handler import MoeGradientHandler
|
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
|
||||||
from colossalai.nn import MoeLoss
|
from colossalai.nn import MoeLoss
|
||||||
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.amp import convert_to_apex_amp
|
from colossalai.amp import convert_to_apex_amp
|
||||||
from colossalai.context import MOE_CONTEXT
|
from colossalai.context import MOE_CONTEXT
|
||||||
from colossalai.engine.gradient_handler import MoeGradientHandler
|
from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
|
||||||
from colossalai.nn import MoeLoss
|
from colossalai.nn import MoeLoss
|
||||||
from colossalai.nn.optimizer import CPUAdam
|
from colossalai.nn.optimizer import CPUAdam
|
||||||
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
Loading…
Reference in New Issue