[format] Run lint on colossalai.engine (#3367)

pull/3219/head
Hakjin Lee 2023-04-06 00:24:43 +09:00 committed by GitHub
parent b92313903f
commit 46c009dba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 32 additions and 20 deletions

View File

@ -1,10 +1,17 @@
from typing import Iterable, List
import torch.nn as nn import torch.nn as nn
from typing import List
from colossalai.engine import BaseGradientHandler
from typing import Iterable
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
from colossalai.engine import BaseGradientHandler
from ._gradient_accumulation import (
GradAccumDataloader,
GradAccumGradientHandler,
GradAccumLrSchedulerByStep,
GradAccumOptimizer,
)
__all__ = [ __all__ = [
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
class BaseGradientHandler(ABC): class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups """A basic helper class to handle all-reduce operations of gradients across different parallel groups
before optimization. before optimization.
Args: Args:

View File

@ -1,16 +1,17 @@
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module
class DataParallelGradientHandler(BaseGradientHandler): class DataParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group. """A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group. :func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication. the same type to improve the efficiency of communication.
Args: Args:

View File

@ -4,9 +4,10 @@ from collections import defaultdict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.registry import GRADIENT_HANDLER
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
@ -14,9 +15,9 @@ from ._base_gradient_handler import BaseGradientHandler
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module
class PipelineSharedModuleGradientHandler(BaseGradientHandler): class PipelineSharedModuleGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in sub parallel groups. """A helper class to handle all-reduce operations in sub parallel groups.
A all-reduce collective communication will be operated in A all-reduce collective communication will be operated in
:func:`handle_gradient` among all sub pipeline parallel groups. :func:`handle_gradient` among all sub pipeline parallel groups.
For better performance, it bucketizes the gradients of all parameters that are For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication. the same type to improve the efficiency of communication.
Args: Args:

View File

@ -1,16 +1,17 @@
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module
class SequenceParallelGradientHandler(BaseGradientHandler): class SequenceParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group. """A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group. :func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication. the same type to improve the efficiency of communication.
Args: Args:

View File

@ -1,4 +1,5 @@
from colossalai.registry import GRADIENT_HANDLER from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler

View File

@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from ._non_pipeline_schedule import NonPipelineSchedule from ._non_pipeline_schedule import NonPipelineSchedule
from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] __all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']

View File

@ -2,10 +2,10 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Iterable
import torch import torch
from typing import Iterable, Callable
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device

View File

@ -1,13 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Iterable import inspect
from typing import Callable, Iterable
import torch import torch
import inspect
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context from colossalai.utils import conditional_context
from typing import Callable
from ._base_schedule import BaseSchedule
class NonPipelineSchedule(BaseSchedule): class NonPipelineSchedule(BaseSchedule):