mirror of https://github.com/hpcaitech/ColossalAI
[format] Run lint on colossalai.engine (#3367)
parent
b92313903f
commit
46c009dba4
|
@ -1,10 +1,17 @@
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import List
|
|
||||||
from colossalai.engine import BaseGradientHandler
|
|
||||||
from typing import Iterable
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
|
|
||||||
|
from colossalai.engine import BaseGradientHandler
|
||||||
|
|
||||||
|
from ._gradient_accumulation import (
|
||||||
|
GradAccumDataloader,
|
||||||
|
GradAccumGradientHandler,
|
||||||
|
GradAccumLrSchedulerByStep,
|
||||||
|
GradAccumOptimizer,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
|
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
|
||||||
|
|
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class BaseGradientHandler(ABC):
|
class BaseGradientHandler(ABC):
|
||||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||||
before optimization.
|
before optimization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
|
||||||
from ...context.parallel_mode import ParallelMode
|
from ...context.parallel_mode import ParallelMode
|
||||||
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
from .utils import bucket_allreduce
|
from .utils import bucket_allreduce
|
||||||
|
|
||||||
|
|
||||||
@GRADIENT_HANDLER.register_module
|
@GRADIENT_HANDLER.register_module
|
||||||
class DataParallelGradientHandler(BaseGradientHandler):
|
class DataParallelGradientHandler(BaseGradientHandler):
|
||||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||||
A all-reduce collective communication will be operated in
|
A all-reduce collective communication will be operated in
|
||||||
:func:`handle_gradient` among a data parallel group.
|
:func:`handle_gradient` among a data parallel group.
|
||||||
For better performance, it bucketizes the gradients of all parameters that are
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
the same type to improve the efficiency of communication.
|
the same type to improve the efficiency of communication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -4,9 +4,10 @@ from collections import defaultdict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
||||||
|
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
|
|
||||||
|
@ -14,9 +15,9 @@ from ._base_gradient_handler import BaseGradientHandler
|
||||||
@GRADIENT_HANDLER.register_module
|
@GRADIENT_HANDLER.register_module
|
||||||
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||||
"""A helper class to handle all-reduce operations in sub parallel groups.
|
"""A helper class to handle all-reduce operations in sub parallel groups.
|
||||||
A all-reduce collective communication will be operated in
|
A all-reduce collective communication will be operated in
|
||||||
:func:`handle_gradient` among all sub pipeline parallel groups.
|
:func:`handle_gradient` among all sub pipeline parallel groups.
|
||||||
For better performance, it bucketizes the gradients of all parameters that are
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
the same type to improve the efficiency of communication.
|
the same type to improve the efficiency of communication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
|
||||||
from ...context.parallel_mode import ParallelMode
|
from ...context.parallel_mode import ParallelMode
|
||||||
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
from .utils import bucket_allreduce
|
from .utils import bucket_allreduce
|
||||||
|
|
||||||
|
|
||||||
@GRADIENT_HANDLER.register_module
|
@GRADIENT_HANDLER.register_module
|
||||||
class SequenceParallelGradientHandler(BaseGradientHandler):
|
class SequenceParallelGradientHandler(BaseGradientHandler):
|
||||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||||
A all-reduce collective communication will be operated in
|
A all-reduce collective communication will be operated in
|
||||||
:func:`handle_gradient` among a data parallel group.
|
:func:`handle_gradient` among a data parallel group.
|
||||||
For better performance, it bucketizes the gradients of all parameters that are
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
the same type to improve the efficiency of communication.
|
the same type to improve the efficiency of communication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
|
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
|
||||||
from ._non_pipeline_schedule import NonPipelineSchedule
|
from ._non_pipeline_schedule import NonPipelineSchedule
|
||||||
|
from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape
|
||||||
|
|
||||||
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
|
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Iterable, Callable
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Iterable
|
import inspect
|
||||||
|
from typing import Callable, Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
|
||||||
from ._base_schedule import BaseSchedule
|
|
||||||
from colossalai.utils import conditional_context
|
from colossalai.utils import conditional_context
|
||||||
from typing import Callable
|
|
||||||
|
from ._base_schedule import BaseSchedule
|
||||||
|
|
||||||
|
|
||||||
class NonPipelineSchedule(BaseSchedule):
|
class NonPipelineSchedule(BaseSchedule):
|
||||||
|
|
Loading…
Reference in New Issue