mirror of https://github.com/hpcaitech/ColossalAI
[zero] add warning for ignored parameters (#2446)
parent
39163417a1
commit
2bfeb24308
|
@ -10,13 +10,18 @@ from colossalai.gemini.chunk.search_utils import search_chunk_configuration
|
|||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
|
||||
def safe_div(a, b):
|
||||
if a == 0:
|
||||
return 0
|
||||
return a / b
|
||||
|
||||
|
||||
def init_chunk_manager(model: nn.Module,
|
||||
init_device: Optional[torch.device] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
search_range_mb: Optional[float] = None,
|
||||
min_chunk_size_mb: Optional[float] = None,
|
||||
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
|
||||
|
||||
kwargs_dict = dict()
|
||||
|
||||
if hidden_dim:
|
||||
|
@ -50,7 +55,7 @@ def init_chunk_manager(model: nn.Module,
|
|||
if dist.get_rank() == 0:
|
||||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
|
||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||
sep='',
|
||||
flush=True)
|
||||
dist.barrier()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
|
||||
|
@ -78,8 +79,16 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
if self.clipping_flag:
|
||||
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
|
||||
|
||||
params_list = [p for p in module.parameters() if not is_ddp_ignored(p)]
|
||||
for p, fp32_p in zip(params_list, module.fp32_params):
|
||||
ddp_param_list = []
|
||||
for name, param in module.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
if param.requires_grad:
|
||||
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
|
||||
"You should handle its optimizer update by yourself!")
|
||||
else:
|
||||
ddp_param_list.append(param)
|
||||
|
||||
for p, fp32_p in zip(ddp_param_list, module.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
if chunk_16 not in self.chunk16_set:
|
||||
chunk_16.l2_norm_flag = self.clipping_flag
|
||||
|
@ -290,6 +299,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
fake_params_list = list()
|
||||
|
||||
for param in group['params']:
|
||||
if is_ddp_ignored(param):
|
||||
continue
|
||||
chunk16 = self.chunk_manager.get_chunk(param)
|
||||
range_pair = get_range_pair(chunk16, param)
|
||||
if range_pair[0] >= range_pair[1]:
|
||||
|
|
Loading…
Reference in New Issue