mirror of https://github.com/hpcaitech/ColossalAI
[refactor] moving grad acc logic to engine (#804)
parent
05d9ae5999
commit
681addb512
|
@ -6,6 +6,11 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
|
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
|
||||||
|
'GradAccumGradientHandler'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def accumulate_gradient(model: nn.Module,
|
def accumulate_gradient(model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
@ -43,7 +48,3 @@ def accumulate_gradient(model: nn.Module,
|
||||||
lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size)
|
lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size)
|
||||||
|
|
||||||
return optimizer, dataloader, gradient_handlers, lr_scheduler
|
return optimizer, dataloader, gradient_handlers, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer',
|
|
||||||
'GradAccumLrSchedulerByStep', 'GradAccumGradientHandler']
|
|
|
@ -12,6 +12,7 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
|
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
|
|
||||||
|
@ -115,7 +116,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
# TODO: remove this after testing new zero with pipeline parallelism
|
# TODO: remove this after testing new zero with pipeline parallelism
|
||||||
model = engine.model
|
model = engine.model
|
||||||
if isinstance(model, (NaiveAMPModel)) or hasattr(model, 'colo_attr'):
|
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
|
||||||
self.dtype = torch.half
|
self.dtype = torch.half
|
||||||
model = model.model
|
model = model.model
|
||||||
sig = inspect.signature(model.forward)
|
sig = inspect.signature(model.forward)
|
||||||
|
@ -386,8 +387,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
self.num_model_chunks = num_model_chunks
|
self.num_model_chunks = num_model_chunks
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
# FIXME(jiaruifang) we shall not use ShardedModelV2 in pipeline mode, due to circular dependency.
|
if isinstance(engine.model, ShardedModelV2):
|
||||||
if hasattr(engine.model, 'colo_attr'):
|
|
||||||
self.dtype = torch.half
|
self.dtype = torch.half
|
||||||
elif isinstance(engine.model[0], NaiveAMPModel):
|
elif isinstance(engine.model[0], NaiveAMPModel):
|
||||||
self.dtype = torch.half
|
self.dtype = torch.half
|
||||||
|
|
|
@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
|
|
||||||
|
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||||
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
|
@ -15,21 +15,26 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||||
|
from colossalai.engine import Engine
|
||||||
|
from colossalai.engine.ophooks import BaseOpHook
|
||||||
|
|
||||||
|
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
|
||||||
|
from colossalai.utils.moe import sync_moe_model_param
|
||||||
|
|
||||||
from colossalai.amp import AMP_TYPE, convert_to_amp
|
from colossalai.amp import AMP_TYPE, convert_to_amp
|
||||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
from colossalai.builder.builder import build_gradient_handler
|
from colossalai.builder.builder import build_gradient_handler
|
||||||
from colossalai.context import Config, ConfigException, ParallelMode
|
from colossalai.context import Config, ConfigException, ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.engine.gradient_accumulation import accumulate_gradient
|
||||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
|
||||||
|
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
|
||||||
from colossalai.engine import Engine
|
|
||||||
from colossalai.engine.ophooks import BaseOpHook
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
|
||||||
sync_model_param)
|
|
||||||
from colossalai.utils.moe import sync_moe_model_param
|
|
||||||
from colossalai.zero import convert_to_zero_v2
|
from colossalai.zero import convert_to_zero_v2
|
||||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
|
||||||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||||
sync_model_param, disposable)
|
sync_model_param, disposable)
|
||||||
from .data_sampler import DataParallelSampler, get_dataloader
|
from .data_sampler import DataParallelSampler, get_dataloader
|
||||||
from .gradient_accumulation import accumulate_gradient
|
|
||||||
from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity
|
from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity
|
||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
from .tensor_detector import TensorDetector
|
from .tensor_detector import TensorDetector
|
||||||
|
@ -18,7 +17,7 @@ __all__ = [
|
||||||
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||||
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
|
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
|
||||||
'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader',
|
'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader',
|
||||||
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
||||||
'ensure_path_exists', 'disposable'
|
'ensure_path_exists', 'disposable'
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.nn import LinearWarmupLR
|
||||||
from colossalai.nn.loss import CrossEntropyLoss
|
from colossalai.nn.loss import CrossEntropyLoss
|
||||||
from colossalai.trainer import Trainer, hooks
|
from colossalai.trainer import Trainer, hooks
|
||||||
from colossalai.utils import free_port, get_dataloader
|
from colossalai.utils import free_port, get_dataloader
|
||||||
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
|
from colossalai.engine.gradient_accumulation import GradAccumLrSchedulerByStep
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from model_zoo.vit import vit_tiny_patch4_32
|
from model_zoo.vit import vit_tiny_patch4_32
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
Loading…
Reference in New Issue