mirror of https://github.com/hpcaitech/ColossalAI
[zero] support extra dp (#6123)
* [zero] support extra dp * [zero] update checkpoint * fix bugs * fix bugsfeature/async-io
parent
30a9443132
commit
a2596519fd
|
@ -29,6 +29,7 @@ from colossalai.checkpoint_io.utils import (
|
|||
save_state_dict,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -358,11 +360,16 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
cast_inputs: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
extra_dp_size: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
|
||||
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
|
||||
if extra_dp_size > 1:
|
||||
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
|
||||
inner_dp_size = dist.get_world_size() // extra_dp_size
|
||||
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
|
||||
self.stage = stage
|
||||
self.precision = precision
|
||||
self.zero_optim_kwargs = dict(
|
||||
|
@ -383,6 +390,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
overlap_allgather=overlap_allgather,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
if extra_dp_size > 1:
|
||||
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
|
||||
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
|
|||
# update the tensor data
|
||||
for p, q in zip(tensor_list, updated_params):
|
||||
p.data = q.data
|
||||
|
||||
|
||||
def all_gather_into_flat_tensor_nd(
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
|
||||
async_op: bool = False,
|
||||
):
|
||||
if isinstance(group, dist.ProcessGroup):
|
||||
group = (group,)
|
||||
sizes = [dist.get_world_size(pg) for pg in group]
|
||||
ranks = [dist.get_rank(pg) for pg in group]
|
||||
for i, pg in list(enumerate(group))[::-1]:
|
||||
if i == 0:
|
||||
out = output_tensor
|
||||
else:
|
||||
prev_sizes = sizes[:i]
|
||||
prev_ranks = ranks[:i]
|
||||
chunks = output_tensor.chunk(np.prod(prev_sizes))
|
||||
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
|
||||
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
|
||||
input_tensor = out
|
||||
return handle
|
||||
|
||||
|
||||
def get_nd_world_size(group) -> int:
|
||||
if isinstance(group, tuple):
|
||||
return int(np.prod([dist.get_world_size(pg) for pg in group]))
|
||||
else:
|
||||
return dist.get_world_size(group)
|
||||
|
||||
|
||||
def get_nd_rank(group) -> int:
|
||||
if isinstance(group, tuple):
|
||||
return np.ravel_multi_index(
|
||||
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
|
||||
)
|
||||
else:
|
||||
return dist.get_rank(group)
|
||||
|
|
|
@ -1,11 +1,20 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class BaseStore:
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
self._world_size = dist.get_world_size(group=torch_pg)
|
||||
self._local_rank = dist.get_rank(group=torch_pg)
|
||||
def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]):
|
||||
if isinstance(torch_pg, tuple):
|
||||
self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg]
|
||||
self._world_size = int(np.prod(self.sizes))
|
||||
self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes)
|
||||
else:
|
||||
self._world_size = dist.get_world_size(group=torch_pg)
|
||||
self._local_rank = dist.get_rank(group=torch_pg)
|
||||
self.sizes = [self._world_size]
|
||||
self.torch_pg = torch_pg
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.quantization.fp8 import all_gather_fp8
|
||||
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
|
@ -65,12 +67,18 @@ class TensorBucket:
|
|||
|
||||
def all_gather(self, group=None, fp8_communication: bool = False):
|
||||
flat = self.flatten()
|
||||
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
|
||||
if isinstance(group, tuple):
|
||||
world_size = np.prod([dist.get_world_size(pg) for pg in group])
|
||||
else:
|
||||
world_size = dist.get_world_size(group)
|
||||
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
|
||||
if fp8_communication:
|
||||
# TODO: fit fp8
|
||||
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
|
||||
else:
|
||||
dist.all_gather_into_tensor(buffer, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
|
||||
# dist.all_gather_into_tensor(buffer, flat, group=group)
|
||||
all_gather_into_flat_tensor_nd(buffer, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
|
||||
# transpose the list of list
|
||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import copy
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
from weakref import proxy
|
||||
|
||||
import torch
|
||||
|
@ -23,7 +23,15 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from ._utils import (
|
||||
all_gather_into_flat_tensor_nd,
|
||||
calculate_global_norm_from_list,
|
||||
get_nd_rank,
|
||||
get_nd_world_size,
|
||||
has_inf_or_nan,
|
||||
release_param_grad,
|
||||
sync_tensor,
|
||||
)
|
||||
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
||||
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
|
||||
|
||||
|
@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
|
||||
pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.0,
|
||||
|
@ -84,6 +92,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
extra_dp_group: Optional[ProcessGroup] = None,
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
overlap_allgather: bool = False,
|
||||
|
@ -98,9 +107,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
if (dp_process_group is not None) and (pg_to_param_list is not None):
|
||||
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
|
||||
if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None:
|
||||
raise ValueError("dp_process_group should be provided when extra_dp_group is provided.")
|
||||
if pg_to_param_list is None and extra_dp_group is not None and fp8_communication:
|
||||
raise ValueError(
|
||||
"fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided."
|
||||
)
|
||||
|
||||
if pg_to_param_list is None:
|
||||
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
|
||||
if extra_dp_group is not None:
|
||||
unique_dp_group = (extra_dp_group, unique_dp_group)
|
||||
pg_to_param_list = {unique_dp_group: []}
|
||||
for group in self.optim.param_groups:
|
||||
pg_to_param_list[unique_dp_group].extend(group["params"])
|
||||
|
@ -336,10 +353,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
|
||||
if not self._partition_grads:
|
||||
if self._fp8_communication:
|
||||
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
|
||||
else:
|
||||
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
||||
for i, sz in enumerate(bucket_store.sizes):
|
||||
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
|
||||
if self._fp8_communication:
|
||||
all_reduce_fp8(flat_grads, group=grp)
|
||||
else:
|
||||
dist.all_reduce(flat_grads, group=grp)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
|
||||
|
@ -347,16 +366,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
grad_in_bucket = bucket_store.get_grad()
|
||||
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
||||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
|
||||
received_grad = torch.zeros_like(flat_grads_list[0])
|
||||
if self._fp8_communication:
|
||||
reduce_scatter_fp8(
|
||||
received_grad,
|
||||
flat_grads_list,
|
||||
group=bucket_store.torch_pg,
|
||||
)
|
||||
else:
|
||||
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
|
||||
cur_flat_grads = flat_grads
|
||||
for i, sz in enumerate(bucket_store.sizes):
|
||||
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
|
||||
flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz))
|
||||
received_grad = torch.zeros_like(flat_grads_list[0])
|
||||
if self._fp8_communication:
|
||||
reduce_scatter_fp8(
|
||||
received_grad,
|
||||
flat_grads_list,
|
||||
group=grp,
|
||||
)
|
||||
else:
|
||||
dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp)
|
||||
cur_flat_grads = received_grad
|
||||
|
||||
if received_grad.dtype != grad_dtype:
|
||||
received_grad = received_grad.to(grad_dtype)
|
||||
|
@ -577,11 +600,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
pg = self.param_to_pg[working_param]
|
||||
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
||||
if self._overlap_allgather:
|
||||
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
||||
# handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
||||
handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True)
|
||||
set_all_gather_handle(working_param, handle)
|
||||
else:
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
if self._fp8_communication:
|
||||
# TODO: fit fp8 communication
|
||||
all_gather_fp8(
|
||||
list(padded_working_param.chunk(dist.get_world_size(pg))),
|
||||
param_to_gather,
|
||||
|
@ -589,7 +614,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
fp8_format="e4m3",
|
||||
)
|
||||
else:
|
||||
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
# dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg)
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
|
@ -602,7 +628,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
|
||||
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
def _compute_grad_norm(
|
||||
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
|
||||
) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
|
@ -625,7 +653,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
device=get_accelerator().get_current_device(),
|
||||
dtype=torch.float,
|
||||
)
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
|
||||
if isinstance(dp_pg, tuple):
|
||||
for grp in dp_pg:
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp)
|
||||
else:
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
else:
|
||||
|
@ -640,11 +672,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
device=get_accelerator().get_current_device(),
|
||||
dtype=torch.float,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=dp_pg,
|
||||
)
|
||||
if isinstance(dp_pg, tuple):
|
||||
for grp in dp_pg:
|
||||
dist.all_reduce(
|
||||
total_norm_exponentiated_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=grp,
|
||||
)
|
||||
else:
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=dp_pg,
|
||||
)
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
@ -744,11 +784,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
working_param = self.master_to_working_param[id(param)]
|
||||
pg = self.param_to_pg[working_param]
|
||||
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
|
||||
dist.all_gather(gather_tensor, v.to(device), group=pg)
|
||||
param_state = (
|
||||
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
zero_state[param][k] = param_state
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
|
@ -770,15 +808,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
cnt += 1
|
||||
for param_idx, state in zero_state_dict["state"].items():
|
||||
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
|
||||
world_size = get_nd_world_size(pg)
|
||||
rank = get_nd_rank(pg)
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size()
|
||||
padding_size = (world_size - v.numel() % world_size) % world_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // pg.size())
|
||||
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone()
|
||||
v_list = v.split(v.numel() // world_size)
|
||||
zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
|
||||
|
@ -814,11 +854,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
|
||||
dist.all_gather(state_tensor, v.to(device), group=pg)
|
||||
state_tensor = (
|
||||
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
|
@ -842,12 +880,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
p_id = id(p)
|
||||
if p_id in self.working_to_master_param:
|
||||
pg = self.param_to_pg[p]
|
||||
world_size = get_nd_world_size(pg)
|
||||
rank = get_nd_rank(pg)
|
||||
master_param = self.working_to_master_param[p_id]
|
||||
padding_size = self.get_param_padding_size(p)
|
||||
working_param = p.data.view(-1)
|
||||
if padding_size > 0:
|
||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||
master_param.copy_(working_param.chunk(pg.size())[pg.rank()])
|
||||
master_param.copy_(working_param.chunk(world_size)[rank])
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self.working_to_master_param
|
||||
|
@ -905,9 +945,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
grad = grad_store.get_working_grad_by_param_id(id(working_param))
|
||||
if grad is None:
|
||||
return None
|
||||
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
||||
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
||||
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
|
||||
grad_flat = grad.flatten()
|
||||
output_grad = torch.empty(
|
||||
grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype
|
||||
)
|
||||
all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg)
|
||||
return output_grad.view(-1)[: working_param.numel()].view_as(working_param)
|
||||
|
||||
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
||||
working_grads = []
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
|
||||
|
||||
|
||||
def check_all_gather_2d():
|
||||
seed_all(1024)
|
||||
tensor = torch.rand(128, device=get_current_device())
|
||||
extra_dp_size, inner_dp_size = 2, 2
|
||||
pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
|
||||
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||
inner_dp_group = pg_mesh.get_group_along_axis(1)
|
||||
ranks = [dist.get_rank(extra_dp_group), dist.get_rank(inner_dp_group)]
|
||||
sizes = [dist.get_world_size(extra_dp_group), dist.get_world_size(inner_dp_group)]
|
||||
chunk = tensor.chunk(dist.get_world_size())[np.ravel_multi_index(ranks, sizes)].clone()
|
||||
out = torch.zeros_like(tensor)
|
||||
all_gather_into_flat_tensor_nd(out, chunk, group=(extra_dp_group, inner_dp_group))
|
||||
assert torch.equal(out, tensor)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
check_all_gather_2d()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_comm_nd():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_comm_nd()
|
|
@ -2,11 +2,13 @@ import copy
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
@ -123,7 +125,8 @@ def exam_zero_1_2(fp8_communication: bool):
|
|||
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parameterize("master_weights", [True, False])
|
||||
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
||||
@parameterize("extra_dp_size", [1, 2])
|
||||
def exam_zero_1_torch_ddp(dtype: torch.dtype, master_weights: bool, extra_dp_size: int):
|
||||
"""
|
||||
In this test, two pairs of model and optimizers are created.
|
||||
1. zero: use sharded optimizer and fp16 parameters
|
||||
|
@ -132,6 +135,15 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||
We feed these two sets of models with the same input and check if the
|
||||
differences in model output and updated parameters are within tolerance.
|
||||
"""
|
||||
if extra_dp_size > 1 and dtype != torch.bfloat16:
|
||||
return
|
||||
if extra_dp_size > 1:
|
||||
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
|
||||
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||
dp_group = pg_mesh.get_group_along_axis(1)
|
||||
else:
|
||||
extra_dp_group = None
|
||||
dp_group = None
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
|
@ -153,6 +165,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||
initial_scale=1,
|
||||
reduce_bucket_size=1024 * 1024,
|
||||
master_weights=master_weights,
|
||||
dp_process_group=dp_group,
|
||||
extra_dp_group=extra_dp_group,
|
||||
)
|
||||
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
@ -200,14 +214,14 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
exam_zero_1_torch_ddp(world_size=world_size)
|
||||
exam_zero_1_torch_ddp()
|
||||
exam_zero_1_2()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_1_2():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2,12 +2,14 @@ import copy
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
@ -40,11 +42,19 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
|||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def exam_zero_1_torch_ddp_ckpt():
|
||||
@parameterize("extra_dp_size", [1, 2])
|
||||
def exam_zero_1_torch_ddp_ckpt(extra_dp_size: int):
|
||||
"""
|
||||
We examine the state_dict of zero and DDP.
|
||||
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
|
||||
"""
|
||||
if extra_dp_size > 1:
|
||||
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
|
||||
extra_dp_group = pg_mesh.get_group_along_axis(0)
|
||||
dp_group = pg_mesh.get_group_along_axis(1)
|
||||
else:
|
||||
dp_group = None
|
||||
extra_dp_group = None
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
|
@ -60,7 +70,12 @@ def exam_zero_1_torch_ddp_ckpt():
|
|||
# we only test stage 1 here
|
||||
# the state dicts of stage 1 and stage 2 are the same
|
||||
zero_optimizer = LowLevelZeroOptimizer(
|
||||
zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144
|
||||
zero_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
reduce_bucket_size=262144,
|
||||
dp_process_group=dp_group,
|
||||
extra_dp_group=extra_dp_group,
|
||||
)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
|
@ -111,7 +126,7 @@ def run_dist(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_ckpt():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue