diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 3541e6a0a..7d3d168bf 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -4,8 +4,9 @@ from .op_wrapper import ( from .colo_tensor import ColoTensor from .utils import convert_parameter, named_params_with_colotensor from ._ops import * +from .optim.colo_optimizer import ColoOptimizer __all__ = [ 'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', - 'named_params_with_colotensor', 'ShardPattern' + 'named_params_with_colotensor', 'ShardPattern', 'ColoOptimizer' ] diff --git a/colossalai/tensor/optim/__init__.py b/colossalai/tensor/optim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/tensor/optim/colo_optimizer.py b/colossalai/tensor/optim/colo_optimizer.py new file mode 100644 index 000000000..52c641594 --- /dev/null +++ b/colossalai/tensor/optim/colo_optimizer.py @@ -0,0 +1,88 @@ +from typing import List, Union, Mapping, Dict, Any + +import torch.optim as optim +from torch import Tensor +from colossalai.tensor.colo_tensor import ColoTensor + + +class ColoOptimizer(optim.Optimizer): + + def __init__(self, named_params: Mapping[str, Union[Tensor, ColoTensor]], optimizer_class, *optimizer_args, + **optimizer_kwargs): + """ + ColoOptimizer collects all tensors in type of ColoTensor and torch.Tensor, + then use these tensors as ``params`` for optimizers + + Args: + named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict + of parameters, where key is the parameter key, value is either + Tensor or ColoTensor. This usually used in + conjunction with model.named_parameters(), the same as PyTorch. + optimizer_class (torch.optim.Optimizer): the Optimizer to use + locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc. + *optimizer_args: the arguments to initialize the optimizer. + **optimizer_kwargs: the key-word arguments to initialize the optimizer. + + """ + tensors: List[Tensor] = [] + for value in named_params.values(): + tensors.append(value) + + self.named_params = named_params + self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs) + self.param_groups = self._optim.param_groups + self.state = self._optim.state + + def zero_grad(self, set_to_none: bool = False): # type: ignore[override] + r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + self._optim.zero_grad(set_to_none) + + def step(self, closure=None): + r"""Performs a single optimization step (parameter update). + + Args: + closure (callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + Unless otherwise specified, this function should not modify the + ``.grad`` field of the parameters. + """ + self._optim.step(closure) + + def state_dict(self) -> Dict[str, Any]: + """ + Returned state and param_groups will contain parameter keys + instead of parameter indices like torch.optim.Optimizer. + """ + # TODO: implement state_dict + raise NotImplementedError("ColoOptimizer state_dict not implemented yet!") + + def load_state_dict(self, state_dict: Mapping[str, Any]): + r"""Loads the ColoOptimizer state. + + Args: + state_dict (dict): ColoOptimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # TODO: implement load_state_dict + raise NotImplementedError("ColoOptimizer load_state_dict not implemented yet!") + + def add_param_group(self, param_group: Any): + r"""Add a new param group + """ + # TODO: implement add_param_group + raise NotImplementedError("ColoOptimizer add_param_group not implemented yet!") diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index ab266b8b1..7efa0c338 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -94,9 +94,15 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): save_torch_payload = True if not self._lazy_memory_allocate else False for name, param in name_list: delattr(module, name) + + # detaching tensor is necessary for optimizers. + requires_grad = param.requires_grad + tensor_detached = param.to(self._device).detach() + tensor_detached.requires_grad = requires_grad + setattr( module, name, - ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), + ColoTensor.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload, is_model_data=True)) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 1561cc177..ba66f1715 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils import ColoInitContext -from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor +from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer from colossalai.context import ParallelMode from colossalai.core import global_context as gpc @@ -26,6 +26,7 @@ def set_seed(seed): torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True + def run_1d_col_tp(): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable('simple_net') @@ -98,6 +99,7 @@ def run_1d_col_tp(): if i > 5: break + # Test the overrided parameters() and named_parameters() member functions def test_model_parameters(): # build a module with 2 Linear, 4 parameters in total. @@ -127,6 +129,34 @@ def test_model_parameters(): assert param_cnt == 2 +def test_colo_optimizer(): + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + set_seed(1) + with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + model = model_builder(checkpoint=True) + + colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) + for i, (data, label) in enumerate(train_dataloader): + colo_optimizer.zero_grad() + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + loss.backward() + colo_optimizer.step() + + if i > 5: + break + + def run_1d_row_tp(): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable('simple_net') @@ -209,4 +239,5 @@ def test_simple_net(world_size): if __name__ == '__main__': # test_simple_net() - test_model_parameters() + # test_model_parameters() + test_colo_optimizer()