[format] Run lint on colossalai.engine (#3367)

pull/3219/head
Hakjin Lee 2023-04-06 00:24:43 +09:00 committed by GitHub
parent b92313903f
commit 46c009dba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 32 additions and 20 deletions

View File

@ -1,10 +1,17 @@
from typing import Iterable, List
import torch.nn as nn
from typing import List
from colossalai.engine import BaseGradientHandler
from typing import Iterable
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
from colossalai.engine import BaseGradientHandler
from ._gradient_accumulation import (
GradAccumDataloader,
GradAccumGradientHandler,
GradAccumLrSchedulerByStep,
GradAccumOptimizer,
)
__all__ = [
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',

View File

@ -1,7 +1,8 @@
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce

View File

@ -4,9 +4,10 @@ from collections import defaultdict
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from ._base_gradient_handler import BaseGradientHandler

View File

@ -1,7 +1,8 @@
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce

View File

@ -1,4 +1,5 @@
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler

View File

@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from ._non_pipeline_schedule import NonPipelineSchedule
from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']

View File

@ -2,10 +2,10 @@
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Callable, Iterable
import torch
from typing import Iterable, Callable
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

View File

@ -1,13 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Iterable
import inspect
from typing import Callable, Iterable
import torch
import inspect
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
from typing import Callable
from ._base_schedule import BaseSchedule
class NonPipelineSchedule(BaseSchedule):