|
|
@ -24,12 +24,7 @@ class ColoOptimizer(optim.Optimizer):
|
|
|
|
**optimizer_kwargs: the key-word arguments to initialize the optimizer.
|
|
|
|
**optimizer_kwargs: the key-word arguments to initialize the optimizer.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
tensors: List[Tensor] = []
|
|
|
|
self._optim = optimizer_class([p for n, p in named_params], *optimizer_args, **optimizer_kwargs)
|
|
|
|
for value in named_params.values():
|
|
|
|
|
|
|
|
tensors.append(value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.named_params = named_params
|
|
|
|
|
|
|
|
self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
|
|
|
|
|
|
|
|
self.param_groups = self._optim.param_groups
|
|
|
|
self.param_groups = self._optim.param_groups
|
|
|
|
self.state = self._optim.state
|
|
|
|
self.state = self._optim.state
|
|
|
|
|
|
|
|
|
|
|
@ -68,8 +63,7 @@ class ColoOptimizer(optim.Optimizer):
|
|
|
|
Returned state and param_groups will contain parameter keys
|
|
|
|
Returned state and param_groups will contain parameter keys
|
|
|
|
instead of parameter indices like torch.optim.Optimizer.
|
|
|
|
instead of parameter indices like torch.optim.Optimizer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# TODO: implement state_dict
|
|
|
|
return self._optim.state_dict()
|
|
|
|
raise NotImplementedError("ColoOptimizer state_dict not implemented yet!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict: Mapping[str, Any]):
|
|
|
|
def load_state_dict(self, state_dict: Mapping[str, Any]):
|
|
|
|
r"""Loads the ColoOptimizer state.
|
|
|
|
r"""Loads the ColoOptimizer state.
|
|
|
@ -78,11 +72,9 @@ class ColoOptimizer(optim.Optimizer):
|
|
|
|
state_dict (dict): ColoOptimizer state. Should be an object returned
|
|
|
|
state_dict (dict): ColoOptimizer state. Should be an object returned
|
|
|
|
from a call to :meth:`state_dict`.
|
|
|
|
from a call to :meth:`state_dict`.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# TODO: implement load_state_dict
|
|
|
|
self._optim.load_state_dict(state_dict)
|
|
|
|
raise NotImplementedError("ColoOptimizer load_state_dict not implemented yet!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_param_group(self, param_group: Any):
|
|
|
|
def add_param_group(self, param_group: Any):
|
|
|
|
r"""Add a new param group
|
|
|
|
r"""Add a new param group
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# TODO: implement add_param_group
|
|
|
|
self._optim.add_param_group(param_group)
|
|
|
|
raise NotImplementedError("ColoOptimizer add_param_group not implemented yet!")
|
|
|
|
|
|
|
|