mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] chunk init using runtime visited param order (#2115)
parent
e7d3afc9cc
commit
9214d1fe28
|
@ -1,10 +1,10 @@
|
|||
import math
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
|
||||
|
@ -73,7 +73,8 @@ def search_chunk_configuration(
|
|||
search_range_mb: float,
|
||||
search_interval_byte: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_mb: float = 32,
|
||||
filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
|
||||
filter_exlarge_params: bool = True,
|
||||
memstas: Optional[MemStats] = None) -> Tuple[Dict, int]:
|
||||
"""search_chunk_configuration
|
||||
|
||||
Args:
|
||||
|
@ -86,6 +87,10 @@ def search_chunk_configuration(
|
|||
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
|
||||
"""
|
||||
|
||||
if memstas is not None:
|
||||
param_order = memstas.param_order()
|
||||
else:
|
||||
# build the param visited order right now
|
||||
param_order = OrderedParamGenerator()
|
||||
for p in model.parameters():
|
||||
param_order.append(p)
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
|
||||
|
||||
def init_chunk_manager(model: nn.Module,
|
||||
|
@ -37,13 +38,13 @@ def init_chunk_manager(model: nn.Module,
|
|||
total_size = sum(params_sizes) / 1024**2
|
||||
|
||||
dist.barrier()
|
||||
begine = time()
|
||||
begin = time()
|
||||
|
||||
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
|
||||
|
||||
dist.barrier()
|
||||
end = time()
|
||||
span_s = end - begine
|
||||
span_s = end - begin
|
||||
wasted_size /= 1024**2
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
|
|
|
@ -25,6 +25,7 @@ class GeminiManager:
|
|||
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
||||
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
||||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||
|
@ -33,8 +34,11 @@ class GeminiManager:
|
|||
self.policy_name = placement_policy
|
||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||
self._chunk_manager = chunk_manager
|
||||
|
||||
self._premade_memstats_ = memstats is not None
|
||||
self._memstats = memstats
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
||||
memstats) if policy_cls.need_mem_stats else None
|
||||
self._memstats) if policy_cls.need_mem_stats else None
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
|
@ -46,6 +50,19 @@ class GeminiManager:
|
|||
self._warmup = True
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
def memstats(self):
|
||||
"""memstats
|
||||
|
||||
get the memory statistics during training.
|
||||
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
|
||||
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
||||
"""
|
||||
if self._premade_memstats_:
|
||||
return self._memstats
|
||||
else:
|
||||
assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup."
|
||||
return self._mem_stats_collector._memstats
|
||||
|
||||
def pre_iter(self, *args):
|
||||
if self._mem_stats_collector and self._warmup:
|
||||
self._mem_stats_collector.start_collection()
|
||||
|
|
|
@ -23,6 +23,12 @@ class MemStats(object):
|
|||
|
||||
self._param_runtime_order = OrderedParamGenerator()
|
||||
|
||||
def param_order(self):
|
||||
if self._param_runtime_order.is_empty():
|
||||
raise RuntimeError
|
||||
else:
|
||||
return self._param_runtime_order
|
||||
|
||||
def append_overall_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
self._overall_cuda_list.append(val)
|
||||
|
|
|
@ -37,7 +37,7 @@ class MemStatsCollector:
|
|||
self._memstats = MemStats()
|
||||
|
||||
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
||||
"""Get max non model data memory usage of current sampling period
|
||||
"""Maximum non model data memory usage during the next Op run
|
||||
|
||||
Args:
|
||||
device_type (str): device type, can be 'cpu' or 'cuda'.
|
||||
|
@ -47,6 +47,9 @@ class MemStatsCollector:
|
|||
"""
|
||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
||||
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
|
||||
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
|
||||
f"step total {self._step_total}"
|
||||
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
|
||||
self._step_idx = (self._step_idx + 1) % self._step_total
|
||||
return next_non_model_data
|
||||
|
@ -61,7 +64,8 @@ class MemStatsCollector:
|
|||
|
||||
def finish_collection(self):
|
||||
self.sample_overall_data()
|
||||
self._step_total = len(self._sampling_time)
|
||||
# self._step_total = len(self._sampling_time)
|
||||
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
||||
self._start_flag = False
|
||||
self._mem_monitor.finish()
|
||||
|
||||
|
|
|
@ -35,5 +35,8 @@ class OrderedParamGenerator(ParamGenerator):
|
|||
visited_set.add(p)
|
||||
del visited_set
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.param_visited_order) > 0
|
||||
|
||||
def clear(self):
|
||||
self.param_visited_order = []
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.distributed as dist
|
|||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
|
@ -216,8 +217,18 @@ class ZeroDDP(ColoDDP):
|
|||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
|
||||
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
||||
# TODO: get param order and filter unused params
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
# build chunk in param runtime visited order.
|
||||
param_order = self.gemini_manager.memstats()._param_runtime_order
|
||||
else:
|
||||
# build chunk in param initialized order.
|
||||
# Note: in this way, it can not get filter unused params during runtime.
|
||||
param_order = OrderedParamGenerator()
|
||||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
|
@ -243,7 +254,7 @@ class ZeroDDP(ColoDDP):
|
|||
self.chunk_manager.close_all_groups()
|
||||
self._cast_buffers()
|
||||
|
||||
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
for p, fp32_p in zip(params_list, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
|||
|
||||
from colossalai.gemini.chunk import init_chunk_manager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
|
||||
from .data_parallel import ZeroDDP
|
||||
|
||||
|
@ -18,7 +19,8 @@ class GeminiDDP(ZeroDDP):
|
|||
force_outputs_fp32: bool = False,
|
||||
search_range_mb: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: Optional[float] = None) -> None:
|
||||
min_chunk_size_mb: Optional[float] = None,
|
||||
memstats: Optional[MemStats] = None) -> None:
|
||||
"""
|
||||
A torch.Module warpper using ZeRO-DP and Genimi.
|
||||
ZeRO is for parallel. Gemini is for memory management.
|
||||
|
@ -44,11 +46,12 @@ class GeminiDDP(ZeroDDP):
|
|||
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
||||
If the aggregate size of parameters is still samller than the minimum chunk size,
|
||||
all parameters will be compacted into one small chunk.
|
||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||
"""
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_mb=search_range_mb,
|
||||
min_chunk_size_mb=min_chunk_size_mb)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
||||
|
|
|
@ -8,7 +8,8 @@ import colossalai
|
|||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
@ -44,29 +45,27 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
|||
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
|
||||
memstats = runtime_mem_tracer.memstats()
|
||||
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
||||
print('runtime tracer: ', runtime_tracer_non_model_data)
|
||||
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats)
|
||||
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
||||
|
||||
pg = ProcessGroup()
|
||||
set_seed(pg.dp_local_rank())
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
# you can only test a single fwd + bwd.
|
||||
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||
if i > 1:
|
||||
# print(f'iteration {i}')
|
||||
if i > 4:
|
||||
break
|
||||
input_ids, label = input_ids.cuda(), label.cuda()
|
||||
|
||||
zero_optim.zero_grad()
|
||||
set_seed(42)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, model)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
zero_optim.step()
|
||||
|
||||
gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
||||
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
||||
|
||||
# print('gemini non model data:', gemini_non_model_data)
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from functools import partial
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
|
Loading…
Reference in New Issue