mirror of https://github.com/hpcaitech/ColossalAI
[zero] reorganize zero/gemini folder structure (#3424)
* [zero] refactor low-level zero folder structure * [zero] fix legacy zero import path * [zero] fix legacy zero import path * [zero] remove useless import * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] fix test import path * [zero] fix test * [zero] fix circular import * [zero] update importpull/3436/head
parent
b09adff724
commit
26b7aac0be
|
@ -14,17 +14,16 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|||
import colossalai
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini.utils import get_static_torch_model
|
||||
|
||||
from .base import Strategy
|
||||
from .ddp import DDPStrategy
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
class ColossalAIStrategy(DDPStrategy):
|
||||
"""
|
||||
|
|
|
@ -4,8 +4,8 @@ from typing import Optional, Set
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.tensor_utils import free_storage
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .util import GlobalRuntimeInfo
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from typing import List, Dict, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
|
||||
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
|
||||
|
||||
|
||||
class Region:
|
||||
"""
|
||||
|
@ -52,15 +55,13 @@ class Region:
|
|||
Map the parameters in the region to a contiguous memory space.
|
||||
"""
|
||||
|
||||
self.fp16_data = torch.zeros(
|
||||
self.param_num, dtype=torch.half, device='cuda')
|
||||
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
|
||||
offset = 0
|
||||
for param in self.fp16_params:
|
||||
param.data = param.data.cuda()
|
||||
p_num = param.data.numel()
|
||||
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
|
||||
param.data = self.fp16_data[offset:offset +
|
||||
p_num].view(param.data.shape)
|
||||
param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
|
||||
self.param_to_range[param] = (offset, offset + p_num)
|
||||
offset += p_num
|
||||
|
||||
|
@ -141,4 +142,4 @@ class Region:
|
|||
def __update_params_ptr(self) -> None:
|
||||
for param in self.fp16_params:
|
||||
begin, end = self.param_to_range[param]
|
||||
param.data = self.fp16_data[begin:end].view(param.data.shape)
|
||||
param.data = self.fp16_data[begin:end].view(param.data.shape)
|
||||
|
|
|
@ -14,12 +14,12 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import _convert_to_coloparam
|
||||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from .plugin_base import Plugin
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@ from torch.nn.modules.loss import _Loss
|
|||
|
||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
|
||||
from colossalai.gemini.ophooks import BaseOpHook, register_ophooks_recursively
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
|
||||
|
||||
|
||||
class Engine:
|
||||
|
|
|
@ -157,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
return self._move_to_device(mciro_batch_data)
|
||||
|
||||
def pre_processing(self, engine):
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.legacy import ShardedModelV2
|
||||
|
||||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
model = engine.model
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
|
||||
from .gemini_mgr import GeminiManager
|
||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
||||
|
||||
__all__ = [
|
||||
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager',
|
||||
'search_chunk_configuration'
|
||||
]
|
|
@ -29,13 +29,12 @@ from colossalai.engine.schedule import (
|
|||
PipelineSchedule,
|
||||
get_tensor_shape,
|
||||
)
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.zero import convert_to_zero_v2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2
|
||||
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
|||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
||||
from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator
|
||||
|
||||
|
||||
class MoeExperts(nn.Module):
|
||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.nn.layer.moe.experts import Experts, MoeExperts
|
|||
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
|
||||
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||
from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||
|
||||
|
||||
@no_shard_zero_decrator(is_replicated=True)
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||
|
||||
__all__ = ['GeminiAdamOptimizer']
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
|
||||
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
||||
optimizer = HybridAdam(model.parameters(), **defaults)
|
||||
super().__init__(optimizer, model, **defaults)
|
|
@ -1,5 +1,5 @@
|
|||
from .data_parallel import ColoDDP, ZeroDDP
|
||||
from .gemini_parallel import GeminiDDP
|
||||
from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper
|
||||
from .data_parallel import ColoDDP
|
||||
|
||||
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper']
|
||||
__all__ = [
|
||||
'ColoDDP',
|
||||
]
|
||||
|
|
|
@ -1,31 +1,14 @@
|
|||
import itertools
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
from typing import Iterable, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
from .reducer import Reducer
|
||||
from .utils import get_static_torch_model
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
|
||||
def free_storage(data: torch.Tensor) -> None:
|
||||
|
@ -189,507 +172,3 @@ class ColoDDP(torch.nn.Module):
|
|||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
return self.module.load_state_dict(state_dict, strict)
|
||||
|
||||
|
||||
class ZeroDDP(ColoDDP):
|
||||
"""ZeRO DDP for ColoTensor.
|
||||
Warning: Nested ZeroDDP is not supported now.
|
||||
It is designed to be used with ChunkManager and GeminiManager.
|
||||
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Module to apply ZeRO-DP.
|
||||
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
pin_memory (bool): Chunks on CPU Memory use pin-memory.
|
||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
|
||||
Defaults to False.
|
||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False) -> None:
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||
self.fp32_params: List[ColoTensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
|
||||
self._cast_buffers()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
# build chunk in param runtime visited order.
|
||||
param_order = self.gemini_manager.memstats()._param_runtime_order
|
||||
else:
|
||||
# build chunk in param initialized order.
|
||||
# Note: in this way, it can not get filter unused params during runtime.
|
||||
param_order = OrderedParamGenerator()
|
||||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
access_list = list(self.chunk_manager.accessed_chunks)
|
||||
# we need to scatter all accessed chunks and move them to their original places
|
||||
for chunk in access_list:
|
||||
if chunk.keep_gathered:
|
||||
self.chunk_manager.fake_release_chunk(chunk)
|
||||
else:
|
||||
assert chunk.can_release
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
first_param = next(iter(chunk.tensors_info))
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
# reset all recorded attributes
|
||||
self.gemini_manager.reset_attributes()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# check whether we are in a inference mode
|
||||
grad_flag = torch.is_grad_enabled()
|
||||
if not grad_flag:
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
# scatter chunks in the inference mode
|
||||
if not grad_flag:
|
||||
self._post_forward()
|
||||
|
||||
if self.force_outputs_fp32:
|
||||
return _cast_float(outputs, torch.float)
|
||||
return outputs
|
||||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
p.grad = None
|
||||
|
||||
def _pre_backward(self):
|
||||
# set a visit label for all parameters
|
||||
# the label is used to check whether the parameter is correctly reduced
|
||||
for param in self.param2name:
|
||||
if not is_ddp_ignored(param):
|
||||
setattr(param, "_gemini_reduced", False)
|
||||
|
||||
def _post_backward(self):
|
||||
if self.chunk_manager.accessed_mem != 0:
|
||||
error_params = ["Reduction failed at followed parameters:"]
|
||||
for param in self.param2name:
|
||||
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
|
||||
error_params.append(self.param2name[param])
|
||||
error_str = "\n\t".join(error_params)
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
|
||||
f"{error_str}")
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
)
|
||||
self.gemini_manager.post_iter()
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
self._pre_backward()
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
self._post_backward()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
|
||||
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter.")
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
if reduced:
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
else:
|
||||
chunk.cuda_shard.div_(chunk.pg_size)
|
||||
# check overflow elements
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping
|
||||
if chunk.l2_norm_flag:
|
||||
chunk.set_l2_norm()
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
|
||||
are shared with other parameters which have been included in the dictionary.
|
||||
When you need to load the state dict, you should set the argument `strict` to False.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
a dictionary containing a whole state of the module
|
||||
"""
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
||||
|
||||
for hook in self._state_dict_hooks.values():
|
||||
hook_result = hook(self, destination, prefix, local_metadata)
|
||||
if hook_result is not None:
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get param content from chunks.
|
||||
|
||||
Args:
|
||||
param_list (_type_): a list of torch.nn.Parameters
|
||||
only_rank_0 (_type_): _description_
|
||||
|
||||
Returns:
|
||||
Dict: a dict whose key is param name and value is param with correct payload
|
||||
"""
|
||||
# save parameters
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in param_to_save_data
|
||||
param_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
return param_to_save_data
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
|
||||
In rare cases, subclasses can achieve class-specific behavior by
|
||||
overriding this method with custom logic.
|
||||
|
||||
Args:
|
||||
destination (dict): a dict where state will be stored
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
# get copies of fp32 parameters in CPU
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||
# get the mapping between copies and fp16 parameters
|
||||
p_mapping = dict()
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
name = self.param2name[p]
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
p_mapping[p] = record_parameter
|
||||
for name, param in self.name2param.items():
|
||||
if param is not None:
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
else:
|
||||
destination[prefix + name] = p_mapping[param]
|
||||
del p_mapping
|
||||
del param_to_save_data
|
||||
|
||||
# save all buffers
|
||||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants. If :attr:`strict` is ``True``, then
|
||||
the keys of :attr:`state_dict` must exactly match the keys returned
|
||||
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
||||
|
||||
Returns:
|
||||
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
||||
* **missing_keys** is a list of str containing the missing keys
|
||||
* **unexpected_keys** is a list of str containing the unexpected keys
|
||||
|
||||
Note:
|
||||
If a parameter or buffer is registered as ``None`` and its corresponding key
|
||||
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
|
||||
``RuntimeError``.
|
||||
"""
|
||||
missing_keys: List[str] = []
|
||||
unexpected_keys: List[str] = []
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
# mypy isn't aware that "_metadata" exists in state_dict
|
||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||
|
||||
prefix = ''
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
def load(param_name, dest_tensor, copy_func):
|
||||
state_key = prefix + param_name
|
||||
if state_key in state_dict:
|
||||
input_param = state_dict[state_key]
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
if input_param.shape != dest_tensor.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(state_key, input_param.shape,
|
||||
dest_tensor.shape))
|
||||
return
|
||||
try:
|
||||
with torch.no_grad():
|
||||
copy_func(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
|
||||
input_param.size(), ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(state_key)
|
||||
|
||||
def load_fp32_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
load(name, param, param.copy_)
|
||||
|
||||
fp32_to_name = dict()
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
if p is not None:
|
||||
name = self.param2name[p]
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||
elif chunk.cuda_shard is not None:
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
else:
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.optim_update()
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
load(name, buf, buf.copy_)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state",
|
||||
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
if input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
ddp_pg = ColoProcessGroup()
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
# gather sharded parameters in the strict ddp mode
|
||||
if strict_ddp_mode:
|
||||
if not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
p.set_process_group(pg=ddp_pg)
|
||||
|
||||
# ignore the parameters with no gradient
|
||||
if not p.requires_grad:
|
||||
self.set_params_to_ignore([p])
|
||||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.half()
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.chunk import init_chunk_manager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
|
||||
from .data_parallel import ZeroDDP
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
device: torch.device,
|
||||
placement_policy: str = "cpu",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
search_range_mb: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: float = 32,
|
||||
memstats: Optional[MemStats] = None) -> None:
|
||||
"""
|
||||
A torch.Module warpper using ZeRO-DP and Genimi.
|
||||
ZeRO is for parallel. Gemini is for memory management.
|
||||
WARNING: The class will modify the module inline!
|
||||
|
||||
Example:
|
||||
model is initialized under the context of ColoInitContext
|
||||
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the model to be wrapped.
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
|
||||
hidden_dim (int, optional): the hidden dimension of DNN.
|
||||
Users can provide this argument to speed up searching.
|
||||
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
||||
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
||||
If the aggregate size of parameters is still samller than the minimum chunk size,
|
||||
all parameters will be compacted into one small chunk.
|
||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||
"""
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_mb is None:
|
||||
search_range_mb = 32
|
||||
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_mb=search_range_mb,
|
||||
min_chunk_size_mb=min_chunk_size_mb,
|
||||
strict_ddp_flag=strict_ddp_mode)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
|
|
@ -1,41 +1,16 @@
|
|||
from typing import Tuple
|
||||
from .gemini import (
|
||||
ColoInitContext,
|
||||
GeminiAdamOptimizer,
|
||||
GeminiDDP,
|
||||
ZeroDDP,
|
||||
ZeroOptimizer,
|
||||
get_static_torch_model,
|
||||
post_process_colo_init_ctx,
|
||||
)
|
||||
from .low_level import LowLevelZeroOptimizer
|
||||
from .wrapper import zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
|
||||
|
||||
from ..nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||
|
||||
|
||||
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
||||
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
|
||||
"""
|
||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||
|
||||
:param model: Your model object
|
||||
:type model: :class:`torch.nn.Module`
|
||||
:param optimizer_config: Your optimizer object
|
||||
:type optimizer_config: :class:`dict`
|
||||
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
|
||||
logger = get_dist_logger('convert_to_zero_v2')
|
||||
|
||||
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
|
||||
if optimizer_config is None:
|
||||
optimizer_config = dict()
|
||||
logger.info(f'model_config is {model_config}', ranks=[0])
|
||||
if model_config is None:
|
||||
model_config = dict()
|
||||
|
||||
zero_model = ShardedModelV2(model, **model_config)
|
||||
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
|
||||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
|
||||
__all__ = [
|
||||
'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
|
||||
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
|
||||
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
|
||||
from .gemini_ddp import GeminiDDP, ZeroDDP
|
||||
from .gemini_mgr import GeminiManager
|
||||
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
|
||||
from .utils import get_static_torch_model
|
||||
|
||||
__all__ = [
|
||||
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
|
||||
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
|
||||
]
|
|
@ -3,10 +3,11 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
"""
|
|
@ -5,9 +5,9 @@ import numpy as np
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
||||
|
||||
|
||||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
|
@ -5,10 +5,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.gemini.chunk.search_utils import search_chunk_configuration
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import search_chunk_configuration
|
||||
|
||||
|
||||
def safe_div(a, b):
|
||||
if a == 0:
|
|
@ -3,10 +3,8 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
|
||||
|
||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
|
||||
# find named_params includes replica
|
||||
|
||||
|
@ -89,6 +87,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self._default_dist_spec = default_dist_spec
|
||||
|
||||
def _register_colo_modules(self):
|
||||
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
|
||||
register_colo_module(torch.nn.Linear, ColoLinear())
|
||||
register_colo_module(torch.nn.Embedding, ColoEmbedding())
|
||||
|
|
@ -0,0 +1,590 @@
|
|||
import itertools
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||
from .gemini_hook import GeminiZeROHook
|
||||
from .gemini_mgr import GeminiManager
|
||||
from .memory_tracer import MemStats, OrderedParamGenerator
|
||||
from .utils import get_temp_total_chunk_on_cuda
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
__all__ = [
|
||||
'ZeroDDP',
|
||||
'GeminiDDP',
|
||||
]
|
||||
|
||||
|
||||
class ZeroDDP(ColoDDP):
|
||||
"""ZeRO DDP for ColoTensor.
|
||||
Warning: Nested ZeroDDP is not supported now.
|
||||
It is designed to be used with ChunkManager and GeminiManager.
|
||||
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Module to apply ZeRO-DP.
|
||||
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
pin_memory (bool): Chunks on CPU Memory use pin-memory.
|
||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
|
||||
Defaults to False.
|
||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False) -> None:
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||
self.fp32_params: List[ColoTensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
|
||||
self._cast_buffers()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
# build chunk in param runtime visited order.
|
||||
param_order = self.gemini_manager.memstats()._param_runtime_order
|
||||
else:
|
||||
# build chunk in param initialized order.
|
||||
# Note: in this way, it can not get filter unused params during runtime.
|
||||
param_order = OrderedParamGenerator()
|
||||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
access_list = list(self.chunk_manager.accessed_chunks)
|
||||
# we need to scatter all accessed chunks and move them to their original places
|
||||
for chunk in access_list:
|
||||
if chunk.keep_gathered:
|
||||
self.chunk_manager.fake_release_chunk(chunk)
|
||||
else:
|
||||
assert chunk.can_release
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
first_param = next(iter(chunk.tensors_info))
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
# reset all recorded attributes
|
||||
self.gemini_manager.reset_attributes()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# check whether we are in a inference mode
|
||||
grad_flag = torch.is_grad_enabled()
|
||||
if not grad_flag:
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
# scatter chunks in the inference mode
|
||||
if not grad_flag:
|
||||
self._post_forward()
|
||||
|
||||
if self.force_outputs_fp32:
|
||||
return _cast_float(outputs, torch.float)
|
||||
return outputs
|
||||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
p.grad = None
|
||||
|
||||
def _pre_backward(self):
|
||||
# set a visit label for all parameters
|
||||
# the label is used to check whether the parameter is correctly reduced
|
||||
for param in self.param2name:
|
||||
if not is_ddp_ignored(param):
|
||||
setattr(param, "_gemini_reduced", False)
|
||||
|
||||
def _post_backward(self):
|
||||
if self.chunk_manager.accessed_mem != 0:
|
||||
error_params = ["Reduction failed at followed parameters:"]
|
||||
for param in self.param2name:
|
||||
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
|
||||
error_params.append(self.param2name[param])
|
||||
error_str = "\n\t".join(error_params)
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
|
||||
f"{error_str}")
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
)
|
||||
self.gemini_manager.post_iter()
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
self._pre_backward()
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
self._post_backward()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
|
||||
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter.")
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
if reduced:
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
else:
|
||||
chunk.cuda_shard.div_(chunk.pg_size)
|
||||
# check overflow elements
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping
|
||||
if chunk.l2_norm_flag:
|
||||
chunk.set_l2_norm()
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
|
||||
are shared with other parameters which have been included in the dictionary.
|
||||
When you need to load the state dict, you should set the argument `strict` to False.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
a dictionary containing a whole state of the module
|
||||
"""
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
||||
|
||||
for hook in self._state_dict_hooks.values():
|
||||
hook_result = hook(self, destination, prefix, local_metadata)
|
||||
if hook_result is not None:
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get param content from chunks.
|
||||
|
||||
Args:
|
||||
param_list (_type_): a list of torch.nn.Parameters
|
||||
only_rank_0 (_type_): _description_
|
||||
|
||||
Returns:
|
||||
Dict: a dict whose key is param name and value is param with correct payload
|
||||
"""
|
||||
# save parameters
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in param_to_save_data
|
||||
param_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
return param_to_save_data
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
|
||||
In rare cases, subclasses can achieve class-specific behavior by
|
||||
overriding this method with custom logic.
|
||||
|
||||
Args:
|
||||
destination (dict): a dict where state will be stored
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
# get copies of fp32 parameters in CPU
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||
# get the mapping between copies and fp16 parameters
|
||||
p_mapping = dict()
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
name = self.param2name[p]
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
p_mapping[p] = record_parameter
|
||||
for name, param in self.name2param.items():
|
||||
if param is not None:
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
else:
|
||||
destination[prefix + name] = p_mapping[param]
|
||||
del p_mapping
|
||||
del param_to_save_data
|
||||
|
||||
# save all buffers
|
||||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants. If :attr:`strict` is ``True``, then
|
||||
the keys of :attr:`state_dict` must exactly match the keys returned
|
||||
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
||||
|
||||
Returns:
|
||||
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
||||
* **missing_keys** is a list of str containing the missing keys
|
||||
* **unexpected_keys** is a list of str containing the unexpected keys
|
||||
|
||||
Note:
|
||||
If a parameter or buffer is registered as ``None`` and its corresponding key
|
||||
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
|
||||
``RuntimeError``.
|
||||
"""
|
||||
missing_keys: List[str] = []
|
||||
unexpected_keys: List[str] = []
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
# mypy isn't aware that "_metadata" exists in state_dict
|
||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||
|
||||
prefix = ''
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
|
||||
'"{}"'.format(k) for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
def load(param_name, dest_tensor, copy_func):
|
||||
state_key = prefix + param_name
|
||||
if state_key in state_dict:
|
||||
input_param = state_dict[state_key]
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
if input_param.shape != dest_tensor.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(state_key, input_param.shape,
|
||||
dest_tensor.shape))
|
||||
return
|
||||
try:
|
||||
with torch.no_grad():
|
||||
copy_func(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
|
||||
input_param.size(), ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(state_key)
|
||||
|
||||
def load_fp32_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
load(name, param, param.copy_)
|
||||
|
||||
fp32_to_name = dict()
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
if p is not None:
|
||||
name = self.param2name[p]
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||
elif chunk.cuda_shard is not None:
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
else:
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.optim_update()
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
load(name, buf, buf.copy_)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state",
|
||||
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
if input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
ddp_pg = ColoProcessGroup()
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
# gather sharded parameters in the strict ddp mode
|
||||
if strict_ddp_mode:
|
||||
if not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
p.set_process_group(pg=ddp_pg)
|
||||
|
||||
# ignore the parameters with no gradient
|
||||
if not p.requires_grad:
|
||||
self.set_params_to_ignore([p])
|
||||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.half()
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
device: torch.device,
|
||||
placement_policy: str = "cpu",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
search_range_mb: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: float = 32,
|
||||
memstats: Optional[MemStats] = None) -> None:
|
||||
"""
|
||||
A torch.Module warpper using ZeRO-DP and Genimi.
|
||||
ZeRO is for parallel. Gemini is for memory management.
|
||||
WARNING: The class will modify the module inline!
|
||||
|
||||
Example:
|
||||
model is initialized under the context of ColoInitContext
|
||||
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the model to be wrapped.
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
|
||||
hidden_dim (int, optional): the hidden dimension of DNN.
|
||||
Users can provide this argument to speed up searching.
|
||||
If users do not know this argument before training, it is ok. We will use a default value 1024.
|
||||
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
||||
If the aggregate size of parameters is still samller than the minimum chunk size,
|
||||
all parameters will be compacted into one small chunk.
|
||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||
"""
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_mb is None:
|
||||
search_range_mb = 32
|
||||
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_mb=search_range_mb,
|
||||
min_chunk_size_mb=min_chunk_size_mb,
|
||||
strict_ddp_flag=strict_ddp_mode)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
|
|
@ -5,10 +5,10 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
from colossalai.zero.gemini import TensorState
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
|
@ -4,10 +4,8 @@ from typing import List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
|
||||
from .memory_tracer import ChunkMemStatsCollector
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
|
||||
|
|
@ -10,12 +10,15 @@ from torch.nn import Parameter
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import ZeroDDP
|
||||
|
||||
__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
|
||||
|
||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
|
||||
|
@ -316,3 +319,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
fake_params_list.append(fake_param)
|
||||
|
||||
group['params'] = fake_params_list
|
||||
|
||||
|
||||
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||
|
||||
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
||||
optimizer = HybridAdam(model.parameters(), **defaults)
|
||||
super().__init__(optimizer, model, **defaults)
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.chunk import ChunkManager
|
||||
|
||||
from .memory_stats import MemStats
|
||||
from .memstats_collector import MemStatsCollector
|
||||
|
||||
|
|
@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||
from .param_runtime_order import OrderedParamGenerator
|
||||
|
||||
|
||||
class MemStats(object):
|
|
@ -1,12 +1,7 @@
|
|||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from typing import Optional
|
||||
|
||||
from .memory_monitor import SyncCudaMemoryMonitor
|
||||
from .memory_stats import MemStats
|
||||
|
||||
|
||||
|
@ -49,7 +44,7 @@ class MemStatsCollector:
|
|||
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
||||
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
|
||||
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
|
||||
f"step total {self._step_total}"
|
||||
f"step total {self._step_total}"
|
||||
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
|
||||
self._step_idx = (self._step_idx + 1) % self._step_total
|
||||
return next_non_model_data
|
||||
|
@ -75,6 +70,8 @@ class MemStatsCollector:
|
|||
Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
from colossalai.zero.legacy.gemini import StatefulTensor
|
||||
|
||||
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
self._memstats.record_max_cuda_model_data(cuda_mem)
|
|
@ -1,9 +1,14 @@
|
|||
import torch.nn
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
|
||||
GradMemStats,
|
||||
GradMemTracerHook,
|
||||
ParamMemTracerHook,
|
||||
)
|
||||
|
||||
from .memory_stats import MemStats
|
||||
|
||||
__all__ = ['RuntimeMemTracer']
|
||||
|
|
@ -6,7 +6,7 @@ from torch.fx import symbolic_trace
|
|||
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.zero.gemini.chunk import ChunkManager
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
|
@ -5,11 +5,12 @@ from typing import Dict, List, Optional, Tuple, Type
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector
|
||||
|
||||
|
||||
class PlacementPolicy(ABC):
|
||||
need_mem_stats: bool = False
|
|
@ -6,9 +6,10 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .chunk import Chunk
|
||||
|
||||
|
||||
def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||
if chunk.is_gathered:
|
||||
|
@ -77,7 +78,7 @@ def get_static_torch_model(zero_ddp_model,
|
|||
Returns:
|
||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||
"""
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.zero.gemini.gemini_ddp import ZeroDDP
|
||||
assert isinstance(zero_ddp_model, ZeroDDP)
|
||||
|
||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
|
||||
from .sharded_model import ShardedModelV2
|
||||
from .sharded_optim import ShardedOptimizerV2
|
||||
|
||||
|
||||
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
||||
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
|
||||
"""
|
||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||
|
||||
:param model: Your model object
|
||||
:type model: :class:`torch.nn.Module`
|
||||
:param optimizer_config: Your optimizer object
|
||||
:type optimizer_config: :class:`dict`
|
||||
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
|
||||
logger = get_dist_logger('convert_to_zero_v2')
|
||||
|
||||
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
|
||||
if optimizer_config is None:
|
||||
optimizer_config = dict()
|
||||
logger.info(f'model_config is {model_config}', ranks=[0])
|
||||
if model_config is None:
|
||||
model_config = dict()
|
||||
|
||||
zero_model = ShardedModelV2(model, **model_config)
|
||||
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
|
||||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
|
||||
'no_shard_zero_decrator'
|
||||
]
|
|
@ -0,0 +1,9 @@
|
|||
from .ophooks import BaseOpHook, register_ophooks_recursively
|
||||
from .stateful_tensor import StatefulTensor
|
||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy
|
||||
|
||||
__all__ = [
|
||||
'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy',
|
||||
'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook'
|
||||
]
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
|
||||
from colossalai.registry import OPHOOKS
|
||||
|
||||
from . import BaseOpHook
|
|
@ -5,9 +5,9 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
|
||||
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
|
@ -1,9 +1,9 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
import torch
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from colossalai.gemini.gemini_context import GeminiMemoryManager
|
||||
import torch
|
||||
|
||||
from .gemini_context import GeminiMemoryManager
|
||||
|
||||
|
||||
def sizeof_tensor(tensor: torch.Tensor):
|
||||
|
@ -19,7 +19,7 @@ class TensorState(Enum):
|
|||
|
||||
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
Inspired from the paper:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
|
@ -1,13 +1,16 @@
|
|||
import functools
|
||||
import torch
|
||||
import types
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
|
||||
from typing import List
|
||||
from colossalai.logging import get_dist_logger
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from .stateful_tensor import StatefulTensor, TensorState
|
||||
from .tensor_placement_policy import TensorPlacementPolicy
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
|
||||
|
||||
class StatefulTensorMgr(object):
|
|
@ -5,11 +5,12 @@ from typing import List, Optional, Type
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
|
||||
from .stateful_tensor import StatefulTensor
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
|
||||
|
||||
class TensorPlacementPolicy(ABC):
|
|
@ -1,6 +1,8 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from typing import Union, Tuple
|
||||
|
||||
from .stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
|
@ -13,10 +13,10 @@ from colossalai.context.singleton_meta import SingletonMeta
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_param import ShardedParamV2
|
||||
|
||||
|
||||
@dataclass
|
|
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
|||
from typing import List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class BaseShardStrategy(ABC):
|
|
@ -2,17 +2,18 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
|
||||
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
|
||||
which will fully utilize network bandwidth.
|
||||
It is especially useful when sub-module contains bias,
|
||||
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
|
||||
which will fully utilize network bandwidth.
|
||||
It is especially useful when sub-module contains bias,
|
||||
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small).
|
||||
"""
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""Return the local shard of a full tensor."""
|
|
@ -2,11 +2,12 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.commons import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.shard_utils.commons import get_shard
|
||||
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
|
@ -27,7 +28,7 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
|
||||
Args:
|
||||
t (ShardedTensor): a tensor to be sharded.
|
||||
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
|
||||
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
|
||||
Defaults to None.
|
||||
"""
|
||||
if t.is_sharded:
|
|
@ -1,3 +1,3 @@
|
|||
from .sharded_model_v2 import ShardedModelV2
|
||||
|
||||
__all__ = ['ShardedModelV2']
|
||||
__all__ = ['ShardedModelV2']
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Any, Callable, List, Tuple
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Union
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
def get_gradient_predivide_factor(world_size: int) -> float:
|
|
@ -13,19 +13,18 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
|
||||
from colossalai.gemini.ophooks import register_ophooks_recursively
|
||||
from colossalai.gemini.paramhooks import BaseParamHookMgr
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import disposable, get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.utils import ZeroHook
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
|
||||
from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively
|
||||
from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
|
||||
from ._utils import (
|
||||
cast_float_arguments,
|
||||
|
@ -35,6 +34,7 @@ from ._utils import (
|
|||
free_storage,
|
||||
get_gradient_predivide_factor,
|
||||
)
|
||||
from .zero_hook import ZeroHook
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
|
@ -1,8 +1,9 @@
|
|||
import torch
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
|
||||
|
||||
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
|
||||
"""
|
|
@ -3,14 +3,14 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
|
@ -0,0 +1,3 @@
|
|||
from .sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
__all__ = ['ShardedOptimizerV2']
|
|
@ -14,13 +14,13 @@ from torch.optim import Optimizer
|
|||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32
|
||||
|
||||
|
||||
class OptimState(Enum):
|
|
@ -0,0 +1,4 @@
|
|||
from .sharded_param import ShardedParamV2
|
||||
from .sharded_tensor import ShardedTensor
|
||||
|
||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
|
@ -1,9 +1,11 @@
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from typing import List
|
||||
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage
|
||||
|
||||
from .sharded_tensor import ShardedTensor
|
||||
|
||||
EMPTY_TENSOR_DICT = {}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedTensor(StatefulTensor):
|
|
@ -0,0 +1,3 @@
|
|||
from .low_level_optim import LowLevelZeroOptimizer
|
||||
|
||||
__all__ = ['LowLevelZeroOptimizer']
|
|
@ -1,4 +0,0 @@
|
|||
from .low_level_optim import LowLevelZeroOptimizer
|
||||
from .sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer']
|
|
@ -1,4 +0,0 @@
|
|||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
|
||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
|
@ -1,3 +0,0 @@
|
|||
from .zero_hook import ZeroHook
|
||||
|
||||
__all__ = ['ZeroHook']
|
|
@ -4,7 +4,7 @@ from typing import Dict, Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .gemini_parallel import GeminiDDP
|
||||
from .gemini import GeminiDDP
|
||||
|
||||
|
||||
def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None):
|
||||
|
@ -99,11 +99,11 @@ def zero_optim_wrapper(model: nn.Module,
|
|||
config_dict['max_scale'] = max_scale
|
||||
|
||||
if zero_stage in [1, 2]:
|
||||
from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
config_dict['partition_grad'] = zero_stage == 2
|
||||
config_dict['clip_grad_norm'] = max_norm
|
||||
return LowLevelZeroOptimizer(optimizer, **config_dict)
|
||||
else:
|
||||
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
|
||||
config_dict['clipping_norm'] = max_norm
|
||||
return ZeroOptimizer(optimizer, model, **config_dict)
|
|
@ -78,7 +78,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
|||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
```
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
```
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ torchrun --standalone --nproc_per_node=1 debug.py
|
|||
from diffusers import AutoencoderKL
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
|
||||
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx
|
||||
|
||||
path = "/data/scratch/diffuser/stable-diffusion-v1-4"
|
||||
|
||||
|
|
|
@ -21,10 +21,9 @@ import colossalai
|
|||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
||||
from colossalai.zero.gemini import get_static_torch_model
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
|
|
|
@ -23,10 +23,9 @@ import colossalai
|
|||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
|
||||
from colossalai.zero.gemini import get_static_torch_model
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, Proc
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
|
||||
def init_1d_row_for_linear_weight_spec(model, world_size: int):
|
||||
|
|
|
@ -12,10 +12,9 @@ from transformers import AlbertConfig, AlbertForSequenceClassification, BertConf
|
|||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
|
|
|
@ -13,10 +13,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
|
|
|
@ -34,12 +34,9 @@ from transformers.utils.versions import require_version
|
|||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
|
||||
|
||||
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
|
@ -179,13 +176,15 @@ def main():
|
|||
# build model
|
||||
if args.model_name_or_path is None:
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
with ColoInitContext(device=init_dev, dtype=torch.half,
|
||||
with ColoInitContext(device=init_dev,
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ColoInitContext(device=init_dev, dtype=torch.half,
|
||||
with ColoInitContext(device=init_dev,
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
|
@ -198,8 +197,11 @@ def main():
|
|||
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
PLACEMENT_POLICY = 'cpu'
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
|
||||
pin_memory=True, strict_ddp_mode=args.shardinit)
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=PLACEMENT_POLICY,
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=args.shardinit)
|
||||
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
|
||||
|
||||
SEQ_LEN = 1024
|
||||
|
|
|
@ -15,11 +15,9 @@ from torch.utils.data import DataLoader, Dataset
|
|||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import MultiTimer, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
||||
|
||||
# constants
|
||||
|
||||
|
@ -127,7 +125,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
|||
return model
|
||||
|
||||
|
||||
## Parameter Sharding Strategies for Tensor Parallelism
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
@ -232,7 +230,7 @@ if args.distplan == "colossalai":
|
|||
tensor_parallelize(model, pg)
|
||||
model = gemini_zero_dpp(model, pg, args.placement)
|
||||
|
||||
#optimizer
|
||||
# optimizer
|
||||
|
||||
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
|
||||
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)
|
||||
|
|
|
@ -1,69 +1,67 @@
|
|||
import colossalai
|
||||
import math
|
||||
import torch
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
import colossalai.nn as col_nn
|
||||
from arguments import parse_args
|
||||
from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt
|
||||
from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args
|
||||
from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer
|
||||
from utils.logger import Logger
|
||||
from evaluation import evaluate
|
||||
from loss import LossForPretraining
|
||||
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from arguments import parse_args
|
||||
from evaluation import evaluate
|
||||
from loss import LossForPretraining
|
||||
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
||||
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator
|
||||
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
|
||||
from utils.logger import Logger
|
||||
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.tensor import ProcessGroup
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager
|
||||
from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import TensorShardStrategy
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||
|
||||
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||
|
||||
|
||||
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
|
||||
|
||||
|
||||
if args.vscode_debug:
|
||||
colossalai.launch(config={},
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
backend=args.backend)
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
backend=args.backend)
|
||||
args.local_rank = -1
|
||||
args.log_interval = 1
|
||||
else:
|
||||
colossalai.launch_from_torch(args.colossal_config) #args.colossal_config
|
||||
colossalai.launch_from_torch(args.colossal_config) # args.colossal_config
|
||||
args.local_rank = int(os.environ["LOCAL_RANK"])
|
||||
logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
|
||||
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}')
|
||||
logger.info(
|
||||
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
|
||||
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}'
|
||||
)
|
||||
|
||||
log_args(logger, args)
|
||||
args.tokenizer = tokenizer
|
||||
args.logger = logger
|
||||
set_global_variables(launch_time, args.tensorboard_path)
|
||||
|
||||
|
||||
use_zero = hasattr(gpc.config, 'zero')
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
|
@ -71,8 +69,8 @@ def main():
|
|||
if use_zero:
|
||||
shard_strategy = TensorShardStrategy()
|
||||
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
|
||||
shard_param=True):
|
||||
|
||||
config, model, numel = get_model(args, logger)
|
||||
# model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
|
||||
else:
|
||||
|
@ -82,9 +80,10 @@ def main():
|
|||
os.mkdir(os.path.join(args.ckpt_path, launch_time))
|
||||
|
||||
logger.info(f'Model numel: {numel}')
|
||||
|
||||
|
||||
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
|
||||
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
|
||||
# len(dataloader)
|
||||
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
|
||||
total_steps = steps_per_epoch * args.epoch
|
||||
|
||||
# build optimizer and lr_scheduler
|
||||
|
@ -98,18 +97,23 @@ def main():
|
|||
o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1
|
||||
optimizer = get_optimizer(model, lr=args.lr)
|
||||
optimizer.load_state_dict(o_l_state_dict['optimizer'])
|
||||
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch']
|
||||
# o_l_state_dict['lr_scheduler']['last_epoch']
|
||||
lr_scheduler = get_lr_scheduler(optimizer,
|
||||
total_steps=total_steps,
|
||||
last_epoch=o_l_state_dict['lr_scheduler']['last_epoch'])
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
|
||||
# if you want delete the above three code, have to move the model to gpu, because in optimizer.step()
|
||||
lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler'])
|
||||
|
||||
|
||||
start_epoch = o_l_state_dict['epoch']
|
||||
start_shard = o_l_state_dict['shard'] + 1
|
||||
# global_step = o_l_state_dict['global_step'] + 1
|
||||
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
|
||||
logger.info(
|
||||
f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
|
||||
)
|
||||
else:
|
||||
optimizer = get_optimizer(model, lr=args.lr)
|
||||
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
|
||||
|
@ -124,12 +128,11 @@ def main():
|
|||
|
||||
# initialize with colossalai
|
||||
engine, _, _, lr_scheduelr = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init model, '))
|
||||
|
||||
|
||||
best_loss = None
|
||||
eval_loss = 0
|
||||
|
@ -146,13 +149,16 @@ def main():
|
|||
dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
|
||||
# pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
|
||||
if torch.distributed.get_rank() == 0:
|
||||
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1)
|
||||
iterator_data = tqdm(enumerate(dataset_iterator),
|
||||
total=(total_length // args.train_micro_batch_size_per_gpu // world_size),
|
||||
colour='cyan',
|
||||
smoothing=1)
|
||||
else:
|
||||
iterator_data = enumerate(dataset_iterator)
|
||||
|
||||
engine.train()
|
||||
|
||||
for step, batch_data in iterator_data:
|
||||
|
||||
for step, batch_data in iterator_data:
|
||||
|
||||
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
|
||||
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
|
||||
|
@ -162,7 +168,7 @@ def main():
|
|||
# nsp_label = batch_data[5].cuda()
|
||||
|
||||
output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
loss = engine.criterion(output.logits, mlm_label)
|
||||
pretrain_dataset_provider.prefetch_batch()
|
||||
|
||||
|
@ -172,14 +178,15 @@ def main():
|
|||
engine.step()
|
||||
lr_scheduelr.step()
|
||||
engine.zero_grad()
|
||||
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.log_interval == 0 and global_step != 0 \
|
||||
and torch.distributed.get_rank() == 0:
|
||||
and torch.distributed.get_rank() == 0:
|
||||
elapsed_time = timers('interval_time').elapsed(reset=False)
|
||||
elapsed_time_per_iteration = elapsed_time / global_step
|
||||
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size)
|
||||
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(
|
||||
numel, args, config, elapsed_time, global_step, world_size)
|
||||
|
||||
cur_loss = train_loss / args.log_interval
|
||||
current_lr = lr_scheduelr.get_last_lr()[0]
|
||||
|
@ -189,12 +196,13 @@ def main():
|
|||
|
||||
if args.wandb:
|
||||
tensorboard_log = get_tensorboard_writer()
|
||||
tensorboard_log.log_train({
|
||||
'lr': current_lr,
|
||||
'loss': cur_loss,
|
||||
'ppl': math.exp(cur_loss),
|
||||
'mins_batch': elapsed_time_per_iteration
|
||||
}, global_step)
|
||||
tensorboard_log.log_train(
|
||||
{
|
||||
'lr': current_lr,
|
||||
'loss': cur_loss,
|
||||
'ppl': math.exp(cur_loss),
|
||||
'mins_batch': elapsed_time_per_iteration
|
||||
}, global_step)
|
||||
|
||||
train_loss = 0
|
||||
|
||||
|
@ -202,12 +210,14 @@ def main():
|
|||
logger.info('*' * 100)
|
||||
|
||||
eval_loss += evaluate(engine, args, logger, global_step)
|
||||
save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
|
||||
|
||||
|
||||
save_ckpt(engine.model, optimizer, lr_scheduelr,
|
||||
os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
|
||||
shard, global_step)
|
||||
|
||||
eval_loss /= len(os.listdir(args.data_path_prefix))
|
||||
logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \
|
||||
f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
|
||||
logger.info(
|
||||
f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
|
||||
+ f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
|
||||
logger.info('-' * 100)
|
||||
if args.wandb and torch.distributed.get_rank() == 0:
|
||||
tensorboard_log = get_tensorboard_writer()
|
||||
|
|
|
@ -30,24 +30,13 @@ from itertools import chain
|
|||
import datasets
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from accelerate.utils import set_seed
|
||||
from context import barrier_context
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import colossalai
|
||||
import transformers
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import get_current_device, get_dataloader
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
|
@ -61,6 +50,15 @@ from transformers import (
|
|||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import get_current_device, get_dataloader
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
|
|
|
@ -12,10 +12,9 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
|
|||
from colossalai.auto_parallel.offload.solver import NOT_NVML
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||
from tests.test_auto_parallel.test_offload.model_utils import *
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
||||
|
|
|
@ -11,12 +11,11 @@ from colossalai.device.device_mesh import DeviceMesh
|
|||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
|
||||
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
|
|
@ -10,14 +10,14 @@ import torch.distributed as dist
|
|||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.nn.parallel import ColoDDP, ZeroDDP
|
||||
from colossalai.nn.parallel import ColoDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
import copy
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from functools import partial
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel import ColoDDP
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor import ProcessGroup, ColoParameter
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
||||
|
|
|
@ -1,73 +1,73 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_gemini_manager():
|
||||
# reset the manager, in case that there exists memory information left
|
||||
manager = StatefulTensor.GST_MGR
|
||||
manager.reset()
|
||||
|
||||
# occupation 8
|
||||
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
|
||||
# occupation 60
|
||||
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
|
||||
|
||||
# occupation 28
|
||||
t1 = torch.empty(7, device='cuda')
|
||||
# occupation 12
|
||||
t2 = torch.empty(3, device='cpu')
|
||||
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
|
||||
st4 = StatefulTensor(None, TensorState.FREE)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 60
|
||||
assert manager.total_mem['cuda'] == 36
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
|
||||
|
||||
st4.payload_reset(t2)
|
||||
st3.payload_reset(t2)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 84
|
||||
assert manager.total_mem['cuda'] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
|
||||
st1.move_to(torch.device('cpu'))
|
||||
st2.move_to(torch.device('cpu'))
|
||||
st3.move_to(torch.device('cuda', 0))
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
|
||||
st1.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gemini_manager()
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_gemini_manager():
|
||||
# reset the manager, in case that there exists memory information left
|
||||
manager = StatefulTensor.GST_MGR
|
||||
manager.reset()
|
||||
|
||||
# occupation 8
|
||||
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
|
||||
# occupation 60
|
||||
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
|
||||
|
||||
# occupation 28
|
||||
t1 = torch.empty(7, device='cuda')
|
||||
# occupation 12
|
||||
t2 = torch.empty(3, device='cpu')
|
||||
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
|
||||
st4 = StatefulTensor(None, TensorState.FREE)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 60
|
||||
assert manager.total_mem['cuda'] == 36
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
|
||||
|
||||
st4.payload_reset(t2)
|
||||
st3.payload_reset(t2)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 84
|
||||
assert manager.total_mem['cuda'] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
|
||||
st1.move_to(torch.device('cpu'))
|
||||
st2.move_to(torch.device('cpu'))
|
||||
st3.move_to(torch.device('cuda', 0))
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
|
||||
st1.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gemini_manager()
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue