Browse Source

[zero] support extra dp (#6123)

* [zero] support extra dp

* [zero] update checkpoint

* fix bugs

* fix bugs
feature/async-io
Hongxin Liu 1 week ago committed by GitHub
parent
commit
a2596519fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 10
      colossalai/booster/plugin/low_level_zero_plugin.py
  2. 42
      colossalai/zero/low_level/_utils.py
  3. 15
      colossalai/zero/low_level/bookkeeping/base_store.py
  4. 14
      colossalai/zero/low_level/bookkeeping/tensor_bucket.py
  5. 129
      colossalai/zero/low_level/low_level_optim.py
  6. 42
      tests/test_zero/test_low_level/test_coll_nd.py
  7. 20
      tests/test_zero/test_low_level/test_zero1_2.py
  8. 23
      tests/test_zero/test_low_level/test_zero_ckpt.py

10
colossalai/booster/plugin/low_level_zero_plugin.py

@ -29,6 +29,7 @@ from colossalai.checkpoint_io.utils import (
save_state_dict, save_state_dict,
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
) )
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
""" """
def __init__( def __init__(
@ -358,11 +360,16 @@ class LowLevelZeroPlugin(DPPluginBase):
cast_inputs: bool = True, cast_inputs: bool = True,
fp8_communication: bool = False, fp8_communication: bool = False,
use_fp8: bool = False, use_fp8: bool = False,
extra_dp_size: int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training" assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now" assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
if extra_dp_size > 1:
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
inner_dp_size = dist.get_world_size() // extra_dp_size
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
self.stage = stage self.stage = stage
self.precision = precision self.precision = precision
self.zero_optim_kwargs = dict( self.zero_optim_kwargs = dict(
@ -383,6 +390,9 @@ class LowLevelZeroPlugin(DPPluginBase):
overlap_allgather=overlap_allgather, overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication, fp8_communication=fp8_communication,
) )
if extra_dp_size > 1:
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
self.lora_enabled = False self.lora_enabled = False
self.verbose = verbose self.verbose = verbose
self.logger = get_dist_logger() self.logger = get_dist_logger()

42
colossalai/zero/low_level/_utils.py

@ -1,6 +1,7 @@
import math import math
from typing import Optional from typing import Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
# update the tensor data # update the tensor data
for p, q in zip(tensor_list, updated_params): for p, q in zip(tensor_list, updated_params):
p.data = q.data p.data = q.data
def all_gather_into_flat_tensor_nd(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
async_op: bool = False,
):
if isinstance(group, dist.ProcessGroup):
group = (group,)
sizes = [dist.get_world_size(pg) for pg in group]
ranks = [dist.get_rank(pg) for pg in group]
for i, pg in list(enumerate(group))[::-1]:
if i == 0:
out = output_tensor
else:
prev_sizes = sizes[:i]
prev_ranks = ranks[:i]
chunks = output_tensor.chunk(np.prod(prev_sizes))
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
input_tensor = out
return handle
def get_nd_world_size(group) -> int:
if isinstance(group, tuple):
return int(np.prod([dist.get_world_size(pg) for pg in group]))
else:
return dist.get_world_size(group)
def get_nd_rank(group) -> int:
if isinstance(group, tuple):
return np.ravel_multi_index(
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
)
else:
return dist.get_rank(group)

15
colossalai/zero/low_level/bookkeeping/base_store.py

@ -1,11 +1,20 @@
from typing import Tuple, Union
import numpy as np
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
class BaseStore: class BaseStore:
def __init__(self, torch_pg: ProcessGroup): def __init__(self, torch_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]]):
self._world_size = dist.get_world_size(group=torch_pg) if isinstance(torch_pg, tuple):
self._local_rank = dist.get_rank(group=torch_pg) self.sizes = [dist.get_world_size(group=pg) for pg in torch_pg]
self._world_size = int(np.prod(self.sizes))
self._local_rank = np.ravel_multi_index(tuple(dist.get_rank(group=pg) for pg in torch_pg), self.sizes)
else:
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
self.sizes = [self._world_size]
self.torch_pg = torch_pg self.torch_pg = torch_pg
@property @property

14
colossalai/zero/low_level/bookkeeping/tensor_bucket.py

@ -1,10 +1,12 @@
from typing import Optional from typing import Optional
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_fp8 from colossalai.quantization.fp8 import all_gather_fp8
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
class TensorBucket: class TensorBucket:
@ -65,12 +67,18 @@ class TensorBucket:
def all_gather(self, group=None, fp8_communication: bool = False): def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten() flat = self.flatten()
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype) if isinstance(group, tuple):
world_size = np.prod([dist.get_world_size(pg) for pg in group])
else:
world_size = dist.get_world_size(group)
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
if fp8_communication: if fp8_communication:
# TODO: fit fp8
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3") all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
else: else:
dist.all_gather_into_tensor(buffer, flat, group=group) # dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))] all_gather_into_flat_tensor_nd(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
# transpose the list of list # transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers))) unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket): for unflat_shards, tensor in zip(unflat_buffers, self._bucket):

129
colossalai/zero/low_level/low_level_optim.py

@ -2,7 +2,7 @@
import copy import copy
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from functools import partial from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple from typing import Dict, Iterator, List, Optional, Tuple, Union
from weakref import proxy from weakref import proxy
import torch import torch
@ -23,7 +23,15 @@ from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8 from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.tensor.moe_tensor.api import is_moe_tensor
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor from ._utils import (
all_gather_into_flat_tensor_nd,
calculate_global_norm_from_list,
get_nd_rank,
get_nd_world_size,
has_inf_or_nan,
release_param_grad,
sync_tensor,
)
from .bookkeeping import BucketStore, GradientStore, TensorBucket from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle from .zero_hook import set_all_gather_handle, wait_all_gather_handle
@ -68,7 +76,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__( def __init__(
self, self,
optimizer: Optimizer, optimizer: Optimizer,
pg_to_param_list: Optional[Dict[ProcessGroup, List[nn.Parameter]]] = None, pg_to_param_list: Optional[Dict[Union[ProcessGroup, Tuple[ProcessGroup, ...]], List[nn.Parameter]]] = None,
initial_scale: int = 2**16, # grad scaler config initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1, min_scale: int = 1,
growth_factor: float = 2.0, growth_factor: float = 2.0,
@ -84,6 +92,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
extra_dp_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights master_weights: bool = True, # master weights
overlap_allgather: bool = False, overlap_allgather: bool = False,
@ -98,9 +107,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if (dp_process_group is not None) and (pg_to_param_list is not None): if (dp_process_group is not None) and (pg_to_param_list is not None):
raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.") raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.")
if pg_to_param_list is None and extra_dp_group is not None and dp_process_group is None:
raise ValueError("dp_process_group should be provided when extra_dp_group is provided.")
if pg_to_param_list is None and extra_dp_group is not None and fp8_communication:
raise ValueError(
"fp8_communication is not supported when pg_to_param_list is None and extra_dp_group is provided."
)
if pg_to_param_list is None: if pg_to_param_list is None:
unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group
if extra_dp_group is not None:
unique_dp_group = (extra_dp_group, unique_dp_group)
pg_to_param_list = {unique_dp_group: []} pg_to_param_list = {unique_dp_group: []}
for group in self.optim.param_groups: for group in self.optim.param_groups:
pg_to_param_list[unique_dp_group].extend(group["params"]) pg_to_param_list[unique_dp_group].extend(group["params"])
@ -336,10 +353,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
flat_grads = flat_grads.to(self._communication_dtype) flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads: if not self._partition_grads:
if self._fp8_communication: for i, sz in enumerate(bucket_store.sizes):
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg) grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
else: if self._fp8_communication:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg) all_reduce_fp8(flat_grads, group=grp)
else:
dist.all_reduce(flat_grads, group=grp)
if flat_grads.dtype != grad_dtype: if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype) flat_grads = flat_grads.to(grad_dtype)
@ -347,16 +366,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
grad_in_bucket = bucket_store.get_grad() grad_in_bucket = bucket_store.get_grad()
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id) self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else: else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) cur_flat_grads = flat_grads
received_grad = torch.zeros_like(flat_grads_list[0]) for i, sz in enumerate(bucket_store.sizes):
if self._fp8_communication: grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
reduce_scatter_fp8( flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz))
received_grad, received_grad = torch.zeros_like(flat_grads_list[0])
flat_grads_list, if self._fp8_communication:
group=bucket_store.torch_pg, reduce_scatter_fp8(
) received_grad,
else: flat_grads_list,
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) group=grp,
)
else:
dist.reduce_scatter_tensor(received_grad, cur_flat_grads, group=grp)
cur_flat_grads = received_grad
if received_grad.dtype != grad_dtype: if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype) received_grad = received_grad.to(grad_dtype)
@ -577,11 +600,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
padded_working_param = self._working_param_to_padded_working_param[working_param] padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather: if self._overlap_allgather:
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True) # handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
handle = all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg, async_op=True)
set_all_gather_handle(working_param, handle) set_all_gather_handle(working_param, handle)
else: else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
if self._fp8_communication: if self._fp8_communication:
# TODO: fit fp8 communication
all_gather_fp8( all_gather_fp8(
list(padded_working_param.chunk(dist.get_world_size(pg))), list(padded_working_param.chunk(dist.get_world_size(pg))),
param_to_gather, param_to_gather,
@ -589,7 +614,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
fp8_format="e4m3", fp8_format="e4m3",
) )
else: else:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) # dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
all_gather_into_flat_tensor_nd(padded_working_param, param_to_gather, pg)
continue continue
try: try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
@ -602,7 +628,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if not tensor_bucket.is_empty(): if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: def _compute_grad_norm(
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
) -> float:
r""" r"""
Compute and return the gradient norm for gradient clipping. Compute and return the gradient norm for gradient clipping.
@ -625,7 +653,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(), device=get_accelerator().get_current_device(),
dtype=torch.float, dtype=torch.float,
) )
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) if isinstance(dp_pg, tuple):
for grp in dp_pg:
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=grp)
else:
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
else: else:
@ -640,11 +672,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(), device=get_accelerator().get_current_device(),
dtype=torch.float, dtype=torch.float,
) )
torch.distributed.all_reduce( if isinstance(dp_pg, tuple):
total_norm_exponentiated_cuda, for grp in dp_pg:
op=torch.distributed.ReduceOp.SUM, dist.all_reduce(
group=dp_pg, total_norm_exponentiated_cuda,
) op=torch.distributed.ReduceOp.SUM,
group=grp,
)
else:
torch.distributed.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=dp_pg,
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm return total_norm
@ -744,11 +784,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
working_param = self.master_to_working_param[id(param)] working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
gather_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
dist.all_gather(gather_tensor, v.to(device), group=pg) all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = ( param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
zero_state[param][k] = param_state zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state) states_dict = self._pack_state(zero_state)
@ -770,15 +808,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
cnt += 1 cnt += 1
for param_idx, state in zero_state_dict["state"].items(): for param_idx, state in zero_state_dict["state"].items():
pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]] pg = self.param_to_pg[self.master_to_working_param[id(idx2master[param_idx])]]
world_size = get_nd_world_size(pg)
rank = get_nd_rank(pg)
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
padding_size = (pg.size() - v.numel() % pg.size()) % pg.size() padding_size = (world_size - v.numel() % world_size) % world_size
with torch.no_grad(): with torch.no_grad():
v = v.flatten() v = v.flatten()
if padding_size > 0: if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size]) v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // pg.size()) v_list = v.split(v.numel() // world_size)
zero_state_dict["state"][param_idx][k] = v_list[pg.rank()].detach().clone() zero_state_dict["state"][param_idx][k] = v_list[rank].detach().clone()
self.optim.load_state_dict(zero_state_dict) self.optim.load_state_dict(zero_state_dict)
@ -814,11 +854,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items(): for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(pg.size())] state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
dist.all_gather(state_tensor, v.to(device), group=pg) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = ( state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
current_block_size += state_tensor.numel() current_block_size += state_tensor.numel()
current_block[k] = state_tensor current_block[k] = state_tensor
@ -842,12 +880,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
p_id = id(p) p_id = id(p)
if p_id in self.working_to_master_param: if p_id in self.working_to_master_param:
pg = self.param_to_pg[p] pg = self.param_to_pg[p]
world_size = get_nd_world_size(pg)
rank = get_nd_rank(pg)
master_param = self.working_to_master_param[p_id] master_param = self.working_to_master_param[p_id]
padding_size = self.get_param_padding_size(p) padding_size = self.get_param_padding_size(p)
working_param = p.data.view(-1) working_param = p.data.view(-1)
if padding_size > 0: if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size]) working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(pg.size())[pg.rank()]) master_param.copy_(working_param.chunk(world_size)[rank])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self.working_to_master_param return self.working_to_master_param
@ -905,9 +945,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
grad = grad_store.get_working_grad_by_param_id(id(working_param)) grad = grad_store.get_working_grad_by_param_id(id(working_param))
if grad is None: if grad is None:
return None return None
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) grad_flat = grad.flatten()
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) output_grad = torch.empty(
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param) grad_flat.numel() * grad_store.world_size, device=grad_flat.device, dtype=grad_flat.dtype
)
all_gather_into_flat_tensor_nd(output_grad, grad_flat, grad_store.torch_pg)
return output_grad.view(-1)[: working_param.numel()].view_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
working_grads = [] working_grads = []

42
tests/test_zero/test_low_level/test_coll_nd.py

@ -0,0 +1,42 @@
import numpy as np
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.utils import get_current_device
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
def check_all_gather_2d():
seed_all(1024)
tensor = torch.rand(128, device=get_current_device())
extra_dp_size, inner_dp_size = 2, 2
pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
extra_dp_group = pg_mesh.get_group_along_axis(0)
inner_dp_group = pg_mesh.get_group_along_axis(1)
ranks = [dist.get_rank(extra_dp_group), dist.get_rank(inner_dp_group)]
sizes = [dist.get_world_size(extra_dp_group), dist.get_world_size(inner_dp_group)]
chunk = tensor.chunk(dist.get_world_size())[np.ravel_multi_index(ranks, sizes)].clone()
out = torch.zeros_like(tensor)
all_gather_into_flat_tensor_nd(out, chunk, group=(extra_dp_group, inner_dp_group))
assert torch.equal(out, tensor)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_all_gather_2d()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_comm_nd():
spawn(run_dist, 4)
if __name__ == "__main__":
test_comm_nd()

20
tests/test_zero/test_low_level/test_zero1_2.py

@ -2,11 +2,13 @@ import copy
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -123,7 +125,8 @@ def exam_zero_1_2(fp8_communication: bool):
@parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("master_weights", [True, False]) @parameterize("master_weights", [True, False])
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): @parameterize("extra_dp_size", [1, 2])
def exam_zero_1_torch_ddp(dtype: torch.dtype, master_weights: bool, extra_dp_size: int):
""" """
In this test, two pairs of model and optimizers are created. In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters 1. zero: use sharded optimizer and fp16 parameters
@ -132,6 +135,15 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
We feed these two sets of models with the same input and check if the We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance. differences in model output and updated parameters are within tolerance.
""" """
if extra_dp_size > 1 and dtype != torch.bfloat16:
return
if extra_dp_size > 1:
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
extra_dp_group = pg_mesh.get_group_along_axis(0)
dp_group = pg_mesh.get_group_along_axis(1)
else:
extra_dp_group = None
dp_group = None
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(1453) seed_all(1453)
@ -153,6 +165,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
initial_scale=1, initial_scale=1,
reduce_bucket_size=1024 * 1024, reduce_bucket_size=1024 * 1024,
master_weights=master_weights, master_weights=master_weights,
dp_process_group=dp_group,
extra_dp_group=extra_dp_group,
) )
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@ -200,14 +214,14 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
exam_zero_1_torch_ddp(world_size=world_size) exam_zero_1_torch_ddp()
exam_zero_1_2() exam_zero_1_2()
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_1_2(): def test_zero_1_2():
spawn(run_dist, 2) spawn(run_dist, 4)
if __name__ == "__main__": if __name__ == "__main__":

23
tests/test_zero/test_low_level/test_zero_ckpt.py

@ -2,12 +2,14 @@ import copy
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -40,11 +42,19 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
assert_close(a, b, rtol=rtol, atol=atol) assert_close(a, b, rtol=rtol, atol=atol)
def exam_zero_1_torch_ddp_ckpt(): @parameterize("extra_dp_size", [1, 2])
def exam_zero_1_torch_ddp_ckpt(extra_dp_size: int):
""" """
We examine the state_dict of zero and DDP. We examine the state_dict of zero and DDP.
Moreover, we examine the zero's loading checkpoint of a torch ckpt. Moreover, we examine the zero's loading checkpoint of a torch ckpt.
""" """
if extra_dp_size > 1:
pg_mesh = ProcessGroupMesh(extra_dp_size, dist.get_world_size() // extra_dp_size)
extra_dp_group = pg_mesh.get_group_along_axis(0)
dp_group = pg_mesh.get_group_along_axis(1)
else:
dp_group = None
extra_dp_group = None
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(1453) seed_all(1453)
@ -60,7 +70,12 @@ def exam_zero_1_torch_ddp_ckpt():
# we only test stage 1 here # we only test stage 1 here
# the state dicts of stage 1 and stage 2 are the same # the state dicts of stage 1 and stage 2 are the same
zero_optimizer = LowLevelZeroOptimizer( zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144 zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=262144,
dp_process_group=dp_group,
extra_dp_group=extra_dp_group,
) )
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@ -111,7 +126,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_ckpt(): def test_zero_ckpt():
spawn(run_dist, 2) spawn(run_dist, 4)
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save