From 9214d1fe28f5ef18e44304ebd0542a4a66b90844 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 12 Dec 2022 18:06:16 +0800 Subject: [PATCH] [Gemini] chunk init using runtime visited param order (#2115) --- colossalai/gemini/chunk/search_utils.py | 17 +++++++++----- colossalai/gemini/chunk/utils.py | 5 ++-- colossalai/gemini/gemini_mgr.py | 19 ++++++++++++++- .../gemini/memory_tracer/memory_stats.py | 6 +++++ .../memory_tracer/memstats_collector.py | 8 +++++-- .../memory_tracer/param_runtime_order.py | 3 +++ colossalai/nn/parallel/data_parallel.py | 17 +++++++++++--- colossalai/nn/parallel/gemini_parallel.py | 7 ++++-- .../test_gemini/update/test_gemini_use_rmt.py | 23 +++++++++---------- tests/test_gemini/update/test_optim.py | 1 - 10 files changed, 77 insertions(+), 29 deletions(-) diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py index b92a8b158..312d77f18 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -1,10 +1,10 @@ import math -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import torch.nn as nn -from colossalai.gemini.memory_tracer import OrderedParamGenerator +from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator from colossalai.tensor import ColoParameter @@ -73,7 +73,8 @@ def search_chunk_configuration( search_range_mb: float, search_interval_byte: int, # hidden size is the best value for the interval min_chunk_size_mb: float = 32, - filter_exlarge_params: bool = True) -> Tuple[Dict, int]: + filter_exlarge_params: bool = True, + memstas: Optional[MemStats] = None) -> Tuple[Dict, int]: """search_chunk_configuration Args: @@ -86,9 +87,13 @@ def search_chunk_configuration( Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. """ - param_order = OrderedParamGenerator() - for p in model.parameters(): - param_order.append(p) + if memstas is not None: + param_order = memstas.param_order() + else: + # build the param visited order right now + param_order = OrderedParamGenerator() + for p in model.parameters(): + param_order.append(p) search_range_byte = round(search_range_mb * 1024**2) min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/gemini/chunk/utils.py index 9d87129db..e9a9f84e7 100644 --- a/colossalai/gemini/chunk/utils.py +++ b/colossalai/gemini/chunk/utils.py @@ -7,6 +7,7 @@ import torch.nn as nn from colossalai.gemini.chunk import ChunkManager from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration +from colossalai.gemini.memory_tracer import MemStats def init_chunk_manager(model: nn.Module, @@ -37,13 +38,13 @@ def init_chunk_manager(model: nn.Module, total_size = sum(params_sizes) / 1024**2 dist.barrier() - begine = time() + begin = time() config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict) dist.barrier() end = time() - span_s = end - begine + span_s = end - begin wasted_size /= 1024**2 if dist.get_rank() == 0: diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index c3a813367..ca3165a71 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -25,6 +25,7 @@ class GeminiManager: If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. Note that 'auto' policy can only work well when no other processes use CUDA during your training. chunk_manager (ChunkManager): A ``ChunkManager`` instance. + memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. """ def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: @@ -33,8 +34,11 @@ class GeminiManager: self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager + + self._premade_memstats_ = memstats is not None + self._memstats = memstats self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, - memstats) if policy_cls.need_mem_stats else None + self._memstats) if policy_cls.need_mem_stats else None self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 @@ -46,6 +50,19 @@ class GeminiManager: self._warmup = True self._comp_cuda_demand_time = 0 + def memstats(self): + """memstats + + get the memory statistics during training. + The stats could be collected by a runtime memory tracer, or collected by the GeminiManager. + Note, for the latter, you can not access the memstats before warmup iteration finishes. + """ + if self._premade_memstats_: + return self._memstats + else: + assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup." + return self._mem_stats_collector._memstats + def pre_iter(self, *args): if self._mem_stats_collector and self._warmup: self._mem_stats_collector.start_collection() diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py index a374ab408..a66829863 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -23,6 +23,12 @@ class MemStats(object): self._param_runtime_order = OrderedParamGenerator() + def param_order(self): + if self._param_runtime_order.is_empty(): + raise RuntimeError + else: + return self._param_runtime_order + def append_overall_data(self, device_type: str, val: float): if device_type == 'cuda': self._overall_cuda_list.append(val) diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 7d034dd8f..a81961227 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -37,7 +37,7 @@ class MemStatsCollector: self._memstats = MemStats() def next_period_non_model_data_usage(self, device_type: str) -> int: - """Get max non model data memory usage of current sampling period + """Maximum non model data memory usage during the next Op run Args: device_type (str): device type, can be 'cpu' or 'cuda'. @@ -47,6 +47,9 @@ class MemStatsCollector: """ assert not self._start_flag, 'Cannot get mem stats info during collection phase.' assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' + assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ + f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ + f"step total {self._step_total}" next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] self._step_idx = (self._step_idx + 1) % self._step_total return next_non_model_data @@ -61,7 +64,8 @@ class MemStatsCollector: def finish_collection(self): self.sample_overall_data() - self._step_total = len(self._sampling_time) + # self._step_total = len(self._sampling_time) + self._step_total = len(self._memstats.non_model_data_list('cuda')) self._start_flag = False self._mem_monitor.finish() diff --git a/colossalai/gemini/memory_tracer/param_runtime_order.py b/colossalai/gemini/memory_tracer/param_runtime_order.py index b65251373..dc9226a53 100644 --- a/colossalai/gemini/memory_tracer/param_runtime_order.py +++ b/colossalai/gemini/memory_tracer/param_runtime_order.py @@ -35,5 +35,8 @@ class OrderedParamGenerator(ParamGenerator): visited_set.add(p) del visited_set + def is_empty(self): + return len(self.param_visited_order) > 0 + def clear(self): self.param_visited_order = [] diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 14d85489a..54f6eb9b7 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -8,6 +8,7 @@ import torch.distributed as dist from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer import OrderedParamGenerator from colossalai.logging import get_dist_logger from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.tensor import ProcessGroup as ColoProcessGroup @@ -216,8 +217,18 @@ class ZeroDDP(ColoDDP): self.grads_device: Dict[torch.Tensor, torch.device] = {} cpu_offload = self.gemini_manager.policy_name != 'cuda' - # TODO: get param order and filter unused params - for p in module.parameters(): + + if self.gemini_manager._premade_memstats_: + # build chunk in param runtime visited order. + param_order = self.gemini_manager.memstats()._param_runtime_order + else: + # build chunk in param initialized order. + # Note: in this way, it can not get filter unused params during runtime. + param_order = OrderedParamGenerator() + for p in module.parameters(): + param_order.append(p) + + for p in param_order.generate(): assert isinstance(p, ColoParameter) if getattr(p, '_ddp_to_ignore', False): @@ -243,7 +254,7 @@ class ZeroDDP(ColoDDP): self.chunk_manager.close_all_groups() self._cast_buffers() - params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] + params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)] for p, fp32_p in zip(params_list, self.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) chunk_32 = self.chunk_manager.get_chunk(fp32_p) diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py index bf11631f9..cd5ef424a 100644 --- a/colossalai/nn/parallel/gemini_parallel.py +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -4,6 +4,7 @@ import torch from colossalai.gemini.chunk import init_chunk_manager from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.gemini.memory_tracer import MemStats from .data_parallel import ZeroDDP @@ -18,7 +19,8 @@ class GeminiDDP(ZeroDDP): force_outputs_fp32: bool = False, search_range_mb: int = 32, hidden_dim: Optional[int] = None, - min_chunk_size_mb: Optional[float] = None) -> None: + min_chunk_size_mb: Optional[float] = None, + memstats: Optional[MemStats] = None) -> None: """ A torch.Module warpper using ZeRO-DP and Genimi. ZeRO is for parallel. Gemini is for memory management. @@ -44,11 +46,12 @@ class GeminiDDP(ZeroDDP): min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. If the aggregate size of parameters is still samller than the minimum chunk size, all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. """ chunk_manager = init_chunk_manager(model=module, init_device=device, hidden_dim=hidden_dim, search_range_mb=search_range_mb, min_chunk_size_mb=min_chunk_size_mb) - gemini_manager = GeminiManager(placement_policy, chunk_manager) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py index 564dee005..5a8f066ac 100644 --- a/tests/test_gemini/update/test_gemini_use_rmt.py +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -8,7 +8,8 @@ import colossalai from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from colossalai.nn.parallel import ZeroDDP +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP, ZeroDDP from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port @@ -44,29 +45,27 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list - print('runtime tracer: ', runtime_tracer_non_model_data) + print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) - world_size = torch.distributed.get_world_size() - config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats) + zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1) pg = ProcessGroup() set_seed(pg.dp_local_rank()) for i, (input_ids, label) in enumerate(train_dataloader): # you can only test a single fwd + bwd. # after bwd param is grad for Gemini, due to the chunk reuse optimization. - if i > 1: + # print(f'iteration {i}') + if i > 4: break input_ids, label = input_ids.cuda(), label.cuda() + zero_optim.zero_grad() set_seed(42) - loss = run_fwd_bwd(model, input_ids, label, criterion, model) + loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + zero_optim.step() - gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') # print('gemini non model data:', gemini_non_model_data) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index f9333f3d1..1f1d488a0 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -1,5 +1,4 @@ from functools import partial -from time import time import pytest import torch