[ddp] add is_ddp_ignored (#2434)

[ddp] rename to is_ddp_ignored
pull/2443/head
HELSON 2023-01-11 12:22:45 +08:00 committed by GitHub
parent a3e5496156
commit 7829aa094e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 56 additions and 30 deletions

View File

@ -6,17 +6,14 @@ import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored
def in_ddp(param: nn.Parameter) -> bool:
return not getattr(param, '_ddp_to_ignore', False)
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
""" """
Filter those parameters whose size is too large (more than 3x standard deviations) from others. Filter those parameters whose size is too large (more than 3x standard deviations) from others.
""" """
params_size = [p.numel() for p in model.parameters() if in_ddp(p)] params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
params_size_arr = np.array(params_size) params_size_arr = np.array(params_size)
std = np.std(params_size_arr) std = np.std(params_size_arr)
@ -56,7 +53,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int
params_dict: Dict[int, List[ColoParameter]] = dict() params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate(): for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param): if is_ddp_ignored(param):
continue continue
param_key = param.process_group.dp_world_size() param_key = param.process_group.dp_world_size()

View File

@ -6,8 +6,8 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration from colossalai.gemini.chunk.search_utils import search_chunk_configuration
from colossalai.gemini.memory_tracer import MemStats from colossalai.utils import is_ddp_ignored
def init_chunk_manager(model: nn.Module, def init_chunk_manager(model: nn.Module,
@ -34,7 +34,7 @@ def init_chunk_manager(model: nn.Module,
if filter_exlarge_params: if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)] params_sizes = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
total_size = sum(params_sizes) / 1024**2 total_size = sum(params_sizes) / 1024**2
dist.barrier() dist.barrier()

View File

@ -12,7 +12,7 @@ from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device from colossalai.utils import disposable, get_current_device, is_ddp_ignored
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
@ -78,7 +78,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
if self.clipping_flag: if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] params_list = [p for p in module.parameters() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, module.fp32_params): for p, fp32_p in zip(params_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set: if chunk_16 not in self.chunk16_set:

View File

@ -14,7 +14,7 @@ from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.zero.utils.gemini_hook import GeminiZeROHook from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer from .reducer import Reducer
@ -81,7 +81,7 @@ class ColoDDP(torch.nn.Module):
self.reducer = Reducer(bucket_cap_mb) self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket self.rebuild_bucket = rebuild_bucket
for p in module.parameters(): for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False): if is_ddp_ignored(p):
continue continue
if p.requires_grad: if p.requires_grad:
p.register_hook(partial(self.grad_handle, p)) p.register_hook(partial(self.grad_handle, p))
@ -116,7 +116,7 @@ class ColoDDP(torch.nn.Module):
if self.rebuild_bucket: if self.rebuild_bucket:
self.reducer.free() self.reducer.free()
for p in self.module.parameters(): for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False): if is_ddp_ignored(p):
continue continue
if p.grad.device.type != "cpu": if p.grad.device.type != "cpu":
p.grad = p._saved_grad p.grad = p._saved_grad
@ -232,7 +232,7 @@ class ZeroDDP(ColoDDP):
for p in param_order.generate(): for p in param_order.generate():
assert isinstance(p, ColoParameter) assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False): if is_ddp_ignored(p):
p.data = p.data.half() p.data = p.data.half()
continue continue
@ -256,7 +256,7 @@ class ZeroDDP(ColoDDP):
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
self._cast_buffers() self._cast_buffers()
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)] params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, self.fp32_params): for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32 = self.chunk_manager.get_chunk(fp32_p)
@ -303,7 +303,7 @@ class ZeroDDP(ColoDDP):
def _setup_grads_ptr(self): def _setup_grads_ptr(self):
for p in self.module.parameters(): for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False): if is_ddp_ignored(p):
continue continue
p.grad = None p.grad = None

View File

@ -1,22 +1,46 @@
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .activation_checkpoint import checkpoint from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint from .checkpointing import load_checkpoint, save_checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, from .common import (
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, clip_grad_norm_fp32,
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, conditional_context,
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, copy_tensor_parallel_attributes,
sync_model_param, disposable) count_zeros_fp32,
disposable,
ensure_path_exists,
free_port,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
is_no_pp_or_last_stage,
is_tp_rank_0,
is_using_ddp,
is_using_pp,
is_using_sequence,
multi_tensor_applier,
param_is_not_tensor_parallel_duplicate,
print_rank_0,
switch_virtual_pipeline_parallel_rank,
sync_model_param,
)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader from .data_sampler import DataParallelSampler, get_dataloader
from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, from .memory import (
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity) colo_device_memory_capacity,
from .timer import MultiTimer, Timer colo_device_memory_used,
colo_get_cpu_memory_capacity,
colo_set_cpu_memory_capacity,
colo_set_process_memory_fraction,
report_memory_usage,
)
from .tensor_detector import TensorDetector from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer
__all__ = [ __all__ = [
'checkpoint', 'checkpoint',
'free_port', 'free_port',
'print_rank_0', 'print_rank_0',
'sync_model_param', 'sync_model_param',
'is_ddp_ignored',
'is_dp_rank_0', 'is_dp_rank_0',
'is_tp_rank_0', 'is_tp_rank_0',
'is_no_pp_or_last_stage', 'is_no_pp_or_last_stage',

View File

@ -126,14 +126,18 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def is_ddp_ignored(p):
return getattr(p, '_ddp_to_ignore', False)
def _calc_l2_norm(grads): def _calc_l2_norm(grads):
# we should not # we should not
global fused_optim global fused_optim
if fused_optim is None: if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimBuilder().load()
norm = 0.0 norm = 0.0
if len(grads) > 0: if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])

View File

@ -8,6 +8,7 @@ import torch
from colossalai.gemini import TensorState from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
class TrainingPhase(Enum): class TrainingPhase(Enum):
@ -24,7 +25,7 @@ class GeminiZeROHook(ColoParamOpHook):
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params): def pre_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)] params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params) chunks = self._chunk_manager.get_chunks(params)
for p in params: for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
@ -37,7 +38,7 @@ class GeminiZeROHook(ColoParamOpHook):
self._gemini_manager.record_model_data_volume() self._gemini_manager.record_model_data_volume()
def post_op(self, params): def post_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)] params = [p for p in params if not is_ddp_ignored(p)]
for p in params: for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state) self._chunk_manager.trans_tensor_state(p, tensor_state)