mirror of https://github.com/hpcaitech/ColossalAI
[format] Run lint on colossalai.engine (#3367)
parent
b92313903f
commit
46c009dba4
|
@ -1,10 +1,17 @@
|
|||
from typing import Iterable, List
|
||||
|
||||
import torch.nn as nn
|
||||
from typing import List
|
||||
from colossalai.engine import BaseGradientHandler
|
||||
from typing import Iterable
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
|
||||
|
||||
from colossalai.engine import BaseGradientHandler
|
||||
|
||||
from ._gradient_accumulation import (
|
||||
GradAccumDataloader,
|
||||
GradAccumGradientHandler,
|
||||
GradAccumLrSchedulerByStep,
|
||||
GradAccumOptimizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
|
||||
|
|
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||
|
||||
|
||||
class BaseGradientHandler(ABC):
|
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||
before optimization.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class DataParallelGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -4,9 +4,10 @@ from collections import defaultdict
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
@ -14,9 +15,9 @@ from ._base_gradient_handler import BaseGradientHandler
|
|||
@GRADIENT_HANDLER.register_module
|
||||
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in sub parallel groups.
|
||||
A all-reduce collective communication will be operated in
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among all sub pipeline parallel groups.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class SequenceParallelGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from colossalai.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from ._base_schedule import BaseSchedule
|
||||
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||
from ._non_pipeline_schedule import NonPipelineSchedule
|
||||
from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape
|
||||
|
||||
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Iterable, Callable
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Iterable
|
||||
import inspect
|
||||
from typing import Callable, Iterable
|
||||
|
||||
import torch
|
||||
import inspect
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
from colossalai.utils import conditional_context
|
||||
from typing import Callable
|
||||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
||||
class NonPipelineSchedule(BaseSchedule):
|
||||
|
|
Loading…
Reference in New Issue