|
|
@ -1,4 +1,5 @@ |
|
|
|
import math |
|
|
|
import math |
|
|
|
|
|
|
|
import warnings |
|
|
|
from enum import Enum |
|
|
|
from enum import Enum |
|
|
|
from typing import Any, Dict, Set, Tuple |
|
|
|
from typing import Any, Dict, Set, Tuple |
|
|
|
|
|
|
|
|
|
|
@ -78,8 +79,16 @@ class ZeroOptimizer(ColossalaiOptimizer): |
|
|
|
if self.clipping_flag: |
|
|
|
if self.clipping_flag: |
|
|
|
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" |
|
|
|
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" |
|
|
|
|
|
|
|
|
|
|
|
params_list = [p for p in module.parameters() if not is_ddp_ignored(p)] |
|
|
|
ddp_param_list = [] |
|
|
|
for p, fp32_p in zip(params_list, module.fp32_params): |
|
|
|
for name, param in module.named_parameters(): |
|
|
|
|
|
|
|
if is_ddp_ignored(param): |
|
|
|
|
|
|
|
if param.requires_grad: |
|
|
|
|
|
|
|
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! " |
|
|
|
|
|
|
|
"You should handle its optimizer update by yourself!") |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
ddp_param_list.append(param) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for p, fp32_p in zip(ddp_param_list, module.fp32_params): |
|
|
|
chunk_16 = self.chunk_manager.get_chunk(p) |
|
|
|
chunk_16 = self.chunk_manager.get_chunk(p) |
|
|
|
if chunk_16 not in self.chunk16_set: |
|
|
|
if chunk_16 not in self.chunk16_set: |
|
|
|
chunk_16.l2_norm_flag = self.clipping_flag |
|
|
|
chunk_16.l2_norm_flag = self.clipping_flag |
|
|
@ -290,6 +299,8 @@ class ZeroOptimizer(ColossalaiOptimizer): |
|
|
|
fake_params_list = list() |
|
|
|
fake_params_list = list() |
|
|
|
|
|
|
|
|
|
|
|
for param in group['params']: |
|
|
|
for param in group['params']: |
|
|
|
|
|
|
|
if is_ddp_ignored(param): |
|
|
|
|
|
|
|
continue |
|
|
|
chunk16 = self.chunk_manager.get_chunk(param) |
|
|
|
chunk16 = self.chunk_manager.get_chunk(param) |
|
|
|
range_pair = get_range_pair(chunk16, param) |
|
|
|
range_pair = get_range_pair(chunk16, param) |
|
|
|
if range_pair[0] >= range_pair[1]: |
|
|
|
if range_pair[0] >= range_pair[1]: |
|
|
|