Browse Source

[zero] add warning for ignored parameters (#2446)

pull/2451/head
HELSON 2 years ago committed by GitHub
parent
commit
2bfeb24308
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      colossalai/gemini/chunk/utils.py
  2. 15
      colossalai/nn/optimizer/zero_optimizer.py

9
colossalai/gemini/chunk/utils.py

@ -10,13 +10,18 @@ from colossalai.gemini.chunk.search_utils import search_chunk_configuration
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
def safe_div(a, b):
if a == 0:
return 0
return a / b
def init_chunk_manager(model: nn.Module, def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None, search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None, min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager: filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
kwargs_dict = dict() kwargs_dict = dict()
if hidden_dim: if hidden_dim:
@ -50,7 +55,7 @@ def init_chunk_manager(model: nn.Module,
if dist.get_rank() == 0: if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), "used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)), "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='', sep='',
flush=True) flush=True)
dist.barrier() dist.barrier()

15
colossalai/nn/optimizer/zero_optimizer.py

@ -1,4 +1,5 @@
import math import math
import warnings
from enum import Enum from enum import Enum
from typing import Any, Dict, Set, Tuple from typing import Any, Dict, Set, Tuple
@ -78,8 +79,16 @@ class ZeroOptimizer(ColossalaiOptimizer):
if self.clipping_flag: if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
params_list = [p for p in module.parameters() if not is_ddp_ignored(p)] ddp_param_list = []
for p, fp32_p in zip(params_list, module.fp32_params): for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!")
else:
ddp_param_list.append(param)
for p, fp32_p in zip(ddp_param_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set: if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag chunk_16.l2_norm_flag = self.clipping_flag
@ -290,6 +299,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
fake_params_list = list() fake_params_list = list()
for param in group['params']: for param in group['params']:
if is_ddp_ignored(param):
continue
chunk16 = self.chunk_manager.get_chunk(param) chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param) range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]: if range_pair[0] >= range_pair[1]:

Loading…
Cancel
Save