diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index e3861c84f..b7fef99b4 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,10 +1,10 @@ -from typing import Dict, Iterator, Optional, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Tuple, Union import torch from torch import nn from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module -from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup, ShardSpec +from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup from .utils import InsertPostInitMethodToModuleSubClasses @@ -26,6 +26,34 @@ def _named_params_with_replica( yield name, val +def _convert_to_coloparam(param: torch.nn.Parameter, + device: torch.device, + dtype=torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec: Optional[Any] = None) -> ColoParameter: + + if isinstance(param, ColoParameter): + return param + # detaching tensor is necessary for optimizers. + requires_grad = param.requires_grad + # param is the global tensor. + colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad) + + # if default_shard_plan exists, shard the param during initialization. + # This can reduce the model size after initialization. + # NOTE() embedding usually can not be correctly sharded. So I use except to handle + # the param that can not be sharded by the default plan + if default_pg is not None: + colo_param.set_process_group(default_pg) + + if default_dist_spec is not None: + try: + colo_param.set_dist_spec(default_dist_spec) + except: + pass + return colo_param + + def ColoModulize(module): """ Replacing the parameters() and named_parameters() with our customized ones @@ -94,26 +122,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): if param in replaced_tensors: colo_param = replaced_tensors[param] else: - # detaching tensor is necessary for optimizers. - requires_grad = param.requires_grad - - # param is the global tensor. - colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype), - requires_grad=requires_grad) - - # if default_shard_plan exists, shard the param during initialization. - # This can reduce the model size after initialization. - # NOTE() embedding usually can not be correctly sharded. So I use except to handle - # the param that can not be sharded by the default plan - if self._default_pg is not None: - colo_param.set_process_group(self._default_pg) - - if self._default_dist_spec is not None: - try: - colo_param.set_dist_spec(self._default_dist_spec) - except: - pass - + colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg, + self._default_dist_spec) replaced_tensors[param] = colo_param delattr(submodule, param_name) setattr(submodule, param_name, colo_param) @@ -121,3 +131,39 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): module.to(self._device) ColoModulize(module) + + +def post_process_colo_init_ctx(model: torch.nn.Module, + device: torch.device = torch.device('cpu'), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None): + """post_process_colo_init_ctx + + This function is called after `ColoInitContext`. + + Args: + model (torch.nn.module): the model + device (torch.device, optional): device type of the model params. Defaults to torch.device('cpu'). + dtype (torch.dtype, optional): dtype of the model params. Defaults to torch.float. + default_pg (Optional[ProcessGroup], optional): default process group. Defaults to None. Inidicates a DP-only process group. + default_dist_spec (Any, optional): default dist spec of params. Defaults to None. + + Raises: + RuntimeError: raise error if + """ + + torch_params = [] + for n, p in model.named_parameters(): + if not isinstance(p, ColoParameter): + print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") + torch_params.append((n, p)) + + for (n, param) in torch_params: + delattr(model, n) + setattr(model, n, _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec)) + + del torch_params + for n, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + raise RuntimeError diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index cd2d7155f..5789d2991 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -15,10 +15,11 @@ from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed @@ -40,8 +41,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): # 'gpt2', 'bert', -TEST_MODELS = ['gpt2', 'bert'] -EXAMPLE_MODELS = ['simple_net'] +TEST_MODELS = ['no_leaf_module', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers'] @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @@ -57,8 +57,12 @@ def exam_model_step(placement_policy, model_name: str): torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - with ColoInitContext(device=get_current_device()): + init_dev = get_current_device() + with ColoInitContext(device=init_dev): model = model_builder() + + post_process_colo_init_ctx(model, device=init_dev) + for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) @@ -99,7 +103,7 @@ def exam_model_step(placement_policy, model_name: str): @parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('model_name', EXAMPLE_MODELS) +@parameterize('model_name', TEST_MODELS) def exam_tiny_example(placement_policy, model_name: str): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -111,8 +115,12 @@ def exam_tiny_example(placement_policy, model_name: str): torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - with ColoInitContext(device=get_current_device()): + init_dev = get_current_device() + with ColoInitContext(device=init_dev): model = model_builder() + + post_process_colo_init_ctx(model, device=init_dev) + for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data)