[refactor] moving grad acc logic to engine (#804)

pull/806/head
Jiarui Fang 2022-04-19 14:03:21 +08:00 committed by GitHub
parent 05d9ae5999
commit 681addb512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 26 additions and 20 deletions

View File

@ -6,6 +6,11 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
__all__ = [
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
'GradAccumGradientHandler'
]
def accumulate_gradient(model: nn.Module, def accumulate_gradient(model: nn.Module,
optimizer: Optimizer, optimizer: Optimizer,
@ -43,7 +48,3 @@ def accumulate_gradient(model: nn.Module,
lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size) lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size)
return optimizer, dataloader, gradient_handlers, lr_scheduler return optimizer, dataloader, gradient_handlers, lr_scheduler
__all__ = ['accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer',
'GradAccumLrSchedulerByStep', 'GradAccumGradientHandler']

View File

@ -12,6 +12,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
@ -115,7 +116,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine): def pre_processing(self, engine):
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model
if isinstance(model, (NaiveAMPModel)) or hasattr(model, 'colo_attr'): if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
self.dtype = torch.half self.dtype = torch.half
model = model.model model = model.model
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
@ -386,8 +387,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.num_model_chunks = num_model_chunks self.num_model_chunks = num_model_chunks
def pre_processing(self, engine): def pre_processing(self, engine):
# FIXME(jiaruifang) we shall not use ShardedModelV2 in pipeline mode, due to circular dependency. if isinstance(engine.model, ShardedModelV2):
if hasattr(engine.model, 'colo_attr'):
self.dtype = torch.half self.dtype = torch.half
elif isinstance(engine.model[0], NaiveAMPModel): elif isinstance(engine.model[0], NaiveAMPModel):
self.dtype = torch.half self.dtype = torch.half

View File

@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
import torch import torch
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.memory_tracer import MemStatsCollector from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER

View File

@ -15,21 +15,26 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.logging import get_dist_logger
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
from colossalai.utils.moe import sync_moe_model_param
from colossalai.amp import AMP_TYPE, convert_to_amp from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc from colossalai.engine.gradient_accumulation import accumulate_gradient
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero import convert_to_zero_v2 from colossalai.zero import convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2

View File

@ -7,7 +7,6 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param, disposable) sync_model_param, disposable)
from .data_sampler import DataParallelSampler, get_dataloader from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector from .tensor_detector import TensorDetector
@ -18,7 +17,7 @@ __all__ = [
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction', 'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader',
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
'ensure_path_exists', 'disposable' 'ensure_path_exists', 'disposable'
] ]

View File

@ -14,7 +14,7 @@ from colossalai.nn import LinearWarmupLR
from colossalai.nn.loss import CrossEntropyLoss from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks from colossalai.trainer import Trainer, hooks
from colossalai.utils import free_port, get_dataloader from colossalai.utils import free_port, get_dataloader
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep from colossalai.engine.gradient_accumulation import GradAccumLrSchedulerByStep
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from model_zoo.vit import vit_tiny_patch4_32 from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms from torchvision import transforms