mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] initialize the ColoOptimizer (#898)
* [Tensor] activation is an attr of ColoTensor * [Tensor] add optimizer * only detach parameters in context * polish codepull/900/head
parent
676f191532
commit
d16671da75
|
@ -4,8 +4,9 @@ from .op_wrapper import (
|
|||
from .colo_tensor import ColoTensor
|
||||
from .utils import convert_parameter, named_params_with_colotensor
|
||||
from ._ops import *
|
||||
from .optim.colo_optimizer import ColoOptimizer
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
|
||||
'named_params_with_colotensor', 'ShardPattern'
|
||||
'named_params_with_colotensor', 'ShardPattern', 'ColoOptimizer'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
from typing import List, Union, Mapping, Dict, Any
|
||||
|
||||
import torch.optim as optim
|
||||
from torch import Tensor
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
|
||||
|
||||
class ColoOptimizer(optim.Optimizer):
|
||||
|
||||
def __init__(self, named_params: Mapping[str, Union[Tensor, ColoTensor]], optimizer_class, *optimizer_args,
|
||||
**optimizer_kwargs):
|
||||
"""
|
||||
ColoOptimizer collects all tensors in type of ColoTensor and torch.Tensor,
|
||||
then use these tensors as ``params`` for optimizers
|
||||
|
||||
Args:
|
||||
named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
|
||||
of parameters, where key is the parameter key, value is either
|
||||
Tensor or ColoTensor. This usually used in
|
||||
conjunction with model.named_parameters(), the same as PyTorch.
|
||||
optimizer_class (torch.optim.Optimizer): the Optimizer to use
|
||||
locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
|
||||
*optimizer_args: the arguments to initialize the optimizer.
|
||||
**optimizer_kwargs: the key-word arguments to initialize the optimizer.
|
||||
|
||||
"""
|
||||
tensors: List[Tensor] = []
|
||||
for value in named_params.values():
|
||||
tensors.append(value)
|
||||
|
||||
self.named_params = named_params
|
||||
self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
|
||||
self.param_groups = self._optim.param_groups
|
||||
self.state = self._optim.state
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False): # type: ignore[override]
|
||||
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
|
||||
|
||||
Args:
|
||||
set_to_none (bool): instead of setting to zero, set the grads to None.
|
||||
This will in general have lower memory footprint, and can modestly improve performance.
|
||||
However, it changes certain behaviors. For example:
|
||||
1. When the user tries to access a gradient and perform manual ops on it,
|
||||
a None attribute or a Tensor full of 0s will behave differently.
|
||||
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
|
||||
are guaranteed to be None for params that did not receive a gradient.
|
||||
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
|
||||
(in one case it does the step with a gradient of 0 and in the other it skips
|
||||
the step altogether).
|
||||
"""
|
||||
self._optim.zero_grad(set_to_none)
|
||||
|
||||
def step(self, closure=None):
|
||||
r"""Performs a single optimization step (parameter update).
|
||||
|
||||
Args:
|
||||
closure (callable): A closure that reevaluates the model and
|
||||
returns the loss. Optional for most optimizers.
|
||||
|
||||
.. note::
|
||||
Unless otherwise specified, this function should not modify the
|
||||
``.grad`` field of the parameters.
|
||||
"""
|
||||
self._optim.step(closure)
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returned state and param_groups will contain parameter keys
|
||||
instead of parameter indices like torch.optim.Optimizer.
|
||||
"""
|
||||
# TODO: implement state_dict
|
||||
raise NotImplementedError("ColoOptimizer state_dict not implemented yet!")
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any]):
|
||||
r"""Loads the ColoOptimizer state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): ColoOptimizer state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
# TODO: implement load_state_dict
|
||||
raise NotImplementedError("ColoOptimizer load_state_dict not implemented yet!")
|
||||
|
||||
def add_param_group(self, param_group: Any):
|
||||
r"""Add a new param group
|
||||
"""
|
||||
# TODO: implement add_param_group
|
||||
raise NotImplementedError("ColoOptimizer add_param_group not implemented yet!")
|
|
@ -94,9 +94,15 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
save_torch_payload = True if not self._lazy_memory_allocate else False
|
||||
for name, param in name_list:
|
||||
delattr(module, name)
|
||||
|
||||
# detaching tensor is necessary for optimizers.
|
||||
requires_grad = param.requires_grad
|
||||
tensor_detached = param.to(self._device).detach()
|
||||
tensor_detached.requires_grad = requires_grad
|
||||
|
||||
setattr(
|
||||
module, name,
|
||||
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device),
|
||||
ColoTensor.init_from_torch_tensor(tensor=tensor_detached,
|
||||
save_payload=save_torch_payload,
|
||||
is_model_data=True))
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils import ColoInitContext
|
||||
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor
|
||||
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
@ -26,6 +26,7 @@ def set_seed(seed):
|
|||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def run_1d_col_tp():
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
|
@ -98,6 +99,7 @@ def run_1d_col_tp():
|
|||
if i > 5:
|
||||
break
|
||||
|
||||
|
||||
# Test the overrided parameters() and named_parameters() member functions
|
||||
def test_model_parameters():
|
||||
# build a module with 2 Linear, 4 parameters in total.
|
||||
|
@ -127,6 +129,34 @@ def test_model_parameters():
|
|||
assert param_cnt == 2
|
||||
|
||||
|
||||
def test_colo_optimizer():
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
set_seed(1)
|
||||
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
colo_optimizer.zero_grad()
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
||||
def run_1d_row_tp():
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
|
@ -209,4 +239,5 @@ def test_simple_net(world_size):
|
|||
|
||||
if __name__ == '__main__':
|
||||
# test_simple_net()
|
||||
test_model_parameters()
|
||||
# test_model_parameters()
|
||||
test_colo_optimizer()
|
||||
|
|
Loading…
Reference in New Issue