mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] chunk init using runtime visited param order (#2115)
parent
e7d3afc9cc
commit
9214d1fe28
|
@ -1,10 +1,10 @@
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
||||||
from colossalai.tensor import ColoParameter
|
from colossalai.tensor import ColoParameter
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,7 +73,8 @@ def search_chunk_configuration(
|
||||||
search_range_mb: float,
|
search_range_mb: float,
|
||||||
search_interval_byte: int, # hidden size is the best value for the interval
|
search_interval_byte: int, # hidden size is the best value for the interval
|
||||||
min_chunk_size_mb: float = 32,
|
min_chunk_size_mb: float = 32,
|
||||||
filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
|
filter_exlarge_params: bool = True,
|
||||||
|
memstas: Optional[MemStats] = None) -> Tuple[Dict, int]:
|
||||||
"""search_chunk_configuration
|
"""search_chunk_configuration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -86,9 +87,13 @@ def search_chunk_configuration(
|
||||||
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
|
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
param_order = OrderedParamGenerator()
|
if memstas is not None:
|
||||||
for p in model.parameters():
|
param_order = memstas.param_order()
|
||||||
param_order.append(p)
|
else:
|
||||||
|
# build the param visited order right now
|
||||||
|
param_order = OrderedParamGenerator()
|
||||||
|
for p in model.parameters():
|
||||||
|
param_order.append(p)
|
||||||
|
|
||||||
search_range_byte = round(search_range_mb * 1024**2)
|
search_range_byte = round(search_range_mb * 1024**2)
|
||||||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.gemini.chunk import ChunkManager
|
from colossalai.gemini.chunk import ChunkManager
|
||||||
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
|
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
|
||||||
|
from colossalai.gemini.memory_tracer import MemStats
|
||||||
|
|
||||||
|
|
||||||
def init_chunk_manager(model: nn.Module,
|
def init_chunk_manager(model: nn.Module,
|
||||||
|
@ -37,13 +38,13 @@ def init_chunk_manager(model: nn.Module,
|
||||||
total_size = sum(params_sizes) / 1024**2
|
total_size = sum(params_sizes) / 1024**2
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
begine = time()
|
begin = time()
|
||||||
|
|
||||||
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
|
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
end = time()
|
end = time()
|
||||||
span_s = end - begine
|
span_s = end - begin
|
||||||
wasted_size /= 1024**2
|
wasted_size /= 1024**2
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
|
|
|
@ -25,6 +25,7 @@ class GeminiManager:
|
||||||
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
||||||
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
||||||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||||
|
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||||
|
@ -33,8 +34,11 @@ class GeminiManager:
|
||||||
self.policy_name = placement_policy
|
self.policy_name = placement_policy
|
||||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
|
|
||||||
|
self._premade_memstats_ = memstats is not None
|
||||||
|
self._memstats = memstats
|
||||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
||||||
memstats) if policy_cls.need_mem_stats else None
|
self._memstats) if policy_cls.need_mem_stats else None
|
||||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||||
self._compute_idx: int = -1
|
self._compute_idx: int = -1
|
||||||
|
@ -46,6 +50,19 @@ class GeminiManager:
|
||||||
self._warmup = True
|
self._warmup = True
|
||||||
self._comp_cuda_demand_time = 0
|
self._comp_cuda_demand_time = 0
|
||||||
|
|
||||||
|
def memstats(self):
|
||||||
|
"""memstats
|
||||||
|
|
||||||
|
get the memory statistics during training.
|
||||||
|
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
|
||||||
|
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
||||||
|
"""
|
||||||
|
if self._premade_memstats_:
|
||||||
|
return self._memstats
|
||||||
|
else:
|
||||||
|
assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup."
|
||||||
|
return self._mem_stats_collector._memstats
|
||||||
|
|
||||||
def pre_iter(self, *args):
|
def pre_iter(self, *args):
|
||||||
if self._mem_stats_collector and self._warmup:
|
if self._mem_stats_collector and self._warmup:
|
||||||
self._mem_stats_collector.start_collection()
|
self._mem_stats_collector.start_collection()
|
||||||
|
|
|
@ -23,6 +23,12 @@ class MemStats(object):
|
||||||
|
|
||||||
self._param_runtime_order = OrderedParamGenerator()
|
self._param_runtime_order = OrderedParamGenerator()
|
||||||
|
|
||||||
|
def param_order(self):
|
||||||
|
if self._param_runtime_order.is_empty():
|
||||||
|
raise RuntimeError
|
||||||
|
else:
|
||||||
|
return self._param_runtime_order
|
||||||
|
|
||||||
def append_overall_data(self, device_type: str, val: float):
|
def append_overall_data(self, device_type: str, val: float):
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
self._overall_cuda_list.append(val)
|
self._overall_cuda_list.append(val)
|
||||||
|
|
|
@ -37,7 +37,7 @@ class MemStatsCollector:
|
||||||
self._memstats = MemStats()
|
self._memstats = MemStats()
|
||||||
|
|
||||||
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
||||||
"""Get max non model data memory usage of current sampling period
|
"""Maximum non model data memory usage during the next Op run
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device_type (str): device type, can be 'cpu' or 'cuda'.
|
device_type (str): device type, can be 'cpu' or 'cuda'.
|
||||||
|
@ -47,6 +47,9 @@ class MemStatsCollector:
|
||||||
"""
|
"""
|
||||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||||
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
||||||
|
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
|
||||||
|
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
|
||||||
|
f"step total {self._step_total}"
|
||||||
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
|
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
|
||||||
self._step_idx = (self._step_idx + 1) % self._step_total
|
self._step_idx = (self._step_idx + 1) % self._step_total
|
||||||
return next_non_model_data
|
return next_non_model_data
|
||||||
|
@ -61,7 +64,8 @@ class MemStatsCollector:
|
||||||
|
|
||||||
def finish_collection(self):
|
def finish_collection(self):
|
||||||
self.sample_overall_data()
|
self.sample_overall_data()
|
||||||
self._step_total = len(self._sampling_time)
|
# self._step_total = len(self._sampling_time)
|
||||||
|
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._mem_monitor.finish()
|
self._mem_monitor.finish()
|
||||||
|
|
||||||
|
|
|
@ -35,5 +35,8 @@ class OrderedParamGenerator(ParamGenerator):
|
||||||
visited_set.add(p)
|
visited_set.add(p)
|
||||||
del visited_set
|
del visited_set
|
||||||
|
|
||||||
|
def is_empty(self):
|
||||||
|
return len(self.param_visited_order) > 0
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.param_visited_order = []
|
self.param_visited_order = []
|
||||||
|
|
|
@ -8,6 +8,7 @@ import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
|
@ -216,8 +217,18 @@ class ZeroDDP(ColoDDP):
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||||
|
|
||||||
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
||||||
# TODO: get param order and filter unused params
|
|
||||||
for p in module.parameters():
|
if self.gemini_manager._premade_memstats_:
|
||||||
|
# build chunk in param runtime visited order.
|
||||||
|
param_order = self.gemini_manager.memstats()._param_runtime_order
|
||||||
|
else:
|
||||||
|
# build chunk in param initialized order.
|
||||||
|
# Note: in this way, it can not get filter unused params during runtime.
|
||||||
|
param_order = OrderedParamGenerator()
|
||||||
|
for p in module.parameters():
|
||||||
|
param_order.append(p)
|
||||||
|
|
||||||
|
for p in param_order.generate():
|
||||||
assert isinstance(p, ColoParameter)
|
assert isinstance(p, ColoParameter)
|
||||||
|
|
||||||
if getattr(p, '_ddp_to_ignore', False):
|
if getattr(p, '_ddp_to_ignore', False):
|
||||||
|
@ -243,7 +254,7 @@ class ZeroDDP(ColoDDP):
|
||||||
self.chunk_manager.close_all_groups()
|
self.chunk_manager.close_all_groups()
|
||||||
self._cast_buffers()
|
self._cast_buffers()
|
||||||
|
|
||||||
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
|
||||||
for p, fp32_p in zip(params_list, self.fp32_params):
|
for p, fp32_p in zip(params_list, self.fp32_params):
|
||||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
|
|
||||||
from colossalai.gemini.chunk import init_chunk_manager
|
from colossalai.gemini.chunk import init_chunk_manager
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
from colossalai.gemini.memory_tracer import MemStats
|
||||||
|
|
||||||
from .data_parallel import ZeroDDP
|
from .data_parallel import ZeroDDP
|
||||||
|
|
||||||
|
@ -18,7 +19,8 @@ class GeminiDDP(ZeroDDP):
|
||||||
force_outputs_fp32: bool = False,
|
force_outputs_fp32: bool = False,
|
||||||
search_range_mb: int = 32,
|
search_range_mb: int = 32,
|
||||||
hidden_dim: Optional[int] = None,
|
hidden_dim: Optional[int] = None,
|
||||||
min_chunk_size_mb: Optional[float] = None) -> None:
|
min_chunk_size_mb: Optional[float] = None,
|
||||||
|
memstats: Optional[MemStats] = None) -> None:
|
||||||
"""
|
"""
|
||||||
A torch.Module warpper using ZeRO-DP and Genimi.
|
A torch.Module warpper using ZeRO-DP and Genimi.
|
||||||
ZeRO is for parallel. Gemini is for memory management.
|
ZeRO is for parallel. Gemini is for memory management.
|
||||||
|
@ -44,11 +46,12 @@ class GeminiDDP(ZeroDDP):
|
||||||
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
||||||
If the aggregate size of parameters is still samller than the minimum chunk size,
|
If the aggregate size of parameters is still samller than the minimum chunk size,
|
||||||
all parameters will be compacted into one small chunk.
|
all parameters will be compacted into one small chunk.
|
||||||
|
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||||
"""
|
"""
|
||||||
chunk_manager = init_chunk_manager(model=module,
|
chunk_manager = init_chunk_manager(model=module,
|
||||||
init_device=device,
|
init_device=device,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
search_range_mb=search_range_mb,
|
search_range_mb=search_range_mb,
|
||||||
min_chunk_size_mb=min_chunk_size_mb)
|
min_chunk_size_mb=min_chunk_size_mb)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
||||||
|
|
|
@ -8,7 +8,8 @@ import colossalai
|
||||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||||
|
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
|
||||||
from colossalai.tensor import ProcessGroup
|
from colossalai.tensor import ProcessGroup
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
@ -44,29 +45,27 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||||
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
|
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
|
||||||
memstats = runtime_mem_tracer.memstats()
|
memstats = runtime_mem_tracer.memstats()
|
||||||
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
||||||
print('runtime tracer: ', runtime_tracer_non_model_data)
|
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats)
|
||||||
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
|
||||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
|
||||||
chunk_manager = ChunkManager(config_dict)
|
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
|
||||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
|
||||||
|
|
||||||
pg = ProcessGroup()
|
pg = ProcessGroup()
|
||||||
set_seed(pg.dp_local_rank())
|
set_seed(pg.dp_local_rank())
|
||||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||||
# you can only test a single fwd + bwd.
|
# you can only test a single fwd + bwd.
|
||||||
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
|
||||||
if i > 1:
|
# print(f'iteration {i}')
|
||||||
|
if i > 4:
|
||||||
break
|
break
|
||||||
input_ids, label = input_ids.cuda(), label.cuda()
|
input_ids, label = input_ids.cuda(), label.cuda()
|
||||||
|
|
||||||
|
zero_optim.zero_grad()
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
loss = run_fwd_bwd(model, input_ids, label, criterion, model)
|
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||||
|
zero_optim.step()
|
||||||
|
|
||||||
gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda')
|
||||||
|
|
||||||
# print('gemini non model data:', gemini_non_model_data)
|
# print('gemini non model data:', gemini_non_model_data)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from time import time
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
Loading…
Reference in New Issue