|
|
@ -1,7 +1,7 @@ |
|
|
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch |
|
|
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch |
|
|
|
import copy |
|
|
|
import copy |
|
|
|
import math |
|
|
|
import math |
|
|
|
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union |
|
|
|
from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
import torch |
|
|
|
import torch.distributed as dist |
|
|
|
import torch.distributed as dist |
|
|
@ -195,6 +195,7 @@ class GeminiOptimizer(OptimizerWrapper): |
|
|
|
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) |
|
|
|
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) |
|
|
|
|
|
|
|
|
|
|
|
self._register_states = disposable(self._register_states_) |
|
|
|
self._register_states = disposable(self._register_states_) |
|
|
|
|
|
|
|
self._current_grad_norm: Optional[float] = None |
|
|
|
|
|
|
|
|
|
|
|
def _set_grad_ptr(self): |
|
|
|
def _set_grad_ptr(self): |
|
|
|
for group in self.param_groups: |
|
|
|
for group in self.param_groups: |
|
|
@ -255,6 +256,7 @@ class GeminiOptimizer(OptimizerWrapper): |
|
|
|
|
|
|
|
|
|
|
|
if self.clipping_flag: |
|
|
|
if self.clipping_flag: |
|
|
|
total_norm = self._calc_global_norm() |
|
|
|
total_norm = self._calc_global_norm() |
|
|
|
|
|
|
|
self._current_grad_norm = total_norm |
|
|
|
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm |
|
|
|
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm |
|
|
|
if clip > 1: |
|
|
|
if clip > 1: |
|
|
|
div_scale = clip * div_scale |
|
|
|
div_scale = clip * div_scale |
|
|
@ -846,6 +848,9 @@ class GeminiOptimizer(OptimizerWrapper): |
|
|
|
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0] |
|
|
|
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0] |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_grad_norm(self, norm_type=2, **kwargs): |
|
|
|
|
|
|
|
return self._current_grad_norm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeminiAdamOptimizer(GeminiOptimizer): |
|
|
|
class GeminiAdamOptimizer(GeminiOptimizer): |
|
|
|
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: |
|
|
|
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: |
|
|
|