Browse Source

[plugin] support get_grad_norm (#6115)

pull/5294/merge
Hongxin Liu 2 weeks ago committed by GitHub
parent
commit
a15ab139ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      colossalai/amp/naive_amp/mixed_precision_optimizer.py
  2. 5
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  3. 12
      colossalai/interface/optimizer.py
  4. 7
      colossalai/zero/gemini/gemini_optimizer.py
  5. 5
      colossalai/zero/low_level/low_level_optim.py
  6. 2
      tests/test_booster/test_plugin/test_3d_plugin.py
  7. 2
      tests/test_booster/test_plugin/test_gemini_plugin.py
  8. 2
      tests/test_booster/test_plugin/test_low_level_zero_plugin.py

7
colossalai/amp/naive_amp/mixed_precision_optimizer.py

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor, inf
@ -84,6 +84,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group["params"] = master_params
self._current_grad_norm: Optional[float] = None
def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
@ -187,6 +188,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
@ -212,3 +214,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm

5
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -293,6 +293,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
self._current_grad_norm: Optional[float] = None
super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs):
@ -364,6 +365,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
# Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm)
@ -477,6 +479,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
def get_master_to_working_map(self):
return None
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(

12
colossalai/interface/optimizer.py

@ -135,6 +135,18 @@ class OptimizerWrapper:
"""
return self.optim
def get_grad_norm(self, norm_type: Union[float, int] = 2.0, **kwargs) -> Optional[float]:
"""
Returns the gradient norm of an iterable of parameters. This method should be called after optimizer.step().
Args:
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
Optional[float]: Total norm of the gradients (viewed as a single vector). If there are no valid gradients, returns None.
"""
raise NotImplementedError("The method get_grad_norm is not implemented yet.")
class DistributedOptim(Optimizer):
def setup_distributed(

7
colossalai/zero/gemini/gemini_optimizer.py

@ -1,7 +1,7 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import math
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union
import torch
import torch.distributed as dist
@ -195,6 +195,7 @@ class GeminiOptimizer(OptimizerWrapper):
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
self._register_states = disposable(self._register_states_)
self._current_grad_norm: Optional[float] = None
def _set_grad_ptr(self):
for group in self.param_groups:
@ -255,6 +256,7 @@ class GeminiOptimizer(OptimizerWrapper):
if self.clipping_flag:
total_norm = self._calc_global_norm()
self._current_grad_norm = total_norm
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
@ -846,6 +848,9 @@ class GeminiOptimizer(OptimizerWrapper):
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:

5
colossalai/zero/low_level/low_level_optim.py

@ -218,6 +218,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
self._current_grad_norm: Optional[float] = None
def __del__(self):
for hook in self.grad_handles:
@ -551,6 +552,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._current_grad_norm = global_norm
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# update the parameters
@ -934,3 +936,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm

2
tests/test_booster/test_plugin/test_3d_plugin.py

@ -76,6 +76,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)
except Exception as e:
return repr(e)

2
tests/test_booster/test_plugin/test_gemini_plugin.py

@ -54,6 +54,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
booster.backward(loss, optimizer)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)
except NotImplementedError:
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")

2
tests/test_booster/test_plugin/test_low_level_zero_plugin.py

@ -50,6 +50,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
booster.backward(loss, optimizer)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)
except Exception as e:
return repr(e)

Loading…
Cancel
Save