mirror of https://github.com/hpcaitech/ColossalAI
[gemini] gemini support extra-dp (#5043)
* support ddp * fix * fix * fix fix * support ddp * fix * fix * fix fix * simplify tests * fix * fix * fix fix fix * fixpull/5060/head
parent
b2ad0d9e8f
commit
3e02154710
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
|
@ -34,8 +35,7 @@ __all__ = ["GeminiPlugin"]
|
|||
SUPPORTED_PRECISION = ["fp16", "bf16"]
|
||||
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
|
||||
|
||||
DP_AXIS = 0
|
||||
TP_AXIS = 1
|
||||
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
|
@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
||||
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
||||
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
|
||||
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
|
||||
tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
|
||||
extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
|
@ -347,8 +347,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
enable_tensor_parallelism: bool = False,
|
||||
tp_size: int = 1,
|
||||
extra_dp_size:int = 1,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
|
@ -393,7 +393,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
)
|
||||
self.enable_tensor_parallelism = enable_tensor_parallelism
|
||||
self.enable_tensor_parallelism = tp_size > 1
|
||||
self.enable_all_optimization = enable_all_optimization
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
|
@ -402,12 +402,17 @@ class GeminiPlugin(DPPluginBase):
|
|||
self.enable_sequence_overlap = enable_sequence_overlap
|
||||
self.verbose = verbose
|
||||
|
||||
self.tp_size = tp_size if self.enable_tensor_parallelism else 1
|
||||
self.dp_size = dist.get_world_size() // self.tp_size
|
||||
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.tp_size = tp_size
|
||||
self.extra_dp_size = extra_dp_size
|
||||
world_size = dist.get_world_size()
|
||||
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
|
||||
assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."
|
||||
|
||||
self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
|
||||
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
|
||||
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
enable_tensor_parallelism=self.enable_tensor_parallelism,
|
||||
|
@ -458,7 +463,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
shardformer = ShardFormer(self.shard_config)
|
||||
model, _ = shardformer.optimize(model)
|
||||
|
||||
model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
|
||||
model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)
|
||||
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(
|
||||
|
|
|
@ -61,12 +61,13 @@ class Chunk:
|
|||
def __init__(
|
||||
self,
|
||||
chunk_size: int,
|
||||
process_group: ProcessGroup,
|
||||
zero_group: ProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
cpu_shard_init: bool = False,
|
||||
keep_gathered: bool = False,
|
||||
pin_memory: bool = False,
|
||||
extra_dp_group: ProcessGroup = None,
|
||||
) -> None:
|
||||
"""
|
||||
Chunk: A container owning a piece of contiguous memory space for tensors
|
||||
|
@ -76,7 +77,7 @@ class Chunk:
|
|||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in the chunk
|
||||
process_group (ProcessGroup): the process group of this chunk
|
||||
zero_group (ProcessGroup): the process group of this chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
|
||||
The default value is None, which is the current GPU
|
||||
|
@ -90,9 +91,11 @@ class Chunk:
|
|||
self.chunk_size = chunk_size
|
||||
self.utilized_size = 0
|
||||
|
||||
self.torch_pg = process_group
|
||||
self.torch_pg = zero_group
|
||||
self.pg_size = dist.get_world_size(self.torch_pg)
|
||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||
self.extra_dp_group = extra_dp_group
|
||||
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1
|
||||
|
||||
# the chunk size should be divisible by the dp degree
|
||||
if not keep_gathered:
|
||||
|
@ -384,14 +387,20 @@ class Chunk:
|
|||
# just move cuda_global_chunk to cuda_shard
|
||||
# the communication is not necessary
|
||||
self.__scatter()
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||
elif self.keep_gathered:
|
||||
# we use all-reduce here
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||
else:
|
||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||
|
||||
free_storage(self.cuda_global_chunk)
|
||||
self.is_gathered = False
|
||||
|
@ -608,10 +617,11 @@ class Chunk:
|
|||
# grad chunk is not initialized
|
||||
grad_chunk = Chunk(
|
||||
chunk_size=self.chunk_size,
|
||||
process_group=self.torch_pg,
|
||||
zero_group=self.torch_pg,
|
||||
dtype=self.dtype,
|
||||
keep_gathered=self.keep_gathered,
|
||||
pin_memory=self.pin_memory,
|
||||
extra_dp_group=self.extra_dp_group,
|
||||
)
|
||||
grad_chunk.num_tensors = self.num_tensors
|
||||
grad_chunk.utilized_size = self.utilized_size
|
||||
|
|
|
@ -38,7 +38,8 @@ class ChunkManager:
|
|||
tensor: torch.Tensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
process_group: ProcessGroup,
|
||||
zero_group: ProcessGroup,
|
||||
extra_dp_group: ProcessGroup = None,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
|
@ -76,15 +77,16 @@ class ChunkManager:
|
|||
|
||||
if tensor.numel() > chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
dp_size = dist.get_world_size(process_group)
|
||||
dp_size = dist.get_world_size(zero_group)
|
||||
chunk_size = chunk_size + (-chunk_size % dp_size)
|
||||
|
||||
chunk = Chunk(
|
||||
chunk_size=chunk_size,
|
||||
process_group=process_group,
|
||||
zero_group=zero_group,
|
||||
dtype=tensor.dtype,
|
||||
cpu_shard_init=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
extra_dp_group=extra_dp_group,
|
||||
**chunk_kwargs,
|
||||
)
|
||||
|
||||
|
|
|
@ -86,9 +86,10 @@ class GeminiDDP(ModelWrapper):
|
|||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
zero_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
master_weights: bool = True,
|
||||
extra_dp_group: Optional[ProcessGroup] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
|
@ -105,7 +106,7 @@ class GeminiDDP(ModelWrapper):
|
|||
search_range_m=search_range_m,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
process_group=process_group,
|
||||
process_group=zero_group,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.gemini_manager = GeminiManager(
|
||||
|
@ -128,7 +129,8 @@ class GeminiDDP(ModelWrapper):
|
|||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
self.mixed_precision = mixed_precision
|
||||
self.dp_process_group = process_group or _get_default_group()
|
||||
self.zero_group = zero_group or _get_default_group()
|
||||
self.extra_dp_group = extra_dp_group
|
||||
|
||||
self.reuse_fp16_chunk = master_weights
|
||||
self.master_weights = master_weights
|
||||
|
@ -377,8 +379,12 @@ class GeminiDDP(ModelWrapper):
|
|||
self.chunk_manager.release_chunk(chunk)
|
||||
if grad_chunk.is_gathered:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
if self.extra_dp_group is not None:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
||||
else:
|
||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||
if self.extra_dp_group is not None:
|
||||
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
||||
# check overflow elements
|
||||
self.overflow_counter += grad_chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||
|
@ -733,7 +739,7 @@ class GeminiDDP(ModelWrapper):
|
|||
unexpected_keys.append(key)
|
||||
|
||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
dp_world_size = dist.get_world_size(self.dp_process_group)
|
||||
zero_world_size = dist.get_world_size(self.zero_group)
|
||||
for p in param_order.generate():
|
||||
self._preprocess_param(p)
|
||||
assert type(p) is ColoParameter
|
||||
|
@ -753,8 +759,9 @@ class GeminiDDP(ModelWrapper):
|
|||
self.chunk_manager.register_tensor(
|
||||
tensor=p,
|
||||
group_type="fp16_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
config_key=zero_world_size,
|
||||
zero_group=self.zero_group,
|
||||
extra_dp_group=self.extra_dp_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
@ -767,8 +774,9 @@ class GeminiDDP(ModelWrapper):
|
|||
self.chunk_manager.register_tensor(
|
||||
tensor=fp32_p,
|
||||
group_type="fp32_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
config_key=zero_world_size,
|
||||
zero_group=self.zero_group,
|
||||
extra_dp_group=self.extra_dp_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -17,14 +18,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
|
||||
try:
|
||||
if init_method == "lazy":
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
|
@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso
|
|||
|
||||
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
|
||||
@parameterize("init_method", ["none"])
|
||||
@parameterize("enable_tensor_parallelism", [True, False])
|
||||
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
|
||||
@parameterize("zero_size", [2])
|
||||
@parameterize("tp_size", [2])
|
||||
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
|
@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa
|
|||
|
||||
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
|
||||
if "transformers_blip2" in name:
|
||||
enable_tensor_parallelism = False
|
||||
tp_size = 1
|
||||
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
|
||||
torch.cuda.empty_cache()
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
|
@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
|||
def test_gemini_plugin(early_stop: bool = True):
|
||||
spawn(run_dist, 4, early_stop=early_stop)
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_plugin_3d(early_stop: bool = True):
|
||||
spawn(run_dist, 8, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gemini_plugin(early_stop=False)
|
|
@ -37,20 +37,21 @@ OPTIM_PLACEMENT_CONFIGS = [
|
|||
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
|
||||
@parameterize("use_safetensors", [False, True])
|
||||
@parameterize("enable_tensor_parallelism", [True, False])
|
||||
@parameterize("tp_size", [2])
|
||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int):
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
pretrained_path = os.path.join(tempdir, "pretrained")
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
|
||||
booster = Booster(plugin=plugin)
|
||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
|
@ -69,13 +70,14 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
|||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_gpt"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("enable_tensor_parallelism", [True, False])
|
||||
@parameterize("tp_size", [2])
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int):
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
|
@ -158,3 +160,9 @@ def run_dist(rank, world_size, port):
|
|||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
@pytest.mark.largedist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO_3d(world_size):
|
||||
spawn(run_dist, world_size)
|
|
@ -124,25 +124,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
@ -153,23 +134,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
|
|
|
@ -102,28 +102,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
|
@ -148,7 +131,7 @@ def run_test(test_config):
|
|||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
|
|
|
@ -106,17 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
@ -126,36 +116,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 2,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 2,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
|
|
@ -39,7 +39,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
|||
pg = _get_default_group()
|
||||
my_chunk = Chunk(
|
||||
chunk_size=1024,
|
||||
process_group=pg,
|
||||
zero_group=pg,
|
||||
dtype=torch.float32,
|
||||
init_device=init_device,
|
||||
cpu_shard_init=True,
|
||||
|
|
Loading…
Reference in New Issue