Browse Source

[plugin] support get_grad_norm (#6115)

pull/5294/merge
Hongxin Liu 2 weeks ago committed by GitHub
parent
commit
a15ab139ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      colossalai/amp/naive_amp/mixed_precision_optimizer.py
  2. 5
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  3. 12
      colossalai/interface/optimizer.py
  4. 7
      colossalai/zero/gemini/gemini_optimizer.py
  5. 5
      colossalai/zero/low_level/low_level_optim.py
  6. 2
      tests/test_booster/test_plugin/test_3d_plugin.py
  7. 2
      tests/test_booster/test_plugin/test_gemini_plugin.py
  8. 2
      tests/test_booster/test_plugin/test_low_level_zero_plugin.py

7
colossalai/amp/naive_amp/mixed_precision_optimizer.py

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import Tensor, inf from torch import Tensor, inf
@ -84,6 +84,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.master_to_working_map[master_p] = p self.master_to_working_map[master_p] = p
master_params.append(master_p) master_params.append(master_p)
group["params"] = master_params group["params"] = master_params
self._current_grad_norm: Optional[float] = None
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss) loss = self.mixed_precision.pre_backward(loss)
@ -187,6 +188,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if p.grad is not None if p.grad is not None
] ]
total_norm = self._compute_grad_norm(param_gradient_pairs) total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
self._unscale_and_clip_grads(total_norm) self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs) self.optim.step(*args, **kwargs)
@ -212,3 +214,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()} return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm

5
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -293,6 +293,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self.pp_pg = pp_process_group self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
self._current_grad_norm: Optional[float] = None
super().__init__(optim) super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, *args, **kwargs):
@ -364,6 +365,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
] ]
total_norm = self._compute_grad_norm(param_gradient_pairs) total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
# Clip the gradients to prevent exploding gradients. # Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm) self._clip_grad_norm(total_norm)
@ -477,6 +479,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
def get_master_to_working_map(self): def get_master_to_working_map(self):
return None return None
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__( def __init__(

12
colossalai/interface/optimizer.py

@ -135,6 +135,18 @@ class OptimizerWrapper:
""" """
return self.optim return self.optim
def get_grad_norm(self, norm_type: Union[float, int] = 2.0, **kwargs) -> Optional[float]:
"""
Returns the gradient norm of an iterable of parameters. This method should be called after optimizer.step().
Args:
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
Optional[float]: Total norm of the gradients (viewed as a single vector). If there are no valid gradients, returns None.
"""
raise NotImplementedError("The method get_grad_norm is not implemented yet.")
class DistributedOptim(Optimizer): class DistributedOptim(Optimizer):
def setup_distributed( def setup_distributed(

7
colossalai/zero/gemini/gemini_optimizer.py

@ -1,7 +1,7 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch # this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy import copy
import math import math
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -195,6 +195,7 @@ class GeminiOptimizer(OptimizerWrapper):
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
self._register_states = disposable(self._register_states_) self._register_states = disposable(self._register_states_)
self._current_grad_norm: Optional[float] = None
def _set_grad_ptr(self): def _set_grad_ptr(self):
for group in self.param_groups: for group in self.param_groups:
@ -255,6 +256,7 @@ class GeminiOptimizer(OptimizerWrapper):
if self.clipping_flag: if self.clipping_flag:
total_norm = self._calc_global_norm() total_norm = self._calc_global_norm()
self._current_grad_norm = total_norm
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1: if clip > 1:
div_scale = clip * div_scale div_scale = clip * div_scale
@ -846,6 +848,9 @@ class GeminiOptimizer(OptimizerWrapper):
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0] f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
) )
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
class GeminiAdamOptimizer(GeminiOptimizer): class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:

5
colossalai/zero/low_level/low_level_optim.py

@ -218,6 +218,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
) )
elif self._dtype is torch.bfloat16: elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin() self.mixed_precision_mixin = BF16MixedPrecisionMixin()
self._current_grad_norm: Optional[float] = None
def __del__(self): def __del__(self):
for hook in self.grad_handles: for hook in self.grad_handles:
@ -551,6 +552,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# unscale and clip grads # unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups) global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._current_grad_norm = global_norm
self._unscale_and_clip_grads(grad_partition_groups, global_norm) self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# update the parameters # update the parameters
@ -934,3 +936,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _force_wait_all_gather(self): def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys(): for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param) wait_all_gather_handle(param)
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm

2
tests/test_booster/test_plugin/test_3d_plugin.py

@ -76,6 +76,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True) booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
optimizer.step() optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)
except Exception as e: except Exception as e:
return repr(e) return repr(e)

2
tests/test_booster/test_plugin/test_gemini_plugin.py

@ -54,6 +54,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)
except NotImplementedError: except NotImplementedError:
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")

2
tests/test_booster/test_plugin/test_low_level_zero_plugin.py

@ -50,6 +50,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)
except Exception as e: except Exception as e:
return repr(e) return repr(e)

Loading…
Cancel
Save