pull/5922/head
YeAnbang 4 months ago
commit 845ea7214e

@ -1 +1,3 @@
2.1.0-12.1.0
2.2.2-12.1.0
2.3.0-12.1.0

@ -55,41 +55,27 @@ jobs:
steps:
- name: Install dependencies
run: |
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

@ -49,42 +49,27 @@ jobs:
steps:
- name: Install dependencies
run: |
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

@ -43,47 +43,28 @@ jobs:
steps:
- name: Install dependencies
run: |
apt update && apt install -y cmake
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing
run: |
PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

@ -2,7 +2,7 @@ import ctypes
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import partial
from types import MethodType
@ -33,8 +33,11 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .pp_plugin_base import PipelinePluginBase
@ -61,6 +64,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
overlap_allgather: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
@ -69,6 +73,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.sp_group = sp_group
self.use_dpp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
@ -106,6 +111,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
@ -197,6 +208,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
with self._wait_all_gather():
return super().forward(*args, **kwargs)
def unwrap(self):
@ -205,6 +217,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = module.module
return module
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
@ -650,6 +669,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
self.model = model
self.param_info = param_info
@ -677,6 +697,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather,
)
def sync_dp_grads(self):
@ -992,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
) -> None:
super().__init__()
assert (
@ -1143,6 +1165,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
)
self.max_norm = max_norm
@ -1220,6 +1243,7 @@ class HybridParallelPlugin(PipelinePluginBase):
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:
@ -1302,7 +1326,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
with ctx, model._wait_all_gather():
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)

@ -2,6 +2,7 @@ import enum
import logging
import os
import warnings
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from types import MethodType
@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
@ -72,13 +76,26 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
with ctx:
return super().forward(*args, **kwargs)
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()
@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
overlap_allgather: bool = False,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
@ -315,6 +356,7 @@ class LowLevelZeroPlugin(DPPluginBase):
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
overlap_allgather=overlap_allgather,
)
self.lora_enabled = False
self.verbose = verbose
@ -431,7 +473,9 @@ class LowLevelZeroPlugin(DPPluginBase):
self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
)
# TODO: Support Galore + ZeRO
zero_stage = self.stage

@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
if os.path.isfile(checkpoint):
@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()
@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
if self.dp_rank != 0:
@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
strict = False
model_before_wrapping = model
model = model.unwrap()

@ -91,7 +91,11 @@ def _broadcast_object_list(
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
if Version(torch.__version__) >= Version("1.13.0"):
if Version(torch.__version__) >= Version("2.3.0"):
tensor_list, size_list = zip(
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
)
elif Version(torch.__version__) >= Version("1.13.0"):
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
@ -276,7 +280,11 @@ def _send_recv_serialization_object(
send_object_tensor = None
send_object_size_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
if Version(torch.__version__) >= Version("2.3.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
object, device=current_device, group=send_group
)
elif Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
else:
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)

@ -1,3 +1,4 @@
import math
from typing import List, Optional, Tuple, Union
import torch
@ -513,7 +514,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
@ -698,9 +698,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
next_decoder_cache = None
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
for decoder_layer in self.layers:
if output_hidden_states:

@ -473,7 +473,7 @@ class LayoutConverter(metaclass=SingletonMeta):
for process_group in used_process_groups:
try:
dist.get_rank(process_group)
except RuntimeError as e:
except (ValueError, RuntimeError) as e:
# If the group is not registered, it means it has been deleted
if str(e) == (
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"

@ -1,4 +1,3 @@
from copy import deepcopy
from typing import Dict, List
from ..utils import merge_same_dim_mesh_list
@ -23,10 +22,11 @@ class DimSpec:
Otherwise, the element in shard_list means the data will be sharded in that dimension.
"""
_DIFFERENCE_DICT = None
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other):
return str(self) == str(other)
@ -39,24 +39,43 @@ class DimSpec:
target += str(dim)
return target
def _convert_str_to_shard_list(self, str_spec):
@property
def difference_dict(self):
"""
Convert str_spec into shard_list.
Returns the difference dict, and lazily initializes it when needed
Argument:
str_spec(str): dim spec in str type.
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
return self._DIFFERENCE_DICT
def build_difference_2d_dict(self):
def dim_diff(self, other):
"""
The difference between two DimSpec.
Argument:
other(DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two DimSpec.
Example:
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.dim_diff(other_dim_spec))
Output:
5
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
@classmethod
def _build_difference_2d_dict(cls):
"""
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
@ -67,9 +86,8 @@ class DimSpec:
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
source_shard_list = self._convert_str_to_shard_list(source_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
source_shard_list = cls._convert_str_to_shard_list(source_spec)
target_shard_list = cls._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
@ -112,30 +130,27 @@ class DimSpec:
else:
difference = NAN
difference_dict[spec_pair] = difference
difference_dict[(source_spec, target_spec)] = difference
self.difference_dict = difference_dict
return difference_dict
def dim_diff(self, other):
@staticmethod
def _convert_str_to_shard_list(str_spec):
"""
The difference between two _DimSpec.
Convert str_spec into shard_list.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
str_spec(str): dim spec in str type.
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpec:

@ -1,5 +1,4 @@
import operator
from copy import deepcopy
from functools import reduce
import torch
@ -27,10 +26,11 @@ class _DimSpec:
Otherwise, the element in shard_list means the data will be sharded in that dimension.
"""
_DIFFERENCE_DICT = None
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other):
return str(self) == str(other)
@ -43,27 +43,46 @@ class _DimSpec:
target += str(dim)
return target
def _convert_str_to_shard_list(self, str_spec):
@property
def difference_dict(self):
"""
Convert str_spec into shard_list.
Returns the difference dict, and lazily initializes it when needed
Argument:
str_spec(str): dim spec in str type.
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
return self._DIFFERENCE_DICT
def build_difference_2d_dict(self):
def difference(self, other):
"""
The difference between two _DimSpec.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
@classmethod
def _build_difference_2d_dict(cls):
"""
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
compute the difference between _DimSpec pairs.
"""
source_spec_list = ["R", "S0", "S1", "S01"]
@ -71,9 +90,8 @@ class _DimSpec:
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
source_shard_list = self._convert_str_to_shard_list(source_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
source_shard_list = cls._convert_str_to_shard_list(source_spec)
target_shard_list = cls._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
@ -116,30 +134,27 @@ class _DimSpec:
else:
difference = NAN
difference_dict[spec_pair] = difference
difference_dict[(source_spec, target_spec)] = difference
self.difference_dict = difference_dict
return difference_dict
def difference(self, other):
@staticmethod
def _convert_str_to_shard_list(str_spec):
"""
The difference between two _DimSpec.
Convert str_spec into shard_list.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
str_spec(str): dim spec in str type.
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpecException(Exception):

@ -23,6 +23,7 @@ from colossalai.logging import get_dist_logger
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -121,6 +123,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# communication params
self._overlap_communication = overlap_communication
self._overlap_allgather = overlap_allgather
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
@ -145,6 +148,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# record the padding size of each param
self._padding_map = dict()
# padded working param is all-gather buffer and it shares the same memory with working param
self._working_param_to_padded_working_param = dict()
# mapping working param and master param
self.master_to_working_param = dict()
@ -245,11 +250,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
with torch.no_grad():
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights
if self._master_weights == False:
# # reset working params' ptr when no master weights
# if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
self._working_param_to_padded_working_param[param] = padding_param
splited_params = padding_param.split(
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
@ -258,7 +264,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params.detach().float().to(device)
splited_param_current_rank = splited_params.detach().clone().float().to(device)
else:
splited_param_current_rank = splited_params
@ -549,12 +555,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param]
padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather:
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
set_all_gather_handle(working_param, handle)
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
@ -562,6 +569,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
if not self._overlap_allgather:
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
@ -892,3 +900,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
grad_store = self.pid_to_grad_store[param_id]
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)
def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)

@ -0,0 +1,33 @@
from typing import List
from torch._tensor import Tensor
from colossalai.tensor.param_op_hook import ColoParamOpHook
_ALL_GATHER_HANDLE = "_all_gather_handle"
def wait_all_gather_handle(p):
if hasattr(p, _ALL_GATHER_HANDLE):
handle = getattr(p, _ALL_GATHER_HANDLE)
handle.wait()
delattr(p, _ALL_GATHER_HANDLE)
def set_all_gather_handle(p, handle):
setattr(p, _ALL_GATHER_HANDLE, handle)
class ZeroOpHook(ColoParamOpHook):
def pre_forward(self, params: List[Tensor]) -> None:
for p in params:
wait_all_gather_handle(p)
def post_forward(self, params: List[Tensor]) -> None:
pass
def pre_backward(self, params: List[Tensor]) -> None:
pass
def post_backward(self, params: List[Tensor]) -> None:
pass

@ -98,6 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
args = parser.parse_args()
colossalai.launch_from_torch()
@ -199,9 +200,9 @@ def main():
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":

@ -113,13 +113,13 @@ class PerformanceEvaluator:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
get_accelerator().synchronize()
# get_accelerator().synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
get_accelerator().synchronize()
# get_accelerator().synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape

@ -8,7 +8,7 @@ click
fabric
contexttimer
ninja
torch>=2.1.0,<2.3.0
torch>=2.1.0,<=2.3.0
safetensors
einops
pydantic

@ -135,51 +135,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_qwen2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
@ -242,6 +197,54 @@ def run_qwen2_test(test_config):
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_qwen2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
@ -259,7 +262,11 @@ def run_qwen2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter()
Randomizer.reset_index()

@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc():
zero1_optimizer.step()
zero2_optimizer.step()
zero1_optimizer._force_wait_all_gather()
zero2_optimizer._force_wait_all_gather()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert not hasattr(z1p, "_all_gather_handle")
assert torch.equal(z1p.data, z2p.data)

@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
# torch ddp step
torch_optimizer.step()
zero_optimizer._force_wait_all_gather()
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p, z1p, dtype=dtype)

@ -1 +1 @@
0.4.0
0.4.1

Loading…
Cancel
Save