mirror of https://github.com/hpcaitech/ColossalAI
[gemini] gemini support extra-dp (#5043)
* support ddp * fix * fix * fix fix * support ddp * fix * fix * fix fix * simplify tests * fix * fix * fix fix fix * fixpull/5060/head
parent
b2ad0d9e8f
commit
3e02154710
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||||
from colossalai.checkpoint_io.utils import (
|
from colossalai.checkpoint_io.utils import (
|
||||||
|
@ -34,8 +35,7 @@ __all__ = ["GeminiPlugin"]
|
||||||
SUPPORTED_PRECISION = ["fp16", "bf16"]
|
SUPPORTED_PRECISION = ["fp16", "bf16"]
|
||||||
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
|
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
|
||||||
|
|
||||||
DP_AXIS = 0
|
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
||||||
TP_AXIS = 1
|
|
||||||
|
|
||||||
def get_param_info(optim: Optimizer):
|
def get_param_info(optim: Optimizer):
|
||||||
# Get a backup of necessary information of parameters for future use, which includes:
|
# Get a backup of necessary information of parameters for future use, which includes:
|
||||||
|
@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase):
|
||||||
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
||||||
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
||||||
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||||
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
|
tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
|
||||||
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
|
extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
|
||||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
@ -347,8 +347,8 @@ class GeminiPlugin(DPPluginBase):
|
||||||
max_scale: float = 2**32,
|
max_scale: float = 2**32,
|
||||||
max_norm: float = 0.0,
|
max_norm: float = 0.0,
|
||||||
norm_type: float = 2.0,
|
norm_type: float = 2.0,
|
||||||
enable_tensor_parallelism: bool = False,
|
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
|
extra_dp_size:int = 1,
|
||||||
enable_all_optimization: bool = False,
|
enable_all_optimization: bool = False,
|
||||||
enable_fused_normalization: bool = False,
|
enable_fused_normalization: bool = False,
|
||||||
enable_flash_attention: bool = False,
|
enable_flash_attention: bool = False,
|
||||||
|
@ -393,7 +393,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
)
|
)
|
||||||
self.enable_tensor_parallelism = enable_tensor_parallelism
|
self.enable_tensor_parallelism = tp_size > 1
|
||||||
self.enable_all_optimization = enable_all_optimization
|
self.enable_all_optimization = enable_all_optimization
|
||||||
self.enable_fused_normalization = enable_fused_normalization
|
self.enable_fused_normalization = enable_fused_normalization
|
||||||
self.enable_flash_attention = enable_flash_attention
|
self.enable_flash_attention = enable_flash_attention
|
||||||
|
@ -402,12 +402,17 @@ class GeminiPlugin(DPPluginBase):
|
||||||
self.enable_sequence_overlap = enable_sequence_overlap
|
self.enable_sequence_overlap = enable_sequence_overlap
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
self.tp_size = tp_size if self.enable_tensor_parallelism else 1
|
self.tp_size = tp_size
|
||||||
self.dp_size = dist.get_world_size() // self.tp_size
|
self.extra_dp_size = extra_dp_size
|
||||||
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
|
world_size = dist.get_world_size()
|
||||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
|
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
|
||||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."
|
||||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
|
||||||
|
self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
|
||||||
|
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
|
||||||
|
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
|
||||||
|
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
|
||||||
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
enable_tensor_parallelism=self.enable_tensor_parallelism,
|
enable_tensor_parallelism=self.enable_tensor_parallelism,
|
||||||
|
@ -458,7 +463,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
shardformer = ShardFormer(self.shard_config)
|
shardformer = ShardFormer(self.shard_config)
|
||||||
model, _ = shardformer.optimize(model)
|
model, _ = shardformer.optimize(model)
|
||||||
|
|
||||||
model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
|
model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)
|
||||||
|
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer = GeminiOptimizer(
|
optimizer = GeminiOptimizer(
|
||||||
|
|
|
@ -61,12 +61,13 @@ class Chunk:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
process_group: ProcessGroup,
|
zero_group: ProcessGroup,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
init_device: Optional[torch.device] = None,
|
init_device: Optional[torch.device] = None,
|
||||||
cpu_shard_init: bool = False,
|
cpu_shard_init: bool = False,
|
||||||
keep_gathered: bool = False,
|
keep_gathered: bool = False,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
|
extra_dp_group: ProcessGroup = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Chunk: A container owning a piece of contiguous memory space for tensors
|
Chunk: A container owning a piece of contiguous memory space for tensors
|
||||||
|
@ -76,7 +77,7 @@ class Chunk:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunk_size (int): the number of elements in the chunk
|
chunk_size (int): the number of elements in the chunk
|
||||||
process_group (ProcessGroup): the process group of this chunk
|
zero_group (ProcessGroup): the process group of this chunk
|
||||||
dtype (torch.dtype): the data type of the chunk
|
dtype (torch.dtype): the data type of the chunk
|
||||||
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
|
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
|
||||||
The default value is None, which is the current GPU
|
The default value is None, which is the current GPU
|
||||||
|
@ -90,9 +91,11 @@ class Chunk:
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.utilized_size = 0
|
self.utilized_size = 0
|
||||||
|
|
||||||
self.torch_pg = process_group
|
self.torch_pg = zero_group
|
||||||
self.pg_size = dist.get_world_size(self.torch_pg)
|
self.pg_size = dist.get_world_size(self.torch_pg)
|
||||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||||
|
self.extra_dp_group = extra_dp_group
|
||||||
|
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1
|
||||||
|
|
||||||
# the chunk size should be divisible by the dp degree
|
# the chunk size should be divisible by the dp degree
|
||||||
if not keep_gathered:
|
if not keep_gathered:
|
||||||
|
@ -384,14 +387,20 @@ class Chunk:
|
||||||
# just move cuda_global_chunk to cuda_shard
|
# just move cuda_global_chunk to cuda_shard
|
||||||
# the communication is not necessary
|
# the communication is not necessary
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
|
if self.extra_dp_group is not None:
|
||||||
|
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||||
elif self.keep_gathered:
|
elif self.keep_gathered:
|
||||||
# we use all-reduce here
|
# we use all-reduce here
|
||||||
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
||||||
|
if self.extra_dp_group is not None:
|
||||||
|
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||||
else:
|
else:
|
||||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||||
|
|
||||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||||
|
if self.extra_dp_group is not None:
|
||||||
|
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||||
|
|
||||||
free_storage(self.cuda_global_chunk)
|
free_storage(self.cuda_global_chunk)
|
||||||
self.is_gathered = False
|
self.is_gathered = False
|
||||||
|
@ -608,10 +617,11 @@ class Chunk:
|
||||||
# grad chunk is not initialized
|
# grad chunk is not initialized
|
||||||
grad_chunk = Chunk(
|
grad_chunk = Chunk(
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
process_group=self.torch_pg,
|
zero_group=self.torch_pg,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
keep_gathered=self.keep_gathered,
|
keep_gathered=self.keep_gathered,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
|
extra_dp_group=self.extra_dp_group,
|
||||||
)
|
)
|
||||||
grad_chunk.num_tensors = self.num_tensors
|
grad_chunk.num_tensors = self.num_tensors
|
||||||
grad_chunk.utilized_size = self.utilized_size
|
grad_chunk.utilized_size = self.utilized_size
|
||||||
|
|
|
@ -38,7 +38,8 @@ class ChunkManager:
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
group_type: str,
|
group_type: str,
|
||||||
config_key: int,
|
config_key: int,
|
||||||
process_group: ProcessGroup,
|
zero_group: ProcessGroup,
|
||||||
|
extra_dp_group: ProcessGroup = None,
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -76,15 +77,16 @@ class ChunkManager:
|
||||||
|
|
||||||
if tensor.numel() > chunk_size:
|
if tensor.numel() > chunk_size:
|
||||||
chunk_size = tensor.numel()
|
chunk_size = tensor.numel()
|
||||||
dp_size = dist.get_world_size(process_group)
|
dp_size = dist.get_world_size(zero_group)
|
||||||
chunk_size = chunk_size + (-chunk_size % dp_size)
|
chunk_size = chunk_size + (-chunk_size % dp_size)
|
||||||
|
|
||||||
chunk = Chunk(
|
chunk = Chunk(
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
process_group=process_group,
|
zero_group=zero_group,
|
||||||
dtype=tensor.dtype,
|
dtype=tensor.dtype,
|
||||||
cpu_shard_init=cpu_offload,
|
cpu_shard_init=cpu_offload,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
extra_dp_group=extra_dp_group,
|
||||||
**chunk_kwargs,
|
**chunk_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -86,9 +86,10 @@ class GeminiDDP(ModelWrapper):
|
||||||
strict_ddp_mode: bool = False,
|
strict_ddp_mode: bool = False,
|
||||||
scatter_after_inference: bool = True,
|
scatter_after_inference: bool = True,
|
||||||
mixed_precision: torch.dtype = torch.float16,
|
mixed_precision: torch.dtype = torch.float16,
|
||||||
process_group: Optional[ProcessGroup] = None,
|
zero_group: Optional[ProcessGroup] = None,
|
||||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||||
master_weights: bool = True,
|
master_weights: bool = True,
|
||||||
|
extra_dp_group: Optional[ProcessGroup] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||||
|
@ -105,7 +106,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
search_range_m=search_range_m,
|
search_range_m=search_range_m,
|
||||||
min_chunk_size_m=min_chunk_size_m,
|
min_chunk_size_m=min_chunk_size_m,
|
||||||
strict_ddp_flag=strict_ddp_mode,
|
strict_ddp_flag=strict_ddp_mode,
|
||||||
process_group=process_group,
|
process_group=zero_group,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
self.gemini_manager = GeminiManager(
|
self.gemini_manager = GeminiManager(
|
||||||
|
@ -128,7 +129,8 @@ class GeminiDDP(ModelWrapper):
|
||||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||||
self.scatter_after_inference = scatter_after_inference
|
self.scatter_after_inference = scatter_after_inference
|
||||||
self.mixed_precision = mixed_precision
|
self.mixed_precision = mixed_precision
|
||||||
self.dp_process_group = process_group or _get_default_group()
|
self.zero_group = zero_group or _get_default_group()
|
||||||
|
self.extra_dp_group = extra_dp_group
|
||||||
|
|
||||||
self.reuse_fp16_chunk = master_weights
|
self.reuse_fp16_chunk = master_weights
|
||||||
self.master_weights = master_weights
|
self.master_weights = master_weights
|
||||||
|
@ -377,8 +379,12 @@ class GeminiDDP(ModelWrapper):
|
||||||
self.chunk_manager.release_chunk(chunk)
|
self.chunk_manager.release_chunk(chunk)
|
||||||
if grad_chunk.is_gathered:
|
if grad_chunk.is_gathered:
|
||||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||||
|
if self.extra_dp_group is not None:
|
||||||
|
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
||||||
else:
|
else:
|
||||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||||
|
if self.extra_dp_group is not None:
|
||||||
|
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
||||||
# check overflow elements
|
# check overflow elements
|
||||||
self.overflow_counter += grad_chunk.has_inf_or_nan
|
self.overflow_counter += grad_chunk.has_inf_or_nan
|
||||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||||
|
@ -733,7 +739,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
unexpected_keys.append(key)
|
unexpected_keys.append(key)
|
||||||
|
|
||||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||||
dp_world_size = dist.get_world_size(self.dp_process_group)
|
zero_world_size = dist.get_world_size(self.zero_group)
|
||||||
for p in param_order.generate():
|
for p in param_order.generate():
|
||||||
self._preprocess_param(p)
|
self._preprocess_param(p)
|
||||||
assert type(p) is ColoParameter
|
assert type(p) is ColoParameter
|
||||||
|
@ -753,8 +759,9 @@ class GeminiDDP(ModelWrapper):
|
||||||
self.chunk_manager.register_tensor(
|
self.chunk_manager.register_tensor(
|
||||||
tensor=p,
|
tensor=p,
|
||||||
group_type="fp16_param",
|
group_type="fp16_param",
|
||||||
config_key=dp_world_size,
|
config_key=zero_world_size,
|
||||||
process_group=self.dp_process_group,
|
zero_group=self.zero_group,
|
||||||
|
extra_dp_group=self.extra_dp_group,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
|
@ -767,8 +774,9 @@ class GeminiDDP(ModelWrapper):
|
||||||
self.chunk_manager.register_tensor(
|
self.chunk_manager.register_tensor(
|
||||||
tensor=fp32_p,
|
tensor=fp32_p,
|
||||||
group_type="fp32_param",
|
group_type="fp32_param",
|
||||||
config_key=dp_world_size,
|
config_key=zero_world_size,
|
||||||
process_group=self.dp_process_group,
|
zero_group=self.zero_group,
|
||||||
|
extra_dp_group=self.extra_dp_group,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import pytest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -17,14 +18,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
|
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
if init_method == "lazy":
|
if init_method == "lazy":
|
||||||
ctx = LazyInitContext()
|
ctx = LazyInitContext()
|
||||||
else:
|
else:
|
||||||
ctx = nullcontext()
|
ctx = nullcontext()
|
||||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||||
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
|
enable_all_optimization = True if tp_size > 1 else False
|
||||||
|
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
with ctx:
|
with ctx:
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
|
@ -62,8 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso
|
||||||
|
|
||||||
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
|
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
|
||||||
@parameterize("init_method", ["none"])
|
@parameterize("init_method", ["none"])
|
||||||
@parameterize("enable_tensor_parallelism", [True, False])
|
@parameterize("zero_size", [2])
|
||||||
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
|
@parameterize("tp_size", [2])
|
||||||
|
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
|
||||||
"""check gemini plugin over model zoo
|
"""check gemini plugin over model zoo
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -125,9 +128,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa
|
||||||
|
|
||||||
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
|
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
|
||||||
if "transformers_blip2" in name:
|
if "transformers_blip2" in name:
|
||||||
enable_tensor_parallelism = False
|
tp_size = 1
|
||||||
|
|
||||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
|
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if err is None:
|
if err is None:
|
||||||
passed_models.append(name)
|
passed_models.append(name)
|
||||||
|
@ -153,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||||
def test_gemini_plugin(early_stop: bool = True):
|
def test_gemini_plugin(early_stop: bool = True):
|
||||||
spawn(run_dist, 4, early_stop=early_stop)
|
spawn(run_dist, 4, early_stop=early_stop)
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_gemini_plugin_3d(early_stop: bool = True):
|
||||||
|
spawn(run_dist, 8, early_stop=early_stop)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_gemini_plugin(early_stop=False)
|
test_gemini_plugin(early_stop=False)
|
|
@ -37,20 +37,21 @@ OPTIM_PLACEMENT_CONFIGS = [
|
||||||
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
|
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
|
||||||
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
|
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
|
||||||
@parameterize("use_safetensors", [False, True])
|
@parameterize("use_safetensors", [False, True])
|
||||||
@parameterize("enable_tensor_parallelism", [True, False])
|
@parameterize("tp_size", [1, 2])
|
||||||
@parameterize("tp_size", [2])
|
@parameterize("zero_size", [2])
|
||||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int):
|
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
|
||||||
from transformers import BertForSequenceClassification
|
from transformers import BertForSequenceClassification
|
||||||
|
|
||||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||||
bert_model = model_fn()
|
bert_model = model_fn()
|
||||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
enable_all_optimization = True if tp_size > 1 else False
|
||||||
|
|
||||||
with shared_tempdir() as tempdir:
|
with shared_tempdir() as tempdir:
|
||||||
pretrained_path = os.path.join(tempdir, "pretrained")
|
pretrained_path = os.path.join(tempdir, "pretrained")
|
||||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||||
|
|
||||||
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
|
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||||
|
plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||||
|
@ -69,13 +70,14 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||||
@parameterize("shard", [True, False])
|
@parameterize("shard", [True, False])
|
||||||
@parameterize("model_name", ["transformers_gpt"])
|
@parameterize("model_name", ["transformers_gpt"])
|
||||||
@parameterize("size_per_shard", [32])
|
@parameterize("size_per_shard", [32])
|
||||||
@parameterize("enable_tensor_parallelism", [True, False])
|
@parameterize("tp_size", [1, 2])
|
||||||
@parameterize("tp_size", [2])
|
@parameterize("zero_size", [2])
|
||||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int):
|
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
|
||||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
enable_all_optimization = True if enable_tensor_parallelism else False
|
enable_all_optimization = True if tp_size > 1 else False
|
||||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
|
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||||
|
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
|
@ -158,3 +160,9 @@ def run_dist(rank, world_size, port):
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_gemini_ckpIO(world_size):
|
def test_gemini_ckpIO(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@pytest.mark.parametrize("world_size", [8])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_gemini_ckpIO_3d(world_size):
|
||||||
|
spawn(run_dist, world_size)
|
|
@ -124,25 +124,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
|
||||||
"tp_size": 1,
|
|
||||||
"pp_size": 2,
|
|
||||||
"num_microbatches": 4,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"precision": "fp16",
|
|
||||||
"max_norm": 5,
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 1,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": False,
|
|
||||||
"precision": "fp16",
|
|
||||||
"max_norm": 5,
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
@ -153,23 +134,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"tp_size": 1,
|
|
||||||
"pp_size": 2,
|
|
||||||
"num_microbatches": 4,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"precision": "bf16",
|
|
||||||
"max_norm": 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 1,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": False,
|
|
||||||
"precision": "bf16",
|
|
||||||
"max_norm": 5,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
|
|
@ -102,28 +102,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
|
||||||
"tp_size": 1,
|
|
||||||
"pp_size": 2,
|
|
||||||
"num_microbatches": 4,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"precision": "fp32",
|
|
||||||
"max_norm": 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 1,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": False,
|
|
||||||
"precision": "fp32",
|
|
||||||
"max_norm": 5,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": False,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -148,7 +131,7 @@ def run_test(test_config):
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": False,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
|
|
@ -106,17 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": False,
|
||||||
"use_lazy_init": True,
|
|
||||||
"precision": "fp16",
|
|
||||||
"max_norm": 5,
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 1,
|
|
||||||
"zero_stage": 1,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -126,36 +116,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": False,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"tp_size": 1,
|
|
||||||
"pp_size": 2,
|
|
||||||
"num_microbatches": 4,
|
|
||||||
"zero_stage": 1,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"precision": "bf16",
|
|
||||||
"max_norm": 5,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": False,
|
||||||
"use_lazy_init": False,
|
|
||||||
"precision": "bf16",
|
|
||||||
"max_norm": 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 1,
|
|
||||||
"zero_stage": 2,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
|
|
@ -39,7 +39,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||||
pg = _get_default_group()
|
pg = _get_default_group()
|
||||||
my_chunk = Chunk(
|
my_chunk = Chunk(
|
||||||
chunk_size=1024,
|
chunk_size=1024,
|
||||||
process_group=pg,
|
zero_group=pg,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
init_device=init_device,
|
init_device=init_device,
|
||||||
cpu_shard_init=True,
|
cpu_shard_init=True,
|
||||||
|
|
Loading…
Reference in New Issue