|
|
|
@ -2,7 +2,7 @@
|
|
|
|
|
import copy |
|
|
|
|
from contextlib import contextmanager, nullcontext |
|
|
|
|
from functools import partial |
|
|
|
|
from typing import Dict, Iterator, List, Optional, Tuple |
|
|
|
|
from typing import Dict, Iterator, List, Optional, Tuple, Union |
|
|
|
|
from weakref import proxy |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
@ -23,7 +23,15 @@ from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8 |
|
|
|
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor |
|
|
|
|
|
|
|
|
|
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor |
|
|
|
|
from ._utils import ( |
|
|
|
|
all_gather_into_flat_tensor_nd, |
|
|
|
|
calculate_global_norm_from_list, |
|
|
|
|
get_nd_rank, |
|
|
|
|
get_nd_world_size, |
|
|
|
|
has_inf_or_nan, |
|
|
|
|
release_param_grad, |
|
|
|
|
sync_tensor, |
|
|
|
|
) |
|
|
|
|
from .bookkeeping import BucketStore, GradientStore, TensorBucket |
|
|
|
|
from .zero_hook import set_all_gather_handle, wait_all_gather_handle |
|
|
|
|
|
|
|
|
@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
optimizer: Optimizer, |
|
|
|
|
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None, |
|
|
|
|
pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None, |
|
|
|
|
initial_scale: int = 2**16, # grad scaler config |
|
|
|
|
min_scale: int = 1, |
|
|
|
|
growth_factor: float = 2.0, |
|
|
|
@ -84,6 +92,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
partition_grad: bool = False, # stage 2 flag |
|
|
|
|
cpu_offload: bool = False, # cpu offload |
|
|
|
|
dp_process_group: Optional[ProcessGroup] = None, |
|
|
|
|
extra_dp_group: Optional[ProcessGroup] = None, |
|
|
|
|
forced_dtype: Optional[torch.dtype] = None, |
|
|
|
|
master_weights: bool = True, # master weights |
|
|
|
|
overlap_allgather: bool = False, |
|
|
|
@ -98,9 +107,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
if (dp_process_group is not None) and (pg_to_param_list is not None): |
|
|
|
|
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.") |
|
|
|
|
if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None: |
|
|
|
|
raise ValueError("dp_process_group should be provided when extra_dp_group is provided.") |
|
|
|
|
if pg_to_param_list is None and extra_dp_group is not None and fp8_communication: |
|
|
|
|
raise ValueError( |
|
|
|
|
"fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided." |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if pg_to_param_list is None: |
|
|
|
|
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group |
|
|
|
|
if extra_dp_group is not None: |
|
|
|
|
unique_dp_group = (extra_dp_group, unique_dp_group) |
|
|
|
|
pg_to_param_list = {unique_dp_group: []} |
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
pg_to_param_list[unique_dp_group].extend(group["params"]) |
|
|
|
@ -336,10 +353,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
flat_grads = flat_grads.to(self._communication_dtype) |
|
|
|
|
|
|
|
|
|
if not self._partition_grads: |
|
|
|
|
if self._fp8_communication: |
|
|
|
|
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg) |
|
|
|
|
else: |
|
|
|
|
dist.all_reduce(flat_grads, group=bucket_store.torch_pg) |
|
|
|
|
for i, sz in enumerate(bucket_store.sizes): |
|
|
|
|
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i] |
|
|
|
|
if self._fp8_communication: |
|
|
|
|
all_reduce_fp8(flat_grads, group=grp) |
|
|
|
|
else: |
|
|
|
|
dist.all_reduce(flat_grads, group=grp) |
|
|
|
|
if flat_grads.dtype != grad_dtype: |
|
|
|
|
flat_grads = flat_grads.to(grad_dtype) |
|
|
|
|
|
|
|
|
@ -347,16 +366,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
grad_in_bucket = bucket_store.get_grad() |
|
|
|
|
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id) |
|
|
|
|
else: |
|
|
|
|
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) |
|
|
|
|
received_grad = torch.zeros_like(flat_grads_list[0]) |
|
|
|
|
if self._fp8_communication: |
|
|
|
|
reduce_scatter_fp8( |
|
|
|
|
received_grad, |
|
|
|
|
flat_grads_list, |
|
|
|
|
group=bucket_store.torch_pg, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) |
|
|
|
|
cur_flat_grads = flat_grads |
|
|
|
|
for i, sz in enumerate(bucket_store.sizes): |
|
|
|
|
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i] |
|
|
|
|
flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz)) |
|
|
|
|
received_grad = torch.zeros_like(flat_grads_list[0]) |
|
|
|
|
if self._fp8_communication: |
|
|
|
|
reduce_scatter_fp8( |
|
|
|
|
received_grad, |
|
|
|
|
flat_grads_list, |
|
|
|
|
group=grp, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp) |
|
|
|
|
cur_flat_grads = received_grad |
|
|
|
|
|
|
|
|
|
if received_grad.dtype != grad_dtype: |
|
|
|
|
received_grad = received_grad.to(grad_dtype) |
|
|
|
@ -577,11 +600,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
pg = self.param_to_pg[working_param] |
|
|
|
|
padded_working_param = self._working_param_to_padded_working_param[working_param] |
|
|
|
|
if self._overlap_allgather: |
|
|
|
|
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) |
|
|
|
|
# handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) |
|
|
|
|
handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True) |
|
|
|
|
set_all_gather_handle(working_param, handle) |
|
|
|
|
else: |
|
|
|
|
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: |
|
|
|
|
if self._fp8_communication: |
|
|
|
|
# TODO: fit fp8 communication |
|
|
|
|
all_gather_fp8( |
|
|
|
|
list(padded_working_param.chunk(dist.get_world_size(pg))), |
|
|
|
|
param_to_gather, |
|
|
|
@ -589,7 +614,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
fp8_format="e4m3", |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) |
|
|
|
|
# dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) |
|
|
|
|
all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg) |
|
|
|
|
continue |
|
|
|
|
try: |
|
|
|
|
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) |
|
|
|
@ -602,7 +628,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
if not tensor_bucket.is_empty(): |
|
|
|
|
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) |
|
|
|
|
|
|
|
|
|
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: |
|
|
|
|
def _compute_grad_norm( |
|
|
|
|
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2 |
|
|
|
|
) -> float: |
|
|
|
|
r""" |
|
|
|
|
Compute and return the gradient norm for gradient clipping. |
|
|
|
|
|
|
|
|
@ -625,7 +653,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
device=get_accelerator().get_current_device(), |
|
|
|
|
dtype=torch.float, |
|
|
|
|
) |
|
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) |
|
|
|
|
if isinstance(dp_pg, tuple): |
|
|
|
|
for grp in dp_pg: |
|
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp) |
|
|
|
|
else: |
|
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) |
|
|
|
|
total_norm = total_norm_cuda.item() |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
@ -640,11 +672,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
device=get_accelerator().get_current_device(), |
|
|
|
|
dtype=torch.float, |
|
|
|
|
) |
|
|
|
|
torch.distributed.all_reduce( |
|
|
|
|
total_norm_exponentiated_cuda, |
|
|
|
|
op=torch.distributed.ReduceOp.SUM, |
|
|
|
|
group=dp_pg, |
|
|
|
|
) |
|
|
|
|
if isinstance(dp_pg, tuple): |
|
|
|
|
for grp in dp_pg: |
|
|
|
|
dist.all_reduce( |
|
|
|
|
total_norm_exponentiated_cuda, |
|
|
|
|
op=torch.distributed.ReduceOp.SUM, |
|
|
|
|
group=grp, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
torch.distributed.all_reduce( |
|
|
|
|
total_norm_exponentiated_cuda, |
|
|
|
|
op=torch.distributed.ReduceOp.SUM, |
|
|
|
|
group=dp_pg, |
|
|
|
|
) |
|
|
|
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) |
|
|
|
|
|
|
|
|
|
return total_norm |
|
|
|
@ -744,11 +784,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step": |
|
|
|
|
working_param = self.master_to_working_param[id(param)] |
|
|
|
|
pg = self.param_to_pg[working_param] |
|
|
|
|
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] |
|
|
|
|
dist.all_gather(gather_tensor, v.to(device), group=pg) |
|
|
|
|
param_state = ( |
|
|
|
|
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() |
|
|
|
|
) |
|
|
|
|
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) |
|
|
|
|
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) |
|
|
|
|
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu() |
|
|
|
|
zero_state[param][k] = param_state |
|
|
|
|
|
|
|
|
|
states_dict = self._pack_state(zero_state) |
|
|
|
@ -770,15 +808,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
cnt += 1 |
|
|
|
|
for param_idx, state in zero_state_dict["state"].items(): |
|
|
|
|
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]] |
|
|
|
|
world_size = get_nd_world_size(pg) |
|
|
|
|
rank = get_nd_rank(pg) |
|
|
|
|
for k, v in state.items(): |
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step": |
|
|
|
|
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size() |
|
|
|
|
padding_size = (world_size - v.numel() % world_size) % world_size |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
v = v.flatten() |
|
|
|
|
if padding_size > 0: |
|
|
|
|
v = torch.nn.functional.pad(v, [0, padding_size]) |
|
|
|
|
v_list = v.split(v.numel() // pg.size()) |
|
|
|
|
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone() |
|
|
|
|
v_list = v.split(v.numel() // world_size) |
|
|
|
|
zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone() |
|
|
|
|
|
|
|
|
|
self.optim.load_state_dict(zero_state_dict) |
|
|
|
|
|
|
|
|
@ -814,11 +854,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
for k, v in states.items(): |
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step": |
|
|
|
|
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] |
|
|
|
|
dist.all_gather(state_tensor, v.to(device), group=pg) |
|
|
|
|
state_tensor = ( |
|
|
|
|
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() |
|
|
|
|
) |
|
|
|
|
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) |
|
|
|
|
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) |
|
|
|
|
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu() |
|
|
|
|
current_block_size += state_tensor.numel() |
|
|
|
|
current_block[k] = state_tensor |
|
|
|
|
|
|
|
|
@ -842,12 +880,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
p_id = id(p) |
|
|
|
|
if p_id in self.working_to_master_param: |
|
|
|
|
pg = self.param_to_pg[p] |
|
|
|
|
world_size = get_nd_world_size(pg) |
|
|
|
|
rank = get_nd_rank(pg) |
|
|
|
|
master_param = self.working_to_master_param[p_id] |
|
|
|
|
padding_size = self.get_param_padding_size(p) |
|
|
|
|
working_param = p.data.view(-1) |
|
|
|
|
if padding_size > 0: |
|
|
|
|
working_param = torch.nn.functional.pad(working_param, [0, padding_size]) |
|
|
|
|
master_param.copy_(working_param.chunk(pg.size())[pg.rank()]) |
|
|
|
|
master_param.copy_(working_param.chunk(world_size)[rank]) |
|
|
|
|
|
|
|
|
|
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: |
|
|
|
|
return self.working_to_master_param |
|
|
|
@ -905,9 +945,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
grad = grad_store.get_working_grad_by_param_id(id(working_param)) |
|
|
|
|
if grad is None: |
|
|
|
|
return None |
|
|
|
|
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) |
|
|
|
|
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) |
|
|
|
|
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param) |
|
|
|
|
grad_flat = grad.flatten() |
|
|
|
|
output_grad = torch.empty( |
|
|
|
|
grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype |
|
|
|
|
) |
|
|
|
|
all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg) |
|
|
|
|
return output_grad.view(-1)[: working_param.numel()].view_as(working_param) |
|
|
|
|
|
|
|
|
|
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: |
|
|
|
|
working_grads = [] |
|
|
|
|