|
|
|
@ -24,12 +24,7 @@ class ColoOptimizer(optim.Optimizer):
|
|
|
|
|
**optimizer_kwargs: the key-word arguments to initialize the optimizer. |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
tensors: List[Tensor] = [] |
|
|
|
|
for value in named_params.values(): |
|
|
|
|
tensors.append(value) |
|
|
|
|
|
|
|
|
|
self.named_params = named_params |
|
|
|
|
self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs) |
|
|
|
|
self._optim = optimizer_class([p for n, p in named_params], *optimizer_args, **optimizer_kwargs) |
|
|
|
|
self.param_groups = self._optim.param_groups |
|
|
|
|
self.state = self._optim.state |
|
|
|
|
|
|
|
|
@ -68,8 +63,7 @@ class ColoOptimizer(optim.Optimizer):
|
|
|
|
|
Returned state and param_groups will contain parameter keys |
|
|
|
|
instead of parameter indices like torch.optim.Optimizer. |
|
|
|
|
""" |
|
|
|
|
# TODO: implement state_dict |
|
|
|
|
raise NotImplementedError("ColoOptimizer state_dict not implemented yet!") |
|
|
|
|
return self._optim.state_dict() |
|
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict: Mapping[str, Any]): |
|
|
|
|
r"""Loads the ColoOptimizer state. |
|
|
|
@ -78,11 +72,9 @@ class ColoOptimizer(optim.Optimizer):
|
|
|
|
|
state_dict (dict): ColoOptimizer state. Should be an object returned |
|
|
|
|
from a call to :meth:`state_dict`. |
|
|
|
|
""" |
|
|
|
|
# TODO: implement load_state_dict |
|
|
|
|
raise NotImplementedError("ColoOptimizer load_state_dict not implemented yet!") |
|
|
|
|
self._optim.load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
|
def add_param_group(self, param_group: Any): |
|
|
|
|
r"""Add a new param group |
|
|
|
|
""" |
|
|
|
|
# TODO: implement add_param_group |
|
|
|
|
raise NotImplementedError("ColoOptimizer add_param_group not implemented yet!") |
|
|
|
|
self._optim.add_param_group(param_group) |
|
|
|
|