[Tensor] initialize the ColoOptimizer (#898)

* [Tensor] activation is an attr of ColoTensor

* [Tensor] add optimizer

* only detach parameters in context

* polish code
pull/900/head
Jiarui Fang 2022-04-28 15:23:40 +08:00 committed by GitHub
parent 676f191532
commit d16671da75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 130 additions and 4 deletions

View File

@ -4,8 +4,9 @@ from .op_wrapper import (
from .colo_tensor import ColoTensor
from .utils import convert_parameter, named_params_with_colotensor
from ._ops import *
from .optim.colo_optimizer import ColoOptimizer
__all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
'named_params_with_colotensor', 'ShardPattern'
'named_params_with_colotensor', 'ShardPattern', 'ColoOptimizer'
]

View File

View File

@ -0,0 +1,88 @@
from typing import List, Union, Mapping, Dict, Any
import torch.optim as optim
from torch import Tensor
from colossalai.tensor.colo_tensor import ColoTensor
class ColoOptimizer(optim.Optimizer):
def __init__(self, named_params: Mapping[str, Union[Tensor, ColoTensor]], optimizer_class, *optimizer_args,
**optimizer_kwargs):
"""
ColoOptimizer collects all tensors in type of ColoTensor and torch.Tensor,
then use these tensors as ``params`` for optimizers
Args:
named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
of parameters, where key is the parameter key, value is either
Tensor or ColoTensor. This usually used in
conjunction with model.named_parameters(), the same as PyTorch.
optimizer_class (torch.optim.Optimizer): the Optimizer to use
locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
*optimizer_args: the arguments to initialize the optimizer.
**optimizer_kwargs: the key-word arguments to initialize the optimizer.
"""
tensors: List[Tensor] = []
for value in named_params.values():
tensors.append(value)
self.named_params = named_params
self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
self.param_groups = self._optim.param_groups
self.state = self._optim.state
def zero_grad(self, set_to_none: bool = False): # type: ignore[override]
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
Args:
set_to_none (bool): instead of setting to zero, set the grads to None.
This will in general have lower memory footprint, and can modestly improve performance.
However, it changes certain behaviors. For example:
1. When the user tries to access a gradient and perform manual ops on it,
a None attribute or a Tensor full of 0s will behave differently.
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
are guaranteed to be None for params that did not receive a gradient.
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other it skips
the step altogether).
"""
self._optim.zero_grad(set_to_none)
def step(self, closure=None):
r"""Performs a single optimization step (parameter update).
Args:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note::
Unless otherwise specified, this function should not modify the
``.grad`` field of the parameters.
"""
self._optim.step(closure)
def state_dict(self) -> Dict[str, Any]:
"""
Returned state and param_groups will contain parameter keys
instead of parameter indices like torch.optim.Optimizer.
"""
# TODO: implement state_dict
raise NotImplementedError("ColoOptimizer state_dict not implemented yet!")
def load_state_dict(self, state_dict: Mapping[str, Any]):
r"""Loads the ColoOptimizer state.
Args:
state_dict (dict): ColoOptimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# TODO: implement load_state_dict
raise NotImplementedError("ColoOptimizer load_state_dict not implemented yet!")
def add_param_group(self, param_group: Any):
r"""Add a new param group
"""
# TODO: implement add_param_group
raise NotImplementedError("ColoOptimizer add_param_group not implemented yet!")

View File

@ -94,9 +94,15 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload = True if not self._lazy_memory_allocate else False
for name, param in name_list:
delattr(module, name)
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
tensor_detached = param.to(self._device).detach()
tensor_detached.requires_grad = requires_grad
setattr(
module, name,
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device),
ColoTensor.init_from_torch_tensor(tensor=tensor_detached,
save_payload=save_torch_payload,
is_model_data=True))

View File

@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
@ -26,6 +26,7 @@ def set_seed(seed):
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def run_1d_col_tp():
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
@ -98,6 +99,7 @@ def run_1d_col_tp():
if i > 5:
break
# Test the overrided parameters() and named_parameters() member functions
def test_model_parameters():
# build a module with 2 Linear, 4 parameters in total.
@ -127,6 +129,34 @@ def test_model_parameters():
assert param_cnt == 2
def test_colo_optimizer():
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(1)
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
model = model_builder(checkpoint=True)
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
for i, (data, label) in enumerate(train_dataloader):
colo_optimizer.zero_grad()
data = data.to(get_current_device())
label = label.to(get_current_device())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
loss.backward()
colo_optimizer.step()
if i > 5:
break
def run_1d_row_tp():
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
@ -209,4 +239,5 @@ def test_simple_net(world_size):
if __name__ == '__main__':
# test_simple_net()
test_model_parameters()
# test_model_parameters()
test_colo_optimizer()