Browse Source

[zero] support extra dp (#6123)

* [zero] support extra dp

* [zero] update checkpoint

* fix bugs

* fix bugs
feature/async-io
Hongxin Liu 1 week ago committed by GitHub
parent
commit
a2596519fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 10
      colossalai/booster/plugin/low_level_zero_plugin.py
  2. 42
      colossalai/zero/low_level/_utils.py
  3. 11
      colossalai/zero/low_level/bookkeeping/base_store.py
  4. 14
      colossalai/zero/low_level/bookkeeping/tensor_bucket.py
  5. 99
      colossalai/zero/low_level/low_level_optim.py
  6. 42
      tests/test_zero/test_low_level/test_coll_nd.py
  7. 20
      tests/test_zero/test_low_level/test_zero1_2.py
  8. 23
      tests/test_zero/test_low_level/test_zero_ckpt.py

10
colossalai/booster/plugin/low_level_zero_plugin.py

@ -29,6 +29,7 @@ from colossalai.checkpoint_io.utils import (
save_state_dict,
sharded_optimizer_loading_epilogue,
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
"""
def __init__(
@ -358,11 +360,16 @@ class LowLevelZeroPlugin(DPPluginBase):
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
extra_dp_size: int = 1,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
if extra_dp_size > 1:
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
inner_dp_size = dist.get_world_size() // extra_dp_size
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
self.stage = stage
self.precision = precision
self.zero_optim_kwargs = dict(
@ -383,6 +390,9 @@ class LowLevelZeroPlugin(DPPluginBase):
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
if extra_dp_size > 1:
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()

42
colossalai/zero/low_level/_utils.py

@ -1,6 +1,7 @@
import math
from typing import Optional
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
# update the tensor data
for p, q in zip(tensor_list, updated_params):
p.data = q.data
def all_gather_into_flat_tensor_nd(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
async_op: bool = False,
):
if isinstance(group, dist.ProcessGroup):
group = (group,)
sizes = [dist.get_world_size(pg) for pg in group]
ranks = [dist.get_rank(pg) for pg in group]
for i, pg in list(enumerate(group))[::-1]:
if i == 0:
out = output_tensor
else:
prev_sizes = sizes[:i]
prev_ranks = ranks[:i]
chunks = output_tensor.chunk(np.prod(prev_sizes))
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
input_tensor = out
return handle
def get_nd_world_size(group) -> int:
if isinstance(group, tuple):
return int(np.prod([dist.get_world_size(pg) for pg in group]))
else:
return dist.get_world_size(group)
def get_nd_rank(group) -> int:
if isinstance(group, tuple):
return np.ravel_multi_index(
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
)
else:
return dist.get_rank(group)

11
colossalai/zero/low_level/bookkeeping/base_store.py

@ -1,11 +1,20 @@
from typing import Tuple, Union
import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
class BaseStore:
def __init__(self, torch_pg: ProcessGroup):
def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]):
if isinstance(torch_pg, tuple):
self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg]
self._world_size = int(np.prod(self.sizes))
self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes)
else:
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
self.sizes = [self._world_size]
self.torch_pg = torch_pg
@property

14
colossalai/zero/low_level/bookkeeping/tensor_bucket.py

@ -1,10 +1,12 @@
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_fp8
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
class TensorBucket:
@ -65,12 +67,18 @@ class TensorBucket:
def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if isinstance(group, tuple):
world_size = np.prod([dist.get_world_size(pg) for pg in group])
else:
world_size = dist.get_world_size(group)
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
if fp8_communication:
# TODO: fit fp8
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# dist.all_gather_into_tensor(buffer, flat, group=group)
all_gather_into_flat_tensor_nd(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):

99
colossalai/zero/low_level/low_level_optim.py

@ -2,7 +2,7 @@
import copy
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Tuple, Union
from weakref import proxy
import torch
@ -23,7 +23,15 @@ from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from ._utils import (
all_gather_into_flat_tensor_nd,
calculate_global_norm_from_list,
get_nd_rank,
get_nd_world_size,
has_inf_or_nan,
release_param_grad,
sync_tensor,
)
from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(
self,
optimizer: Optimizer,
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None,
pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
@ -84,6 +92,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None,
extra_dp_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
@ -98,9 +107,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if (dp_process_group is not None) and (pg_to_param_list is not None):
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None:
raise ValueError("dp_process_group should be provided when extra_dp_group is provided.")
if pg_to_param_list is None and extra_dp_group is not None and fp8_communication:
raise ValueError(
"fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided."
)
if pg_to_param_list is None:
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
if extra_dp_group is not None:
unique_dp_group = (extra_dp_group, unique_dp_group)
pg_to_param_list = {unique_dp_group: []}
for group in self.optim.param_groups:
pg_to_param_list[unique_dp_group].extend(group["params"])
@ -336,10 +353,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
for i, sz in enumerate(bucket_store.sizes):
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
if self._fp8_communication:
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
all_reduce_fp8(flat_grads, group=grp)
else:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
dist.all_reduce(flat_grads, group=grp)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
@ -347,16 +366,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
grad_in_bucket = bucket_store.get_grad()
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
cur_flat_grads = flat_grads
for i, sz in enumerate(bucket_store.sizes):
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz))
received_grad = torch.zeros_like(flat_grads_list[0])
if self._fp8_communication:
reduce_scatter_fp8(
received_grad,
flat_grads_list,
group=bucket_store.torch_pg,
group=grp,
)
else:
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp)
cur_flat_grads = received_grad
if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype)
@ -577,11 +600,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
pg = self.param_to_pg[working_param]
padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather:
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
# handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True)
set_all_gather_handle(working_param, handle)
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
if self._fp8_communication:
# TODO: fit fp8 communication
all_gather_fp8(
list(padded_working_param.chunk(dist.get_world_size(pg))),
param_to_gather,
@ -589,7 +614,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
fp8_format="e4m3",
)
else:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
# dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg)
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
@ -602,7 +628,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
def _compute_grad_norm(
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@ -625,6 +653,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
if isinstance(dp_pg, tuple):
for grp in dp_pg:
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp)
else:
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
total_norm = total_norm_cuda.item()
@ -640,6 +672,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
if isinstance(dp_pg, tuple):
for grp in dp_pg:
dist.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=grp,
)
else:
torch.distributed.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
@ -744,11 +784,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if isinstance(v, torch.Tensor) and k != "step":
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
dist.all_gather(gather_tensor, v.to(device), group=pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
@ -770,15 +808,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
cnt += 1
for param_idx, state in zero_state_dict["state"].items():
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
world_size = get_nd_world_size(pg)
rank = get_nd_rank(pg)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size()
padding_size = (world_size - v.numel() % world_size) % world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // pg.size())
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone()
v_list = v.split(v.numel() // world_size)
zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
@ -814,11 +854,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())]
dist.all_gather(state_tensor, v.to(device), group=pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
current_block_size += state_tensor.numel()
current_block[k] = state_tensor
@ -842,12 +880,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
p_id = id(p)
if p_id in self.working_to_master_param:
pg = self.param_to_pg[p]
world_size = get_nd_world_size(pg)
rank = get_nd_rank(pg)
master_param = self.working_to_master_param[p_id]
padding_size = self.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(pg.size())[pg.rank()])
master_param.copy_(working_param.chunk(world_size)[rank])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self.working_to_master_param
@ -905,9 +945,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
grad = grad_store.get_working_grad_by_param_id(id(working_param))
if grad is None:
return None
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
grad_flat = grad.flatten()
output_grad = torch.empty(
grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype
)
all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg)
return output_grad.view(-1)[: working_param.numel()].view_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
working_grads = []

42
tests/test_zero/test_low_level/test_coll_nd.py

@ -0,0 +1,42 @@
import numpy as np
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.utils import get_current_device
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
def check_all_gather_2d():
seed_all(1024)
tensor = torch.rand(128, device=get_current_device())
extra_dp_size, inner_dp_size = 2, 2
pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
extra_dp_group = pg_mesh.get_group_along_axis(0)
inner_dp_group = pg_mesh.get_group_along_axis(1)
ranks = [dist.get_rank(extra_dp_group), dist.get_rank(inner_dp_group)]
sizes = [dist.get_world_size(extra_dp_group), dist.get_world_size(inner_dp_group)]
chunk = tensor.chunk(dist.get_world_size())[np.ravel_multi_index(ranks, sizes)].clone()
out = torch.zeros_like(tensor)
all_gather_into_flat_tensor_nd(out, chunk, group=(extra_dp_group, inner_dp_group))
assert torch.equal(out, tensor)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_all_gather_2d()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_comm_nd():
spawn(run_dist, 4)
if __name__ == "__main__":
test_comm_nd()

20
tests/test_zero/test_low_level/test_zero1_2.py

@ -2,11 +2,13 @@ import copy
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
@ -123,7 +125,8 @@ def exam_zero_1_2(fp8_communication: bool):
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("master_weights", [True, False])
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
@parameterize("extra_dp_size", [1, 2])
def exam_zero_1_torch_ddp(dtype: torch.dtype, master_weights: bool, extra_dp_size: int):
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
@ -132,6 +135,15 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance.
"""
if extra_dp_size > 1 and dtype != torch.bfloat16:
return
if extra_dp_size > 1:
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
extra_dp_group = pg_mesh.get_group_along_axis(0)
dp_group = pg_mesh.get_group_along_axis(1)
else:
extra_dp_group = None
dp_group = None
local_rank = torch.distributed.get_rank()
seed_all(1453)
@ -153,6 +165,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
initial_scale=1,
reduce_bucket_size=1024 * 1024,
master_weights=master_weights,
dp_process_group=dp_group,
extra_dp_group=extra_dp_group,
)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@ -200,14 +214,14 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
exam_zero_1_torch_ddp(world_size=world_size)
exam_zero_1_torch_ddp()
exam_zero_1_2()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_1_2():
spawn(run_dist, 2)
spawn(run_dist, 4)
if __name__ == "__main__":

23
tests/test_zero/test_low_level/test_zero_ckpt.py

@ -2,12 +2,14 @@ import copy
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
@ -40,11 +42,19 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
assert_close(a, b, rtol=rtol, atol=atol)
def exam_zero_1_torch_ddp_ckpt():
@parameterize("extra_dp_size", [1, 2])
def exam_zero_1_torch_ddp_ckpt(extra_dp_size: int):
"""
We examine the state_dict of zero and DDP.
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
"""
if extra_dp_size > 1:
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
extra_dp_group = pg_mesh.get_group_along_axis(0)
dp_group = pg_mesh.get_group_along_axis(1)
else:
dp_group = None
extra_dp_group = None
local_rank = torch.distributed.get_rank()
seed_all(1453)
@ -60,7 +70,12 @@ def exam_zero_1_torch_ddp_ckpt():
# we only test stage 1 here
# the state dicts of stage 1 and stage 2 are the same
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144
zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=262144,
dp_process_group=dp_group,
extra_dp_group=extra_dp_group,
)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@ -111,7 +126,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_ckpt():
spawn(run_dist, 2)
spawn(run_dist, 4)
if __name__ == "__main__":

Loading…
Cancel
Save