|
|
|
@ -2,7 +2,7 @@
|
|
|
|
|
import copy
|
|
|
|
|
from contextlib import contextmanager, nullcontext
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Dict, Iterator, List, Optional, Tuple
|
|
|
|
|
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
|
|
|
from weakref import proxy
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -23,7 +23,15 @@ from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
|
|
|
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
|
|
|
|
|
|
|
|
|
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
|
|
|
|
from ._utils import (
|
|
|
|
|
all_gather_into_flat_tensor_nd,
|
|
|
|
|
calculate_global_norm_from_list,
|
|
|
|
|
get_nd_rank,
|
|
|
|
|
get_nd_world_size,
|
|
|
|
|
has_inf_or_nan,
|
|
|
|
|
release_param_grad,
|
|
|
|
|
sync_tensor,
|
|
|
|
|
)
|
|
|
|
|
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
|
|
|
|
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
|
|
|
|
|
|
|
|
|
@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
optimizer: Optimizer,
|
|
|
|
|
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
|
|
|
|
|
pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None,
|
|
|
|
|
initial_scale: int = 2**16, # grad scaler config
|
|
|
|
|
min_scale: int = 1,
|
|
|
|
|
growth_factor: float = 2.0,
|
|
|
|
@ -84,6 +92,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
partition_grad: bool = False, # stage 2 flag
|
|
|
|
|
cpu_offload: bool = False, # cpu offload
|
|
|
|
|
dp_process_group: Optional[ProcessGroup] = None,
|
|
|
|
|
extra_dp_group: Optional[ProcessGroup] = None,
|
|
|
|
|
forced_dtype: Optional[torch.dtype] = None,
|
|
|
|
|
master_weights: bool = True, # master weights
|
|
|
|
|
overlap_allgather: bool = False,
|
|
|
|
@ -98,9 +107,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
if (dp_process_group is not None) and (pg_to_param_list is not None):
|
|
|
|
|
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
|
|
|
|
|
if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None:
|
|
|
|
|
raise ValueError("dp_process_group should be provided when extra_dp_group is provided.")
|
|
|
|
|
if pg_to_param_list is None and extra_dp_group is not None and fp8_communication:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if pg_to_param_list is None:
|
|
|
|
|
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
|
|
|
|
|
if extra_dp_group is not None:
|
|
|
|
|
unique_dp_group = (extra_dp_group, unique_dp_group)
|
|
|
|
|
pg_to_param_list = {unique_dp_group: []}
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
pg_to_param_list[unique_dp_group].extend(group["params"])
|
|
|
|
@ -336,10 +353,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
flat_grads = flat_grads.to(self._communication_dtype)
|
|
|
|
|
|
|
|
|
|
if not self._partition_grads:
|
|
|
|
|
if self._fp8_communication:
|
|
|
|
|
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
|
|
|
|
|
else:
|
|
|
|
|
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
|
|
|
|
for i, sz in enumerate(bucket_store.sizes):
|
|
|
|
|
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
|
|
|
|
|
if self._fp8_communication:
|
|
|
|
|
all_reduce_fp8(flat_grads, group=grp)
|
|
|
|
|
else:
|
|
|
|
|
dist.all_reduce(flat_grads, group=grp)
|
|
|
|
|
if flat_grads.dtype != grad_dtype:
|
|
|
|
|
flat_grads = flat_grads.to(grad_dtype)
|
|
|
|
|
|
|
|
|
@ -347,16 +366,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
grad_in_bucket = bucket_store.get_grad()
|
|
|
|
|
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
|
|
|
|
else:
|
|
|
|
|
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
|
|
|
|
|
received_grad = torch.zeros_like(flat_grads_list[0])
|
|
|
|
|
if self._fp8_communication:
|
|
|
|
|
reduce_scatter_fp8(
|
|
|
|
|
received_grad,
|
|
|
|
|
flat_grads_list,
|
|
|
|
|
group=bucket_store.torch_pg,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
|
|
|
|
|
cur_flat_grads = flat_grads
|
|
|
|
|
for i, sz in enumerate(bucket_store.sizes):
|
|
|
|
|
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
|
|
|
|
|
flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz))
|
|
|
|
|
received_grad = torch.zeros_like(flat_grads_list[0])
|
|
|
|
|
if self._fp8_communication:
|
|
|
|
|
reduce_scatter_fp8(
|
|
|
|
|
received_grad,
|
|
|
|
|
flat_grads_list,
|
|
|
|
|
group=grp,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp)
|
|
|
|
|
cur_flat_grads = received_grad
|
|
|
|
|
|
|
|
|
|
if received_grad.dtype != grad_dtype:
|
|
|
|
|
received_grad = received_grad.to(grad_dtype)
|
|
|
|
@ -577,11 +600,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
pg = self.param_to_pg[working_param]
|
|
|
|
|
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
|
|
|
|
if self._overlap_allgather:
|
|
|
|
|
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
|
|
|
|
# handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
|
|
|
|
handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True)
|
|
|
|
|
set_all_gather_handle(working_param, handle)
|
|
|
|
|
else:
|
|
|
|
|
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
|
|
|
|
if self._fp8_communication:
|
|
|
|
|
# TODO: fit fp8 communication
|
|
|
|
|
all_gather_fp8(
|
|
|
|
|
list(padded_working_param.chunk(dist.get_world_size(pg))),
|
|
|
|
|
param_to_gather,
|
|
|
|
@ -589,7 +614,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
fp8_format="e4m3",
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
|
|
|
|
# dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
|
|
|
|
all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg)
|
|
|
|
|
continue
|
|
|
|
|
try:
|
|
|
|
|
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
|
|
|
@ -602,7 +628,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
if not tensor_bucket.is_empty():
|
|
|
|
|
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
|
|
|
|
|
|
|
|
|
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
|
|
|
|
def _compute_grad_norm(
|
|
|
|
|
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
|
|
|
|
|
) -> float:
|
|
|
|
|
r"""
|
|
|
|
|
Compute and return the gradient norm for gradient clipping.
|
|
|
|
|
|
|
|
|
@ -625,7 +653,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
device=get_accelerator().get_current_device(),
|
|
|
|
|
dtype=torch.float,
|
|
|
|
|
)
|
|
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
|
|
|
|
|
if isinstance(dp_pg, tuple):
|
|
|
|
|
for grp in dp_pg:
|
|
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp)
|
|
|
|
|
else:
|
|
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
|
|
|
|
|
total_norm = total_norm_cuda.item()
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
@ -640,11 +672,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
device=get_accelerator().get_current_device(),
|
|
|
|
|
dtype=torch.float,
|
|
|
|
|
)
|
|
|
|
|
torch.distributed.all_reduce(
|
|
|
|
|
total_norm_exponentiated_cuda,
|
|
|
|
|
op=torch.distributed.ReduceOp.SUM,
|
|
|
|
|
group=dp_pg,
|
|
|
|
|
)
|
|
|
|
|
if isinstance(dp_pg, tuple):
|
|
|
|
|
for grp in dp_pg:
|
|
|
|
|
dist.all_reduce(
|
|
|
|
|
total_norm_exponentiated_cuda,
|
|
|
|
|
op=torch.distributed.ReduceOp.SUM,
|
|
|
|
|
group=grp,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
torch.distributed.all_reduce(
|
|
|
|
|
total_norm_exponentiated_cuda,
|
|
|
|
|
op=torch.distributed.ReduceOp.SUM,
|
|
|
|
|
group=dp_pg,
|
|
|
|
|
)
|
|
|
|
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
|
|
|
|
|
|
|
|
|
return total_norm
|
|
|
|
@ -744,11 +784,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
|
|
|
|
working_param = self.master_to_working_param[id(param)]
|
|
|
|
|
pg = self.param_to_pg[working_param]
|
|
|
|
|
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
|
|
|
|
|
dist.all_gather(gather_tensor, v.to(device), group=pg)
|
|
|
|
|
param_state = (
|
|
|
|
|
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
|
|
|
|
)
|
|
|
|
|
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
|
|
|
|
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
|
|
|
|
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
|
|
|
|
zero_state[param][k] = param_state
|
|
|
|
|
|
|
|
|
|
states_dict = self._pack_state(zero_state)
|
|
|
|
@ -770,15 +808,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
cnt += 1
|
|
|
|
|
for param_idx, state in zero_state_dict["state"].items():
|
|
|
|
|
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
|
|
|
|
|
world_size = get_nd_world_size(pg)
|
|
|
|
|
rank = get_nd_rank(pg)
|
|
|
|
|
for k, v in state.items():
|
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
|
|
|
|
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size()
|
|
|
|
|
padding_size = (world_size - v.numel() % world_size) % world_size
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
v = v.flatten()
|
|
|
|
|
if padding_size > 0:
|
|
|
|
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
|
|
|
|
v_list = v.split(v.numel() // pg.size())
|
|
|
|
|
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone()
|
|
|
|
|
v_list = v.split(v.numel() // world_size)
|
|
|
|
|
zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone()
|
|
|
|
|
|
|
|
|
|
self.optim.load_state_dict(zero_state_dict)
|
|
|
|
|
|
|
|
|
@ -814,11 +854,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
for k, v in states.items():
|
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
|
|
|
|
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
|
|
|
|
|
dist.all_gather(state_tensor, v.to(device), group=pg)
|
|
|
|
|
state_tensor = (
|
|
|
|
|
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
|
|
|
|
)
|
|
|
|
|
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
|
|
|
|
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
|
|
|
|
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
|
|
|
|
current_block_size += state_tensor.numel()
|
|
|
|
|
current_block[k] = state_tensor
|
|
|
|
|
|
|
|
|
@ -842,12 +880,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
p_id = id(p)
|
|
|
|
|
if p_id in self.working_to_master_param:
|
|
|
|
|
pg = self.param_to_pg[p]
|
|
|
|
|
world_size = get_nd_world_size(pg)
|
|
|
|
|
rank = get_nd_rank(pg)
|
|
|
|
|
master_param = self.working_to_master_param[p_id]
|
|
|
|
|
padding_size = self.get_param_padding_size(p)
|
|
|
|
|
working_param = p.data.view(-1)
|
|
|
|
|
if padding_size > 0:
|
|
|
|
|
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
|
|
|
|
master_param.copy_(working_param.chunk(pg.size())[pg.rank()])
|
|
|
|
|
master_param.copy_(working_param.chunk(world_size)[rank])
|
|
|
|
|
|
|
|
|
|
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
|
|
|
|
return self.working_to_master_param
|
|
|
|
@ -905,9 +945,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
grad = grad_store.get_working_grad_by_param_id(id(working_param))
|
|
|
|
|
if grad is None:
|
|
|
|
|
return None
|
|
|
|
|
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
|
|
|
|
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
|
|
|
|
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
|
|
|
|
|
grad_flat = grad.flatten()
|
|
|
|
|
output_grad = torch.empty(
|
|
|
|
|
grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype
|
|
|
|
|
)
|
|
|
|
|
all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg)
|
|
|
|
|
return output_grad.view(-1)[: working_param.numel()].view_as(working_param)
|
|
|
|
|
|
|
|
|
|
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
|
|
|
|
working_grads = []
|
|
|
|
|