mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
554 lines
28 KiB
554 lines
28 KiB
from typing import Dict |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
from colossalai.interface.optimizer import DistributedOptim |
|
from colossalai.shardformer.layer._operation import _gather, _split |
|
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor |
|
|
|
|
|
class DistributedCAME(DistributedOptim): |
|
"""Implements CAME algorithm. |
|
This implementation is based on: |
|
`CAME: Confidence-guided Adaptive Memory Efficient Optimization` |
|
Args: |
|
params (iterable): iterable of parameters to optimize or dicts defining |
|
parameter groups |
|
lr (float, optional): external learning rate (default: None) |
|
eps (tuple[float, float]): regularization constants for square gradient |
|
and instability respectively (default: (1e-30, 1e-16)) |
|
clip_threshold (float): threshold of root-mean-square of |
|
final gradient update (default: 1.0) |
|
betas (tuple[float, float, float]): coefficient used for computing running averages of |
|
update, square gradient and instability (default: (0.9, 0.999, 0.9999))) |
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
params, |
|
lr=None, |
|
eps=(1e-30, 1e-16), |
|
clip_threshold=1.0, |
|
betas=(0.9, 0.999, 0.9999), |
|
weight_decay=0.0, |
|
): |
|
defaults = dict( |
|
lr=lr, |
|
eps=eps, |
|
clip_threshold=clip_threshold, |
|
betas=betas, |
|
weight_decay=weight_decay, |
|
) |
|
|
|
self.tp_size = 1 |
|
self.tp_group = None |
|
self.dp_size = 1 |
|
self.dp_group = None |
|
self.shard_to_working_param = None # Dict{id:shape}, sample {id(param): torch.tensor} |
|
self.use_zero = True |
|
|
|
self.param_is_dtensor_dict = {} # {id(p): True/False} |
|
self.grad_shape_dict = {} # {id(p): master param shape} |
|
self.factored_dict = {} # {id(p): True/False} |
|
self.use_first_moment_dict = {} # {id(p): True/False} |
|
self.shard_spec_dict = {} # {id(p): ShardSpec} |
|
|
|
super(DistributedCAME, self).__init__(params, defaults) |
|
|
|
@property |
|
def supports_memory_efficient_fp16(self): |
|
return True |
|
|
|
@property |
|
def supports_flat_params(self): |
|
return False |
|
|
|
def setup_distributed( |
|
self, |
|
tp_group: dist.ProcessGroup = None, |
|
dp_group: dist.ProcessGroup = None, |
|
shard_to_working_param: Dict = {}, |
|
padding_map=None, |
|
use_zero: bool = True, |
|
) -> None: |
|
""" |
|
Inject features to the Optimizer |
|
|
|
Args: |
|
tp_group: The devices group for tensor parallel; |
|
dp_group: The devices group for data parallel; |
|
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. |
|
This maps from id(view) to working params used in forward & backward. |
|
padding_map: Interface placeholder |
|
use_zero: Whether or not to use zero; |
|
|
|
""" |
|
self.tp_group = tp_group # "Expected row process group" |
|
self.dp_group = dp_group |
|
if self.tp_group is not None: |
|
self.tp_size = dist.get_world_size(self.tp_group) |
|
if self.dp_group is not None: |
|
self.dp_size = dist.get_world_size(self.dp_group) |
|
self.use_zero = use_zero |
|
|
|
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {} |
|
# grad is None, cause we dont setup now |
|
for group in self.param_groups: |
|
for p in group["params"]: |
|
# w/o ZeRO: master param = working param |
|
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p) |
|
self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_working_param[id(p)]) |
|
self.grad_shape_dict[id(p)] = self.shard_to_working_param[id(p)].shape |
|
# Avoid row parallel lead H=1, then factored param is determined as not factored; |
|
if self.param_is_dtensor_dict[id(p)]: |
|
self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_working_param[id(p)]) |
|
if self.shard_spec_dict[id(p)].sharding_sequence[0] == "R": |
|
self.factored_dict[id(p)] = True |
|
elif self.shard_spec_dict[id(p)].sharding_sequence[-1] == "R": |
|
self.factored_dict[id(p)] = True |
|
else: |
|
self.factored_dict[id(p)] = self._get_options(self.grad_shape_dict[id(p)]) |
|
|
|
else: |
|
self.shard_spec_dict[id(p)] = None |
|
self.factored_dict[id(p)] = self._get_options(self.grad_shape_dict[id(p)]) |
|
|
|
@staticmethod |
|
def _get_options(param_shape): |
|
factored = len(param_shape) >= 2 |
|
return factored |
|
|
|
@staticmethod |
|
def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group): |
|
tensor_sum = tensor.pow(2).sum() |
|
num_of_element = tensor.numel() |
|
|
|
if param_is_dtensor: |
|
# reduce tensor_sum from tp_group |
|
dist.all_reduce(tensor_sum, group=tp_group) |
|
num_of_element = num_of_element * tp_size |
|
if use_zero: |
|
dist.all_reduce(tensor_sum, group=dp_group) |
|
num_of_element = num_of_element * dp_size |
|
rms = (tensor_sum / num_of_element).sqrt() |
|
return rms |
|
|
|
@staticmethod |
|
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): |
|
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) |
|
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() |
|
return torch.mul(r_factor, c_factor) |
|
|
|
# approx_sq_grad for row parallel weight |
|
@staticmethod |
|
def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): |
|
r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) |
|
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() |
|
return torch.mul(r_factor, c_factor) |
|
|
|
def _col_parallel_factor(self, update, grad, state_row, state_col, grad_shape, beta2t): |
|
if grad_shape[0] % self.dp_size != 0: |
|
# gather update[flatten] along dp group then reshape to [H, W/tp] |
|
update = _gather(input_=update, dim=-1, process_group=self.dp_group) |
|
update_reshape = update.view(-1, grad_shape[1]) |
|
# gather grad[flatten] along dp group then reshape to [H, W/tp] |
|
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) |
|
grad_reshape = grad.view(-1, grad_shape[1]) |
|
exp_avg_sq_row = state_row # [H] |
|
exp_avg_sq_col = state_col # [W/tp] |
|
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
update_reshape.mul_(grad_reshape) |
|
else: |
|
update_reshape = update.view(-1, grad_shape[1]) |
|
grad_reshape = grad.view(-1, grad_shape[1]) |
|
exp_avg_sq_row = state_row # [H] |
|
exp_avg_sq_col = state_col # [W/tp] |
|
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
dist.all_reduce(exp_avg_sq_row, group=self.tp_group) |
|
exp_avg_sq_row.div_(self.tp_size) |
|
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
update_reshape.mul_(grad_reshape) |
|
|
|
if self.use_zero: |
|
update = update_reshape.view(-1) |
|
else: |
|
update = update_reshape |
|
return update |
|
|
|
def _row_parallel_factor(self, update, grad, state_row, state_col, grad_shape, beta2t): |
|
if grad_shape[0] % self.dp_size != 0: |
|
# gather update[flatten] along dp group then reshape to [H/tp, W] |
|
update = _gather(input_=update, dim=-1, process_group=self.dp_group) |
|
# view update to origin[tp] shape |
|
update_reshape = update.view(-1, grad_shape[1]) |
|
# gather grad[flatten] along dp group then reshape to [H/tp, W] |
|
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) |
|
grad_reshape = grad.view(-1, grad_shape[1]) |
|
exp_avg_sq_row = state_row # [H] |
|
exp_avg_sq_col = state_col # [W/tp] |
|
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
# reduce col |
|
dist.all_reduce(exp_avg_sq_col, group=self.tp_group) |
|
exp_avg_sq_col.div_(self.tp_size) |
|
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
update_reshape.mul_(grad_reshape) |
|
if self.use_zero: |
|
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group) |
|
else: |
|
update = update_reshape |
|
else: |
|
update_reshape = update.view(-1, grad_shape[1]) |
|
grad_reshape = grad.view(-1, grad_shape[1]) |
|
exp_avg_sq_row = state_row # [H] |
|
exp_avg_sq_col = state_col # [W/tp] |
|
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
# reduce col |
|
dist.all_reduce(exp_avg_sq_col, group=self.tp_group) |
|
exp_avg_sq_col.div_(self.tp_size) |
|
# gather row |
|
exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tp_group) |
|
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) |
|
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) |
|
update_reshape.mul_(grad_reshape) |
|
if self.use_zero: |
|
update = update_reshape.view(-1) |
|
else: |
|
update = update_reshape |
|
return update |
|
|
|
def _base_factor(self, update, grad, state_row, state_col, grad_shape, beta2t): |
|
if self.use_zero: |
|
# only zero |
|
# [30522, 128], [2, 128] |
|
if grad_shape[0] % self.dp_size != 0: |
|
# view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) |
|
# row mean no change |
|
# col mean need reduce and div |
|
# gather update[flatten] along dp group then reshape to [H, W] |
|
update = _gather(input_=update, dim=-1, process_group=self.dp_group) |
|
# view update to origin[tp] shape |
|
update_reshape = update.view(-1, grad_shape[1]) |
|
# gather grad[flatten] along dp group then reshape to [H, W] |
|
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) |
|
grad_reshape = grad.view(-1, grad_shape[1]) |
|
exp_avg_sq_row = state_row # [H/dp] |
|
exp_avg_sq_col = state_col # [W] |
|
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
# reduce col |
|
dist.all_reduce(exp_avg_sq_col, group=self.tp_group) |
|
exp_avg_sq_col.div_(self.tp_size) |
|
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
update_reshape.mul_(grad_reshape) |
|
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group) |
|
else: |
|
# no residual row |
|
# view update to origin[tp] shape |
|
update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] |
|
grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] |
|
exp_avg_sq_row = state_row # [H/dp] |
|
exp_avg_sq_col = state_col # [W] |
|
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
# reduce col |
|
dist.all_reduce(exp_avg_sq_col, group=self.tp_group) |
|
exp_avg_sq_col.div_(self.tp_size) |
|
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
update_reshape.mul_(grad_reshape) |
|
update = update_reshape.view(-1) |
|
else: |
|
# # base factor; no tp, no dp |
|
exp_avg_sq_row = state_row # [H/dp] |
|
exp_avg_sq_col = state_col # [W] |
|
# Exponential average of row indexes |
|
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
# Exponential average of columns indexes |
|
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
# Approximation of exponential moving average of square of gradient |
|
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
update.mul_(grad) |
|
return update |
|
|
|
# factor |
|
def _base_res_factor(self, res, exp_avg, state_row, state_col, grad_shape, beta2t): |
|
if self.use_zero: |
|
# only zero |
|
if grad_shape[0] % self.dp_size != 0: |
|
# view res to origin shape res.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) |
|
# row mean no change |
|
# col mean need reduce and div |
|
# gather res[flatten] along dp group then reshape to [H, W] |
|
res = _gather(input_=res, dim=-1, process_group=self.dp_group) |
|
# view res to origin[tp] shape |
|
res_reshape = res.view(-1, grad_shape[1]) |
|
# gather exp_avg[flatten] along dp group then reshape to [H, W] |
|
exp_avg = _gather(input_=exp_avg, dim=-1, process_group=self.dp_group) |
|
exp_avg_reshape = exp_avg.view(-1, grad_shape[1]) |
|
exp_avg_sq_row = state_row # [H/dp] |
|
exp_avg_sq_col = state_col # [W] |
|
exp_avg_sq_row.mul_(beta2t).add_(res_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
exp_avg_sq_col.mul_(beta2t).add_(res_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
# reduce col |
|
dist.all_reduce(exp_avg_sq_col, group=self.tp_group) |
|
exp_avg_sq_col.div_(self.tp_size) |
|
res_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
res_reshape.mul_(exp_avg_reshape) |
|
res = _split(input_=res_reshape.view(-1), dim=-1, process_group=self.dp_group) |
|
else: |
|
# no residual row |
|
# view res to origin[tp] shape |
|
res_reshape = res.view(-1, grad_shape[1]) # [H/dp, W] |
|
exp_avg_reshape = exp_avg.view(-1, grad_shape[1]) # [H/dp, W] |
|
exp_avg_sq_row = state_row # [H/dp] |
|
exp_avg_sq_col = state_col # [W] |
|
exp_avg_sq_row.mul_(beta2t).add_(res_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
exp_avg_sq_col.mul_(beta2t).add_(res_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
# reduce col |
|
dist.all_reduce(exp_avg_sq_col, group=self.tp_group) |
|
exp_avg_sq_col.div_(self.tp_size) |
|
res_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
res_reshape.mul_(exp_avg_reshape) |
|
res = res_reshape.view(-1) |
|
else: |
|
# # base factor; no tp, no dp |
|
exp_avg_sq_row = state_row # [H/dp] |
|
exp_avg_sq_col = state_col # [W] |
|
# Exponential average of row indexes |
|
exp_avg_sq_row.mul_(beta2t).add_(res.mean(dim=-1), alpha=(1.0 - beta2t)) |
|
# Exponential average of columns indexes |
|
exp_avg_sq_col.mul_(beta2t).add_(res.mean(dim=-2), alpha=(1.0 - beta2t)) |
|
# Approximation of exponential moving average of square of gradient |
|
res = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
|
res.mul_(exp_avg) |
|
return res |
|
|
|
@torch.no_grad() |
|
def step(self, closure=None): |
|
"""Performs a single optimization step. |
|
Args: |
|
closure (callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in group["params"]: |
|
if p.grad is None: |
|
continue |
|
grad = p.grad |
|
if grad.is_sparse: |
|
raise RuntimeError("CAME does not support sparse gradients.") |
|
|
|
state = self.state[p] |
|
# Under zero the grad_shape is the original grad that is flattened and then cut (only one dimension) |
|
grad_shape = grad.shape |
|
grad_shape = self.grad_shape_dict[id(p)] |
|
param_is_dtensor = self.param_is_dtensor_dict[id(p)] |
|
if param_is_dtensor: |
|
grad_shape = self.shard_to_working_param.get(id(p)).shape # tp shape (2 dim) |
|
factored = self.factored_dict[id(p)] |
|
shard_spec = self.shard_spec_dict[id(p)] |
|
|
|
# State Initialization |
|
if len(state) == 0: |
|
state["step"] = 0 |
|
state["exp_avg"] = torch.zeros_like(p) |
|
if factored: |
|
if param_is_dtensor: |
|
if shard_spec.sharding_sequence[0] == "R": # Col Parallel |
|
if grad_shape[0] % self.dp_size != 0: |
|
state["exp_avg_sq_row"] = torch.zeros( |
|
grad_shape[0], device=p.device, dtype=p.dtype |
|
) # [H] |
|
state["exp_avg_res_row"] = torch.zeros( |
|
grad_shape[0], device=p.device, dtype=p.dtype |
|
) # [H] |
|
else: |
|
state["exp_avg_sq_row"] = torch.zeros( |
|
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype |
|
) # [H/dp] |
|
state["exp_avg_res_row"] = torch.zeros( |
|
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype |
|
) # [H/dp] |
|
state["exp_avg_sq_col"] = torch.zeros( |
|
grad_shape[1], device=p.device, dtype=p.dtype |
|
) # [W/TP] |
|
state["exp_avg_res_col"] = torch.zeros( |
|
grad_shape[1], device=p.device, dtype=p.dtype |
|
) # [W/TP] |
|
|
|
if shard_spec.sharding_sequence[-1] == "R": # Row Parallel |
|
# Row indivisible shape situation |
|
if grad_shape[0] % self.dp_size != 0: |
|
state["exp_avg_sq_row"] = torch.zeros( |
|
grad_shape[0], device=p.device, dtype=p.dtype |
|
) # [H/tp] |
|
state["exp_avg_res_row"] = torch.zeros( |
|
grad_shape[0], device=p.device, dtype=p.dtype |
|
) # [H/tp] |
|
else: |
|
state["exp_avg_sq_row"] = torch.zeros( |
|
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype |
|
) # [H/dp/tp] |
|
state["exp_avg_res_row"] = torch.zeros( |
|
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype |
|
) # [H/dp/tp] |
|
|
|
state["exp_avg_sq_col"] = torch.zeros( |
|
grad_shape[1], device=p.device, dtype=p.dtype |
|
) # [W] |
|
state["exp_avg_res_col"] = torch.zeros( |
|
grad_shape[1], device=p.device, dtype=p.dtype |
|
) # [W] |
|
else: |
|
if self.use_zero: |
|
if grad_shape[0] % self.dp_size != 0: |
|
# save all exp_avg_sq_row [H] |
|
state["exp_avg_sq_row"] = torch.zeros( |
|
grad_shape[0], device=grad.device, dtype=p.dtype |
|
) |
|
state["exp_avg_res_row"] = torch.zeros( |
|
grad_shape[0], device=grad.device, dtype=p.dtype |
|
) |
|
else: |
|
# exp_avg_sq_row [H // dp] |
|
state["exp_avg_sq_row"] = torch.zeros( |
|
grad_shape[0] // self.dp_size, device=grad.device, dtype=p.dtype |
|
) |
|
state["exp_avg_res_row"] = torch.zeros( |
|
grad_shape[0] // self.dp_size, device=grad.device, dtype=p.dtype |
|
) |
|
else: |
|
# exp_avg_sq_row [H] |
|
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) |
|
state["exp_avg_res_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) |
|
# exp_avg_sq_col alaways [W] |
|
state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) |
|
state["exp_avg_res_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) |
|
else: |
|
state["exp_avg_sq"] = torch.zeros_like(p) |
|
state["RMS"] = 0 |
|
else: |
|
if factored: |
|
state["exp_avg_sq_row"] = state["exp_avg_sq_row"] |
|
state["exp_avg_sq_col"] = state["exp_avg_sq_col"] |
|
state["exp_avg_res_row"] = state["exp_avg_sq_row"] |
|
state["exp_avg_res_col"] = state["exp_avg_sq_col"] |
|
else: |
|
state["exp_avg_sq"] = state["exp_avg_sq"] |
|
|
|
state["step"] += 1 |
|
|
|
update = (grad**2) + group["eps"][0] |
|
if factored: |
|
if param_is_dtensor: |
|
# ============================== |
|
# First Dim is R, Last Dim is S{} means split dim -1 ---> |
|
# Coloum Parallel ---> sq_row need Do (col) Reduce |
|
# ============================== |
|
if shard_spec.sharding_sequence[0] == "R": |
|
update = self._col_parallel_factor( |
|
update, |
|
grad, |
|
state["exp_avg_sq_row"], |
|
state["exp_avg_sq_col"], |
|
grad_shape, |
|
group["betas"][1], |
|
) |
|
# ============================== |
|
# Last Dim is R, First Dim is S{} means split dim 0 ---> |
|
# Row Parallel ---> sq_col need Do (row) Reduce |
|
# ============================== |
|
elif shard_spec.sharding_sequence[-1] == "R": |
|
update = self._row_parallel_factor( |
|
update, |
|
grad, |
|
state["exp_avg_sq_row"], |
|
state["exp_avg_sq_col"], |
|
grad_shape, |
|
group["betas"][1], |
|
) |
|
else: |
|
update = self._base_factor( |
|
update, |
|
grad, |
|
state["exp_avg_sq_row"], |
|
state["exp_avg_sq_col"], |
|
grad_shape, |
|
group["betas"][1], |
|
) |
|
else: |
|
exp_avg_sq = state["exp_avg_sq"] |
|
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=(1.0 - group["betas"][1])) |
|
update = exp_avg_sq.rsqrt().mul_(grad) |
|
rms = self._rms( |
|
update, |
|
param_is_dtensor, |
|
self.use_zero, |
|
self.tp_size, |
|
self.dp_size, |
|
self.tp_group, |
|
self.dp_group, |
|
) |
|
|
|
update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) |
|
|
|
exp_avg = state["exp_avg"] |
|
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) |
|
# Confidence-guided strategy |
|
# Calculation of instability |
|
res = (update - exp_avg) ** 2 + group["eps"][1] |
|
if factored: |
|
if param_is_dtensor: |
|
# ============================== |
|
# First Dim is R, Last Dim is S{} means split dim -1 ---> |
|
# Coloum Parallel ---> sq_row need Do (col) Reduce |
|
# ============================== |
|
if shard_spec.sharding_sequence[0] == "R": |
|
update = self._col_parallel_factor( |
|
res, |
|
exp_avg, |
|
state["exp_avg_res_row"], |
|
state["exp_avg_res_col"], |
|
grad_shape, |
|
group["betas"][2], |
|
) |
|
# ============================== |
|
# Last Dim is R, First Dim is S{} means split dim 0 ---> |
|
# Row Parallel ---> sq_col need Do (row) Reduce |
|
# ============================== |
|
elif shard_spec.sharding_sequence[-1] == "R": |
|
update = self._row_parallel_factor( |
|
res, |
|
exp_avg, |
|
state["exp_avg_res_row"], |
|
state["exp_avg_res_col"], |
|
grad_shape, |
|
group["betas"][2], |
|
) |
|
else: |
|
update = self._base_res_factor( |
|
res, |
|
exp_avg, |
|
state["exp_avg_res_row"], |
|
state["exp_avg_res_col"], |
|
grad_shape, |
|
group["betas"][2], |
|
) |
|
else: |
|
update = exp_avg |
|
|
|
if group["weight_decay"] != 0: |
|
p.add_(p, alpha=-group["weight_decay"] * group["lr"]) |
|
update.mul_(group["lr"]) |
|
p.add_(-update) |
|
return loss
|
|
|