[gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix
pull/5060/head
flybird11111 2023-11-16 21:03:04 +08:00 committed by GitHub
parent b2ad0d9e8f
commit 3e02154710
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 96 additions and 137 deletions

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
@ -34,8 +35,7 @@ __all__ = ["GeminiPlugin"]
SUPPORTED_PRECISION = ["fp16", "bf16"] SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
DP_AXIS = 0 ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
TP_AXIS = 1
def get_param_info(optim: Optimizer): def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes: # Get a backup of necessary information of parameters for future use, which includes:
@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`. norm_type (float, optional): norm_type used for `clip_grad_norm`.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT. Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False. Defaults to False.
@ -347,8 +347,8 @@ class GeminiPlugin(DPPluginBase):
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
enable_tensor_parallelism: bool = False,
tp_size: int = 1, tp_size: int = 1,
extra_dp_size:int = 1,
enable_all_optimization: bool = False, enable_all_optimization: bool = False,
enable_fused_normalization: bool = False, enable_fused_normalization: bool = False,
enable_flash_attention: bool = False, enable_flash_attention: bool = False,
@ -393,7 +393,7 @@ class GeminiPlugin(DPPluginBase):
max_norm=max_norm, max_norm=max_norm,
norm_type=norm_type, norm_type=norm_type,
) )
self.enable_tensor_parallelism = enable_tensor_parallelism self.enable_tensor_parallelism = tp_size > 1
self.enable_all_optimization = enable_all_optimization self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention self.enable_flash_attention = enable_flash_attention
@ -402,12 +402,17 @@ class GeminiPlugin(DPPluginBase):
self.enable_sequence_overlap = enable_sequence_overlap self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose self.verbose = verbose
self.tp_size = tp_size if self.enable_tensor_parallelism else 1 self.tp_size = tp_size
self.dp_size = dist.get_world_size() // self.tp_size self.extra_dp_size = extra_dp_size
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." world_size = dist.get_world_size()
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size) self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism, enable_tensor_parallelism=self.enable_tensor_parallelism,
@ -458,7 +463,7 @@ class GeminiPlugin(DPPluginBase):
shardformer = ShardFormer(self.shard_config) shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model) model, _ = shardformer.optimize(model)
model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer( optimizer = GeminiOptimizer(

View File

@ -61,12 +61,13 @@ class Chunk:
def __init__( def __init__(
self, self,
chunk_size: int, chunk_size: int,
process_group: ProcessGroup, zero_group: ProcessGroup,
dtype: torch.dtype, dtype: torch.dtype,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False, cpu_shard_init: bool = False,
keep_gathered: bool = False, keep_gathered: bool = False,
pin_memory: bool = False, pin_memory: bool = False,
extra_dp_group: ProcessGroup = None,
) -> None: ) -> None:
""" """
Chunk: A container owning a piece of contiguous memory space for tensors Chunk: A container owning a piece of contiguous memory space for tensors
@ -76,7 +77,7 @@ class Chunk:
Args: Args:
chunk_size (int): the number of elements in the chunk chunk_size (int): the number of elements in the chunk
process_group (ProcessGroup): the process group of this chunk zero_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored. init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU The default value is None, which is the current GPU
@ -90,9 +91,11 @@ class Chunk:
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.utilized_size = 0 self.utilized_size = 0
self.torch_pg = process_group self.torch_pg = zero_group
self.pg_size = dist.get_world_size(self.torch_pg) self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg) self.pg_rank = dist.get_rank(self.torch_pg)
self.extra_dp_group = extra_dp_group
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1
# the chunk size should be divisible by the dp degree # the chunk size should be divisible by the dp degree
if not keep_gathered: if not keep_gathered:
@ -384,14 +387,20 @@ class Chunk:
# just move cuda_global_chunk to cuda_shard # just move cuda_global_chunk to cuda_shard
# the communication is not necessary # the communication is not necessary
self.__scatter() self.__scatter()
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
elif self.keep_gathered: elif self.keep_gathered:
# we use all-reduce here # we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else: else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
free_storage(self.cuda_global_chunk) free_storage(self.cuda_global_chunk)
self.is_gathered = False self.is_gathered = False
@ -608,10 +617,11 @@ class Chunk:
# grad chunk is not initialized # grad chunk is not initialized
grad_chunk = Chunk( grad_chunk = Chunk(
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
process_group=self.torch_pg, zero_group=self.torch_pg,
dtype=self.dtype, dtype=self.dtype,
keep_gathered=self.keep_gathered, keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
extra_dp_group=self.extra_dp_group,
) )
grad_chunk.num_tensors = self.num_tensors grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size grad_chunk.utilized_size = self.utilized_size

View File

@ -38,7 +38,8 @@ class ChunkManager:
tensor: torch.Tensor, tensor: torch.Tensor,
group_type: str, group_type: str,
config_key: int, config_key: int,
process_group: ProcessGroup, zero_group: ProcessGroup,
extra_dp_group: ProcessGroup = None,
cpu_offload: bool = False, cpu_offload: bool = False,
pin_memory: bool = False, pin_memory: bool = False,
) -> None: ) -> None:
@ -76,15 +77,16 @@ class ChunkManager:
if tensor.numel() > chunk_size: if tensor.numel() > chunk_size:
chunk_size = tensor.numel() chunk_size = tensor.numel()
dp_size = dist.get_world_size(process_group) dp_size = dist.get_world_size(zero_group)
chunk_size = chunk_size + (-chunk_size % dp_size) chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk( chunk = Chunk(
chunk_size=chunk_size, chunk_size=chunk_size,
process_group=process_group, zero_group=zero_group,
dtype=tensor.dtype, dtype=tensor.dtype,
cpu_shard_init=cpu_offload, cpu_shard_init=cpu_offload,
pin_memory=pin_memory, pin_memory=pin_memory,
extra_dp_group=extra_dp_group,
**chunk_kwargs, **chunk_kwargs,
) )

View File

@ -86,9 +86,10 @@ class GeminiDDP(ModelWrapper):
strict_ddp_mode: bool = False, strict_ddp_mode: bool = False,
scatter_after_inference: bool = True, scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16, mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None, zero_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats memstats: Optional[MemStats] = None, # genimi memory stats
master_weights: bool = True, master_weights: bool = True,
extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False, verbose: bool = False,
) -> None: ) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16) assert mixed_precision in (torch.float16, torch.bfloat16)
@ -105,7 +106,7 @@ class GeminiDDP(ModelWrapper):
search_range_m=search_range_m, search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m, min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode, strict_ddp_flag=strict_ddp_mode,
process_group=process_group, process_group=zero_group,
verbose=verbose, verbose=verbose,
) )
self.gemini_manager = GeminiManager( self.gemini_manager = GeminiManager(
@ -128,7 +129,8 @@ class GeminiDDP(ModelWrapper):
self.name2param: Dict[str, nn.Parameter] = dict() self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group() self.zero_group = zero_group or _get_default_group()
self.extra_dp_group = extra_dp_group
self.reuse_fp16_chunk = master_weights self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights self.master_weights = master_weights
@ -377,8 +379,12 @@ class GeminiDDP(ModelWrapper):
self.chunk_manager.release_chunk(chunk) self.chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered: if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size) grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if self.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else: else:
grad_chunk.cuda_shard.div_(chunk.pg_size) grad_chunk.cuda_shard.div_(chunk.pg_size)
if self.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements # check overflow elements
self.overflow_counter += grad_chunk.has_inf_or_nan self.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk # record l2 norm for gradient clipping. flag is bound to fp16 chunk
@ -733,7 +739,7 @@ class GeminiDDP(ModelWrapper):
unexpected_keys.append(key) unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
dp_world_size = dist.get_world_size(self.dp_process_group) zero_world_size = dist.get_world_size(self.zero_group)
for p in param_order.generate(): for p in param_order.generate():
self._preprocess_param(p) self._preprocess_param(p)
assert type(p) is ColoParameter assert type(p) is ColoParameter
@ -753,8 +759,9 @@ class GeminiDDP(ModelWrapper):
self.chunk_manager.register_tensor( self.chunk_manager.register_tensor(
tensor=p, tensor=p,
group_type="fp16_param", group_type="fp16_param",
config_key=dp_world_size, config_key=zero_world_size,
process_group=self.dp_process_group, zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
@ -767,8 +774,9 @@ class GeminiDDP(ModelWrapper):
self.chunk_manager.register_tensor( self.chunk_manager.register_tensor(
tensor=fp32_p, tensor=fp32_p,
group_type="fp32_param", group_type="fp32_param",
config_key=dp_world_size, config_key=zero_world_size,
process_group=self.dp_process_group, zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
pin_memory=pin_memory, pin_memory=pin_memory,
) )

View File

@ -1,5 +1,6 @@
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -17,14 +18,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]: def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
try: try:
if init_method == "lazy": if init_method == "lazy":
ctx = LazyInitContext() ctx = LazyInitContext()
else: else:
ctx = nullcontext() ctx = nullcontext()
enable_all_optimization = True if enable_tensor_parallelism else False extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization) enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
with ctx: with ctx:
model = model_fn() model = model_fn()
@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso
@parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"]) @parameterize("init_method", ["none"])
@parameterize("enable_tensor_parallelism", [True, False]) @parameterize("zero_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True): @parameterize("tp_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
"""check gemini plugin over model zoo """check gemini plugin over model zoo
Args: Args:
@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa
# TODO debug blip2 when using tp, something wrong with shift_logits's shape # TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name: if "transformers_blip2" in name:
enable_tensor_parallelism = False tp_size = 1
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if err is None: if err is None:
passed_models.append(name) passed_models.append(name)
@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
def test_gemini_plugin(early_stop: bool = True): def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop) spawn(run_dist, 4, early_stop=early_stop)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)
if __name__ == "__main__": if __name__ == "__main__":
test_gemini_plugin(early_stop=False) test_gemini_plugin(early_stop=False)

View File

@ -37,20 +37,21 @@ OPTIM_PLACEMENT_CONFIGS = [
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True]) @parameterize("use_safetensors", [False, True])
@parameterize("enable_tensor_parallelism", [True, False]) @parameterize("tp_size", [1, 2])
@parameterize("tp_size", [2]) @parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int): def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn() bert_model = model_fn()
enable_all_optimization = True if enable_tensor_parallelism else False enable_all_optimization = True if tp_size > 1 else False
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained") pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path) bert_model.config.save_pretrained(save_directory=pretrained_path)
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model) bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@ -69,13 +70,14 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@parameterize("shard", [True, False]) @parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_gpt"]) @parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32]) @parameterize("size_per_shard", [32])
@parameterize("enable_tensor_parallelism", [True, False]) @parameterize("tp_size", [1, 2])
@parameterize("tp_size", [2]) @parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int): def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
enable_all_optimization = True if enable_tensor_parallelism else False enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = model_fn() model = model_fn()
@ -158,3 +160,9 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size): def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size) spawn(run_dist, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO_3d(world_size):
spawn(run_dist, world_size)

View File

@ -124,25 +124,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
@ -153,23 +134,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"max_norm": 5, "max_norm": 5,
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,

View File

@ -102,28 +102,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp32",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"max_norm": 5, "max_norm": 5,
@ -148,7 +131,7 @@ def run_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"max_norm": 5, "max_norm": 5,

View File

@ -106,17 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"zero_stage": 1, "zero_stage": 1,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,
@ -126,36 +116,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"zero_stage": 2, "zero_stage": 2,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"zero_stage": 1, "zero_stage": 1,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,

View File

@ -39,7 +39,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
pg = _get_default_group() pg = _get_default_group()
my_chunk = Chunk( my_chunk = Chunk(
chunk_size=1024, chunk_size=1024,
process_group=pg, zero_group=pg,
dtype=torch.float32, dtype=torch.float32,
init_device=init_device, init_device=init_device,
cpu_shard_init=True, cpu_shard_init=True,