[Gemini] chunk init using runtime visited param order (#2115)

pull/2123/head
Jiarui Fang 2022-12-12 18:06:16 +08:00 committed by GitHub
parent e7d3afc9cc
commit 9214d1fe28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 77 additions and 29 deletions

View File

@ -1,10 +1,10 @@
import math import math
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.memory_tracer import OrderedParamGenerator from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
@ -73,7 +73,8 @@ def search_chunk_configuration(
search_range_mb: float, search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32, min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Tuple[Dict, int]: filter_exlarge_params: bool = True,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int]:
"""search_chunk_configuration """search_chunk_configuration
Args: Args:
@ -86,9 +87,13 @@ def search_chunk_configuration(
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
""" """
param_order = OrderedParamGenerator() if memstas is not None:
for p in model.parameters(): param_order = memstas.param_order()
param_order.append(p) else:
# build the param visited order right now
param_order = OrderedParamGenerator()
for p in model.parameters():
param_order.append(p)
search_range_byte = round(search_range_mb * 1024**2) search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
from colossalai.gemini.memory_tracer import MemStats
def init_chunk_manager(model: nn.Module, def init_chunk_manager(model: nn.Module,
@ -37,13 +38,13 @@ def init_chunk_manager(model: nn.Module,
total_size = sum(params_sizes) / 1024**2 total_size = sum(params_sizes) / 1024**2
dist.barrier() dist.barrier()
begine = time() begin = time()
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict) config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
dist.barrier() dist.barrier()
end = time() end = time()
span_s = end - begine span_s = end - begin
wasted_size /= 1024**2 wasted_size /= 1024**2
if dist.get_rank() == 0: if dist.get_rank() == 0:

View File

@ -25,6 +25,7 @@ class GeminiManager:
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training. Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance. chunk_manager (ChunkManager): A ``ChunkManager`` instance.
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
""" """
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
@ -33,8 +34,11 @@ class GeminiManager:
self.policy_name = placement_policy self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy) policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
self._premade_memstats_ = memstats is not None
self._memstats = memstats
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
memstats) if policy_cls.need_mem_stats else None self._memstats) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1 self._compute_idx: int = -1
@ -46,6 +50,19 @@ class GeminiManager:
self._warmup = True self._warmup = True
self._comp_cuda_demand_time = 0 self._comp_cuda_demand_time = 0
def memstats(self):
"""memstats
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter you can not access the memstats before warmup iteration finishes.
"""
if self._premade_memstats_:
return self._memstats
else:
assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup."
return self._mem_stats_collector._memstats
def pre_iter(self, *args): def pre_iter(self, *args):
if self._mem_stats_collector and self._warmup: if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.start_collection() self._mem_stats_collector.start_collection()

View File

@ -23,6 +23,12 @@ class MemStats(object):
self._param_runtime_order = OrderedParamGenerator() self._param_runtime_order = OrderedParamGenerator()
def param_order(self):
if self._param_runtime_order.is_empty():
raise RuntimeError
else:
return self._param_runtime_order
def append_overall_data(self, device_type: str, val: float): def append_overall_data(self, device_type: str, val: float):
if device_type == 'cuda': if device_type == 'cuda':
self._overall_cuda_list.append(val) self._overall_cuda_list.append(val)

View File

@ -37,7 +37,7 @@ class MemStatsCollector:
self._memstats = MemStats() self._memstats = MemStats()
def next_period_non_model_data_usage(self, device_type: str) -> int: def next_period_non_model_data_usage(self, device_type: str) -> int:
"""Get max non model data memory usage of current sampling period """Maximum non model data memory usage during the next Op run
Args: Args:
device_type (str): device type, can be 'cpu' or 'cuda'. device_type (str): device type, can be 'cpu' or 'cuda'.
@ -47,6 +47,9 @@ class MemStatsCollector:
""" """
assert not self._start_flag, 'Cannot get mem stats info during collection phase.' assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
f"step total {self._step_total}"
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data return next_non_model_data
@ -61,7 +64,8 @@ class MemStatsCollector:
def finish_collection(self): def finish_collection(self):
self.sample_overall_data() self.sample_overall_data()
self._step_total = len(self._sampling_time) # self._step_total = len(self._sampling_time)
self._step_total = len(self._memstats.non_model_data_list('cuda'))
self._start_flag = False self._start_flag = False
self._mem_monitor.finish() self._mem_monitor.finish()

View File

@ -35,5 +35,8 @@ class OrderedParamGenerator(ParamGenerator):
visited_set.add(p) visited_set.add(p)
del visited_set del visited_set
def is_empty(self):
return len(self.param_visited_order) > 0
def clear(self): def clear(self):
self.param_visited_order = [] self.param_visited_order = []

View File

@ -8,6 +8,7 @@ import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
@ -216,8 +217,18 @@ class ZeroDDP(ColoDDP):
self.grads_device: Dict[torch.Tensor, torch.device] = {} self.grads_device: Dict[torch.Tensor, torch.device] = {}
cpu_offload = self.gemini_manager.policy_name != 'cuda' cpu_offload = self.gemini_manager.policy_name != 'cuda'
# TODO: get param order and filter unused params
for p in module.parameters(): if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.gemini_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order = OrderedParamGenerator()
for p in module.parameters():
param_order.append(p)
for p in param_order.generate():
assert isinstance(p, ColoParameter) assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False): if getattr(p, '_ddp_to_ignore', False):
@ -243,7 +254,7 @@ class ZeroDDP(ColoDDP):
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
self._cast_buffers() self._cast_buffers()
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, self.fp32_params): for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32 = self.chunk_manager.get_chunk(fp32_p)

View File

@ -4,6 +4,7 @@ import torch
from colossalai.gemini.chunk import init_chunk_manager from colossalai.gemini.chunk import init_chunk_manager
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import MemStats
from .data_parallel import ZeroDDP from .data_parallel import ZeroDDP
@ -18,7 +19,8 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
search_range_mb: int = 32, search_range_mb: int = 32,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None) -> None: min_chunk_size_mb: Optional[float] = None,
memstats: Optional[MemStats] = None) -> None:
""" """
A torch.Module warpper using ZeRO-DP and Genimi. A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management. ZeRO is for parallel. Gemini is for memory management.
@ -44,11 +46,12 @@ class GeminiDDP(ZeroDDP):
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size, If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk. all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
""" """
chunk_manager = init_chunk_manager(model=module, chunk_manager = init_chunk_manager(model=module,
init_device=device, init_device=device,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
search_range_mb=search_range_mb, search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb) min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)

View File

@ -8,7 +8,8 @@ import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
@ -44,29 +45,27 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
memstats = runtime_mem_tracer.memstats() memstats = runtime_mem_tracer.memstats()
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
print('runtime tracer: ', runtime_tracer_non_model_data) print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
world_size = torch.distributed.get_world_size() model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats)
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
pg = ProcessGroup() pg = ProcessGroup()
set_seed(pg.dp_local_rank()) set_seed(pg.dp_local_rank())
for i, (input_ids, label) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd. # you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization. # after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 1: # print(f'iteration {i}')
if i > 4:
break break
input_ids, label = input_ids.cuda(), label.cuda() input_ids, label = input_ids.cuda(), label.cuda()
zero_optim.zero_grad()
set_seed(42) set_seed(42)
loss = run_fwd_bwd(model, input_ids, label, criterion, model) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
zero_optim.step()
gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
# print('gemini non model data:', gemini_non_model_data) # print('gemini non model data:', gemini_non_model_data)

View File

@ -1,5 +1,4 @@
from functools import partial from functools import partial
from time import time
import pytest import pytest
import torch import torch