diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index eb8db6212..ab554d21d 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase): chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -387,6 +388,7 @@ class GeminiPlugin(DPPluginBase): memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], master_weights=master_weights, + max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, ) self.zero_optim_config = dict( diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 8f048f0b7..ed5b96519 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -359,14 +359,15 @@ class Chunk: else: raise NotImplementedError - def access_chunk(self): + def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None - + maybe_work = None if not self.is_gathered: - self.__gather() + maybe_work = self.__gather(async_op=async_access) self.__update_tensors_ptr() + return maybe_work def release_chunk(self): """Release the usable chunk. It's an operation done in CUDA.""" @@ -512,17 +513,19 @@ class Chunk: def get_tensors(self) -> List[torch.Tensor]: return list(self.tensors_info.keys()) - def __gather(self): + def __gather(self, async_op: bool = False) -> Optional[dist.Work]: if not self.is_gathered: # sanity check assert self.cuda_shard is not None alloc_storage(self.cuda_global_chunk) gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op) self.cuda_shard = None self.is_gathered = True + return work + return None def __scatter(self): if self.keep_gathered: diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 6ec595914..36e7ee57b 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -111,15 +111,16 @@ class ChunkManager: for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) - def access_chunk(self, chunk: Chunk) -> None: + def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: - return + return None self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) - self.__add_accessed_chunk(chunk) + maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access) self.__add_memory_usage(chunk.memory_usage) + return maybe_work def release_chunk(self, chunk: Chunk) -> None: """Scatter the chunk in CUDA.""" @@ -251,10 +252,11 @@ class ChunkManager: for k, v in usage.items(): self.total_mem[k] += v - def __add_accessed_chunk(self, chunk: Chunk): - chunk.access_chunk() + def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: + maybe_work = chunk.access_chunk(async_access=async_access) self.accessed_chunks.add(chunk) self.accessed_mem += chunk.chunk_mem + return maybe_work def __sub_accessed_chunk(self, chunk: Chunk): chunk.release_chunk() diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 23f6ee683..050643dfa 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper): chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -131,6 +132,7 @@ class GeminiDDP(ModelWrapper): offload_param_frac=offload_param_frac, warmup_non_model_data_ratio=warmup_non_model_data_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio, + max_prefetch=max_prefetch, ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 480a14511..736238a09 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -24,16 +24,42 @@ class GeminiZeROHook(ColoParamOpHook): self._training_phase = TrainingPhase.FORWARD def pre_op(self, params): + # map params to chunks params = [p for p in params if not is_ddp_ignored(p)] - chunks = self._chunk_manager.get_chunks(params) + all_chunks = self._chunk_manager.get_chunks(params) + + # wait for prefetched chunks, filter those are not prefetched + chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) + + # transfer state for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks) - for chunk in chunks: + + # evit chunks, aware of async fetched + self._gemini_manager.adjust_layout( + all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 + ) + + # fetch the rest synchronously + for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) - # record cuda model data of the current OP + # get possible chunks to prefetch + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks( + is_warmup=self._gemini_manager.is_warmup(), + compute_list=self._gemini_manager.compute_list, + compute_idx=self._gemini_manager.compute_idx, + async_works=self._gemini_manager.async_works, + ) + + # prefetch + for chunk in chunks_fetch_async: + maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) + if maybe_work is not None: + self._gemini_manager.add_work(chunk, maybe_work) + + # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() def post_op(self, params): diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 150932e3d..83e475575 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,12 +1,13 @@ import functools from time import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import torch +import torch.distributed as dist from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector, MemStats -from .placement_policy import PlacementPolicyFactory +from .placement_policy import PlacementPolicy, PlacementPolicyFactory class GeminiManager: @@ -41,9 +42,12 @@ class GeminiManager: self._mem_stats_collector = ( ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ) - self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) + self._placement_policy = policy_cls( + chunk_manager=chunk_manager, mem_stats_collector=self._mem_stats_collector, **placement_kwargs + ) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 + self._async_works: Dict[Chunk, dist.Work] = {} self._h2d_volume = 0 self._d2h_volume = 0 @@ -91,18 +95,20 @@ class GeminiManager: self._warmup = False self.reset_attributes() - def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() - self._record_chunks_order(chunks) - cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) + self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) + cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks) + # don't evict chunks that are asynchronously fetched + can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works] self._layout_time += time() - start vol, evict_time = self._placement_policy.evict_tensors( - can_evict_chunks=hold_cuda_tensor_list, + can_evict_chunks=can_evict_chunks, cuda_demand=cuda_demand, warmup=self._warmup, compute_list=self._compute_list, @@ -114,6 +120,21 @@ class GeminiManager: # move COMPUTE tensors to CUDA self._h2d_volume += cuda_demand + def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]: + non_prefetched_chunks = [] + for chunk in chunks: + if chunk in self._async_works: + self._async_works[chunk].wait() + del self._async_works[chunk] + else: + non_prefetched_chunks.append(chunk) + return tuple(non_prefetched_chunks) + + def add_work(self, chunk: Chunk, work: dist.Work): + assert work is not None + assert chunk not in self._async_works + self._async_works[chunk] = work + @functools.lru_cache(maxsize=None) def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): start = time() @@ -133,9 +154,9 @@ class GeminiManager: can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() return cuda_demand, can_evict_chunks - def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: + def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - if self._warmup and self._placement_policy.need_mem_stats: + if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) def sample_overall_data(self): @@ -156,6 +177,22 @@ class GeminiManager: return self._mem_stats_collector.cuda_margin_mem return None + @property + def placement_policy(self) -> PlacementPolicy: + return self._placement_policy + + @property + def compute_list(self) -> List[Tuple[Chunk, ...]]: + return self._compute_list + + @property + def compute_idx(self) -> int: + return self._compute_idx + + @property + def async_works(self) -> Dict[Chunk, dist.Work]: + return self._async_works + @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 388999549..178755d03 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -5,6 +5,7 @@ from time import time from typing import Dict, List, Optional, Tuple, Type import torch +import torch.distributed as dist from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity @@ -18,10 +19,15 @@ class PlacementPolicy(ABC): need_mem_stats: bool = False def __init__( - self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs + self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, + **kwargs, ) -> None: self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector + self.max_prefetch = max_prefetch @abstractmethod def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: @@ -33,18 +39,24 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError + def get_prefetch_chunks( + self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + return [] # no prefetch by default + class StaticPlacementPolicy(PlacementPolicy): def __init__( self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, shard_param_frac: float = 1.0, offload_optim_frac: float = 0.0, offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -95,6 +107,24 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) + def get_prefetch_chunks( + self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + if is_warmup: # no prefetch during warmup since we need compute_list + return [] + can_prefetch = self.max_prefetch - len(async_works) + prefetch = [] + for i in range(compute_idx + 1, len(compute_list)): + for chunk in compute_list[i]: + if len(prefetch) >= can_prefetch: + break + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch.append(chunk) + else: + continue + break + return prefetch + class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True @@ -103,17 +133,20 @@ class AutoPlacementPolicy(PlacementPolicy): self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, warmup_non_model_data_ratio: float = 0.8, steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() self._warmup_non_model_data_ratio = warmup_non_model_data_ratio self._steady_cuda_cap_ratio = steady_cuda_cap_ratio + self.__avail_cuda_model_data_for_prefetch = None + def evict_tensors( self, can_evict_chunks: List[Chunk], @@ -173,6 +206,7 @@ class AutoPlacementPolicy(PlacementPolicy): f"Adjust layout failed! No enough CUDA memory! " f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" ) + self.__avail_cuda_model_data_for_prefetch = avail_cuda_model_data + freed_cuda_model_data return freed_cuda_model_data, time() - start @staticmethod @@ -198,6 +232,30 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") + def get_prefetch_chunks( + self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work] + ) -> List[Chunk]: + if is_warmup: # no prefetch during warmup since we need compute_list + return [] + + avail_cuda_model_data = self.__avail_cuda_model_data_for_prefetch + self.__avail_cuda_model_data_for_prefetch = None # incase of double use + + prefetch_chunk_memory = 0 + can_prefetch = self.max_prefetch - len(async_works) + prefetch = [] + for i in range(compute_idx + 1, len(compute_list)): + for chunk in compute_list[i]: + if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data: + break + if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks: + prefetch_chunk_memory += chunk.chunk_mem + prefetch.append(chunk) + else: + continue + break + return prefetch + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { diff --git a/examples/language/gpt/gemini/commons/performance_evaluator.py b/examples/language/gpt/gemini/commons/performance_evaluator.py new file mode 120000 index 000000000..152602774 --- /dev/null +++ b/examples/language/gpt/gemini/commons/performance_evaluator.py @@ -0,0 +1 @@ +../../../performance_evaluator.py \ No newline at end of file diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 7ed5fdb92..ba80cc4a6 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -1,8 +1,6 @@ import time -from contextlib import nullcontext import torch -from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler class DummyProfiler: @@ -24,19 +22,6 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) -def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): - if enable_flag: - return profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), - on_trace_ready=tensorboard_trace_handler(save_dir), - record_shapes=True, - profile_memory=True, - ) - else: - return nullcontext(DummyProfiler()) - - def get_time_stamp(): cur_time = time.strftime("%d-%H:%M", time.localtime()) return cur_time diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 4911ff124..cb5d2c32c 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -8,7 +8,8 @@ import psutil import torch import torch.nn as nn from commons.model_zoo import model_builder -from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp +from commons.performance_evaluator import get_profile_context +from commons.utils import get_data, get_tflops, get_time_stamp from packaging import version import colossalai diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 6f91ff7b7..8d4dae314 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -1,11 +1,12 @@ import argparse import resource +import time from contextlib import nullcontext import torch from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel -from performance_evaluator import PerformanceEvaluator +from performance_evaluator import PerformanceEvaluator, get_profile_context from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM @@ -76,8 +77,11 @@ def main(): parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) - parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) - + parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) + parser.add_argument( + "--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False + ) + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") args = parser.parse_args() colossalai.launch_from_torch() @@ -112,6 +116,7 @@ def main(): extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, + max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, ) elif args.plugin == "gemini_auto": @@ -122,6 +127,8 @@ def main(): tp_size=args.tp, extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), + max_prefetch=args.prefetch_num, + enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, ) elif args.plugin == "fsdp": @@ -249,25 +256,37 @@ def main(): f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) - if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: - data_iter = iter(dataloader) - for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): - performance_evaluator.on_step_start(step) - booster.execute_pipeline( - data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False - ) - optimizer.step() - optimizer.zero_grad() - performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) - else: - for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): - performance_evaluator.on_step_start(step) - outputs = model(**batch) - loss = outputs[0] - booster.backward(loss, optimizer) - optimizer.step() - optimizer.zero_grad() - performance_evaluator.on_step_end(**batch) + with get_profile_context( + args.profile, + 1, + len(dataloader) - 1, + save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + ) as prof: + if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=False, + ) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + prof.step() + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(**batch) + prof.step() performance_evaluator.on_fit_end() coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") diff --git a/examples/language/performance_evaluator.py b/examples/language/performance_evaluator.py index c2169a730..6b8daf37d 100644 --- a/examples/language/performance_evaluator.py +++ b/examples/language/performance_evaluator.py @@ -4,6 +4,7 @@ from typing import Optional import torch import torch.distributed as dist from torch import Tensor +from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator @@ -27,6 +28,33 @@ def all_reduce_mean(x: float, world_size: int) -> float: return tensor.item() +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): + class DummyProfiler: + def __init__(self): + self.step_number = 0 + + def step(self): + self.step_number += 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + if enable_flag: + return profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + else: + return DummyProfiler() + + class Timer: def __init__(self) -> None: self.start_time: Optional[float] = None diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 4279793d7..4d3981329 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -40,6 +40,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) +@parameterize("max_prefetch", [0, 4]) @parameterize("enable_async_reduce", [False, True]) def exam_gpt_fwd_bwd( placement_config, @@ -47,6 +48,7 @@ def exam_gpt_fwd_bwd( model_name: str, use_grad_checkpoint: bool = False, master_weights: bool = True, + max_prefetch: int = 0, enable_async_reduce=True, ): init_device = get_accelerator().get_current_device() @@ -77,6 +79,7 @@ def exam_gpt_fwd_bwd( pin_memory=True, **placement_config, master_weights=master_weights, + max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index 6e6c27e3f..002741389 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -50,6 +50,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) @parameterize("use_grad_checkpoint", [False, True]) +@parameterize("max_prefetch", [0, 4]) @parameterize("enable_async_reduce", [False, True]) def exam_gemini_grad_acc( placement_config, @@ -57,6 +58,7 @@ def exam_gemini_grad_acc( model_name: str, master_weights: bool, use_grad_checkpoint: bool, + max_prefetch: int, enable_async_reduce: bool, ): init_device = get_accelerator().get_current_device() @@ -87,6 +89,7 @@ def exam_gemini_grad_acc( pin_memory=True, enable_gradient_accumulation=True, master_weights=master_weights, + max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, **placement_config, ) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 7a1609ca5..41a66e91e 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,8 +52,11 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [True, False]) +@parameterize("max_prefetch", [0, 1, 4]) @parameterize("enable_async_reduce", [False, True]) -def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, enable_async_reduce: bool): +def exam_grad_clipping( + placement_config, model_name: str, master_weights: bool, max_prefetch: int, enable_async_reduce: bool +): set_seed(1912) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -85,6 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, chunk_init_device=init_device, pin_memory=True, master_weights=master_weights, + max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, **placement_config, )