[gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix
pull/5060/head
flybird11111 2023-11-16 21:03:04 +08:00 committed by GitHub
parent b2ad0d9e8f
commit 3e02154710
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 96 additions and 137 deletions

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
@ -34,8 +35,7 @@ __all__ = ["GeminiPlugin"]
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
DP_AXIS = 0
TP_AXIS = 1
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
@ -347,8 +347,8 @@ class GeminiPlugin(DPPluginBase):
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
enable_tensor_parallelism: bool = False,
tp_size: int = 1,
extra_dp_size:int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
@ -393,7 +393,7 @@ class GeminiPlugin(DPPluginBase):
max_norm=max_norm,
norm_type=norm_type,
)
self.enable_tensor_parallelism = enable_tensor_parallelism
self.enable_tensor_parallelism = tp_size > 1
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
@ -402,12 +402,17 @@ class GeminiPlugin(DPPluginBase):
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose
self.tp_size = tp_size if self.enable_tensor_parallelism else 1
self.dp_size = dist.get_world_size() // self.tp_size
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.tp_size = tp_size
self.extra_dp_size = extra_dp_size
world_size = dist.get_world_size()
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."
self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
@ -458,7 +463,7 @@ class GeminiPlugin(DPPluginBase):
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)
model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(

View File

@ -61,12 +61,13 @@ class Chunk:
def __init__(
self,
chunk_size: int,
process_group: ProcessGroup,
zero_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False,
extra_dp_group: ProcessGroup = None,
) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
@ -76,7 +77,7 @@ class Chunk:
Args:
chunk_size (int): the number of elements in the chunk
process_group (ProcessGroup): the process group of this chunk
zero_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
@ -90,9 +91,11 @@ class Chunk:
self.chunk_size = chunk_size
self.utilized_size = 0
self.torch_pg = process_group
self.torch_pg = zero_group
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
self.extra_dp_group = extra_dp_group
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1
# the chunk size should be divisible by the dp degree
if not keep_gathered:
@ -384,14 +387,20 @@ class Chunk:
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
@ -608,10 +617,11 @@ class Chunk:
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
process_group=self.torch_pg,
zero_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
extra_dp_group=self.extra_dp_group,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size

View File

@ -38,7 +38,8 @@ class ChunkManager:
tensor: torch.Tensor,
group_type: str,
config_key: int,
process_group: ProcessGroup,
zero_group: ProcessGroup,
extra_dp_group: ProcessGroup = None,
cpu_offload: bool = False,
pin_memory: bool = False,
) -> None:
@ -76,15 +77,16 @@ class ChunkManager:
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
dp_size = dist.get_world_size(process_group)
dp_size = dist.get_world_size(zero_group)
chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk(
chunk_size=chunk_size,
process_group=process_group,
zero_group=zero_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
extra_dp_group=extra_dp_group,
**chunk_kwargs,
)

View File

@ -86,9 +86,10 @@ class GeminiDDP(ModelWrapper):
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
zero_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
master_weights: bool = True,
extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
@ -105,7 +106,7 @@ class GeminiDDP(ModelWrapper):
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
process_group=zero_group,
verbose=verbose,
)
self.gemini_manager = GeminiManager(
@ -128,7 +129,8 @@ class GeminiDDP(ModelWrapper):
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group()
self.zero_group = zero_group or _get_default_group()
self.extra_dp_group = extra_dp_group
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
@ -377,8 +379,12 @@ class GeminiDDP(ModelWrapper):
self.chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if self.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if self.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
self.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
@ -733,7 +739,7 @@ class GeminiDDP(ModelWrapper):
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
dp_world_size = dist.get_world_size(self.dp_process_group)
zero_world_size = dist.get_world_size(self.zero_group)
for p in param_order.generate():
self._preprocess_param(p)
assert type(p) is ColoParameter
@ -753,8 +759,9 @@ class GeminiDDP(ModelWrapper):
self.chunk_manager.register_tensor(
tensor=p,
group_type="fp16_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
config_key=zero_world_size,
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
@ -767,8 +774,9 @@ class GeminiDDP(ModelWrapper):
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
config_key=zero_world_size,
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)

View File

@ -1,5 +1,6 @@
from contextlib import nullcontext
from typing import Optional
import pytest
import torch
import torch.distributed as dist
@ -17,14 +18,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
try:
if init_method == "lazy":
ctx = LazyInitContext()
else:
ctx = nullcontext()
enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
@parameterize("enable_tensor_parallelism", [True, False])
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
"""check gemini plugin over model zoo
Args:
@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name:
enable_tensor_parallelism = False
tp_size = 1
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)
if __name__ == "__main__":
test_gemini_plugin(early_stop=False)

View File

@ -37,20 +37,21 @@ OPTIM_PLACEMENT_CONFIGS = [
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
@parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int):
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
enable_all_optimization = True if enable_tensor_parallelism else False
enable_all_optimization = True if tp_size > 1 else False
with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path)
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@ -69,13 +70,14 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32])
@parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int):
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
enable_all_optimization = True if tp_size > 1 else False
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
model = model_fn()
@ -158,3 +160,9 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO_3d(world_size):
spawn(run_dist, world_size)

View File

@ -124,25 +124,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
@ -153,23 +134,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 2,

View File

@ -102,28 +102,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp32",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
@ -148,7 +131,7 @@ def run_test(test_config):
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,

View File

@ -106,17 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
@ -126,36 +116,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,

View File

@ -39,7 +39,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
pg = _get_default_group()
my_chunk = Chunk(
chunk_size=1024,
process_group=pg,
zero_group=pg,
dtype=torch.float32,
init_device=init_device,
cpu_shard_init=True,