Browse Source

[Optimizer] polish the init method of ColoOptimizer (#1310)

pull/1312/head
Jiarui Fang 2 years ago committed by GitHub
parent
commit
9f10524313
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      colossalai/nn/optimizer/colo_optimizer.py
  2. 9
      colossalai/tensor/process_group.py
  3. 6
      tests/test_tensor/test_model.py
  4. 2
      tests/test_utils/test_colo_checkpoint.py

16
colossalai/nn/optimizer/colo_optimizer.py

@ -24,12 +24,7 @@ class ColoOptimizer(optim.Optimizer):
**optimizer_kwargs: the key-word arguments to initialize the optimizer.
"""
tensors: List[Tensor] = []
for value in named_params.values():
tensors.append(value)
self.named_params = named_params
self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
self._optim = optimizer_class([p for n, p in named_params], *optimizer_args, **optimizer_kwargs)
self.param_groups = self._optim.param_groups
self.state = self._optim.state
@ -68,8 +63,7 @@ class ColoOptimizer(optim.Optimizer):
Returned state and param_groups will contain parameter keys
instead of parameter indices like torch.optim.Optimizer.
"""
# TODO: implement state_dict
raise NotImplementedError("ColoOptimizer state_dict not implemented yet!")
return self._optim.state_dict()
def load_state_dict(self, state_dict: Mapping[str, Any]):
r"""Loads the ColoOptimizer state.
@ -78,11 +72,9 @@ class ColoOptimizer(optim.Optimizer):
state_dict (dict): ColoOptimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# TODO: implement load_state_dict
raise NotImplementedError("ColoOptimizer load_state_dict not implemented yet!")
self._optim.load_state_dict(state_dict)
def add_param_group(self, param_group: Any):
r"""Add a new param group
"""
# TODO: implement add_param_group
raise NotImplementedError("ColoOptimizer add_param_group not implemented yet!")
self._optim.add_param_group(param_group)

9
colossalai/tensor/process_group.py

@ -48,6 +48,7 @@ class ProcessGroup:
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
if not torch.distributed.is_initialized():
self.is_init = False
return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
@ -96,6 +97,7 @@ class ProcessGroup:
self._has_cpu_groups = False
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
self.is_init = True
def set_cpu_groups(self):
if self.has_cpu_groups:
@ -110,8 +112,11 @@ class ProcessGroup:
return self._has_cpu_groups
def __repr__(self):
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
if self.is_init:
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
else:
return "ProcessGroup not initialized"
def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup):

6
tests/test_tensor/test_model.py

@ -33,7 +33,7 @@ def run_1d_hybrid_tp(model_name):
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)
optimizer_torch = ColoOptimizer(model_torch.named_parameters(), torch.optim.SGD, lr=0.1)
# Make two models have the same init params
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
@ -80,7 +80,7 @@ def run_1d_hybrid_tp(model_name):
if rank == 0:
model_torch.train()
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
for i, (data, label) in enumerate(train_dataloader):
@ -170,7 +170,7 @@ def test_colo_optimizer():
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
model = model_builder(checkpoint=True)
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
for i, (data, label) in enumerate(train_dataloader):
colo_optimizer.zero_grad()
data = data.to(get_current_device())

2
tests/test_utils/test_colo_checkpoint.py

@ -117,7 +117,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
model_reload = model_reload.cuda()
model_reload.train()
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
for i, (data, label) in enumerate(train_dataloader):

Loading…
Cancel
Save