mirror of https://github.com/hpcaitech/ColossalAI
[bug] fix early return (#5740)
* [bug] fix silly bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] add test for prefetch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5727/head
parent
83716e9feb
commit
13c06d36a3
|
@ -361,10 +361,11 @@ class Chunk:
|
||||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
||||||
# sanity check
|
# sanity check
|
||||||
assert self.chunk_temp is None
|
assert self.chunk_temp is None
|
||||||
|
maybe_work = None
|
||||||
if not self.is_gathered:
|
if not self.is_gathered:
|
||||||
return self.__gather(async_op=async_access)
|
maybe_work = self.__gather(async_op=async_access)
|
||||||
self.__update_tensors_ptr()
|
self.__update_tensors_ptr()
|
||||||
return None
|
return maybe_work
|
||||||
|
|
||||||
def release_chunk(self):
|
def release_chunk(self):
|
||||||
"""Release the usable chunk. It's an operation done in CUDA."""
|
"""Release the usable chunk. It's an operation done in CUDA."""
|
||||||
|
|
|
@ -5,7 +5,6 @@ from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.logging import DistributedLogger
|
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||||
from colossalai.utils import is_ddp_ignored
|
from colossalai.utils import is_ddp_ignored
|
||||||
from colossalai.zero.gemini import TensorState
|
from colossalai.zero.gemini import TensorState
|
||||||
|
@ -17,9 +16,6 @@ class TrainingPhase(Enum):
|
||||||
BACKWARD = 1
|
BACKWARD = 1
|
||||||
|
|
||||||
|
|
||||||
logger = DistributedLogger("gemini_hook")
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiZeROHook(ColoParamOpHook):
|
class GeminiZeROHook(ColoParamOpHook):
|
||||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -177,6 +177,10 @@ class GeminiManager:
|
||||||
return self._mem_stats_collector.cuda_margin_mem
|
return self._mem_stats_collector.cuda_margin_mem
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def placement_policy(self) -> PlacementPolicy:
|
||||||
|
return self._placement_policy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def compute_list(self) -> List[Tuple[Chunk, ...]]:
|
def compute_list(self) -> List[Tuple[Chunk, ...]]:
|
||||||
return self._compute_list
|
return self._compute_list
|
||||||
|
@ -189,10 +193,6 @@ class GeminiManager:
|
||||||
def async_works(self) -> Dict[Chunk, dist.Work]:
|
def async_works(self) -> Dict[Chunk, dist.Work]:
|
||||||
return self._async_works
|
return self._async_works
|
||||||
|
|
||||||
@property
|
|
||||||
def placement_policy(self) -> PlacementPolicy:
|
|
||||||
return self._placement_policy
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_cuda_margin_mem_avail(self) -> bool:
|
def is_cuda_margin_mem_avail(self) -> bool:
|
||||||
return self._placement_policy.need_mem_stats
|
return self._placement_policy.need_mem_stats
|
||||||
|
|
|
@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||||
@parameterize("use_grad_checkpoint", [False, True])
|
@parameterize("use_grad_checkpoint", [False, True])
|
||||||
@parameterize("master_weights", [False, True])
|
@parameterize("master_weights", [False, True])
|
||||||
|
@parameterize("max_prefetch", [0, 1, 4])
|
||||||
def exam_gpt_fwd_bwd(
|
def exam_gpt_fwd_bwd(
|
||||||
placement_config,
|
placement_config,
|
||||||
keep_gather,
|
keep_gather,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
use_grad_checkpoint: bool = False,
|
use_grad_checkpoint: bool = False,
|
||||||
master_weights: bool = True,
|
master_weights: bool = True,
|
||||||
|
max_prefetch: int = 0,
|
||||||
):
|
):
|
||||||
init_device = get_accelerator().get_current_device()
|
init_device = get_accelerator().get_current_device()
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
|
@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd(
|
||||||
config_dict[world_size]["chunk_size"] = 5000
|
config_dict[world_size]["chunk_size"] = 5000
|
||||||
config_dict[world_size]["keep_gathered"] = keep_gather
|
config_dict[world_size]["keep_gathered"] = keep_gather
|
||||||
model = GeminiDDP(
|
model = GeminiDDP(
|
||||||
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
|
model,
|
||||||
|
config_dict,
|
||||||
|
init_device,
|
||||||
|
pin_memory=True,
|
||||||
|
**placement_config,
|
||||||
|
master_weights=master_weights,
|
||||||
|
max_prefetch=max_prefetch,
|
||||||
)
|
)
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
||||||
|
|
|
@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||||
@parameterize("master_weights", [False, True])
|
@parameterize("master_weights", [False, True])
|
||||||
@parameterize("use_grad_checkpoint", [False, True])
|
@parameterize("use_grad_checkpoint", [False, True])
|
||||||
|
@parameterize("max_prefetch", [0, 1, 4])
|
||||||
def exam_gemini_grad_acc(
|
def exam_gemini_grad_acc(
|
||||||
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
|
placement_config,
|
||||||
|
keep_gathered: bool,
|
||||||
|
model_name: str,
|
||||||
|
master_weights: bool,
|
||||||
|
use_grad_checkpoint: bool,
|
||||||
|
max_prefetch: int,
|
||||||
):
|
):
|
||||||
init_device = get_accelerator().get_current_device()
|
init_device = get_accelerator().get_current_device()
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
|
@ -81,6 +87,7 @@ def exam_gemini_grad_acc(
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
enable_gradient_accumulation=True,
|
enable_gradient_accumulation=True,
|
||||||
master_weights=master_weights,
|
master_weights=master_weights,
|
||||||
|
max_prefetch=max_prefetch,
|
||||||
**placement_config,
|
**placement_config,
|
||||||
)
|
)
|
||||||
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
||||||
|
|
|
@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||||
@parameterize("master_weights", [True, False])
|
@parameterize("master_weights", [True, False])
|
||||||
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
|
@parameterize("max_prefetch", [0, 1, 4])
|
||||||
|
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, max_prefetch: int):
|
||||||
set_seed(1912)
|
set_seed(1912)
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
iter(model_zoo.get_sub_registry(model_name).values())
|
iter(model_zoo.get_sub_registry(model_name).values())
|
||||||
|
@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
|
||||||
chunk_init_device=init_device,
|
chunk_init_device=init_device,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
master_weights=master_weights,
|
master_weights=master_weights,
|
||||||
|
max_prefetch=max_prefetch,
|
||||||
**placement_config,
|
**placement_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
|
||||||
@parameterize("model_name", TEST_MODELS)
|
@parameterize("model_name", TEST_MODELS)
|
||||||
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
|
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
|
||||||
@parameterize("master_weights", [True, False])
|
@parameterize("master_weights", [True, False])
|
||||||
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
|
@parameterize("max_prefetch", [0, 1, 4])
|
||||||
|
def exam_model_step(
|
||||||
|
placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, max_prefetch: int
|
||||||
|
):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
iter(model_zoo.get_sub_registry(model_name).values())
|
iter(model_zoo.get_sub_registry(model_name).values())
|
||||||
|
@ -94,7 +97,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
|
||||||
config_dict[world_size]["chunk_size"] = 5000
|
config_dict[world_size]["chunk_size"] = 5000
|
||||||
config_dict[world_size]["keep_gathered"] = False
|
config_dict[world_size]["keep_gathered"] = False
|
||||||
model = GeminiDDP(
|
model = GeminiDDP(
|
||||||
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
|
model,
|
||||||
|
config_dict,
|
||||||
|
**placement_config,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
master_weights=master_weights,
|
||||||
|
max_prefetch=max_prefetch,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
|
|
@ -28,7 +28,8 @@ def ignore_the_first_parameter(model: torch.nn.Module):
|
||||||
@parameterize("keep_gathered", [True, False])
|
@parameterize("keep_gathered", [True, False])
|
||||||
@parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"])
|
@parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"])
|
||||||
@parameterize("master_weights", [False, True])
|
@parameterize("master_weights", [False, True])
|
||||||
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
@parameterize("max_prefetch", [0, 1, 4])
|
||||||
|
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool, max_prefetch: int):
|
||||||
set_seed(431)
|
set_seed(431)
|
||||||
model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||||
|
|
||||||
|
@ -44,7 +45,14 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]["chunk_size"] = 5000
|
config_dict[world_size]["chunk_size"] = 5000
|
||||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
|
model = GeminiDDP(
|
||||||
|
model,
|
||||||
|
config_dict,
|
||||||
|
**placement_config,
|
||||||
|
pin_memory=True,
|
||||||
|
master_weights=master_weights,
|
||||||
|
max_prefetch=max_prefetch,
|
||||||
|
)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
zero_dict = model.state_dict(only_rank_0=False)
|
zero_dict = model.state_dict(only_rank_0=False)
|
||||||
|
|
|
@ -20,7 +20,8 @@ PLACEMENT_CONFIGS = [
|
||||||
|
|
||||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||||
@parameterize("keep_gathered", [True, False])
|
@parameterize("keep_gathered", [True, False])
|
||||||
def exam_zero_optim_state_dict(placement_config, keep_gathered):
|
@parameterize("max_prefetch", [0, 1, 4])
|
||||||
|
def exam_zero_optim_state_dict(placement_config, keep_gathered, max_prefetch):
|
||||||
set_seed(431)
|
set_seed(431)
|
||||||
model_builder, data_gen_fn, output_transform_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, *_ = next(
|
||||||
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
|
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
|
||||||
|
@ -35,7 +36,7 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered):
|
||||||
config_dict[world_size]["chunk_size"] = 5000
|
config_dict[world_size]["chunk_size"] = 5000
|
||||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||||
|
|
||||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, max_prefetch=max_prefetch)
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters())
|
optimizer = HybridAdam(model.parameters())
|
||||||
optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
|
optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
|
||||||
|
|
Loading…
Reference in New Issue