[zero] reorganize zero/gemini folder structure (#3424)

* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
pull/3436/head
ver217 2023-04-04 13:48:16 +08:00 committed by GitHub
parent b09adff724
commit 26b7aac0be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
142 changed files with 1435 additions and 1404 deletions

View File

@ -14,17 +14,16 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
logger = get_dist_logger(__name__)
from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.utils import get_static_torch_model
from .base import Strategy
from .ddp import DDPStrategy
logger = get_dist_logger(__name__)
class ColossalAIStrategy(DDPStrategy):
"""

View File

@ -4,8 +4,8 @@ from typing import Optional, Set
import torch
import torch.nn as nn
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager
from .util import GlobalRuntimeInfo

View File

@ -1,7 +1,10 @@
from typing import List, Dict, Tuple
from typing import Dict, List, Tuple
import torch
from torch.fx import Node
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class Region:
"""
@ -52,15 +55,13 @@ class Region:
Map the parameters in the region to a contiguous memory space.
"""
self.fp16_data = torch.zeros(
self.param_num, dtype=torch.half, device='cuda')
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset:offset +
p_num].view(param.data.shape)
param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num
@ -141,4 +142,4 @@ class Region:
def __update_params_ptr(self) -> None:
for param in self.fp16_params:
begin, end = self.param_to_range[param]
param.data = self.fp16_data[begin:end].view(param.data.shape)
param.data = self.fp16_data[begin:end].view(param.data.shape)

View File

@ -14,12 +14,12 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.gemini.memory_tracer import MemStats
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import _convert_to_coloparam
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
from colossalai.zero.gemini.memory_tracer import MemStats
from .plugin_base import Plugin

View File

@ -10,8 +10,8 @@ from torch.nn.modules.loss import _Loss
from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
from colossalai.gemini.ophooks import BaseOpHook, register_ophooks_recursively
from colossalai.logging import get_dist_logger
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
class Engine:

View File

@ -157,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
return self._move_to_device(mciro_batch_data)
def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.legacy import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model

View File

@ -1,9 +0,0 @@
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .gemini_mgr import GeminiManager
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
__all__ = [
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager',
'search_chunk_configuration'
]

View File

@ -29,13 +29,12 @@ from colossalai.engine.schedule import (
PipelineSchedule,
get_tensor_shape,
)
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero import convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
def get_default_parser():

View File

@ -9,7 +9,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_decrator
from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator
class MoeExperts(nn.Module):

View File

@ -18,7 +18,7 @@ from colossalai.nn.layer.moe.experts import Experts, MoeExperts
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator
@no_shard_zero_decrator(is_replicated=True)

View File

@ -1,15 +0,0 @@
from typing import Any
import torch
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
__all__ = ['GeminiAdamOptimizer']
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)

View File

@ -1,5 +1,5 @@
from .data_parallel import ColoDDP, ZeroDDP
from .gemini_parallel import GeminiDDP
from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper
from .data_parallel import ColoDDP
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper']
__all__ = [
'ColoDDP',
]

View File

@ -1,31 +1,14 @@
import itertools
from collections import OrderedDict
from functools import partial
from typing import Dict, Iterable, List, Optional, Set
from typing import Iterable, Optional, Set
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from colossalai.utils import is_ddp_ignored
from .reducer import Reducer
from .utils import get_static_torch_model
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
def free_storage(data: torch.Tensor) -> None:
@ -189,507 +172,3 @@ class ColoDDP(torch.nn.Module):
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
return self.module.load_state_dict(state_dict, strict)
class ZeroDDP(ColoDDP):
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
Defaults to False.
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self._cast_buffers()
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.gemini_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order = OrderedParamGenerator()
for p in module.parameters():
param_order.append(p)
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
def _post_forward(self):
"""This function is only triggered for inference.
"""
access_list = list(self.chunk_manager.accessed_chunks)
# we need to scatter all accessed chunks and move them to their original places
for chunk in access_list:
if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
else:
assert chunk.can_release
self.chunk_manager.release_chunk(chunk)
first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
def _setup_grads_ptr(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
p.grad = None
def _pre_backward(self):
# set a visit label for all parameters
# the label is used to check whether the parameter is correctly reduced
for param in self.param2name:
if not is_ddp_ignored(param):
setattr(param, "_gemini_reduced", False)
def _post_backward(self):
if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"]
for param in self.param2name:
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
)
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter.")
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
if reduced:
if chunk.is_gathered:
chunk.cuda_global_chunk.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
are shared with other parameters which have been included in the dictionary.
When you need to load the state dict, you should set the argument `strict` to False.
Returns:
dict:
a dictionary containing a whole state of the module
"""
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
"""
get param content from chunks.
Args:
param_list (_type_): a list of torch.nn.Parameters
only_rank_0 (_type_): _description_
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
assert tensor not in param_to_save_data
param_to_save_data[tensor] = record_tensor
del temp_chunk
return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
# get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
p_mapping[p] = record_parameter
for name, param in self.name2param.items():
if param is not None:
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data
# save all buffers
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
"""
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
prefix = ''
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func):
state_key = prefix + param_name
if state_key in state_dict:
input_param = state_dict[state_key]
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if input_param.shape != dest_tensor.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(state_key, input_param.shape,
dest_tensor.shape))
return
try:
with torch.no_grad():
copy_func(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
input_param.size(), ex.args))
elif strict:
missing_keys.append(state_key)
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
load(name, param, param.copy_)
fp32_to_name = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
if p is not None:
name = self.param2name[p]
fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
parameter_name = fp32_to_name[tensor]
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
else:
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
del temp_chunk
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.optim_update()
for name, buf in persistent_buffers.items():
if buf is not None:
load(name, buf, buf.copy_)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state",
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
if input_name not in local_state:
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
def _cast_buffers(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()

View File

@ -1,63 +0,0 @@
from typing import Optional
import torch
from colossalai.gemini.chunk import init_chunk_manager
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import MemStats
from .data_parallel import ZeroDDP
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb,
strict_ddp_flag=strict_ddp_mode)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)

View File

@ -1,41 +1,16 @@
from typing import Tuple
from .gemini import (
ColoInitContext,
GeminiAdamOptimizer,
GeminiDDP,
ZeroDDP,
ZeroOptimizer,
get_static_torch_model,
post_process_colo_init_ctx,
)
from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper
import torch
import torch.nn as nn
from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
from ..nn.optimizer.zero_optimizer import ZeroOptimizer
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
if optimizer_config is None:
optimizer_config = dict()
logger.info(f'model_config is {model_config}', ranks=[0])
if model_config is None:
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer
__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
__all__ = [
'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
]

View File

@ -0,0 +1,11 @@
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .gemini_ddp import GeminiDDP, ZeroDDP
from .gemini_mgr import GeminiManager
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
from .utils import get_static_torch_model
__all__ = [
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
]

View File

@ -3,10 +3,11 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device
from .chunk import Chunk, ChunkFullError, TensorState
class ChunkManager:
"""

View File

@ -5,9 +5,9 @@ import numpy as np
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:

View File

@ -5,10 +5,11 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import search_chunk_configuration
from colossalai.utils import is_ddp_ignored
from .manager import ChunkManager
from .search_utils import search_chunk_configuration
def safe_div(a, b):
if a == 0:

View File

@ -3,10 +3,8 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Union
import torch
from torch import nn
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from .utils import InsertPostInitMethodToModuleSubClasses
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
# find named_params includes replica
@ -89,6 +87,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._default_dist_spec = default_dist_spec
def _register_colo_modules(self):
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding())

View File

@ -0,0 +1,590 @@
import itertools
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
__all__ = [
'ZeroDDP',
'GeminiDDP',
]
class ZeroDDP(ColoDDP):
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
Defaults to False.
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self._cast_buffers()
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.gemini_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order = OrderedParamGenerator()
for p in module.parameters():
param_order.append(p)
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
def _post_forward(self):
"""This function is only triggered for inference.
"""
access_list = list(self.chunk_manager.accessed_chunks)
# we need to scatter all accessed chunks and move them to their original places
for chunk in access_list:
if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
else:
assert chunk.can_release
self.chunk_manager.release_chunk(chunk)
first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
def _setup_grads_ptr(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
p.grad = None
def _pre_backward(self):
# set a visit label for all parameters
# the label is used to check whether the parameter is correctly reduced
for param in self.param2name:
if not is_ddp_ignored(param):
setattr(param, "_gemini_reduced", False)
def _post_backward(self):
if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"]
for param in self.param2name:
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
)
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter.")
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
if reduced:
if chunk.is_gathered:
chunk.cuda_global_chunk.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
are shared with other parameters which have been included in the dictionary.
When you need to load the state dict, you should set the argument `strict` to False.
Returns:
dict:
a dictionary containing a whole state of the module
"""
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
"""
get param content from chunks.
Args:
param_list (_type_): a list of torch.nn.Parameters
only_rank_0 (_type_): _description_
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
assert tensor not in param_to_save_data
param_to_save_data[tensor] = record_tensor
del temp_chunk
return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
# get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
p_mapping[p] = record_parameter
for name, param in self.name2param.items():
if param is not None:
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data
# save all buffers
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
"""
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
prefix = ''
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func):
state_key = prefix + param_name
if state_key in state_dict:
input_param = state_dict[state_key]
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if input_param.shape != dest_tensor.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(state_key, input_param.shape,
dest_tensor.shape))
return
try:
with torch.no_grad():
copy_func(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
input_param.size(), ex.args))
elif strict:
missing_keys.append(state_key)
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
load(name, param, param.copy_)
fp32_to_name = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
if p is not None:
name = self.param2name[p]
fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
parameter_name = fp32_to_name[tensor]
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
else:
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
del temp_chunk
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.optim_update()
for name, buf in persistent_buffers.items():
if buf is not None:
load(name, buf, buf.copy_)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state",
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
if input_name not in local_state:
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
def _cast_buffers(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb,
strict_ddp_flag=strict_ddp_mode)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)

View File

@ -5,10 +5,10 @@ from typing import List
import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager
class TrainingPhase(Enum):

View File

@ -4,10 +4,8 @@ from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from .memory_tracer import ChunkMemStatsCollector
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats
from .placement_policy import PlacementPolicyFactory

View File

@ -10,12 +10,15 @@ from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import ZeroDDP
__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
@ -316,3 +319,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
fake_params_list.append(fake_param)
group['params'] = fake_params_list
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)

View File

@ -1,10 +1,10 @@
from typing import Optional
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import ChunkManager
from .memory_stats import MemStats
from .memstats_collector import MemStatsCollector

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional
import torch
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from .param_runtime_order import OrderedParamGenerator
class MemStats(object):

View File

@ -1,12 +1,7 @@
import time
from typing import List, Optional
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used
from typing import Optional
from .memory_monitor import SyncCudaMemoryMonitor
from .memory_stats import MemStats
@ -49,7 +44,7 @@ class MemStatsCollector:
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
f"step total {self._step_total}"
f"step total {self._step_total}"
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data
@ -75,6 +70,8 @@ class MemStatsCollector:
Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
from colossalai.zero.legacy.gemini import StatefulTensor
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
self._memstats.record_max_cuda_model_data(cuda_mem)

View File

@ -1,9 +1,14 @@
import torch.nn
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from .memory_stats import MemStats
__all__ = ['RuntimeMemTracer']

View File

@ -6,7 +6,7 @@ from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
from colossalai.gemini.chunk import ChunkManager
from colossalai.zero.gemini.chunk import ChunkManager
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor

View File

@ -5,11 +5,12 @@ from typing import Dict, List, Optional, Tuple, Type
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
class PlacementPolicy(ABC):
need_mem_stats: bool = False

View File

@ -6,9 +6,10 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device
from .chunk import Chunk
def get_temp_total_chunk_on_cuda(chunk: Chunk):
if chunk.is_gathered:
@ -77,7 +78,7 @@ def get_static_torch_model(zero_ddp_model,
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero.gemini.gemini_ddp import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)

View File

@ -0,0 +1,44 @@
from typing import Tuple
import torch
import torch.nn as nn
from colossalai.logging import get_dist_logger
from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
from .sharded_model import ShardedModelV2
from .sharded_optim import ShardedOptimizerV2
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
if optimizer_config is None:
optimizer_config = dict()
logger.info(f'model_config is {model_config}', ranks=[0])
if model_config is None:
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer
__all__ = [
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
'no_shard_zero_decrator'
]

View File

@ -0,0 +1,9 @@
from .ophooks import BaseOpHook, register_ophooks_recursively
from .stateful_tensor import StatefulTensor
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy
__all__ = [
'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy',
'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook'
]

View File

@ -1,4 +1,5 @@
import torch
from colossalai.registry import OPHOOKS
from . import BaseOpHook

View File

@ -5,9 +5,9 @@ from typing import List
import torch
from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class TrainingPhase(Enum):

View File

@ -1,9 +1,9 @@
from enum import Enum
from typing import Optional
import torch
from typing import Union
from typing import Optional, Union
from colossalai.gemini.gemini_context import GeminiMemoryManager
import torch
from .gemini_context import GeminiMemoryManager
def sizeof_tensor(tensor: torch.Tensor):
@ -19,7 +19,7 @@ class TensorState(Enum):
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management

View File

@ -1,13 +1,16 @@
import functools
import torch
import types
from colossalai.utils.cuda import get_current_device
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List
from colossalai.logging import get_dist_logger
from time import time
from typing import List
import torch
from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class StatefulTensorMgr(object):

View File

@ -5,11 +5,12 @@ from typing import List, Optional, Type
import torch
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class TensorPlacementPolicy(ABC):

View File

@ -1,6 +1,8 @@
from typing import Tuple, Union
import torch
from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple
from .stateful_tensor import StatefulTensor
def is_storage_empty(tensor: torch.Tensor) -> bool:

View File

@ -13,10 +13,10 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.legacy.sharded_param import ShardedParamV2
@dataclass

View File

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from typing import List, Optional
import torch.distributed as dist
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC):

View File

@ -2,17 +2,18 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
which will fully utilize network bandwidth.
It is especially useful when sub-module contains bias,
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
which will fully utilize network bandwidth.
It is especially useful when sub-module contains bias,
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small).
"""

View File

@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from typing import Tuple
import torch
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""

View File

@ -2,11 +2,12 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.shard_utils.commons import get_shard
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy):
@ -27,7 +28,7 @@ class TensorShardStrategy(BaseShardStrategy):
Args:
t (ShardedTensor): a tensor to be sharded.
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
Defaults to None.
"""
if t.is_sharded:

View File

@ -1,3 +1,3 @@
from .sharded_model_v2 import ShardedModelV2
__all__ = ['ShardedModelV2']
__all__ = ['ShardedModelV2']

View File

@ -1,9 +1,9 @@
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, List, Tuple, Union
import torch
import torch.nn.functional as F
from typing import Union
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float:

View File

@ -13,19 +13,18 @@ from torch.nn.parameter import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.logging import get_dist_logger
from colossalai.utils import disposable, get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.utils import ZeroHook
from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer
from ._utils import (
cast_float_arguments,
@ -35,6 +34,7 @@ from ._utils import (
free_storage,
get_gradient_predivide_factor,
)
from .zero_hook import ZeroHook
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX

View File

@ -1,8 +1,9 @@
import torch
from colossalai.zero.sharded_model import ShardedModelV2
import copy
import torch
from colossalai.zero.legacy.sharded_model import ShardedModelV2
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
"""

View File

@ -3,14 +3,14 @@ from typing import Optional
import torch
import torch.distributed as dist
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.logging import get_dist_logger
from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
@OPHOOKS.register_module

View File

@ -0,0 +1,3 @@
from .sharded_optim_v2 import ShardedOptimizerV2
__all__ = ['ShardedOptimizerV2']

View File

@ -14,13 +14,13 @@ from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.legacy.sharded_model import ShardedModelV2
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32
class OptimState(Enum):

View File

@ -0,0 +1,4 @@
from .sharded_param import ShardedParamV2
from .sharded_tensor import ShardedTensor
__all__ = ['ShardedTensor', 'ShardedParamV2']

View File

@ -1,9 +1,11 @@
from typing import List, Optional, Tuple
import torch
from typing import Optional, Tuple
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from typing import List
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage
from .sharded_tensor import ShardedTensor
EMPTY_TENSOR_DICT = {}

View File

@ -1,5 +1,6 @@
import torch
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
class ShardedTensor(StatefulTensor):

View File

@ -0,0 +1,3 @@
from .low_level_optim import LowLevelZeroOptimizer
__all__ = ['LowLevelZeroOptimizer']

View File

@ -1,4 +0,0 @@
from .low_level_optim import LowLevelZeroOptimizer
from .sharded_optim_v2 import ShardedOptimizerV2
__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer']

View File

@ -1,4 +0,0 @@
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
__all__ = ['ShardedTensor', 'ShardedParamV2']

View File

@ -1,3 +0,0 @@
from .zero_hook import ZeroHook
__all__ = ['ZeroHook']

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional
import torch
import torch.nn as nn
from .gemini_parallel import GeminiDDP
from .gemini import GeminiDDP
def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None):
@ -99,11 +99,11 @@ def zero_optim_wrapper(model: nn.Module,
config_dict['max_scale'] = max_scale
if zero_stage in [1, 2]:
from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer
from colossalai.zero.low_level import LowLevelZeroOptimizer
config_dict['partition_grad'] = zero_stage == 2
config_dict['clip_grad_norm'] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict)
else:
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
config_dict['clipping_norm'] = max_norm
return ZeroOptimizer(optimizer, model, **config_dict)

View File

@ -78,7 +78,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.utils.model.colo_init_context import ColoInitContext
```

View File

@ -77,7 +77,7 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.utils.model.colo_init_context import ColoInitContext
```

View File

@ -5,7 +5,7 @@ torchrun --standalone --nproc_per_node=1 debug.py
from diffusers import AutoencoderKL
import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx
path = "/data/scratch/diffuser/stable-diffusion-v1-4"

View File

@ -21,10 +21,9 @@ import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
disable_existing_loggers()
logger = get_dist_logger()

View File

@ -23,10 +23,9 @@ import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
disable_existing_loggers()
logger = get_dist_logger()

View File

@ -18,7 +18,7 @@ from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, Proc
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext
def set_seed(seed):

View File

@ -19,7 +19,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext
def init_1d_row_for_linear_weight_spec(model, world_size: int):

View File

@ -12,10 +12,9 @@ from transformers import AlbertConfig, AlbertForSequenceClassification, BertConf
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
CAI_VERSION = colossalai.__version__

View File

@ -13,10 +13,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
CAI_VERSION = colossalai.__version__

View File

@ -34,12 +34,9 @@ from transformers.utils.versions import require_version
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
def get_data(batch_size, seq_len, vocab_size):
@ -179,13 +176,15 @@ def main():
# build model
if args.model_name_or_path is None:
logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half,
with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half,
with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
@ -198,8 +197,11 @@ def main():
numel = sum([p.numel() for p in model.parameters()])
PLACEMENT_POLICY = 'cpu'
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
pin_memory=True, strict_ddp_mode=args.shardinit)
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=PLACEMENT_POLICY,
pin_memory=True,
strict_ddp_mode=args.shardinit)
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
SEQ_LEN = 1024

View File

@ -15,11 +15,9 @@ from torch.utils.data import DataLoader, Dataset
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
# constants
@ -127,7 +125,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
return model
## Parameter Sharding Strategies for Tensor Parallelism
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
@ -232,7 +230,7 @@ if args.distplan == "colossalai":
tensor_parallelize(model, pg)
model = gemini_zero_dpp(model, pg, args.placement)
#optimizer
# optimizer
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)

View File

@ -1,69 +1,67 @@
import colossalai
import math
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
import colossalai.nn as col_nn
from arguments import parse_args
from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt
from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args
from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer
from utils.logger import Logger
from evaluation import evaluate
from loss import LossForPretraining
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from tqdm import tqdm
import os
import time
from functools import partial
import torch
from arguments import parse_args
from evaluation import evaluate
from loss import LossForPretraining
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
from tqdm import tqdm
from transformers import AutoTokenizer
from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
from utils.logger import Logger
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils import get_current_device
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ProcessGroup
import colossalai
import colossalai.nn as col_nn
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device
from colossalai.zero import ZeroOptimizer
from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager
from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext
from colossalai.zero.legacy.shard_utils import TensorShardStrategy
def main():
args = parse_args()
launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
if args.vscode_debug:
colossalai.launch(config={},
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
args.local_rank = -1
args.log_interval = 1
else:
colossalai.launch_from_torch(args.colossal_config) #args.colossal_config
colossalai.launch_from_torch(args.colossal_config) # args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}')
logger.info(
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}'
)
log_args(logger, args)
args.tokenizer = tokenizer
args.logger = logger
set_global_variables(launch_time, args.tensorboard_path)
use_zero = hasattr(gpc.config, 'zero')
world_size = torch.distributed.get_world_size()
@ -71,8 +69,8 @@ def main():
if use_zero:
shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy,
shard_param=True):
shard_param=True):
config, model, numel = get_model(args, logger)
# model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
else:
@ -82,9 +80,10 @@ def main():
os.mkdir(os.path.join(args.ckpt_path, launch_time))
logger.info(f'Model numel: {numel}')
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
# len(dataloader)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
total_steps = steps_per_epoch * args.epoch
# build optimizer and lr_scheduler
@ -98,18 +97,23 @@ def main():
o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1
optimizer = get_optimizer(model, lr=args.lr)
optimizer.load_state_dict(o_l_state_dict['optimizer'])
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch']
# o_l_state_dict['lr_scheduler']['last_epoch']
lr_scheduler = get_lr_scheduler(optimizer,
total_steps=total_steps,
last_epoch=o_l_state_dict['lr_scheduler']['last_epoch'])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
# if you want delete the above three code, have to move the model to gpu, because in optimizer.step()
lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler'])
start_epoch = o_l_state_dict['epoch']
start_shard = o_l_state_dict['shard'] + 1
# global_step = o_l_state_dict['global_step'] + 1
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
logger.info(
f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
)
else:
optimizer = get_optimizer(model, lr=args.lr)
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
@ -124,12 +128,11 @@ def main():
# initialize with colossalai
engine, _, _, lr_scheduelr = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
lr_scheduler=lr_scheduler)
optimizer=optimizer,
criterion=criterion,
lr_scheduler=lr_scheduler)
logger.info(get_mem_info(prefix='After init model, '))
best_loss = None
eval_loss = 0
@ -146,13 +149,16 @@ def main():
dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
# pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1)
iterator_data = tqdm(enumerate(dataset_iterator),
total=(total_length // args.train_micro_batch_size_per_gpu // world_size),
colour='cyan',
smoothing=1)
else:
iterator_data = enumerate(dataset_iterator)
engine.train()
for step, batch_data in iterator_data:
for step, batch_data in iterator_data:
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
@ -162,7 +168,7 @@ def main():
# nsp_label = batch_data[5].cuda()
output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = engine.criterion(output.logits, mlm_label)
pretrain_dataset_provider.prefetch_batch()
@ -172,14 +178,15 @@ def main():
engine.step()
lr_scheduelr.step()
engine.zero_grad()
global_step += 1
if global_step % args.log_interval == 0 and global_step != 0 \
and torch.distributed.get_rank() == 0:
and torch.distributed.get_rank() == 0:
elapsed_time = timers('interval_time').elapsed(reset=False)
elapsed_time_per_iteration = elapsed_time / global_step
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size)
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(
numel, args, config, elapsed_time, global_step, world_size)
cur_loss = train_loss / args.log_interval
current_lr = lr_scheduelr.get_last_lr()[0]
@ -189,12 +196,13 @@ def main():
if args.wandb:
tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_train({
'lr': current_lr,
'loss': cur_loss,
'ppl': math.exp(cur_loss),
'mins_batch': elapsed_time_per_iteration
}, global_step)
tensorboard_log.log_train(
{
'lr': current_lr,
'loss': cur_loss,
'ppl': math.exp(cur_loss),
'mins_batch': elapsed_time_per_iteration
}, global_step)
train_loss = 0
@ -202,12 +210,14 @@ def main():
logger.info('*' * 100)
eval_loss += evaluate(engine, args, logger, global_step)
save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
save_ckpt(engine.model, optimizer, lr_scheduelr,
os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
shard, global_step)
eval_loss /= len(os.listdir(args.data_path_prefix))
logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \
f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
logger.info(
f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
+ f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
logger.info('-' * 100)
if args.wandb and torch.distributed.get_rank() == 0:
tensorboard_log = get_tensorboard_writer()

View File

@ -30,24 +30,13 @@ from itertools import chain
import datasets
import torch
import torch.distributed as dist
import transformers
from accelerate.utils import set_seed
from context import barrier_context
from datasets import load_dataset
from packaging import version
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -61,6 +50,15 @@ from transformers import (
)
from transformers.utils.versions import require_version
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())

View File

@ -12,10 +12,9 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.testing import parameterize
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed

View File

@ -11,12 +11,11 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module):

View File

@ -10,14 +10,14 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.nn.parallel import ColoDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
def set_seed(seed):

View File

@ -1,18 +1,19 @@
import copy
from collections import OrderedDict
from functools import partial
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
from colossalai.nn.parallel import ColoDDP
from collections import OrderedDict
from colossalai.tensor import ProcessGroup, ColoParameter
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):

View File

@ -1,73 +1,73 @@
import pytest
import torch
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
@pytest.mark.dist
def test_gemini_manager():
# reset the manager, in case that there exists memory information left
manager = StatefulTensor.GST_MGR
manager.reset()
# occupation 8
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
# occupation 60
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
# occupation 28
t1 = torch.empty(7, device='cuda')
# occupation 12
t2 = torch.empty(3, device='cpu')
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
st4 = StatefulTensor(None, TensorState.FREE)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 60
assert manager.total_mem['cuda'] == 36
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
st4.payload_reset(t2)
st3.payload_reset(t2)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 84
assert manager.total_mem['cuda'] == 8
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
st1.move_to(torch.device('cpu'))
st2.move_to(torch.device('cpu'))
st3.move_to(torch.device('cuda', 0))
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
st1.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.HOLD_AFTER_BWD)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
if __name__ == '__main__':
test_gemini_manager()
import pytest
import torch
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
@pytest.mark.dist
def test_gemini_manager():
# reset the manager, in case that there exists memory information left
manager = StatefulTensor.GST_MGR
manager.reset()
# occupation 8
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
# occupation 60
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
# occupation 28
t1 = torch.empty(7, device='cuda')
# occupation 12
t2 = torch.empty(3, device='cpu')
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
st4 = StatefulTensor(None, TensorState.FREE)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 60
assert manager.total_mem['cuda'] == 36
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
st4.payload_reset(t2)
st3.payload_reset(t2)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 84
assert manager.total_mem['cuda'] == 8
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
st1.move_to(torch.device('cpu'))
st2.move_to(torch.device('cpu'))
st3.move_to(torch.device('cuda', 0))
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
st1.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.HOLD_AFTER_BWD)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
if __name__ == '__main__':
test_gemini_manager()

Some files were not shown because too many files have changed in this diff Show More