mirror of https://github.com/hpcaitech/ColossalAI
[gemini] accelerate adjust_layout() (#878)
* add lru cache * polish code * update unit test * fix sharded optimpull/879/head^2
parent
909211453b
commit
c4d903e64a
|
@ -6,6 +6,8 @@ from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, c
|
|||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
|
||||
from typing import List
|
||||
from colossalai.logging import get_dist_logger
|
||||
from time import time
|
||||
|
||||
|
||||
class StatefulTensorMgr(object):
|
||||
|
@ -24,6 +26,8 @@ class StatefulTensorMgr(object):
|
|||
self._compute_idx: int = -1
|
||||
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
self._warmup = True
|
||||
|
||||
def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None:
|
||||
|
@ -42,6 +46,8 @@ class StatefulTensorMgr(object):
|
|||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
|
||||
def adjust_layout(self) -> None:
|
||||
""" Adjust the layout of statefuil tensor according to the information provided
|
||||
|
@ -49,25 +55,16 @@ class StatefulTensorMgr(object):
|
|||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE]
|
||||
move_to_cuda_tensor_list = []
|
||||
hold_cuda_tensor_list = []
|
||||
for tensor in self._stateful_tensor_list:
|
||||
if tensor.state == TensorState.FREE:
|
||||
continue
|
||||
|
||||
if tensor.device.type == 'cuda':
|
||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||
hold_cuda_tensor_list.append(tensor)
|
||||
elif tensor.device.type == 'cpu':
|
||||
if tensor.state == TensorState.COMPUTE:
|
||||
move_to_cuda_tensor_list.append(tensor)
|
||||
else:
|
||||
raise RuntimeError
|
||||
self._cpu_gpu_move_volume += self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
start = time()
|
||||
move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup)
|
||||
self._layout_time += time() - start
|
||||
vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
self._cpu_gpu_move_volume += vol
|
||||
self._evict_time += evict_time
|
||||
# move COMPUTE tensors to CUDA
|
||||
self._cpu_gpu_move_volume += cuda_demand
|
||||
for t in move_to_cuda_tensor_list:
|
||||
|
@ -83,3 +80,21 @@ class StatefulTensorMgr(object):
|
|||
self._compute_idx += 1
|
||||
if self._warmup:
|
||||
self._compute_list.append(stateful_tensor)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool):
|
||||
move_to_cuda_tensor_list = []
|
||||
hold_cuda_tensor_list = []
|
||||
for tensor in self._stateful_tensor_list:
|
||||
if tensor.state == TensorState.FREE:
|
||||
continue
|
||||
|
||||
if tensor.device.type == 'cuda':
|
||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||
hold_cuda_tensor_list.append(tensor)
|
||||
elif tensor.device.type == 'cpu':
|
||||
if tensor.state == TensorState.COMPUTE:
|
||||
move_to_cuda_tensor_list.append(tensor)
|
||||
else:
|
||||
raise RuntimeError
|
||||
return move_to_cuda_tensor_list, hold_cuda_tensor_list
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -8,6 +9,7 @@ from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, c
|
|||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
from typing import Type
|
||||
import functools
|
||||
|
||||
|
||||
class TensorPlacementPolicy(ABC):
|
||||
|
@ -31,7 +33,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
|
|||
for t in hold_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, self.device)
|
||||
volume += t.payload.numel() * t.payload.element_size()
|
||||
return volume
|
||||
return volume, 0
|
||||
|
||||
|
||||
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
@ -41,7 +43,7 @@ class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
|||
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||
return 0
|
||||
return 0, 0
|
||||
|
||||
|
||||
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
@ -51,7 +53,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# TODO(ver217): make these args configurable
|
||||
self._warmup_non_model_data_ratio: float = 0.8
|
||||
self._steady_cuda_cap_ratio: float = 0.8
|
||||
self._steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def evict_tensors(self,
|
||||
hold_cuda_tensor_list: List[StatefulTensor],
|
||||
|
@ -76,6 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||
Returns:
|
||||
int: the volume of memory that is evicted
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
if warmup:
|
||||
|
@ -87,20 +90,18 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
|
||||
freed_cuda_model_data = 0
|
||||
end = time()
|
||||
if avail_cuda_model_data < cuda_demand:
|
||||
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
||||
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_tensor_list = hold_cuda_tensor_list
|
||||
if not warmup:
|
||||
next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensor_list}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
if compute_list[i] in next_compute_idx:
|
||||
next_compute_idx[compute_list[i]] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
||||
to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx,
|
||||
tuple(compute_list))
|
||||
# print(self._sort_hold_cuda_tensors.cache_info())
|
||||
end = time()
|
||||
for t in to_free_tensor_list:
|
||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||
break
|
||||
|
@ -110,8 +111,17 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
return freed_cuda_model_data, end - start
|
||||
|
||||
return freed_cuda_model_data
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_list: tuple) -> list:
|
||||
next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensors}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
if compute_list[i] in next_compute_idx:
|
||||
next_compute_idx[compute_list[i]] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
|
||||
class TensorPlacementPolicyFactory:
|
||||
|
|
|
@ -285,7 +285,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
|
||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
|
||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||
colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
|
||||
p.colo_attr.offload_grad = False
|
||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||
|
||||
|
@ -297,7 +297,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
|
||||
# If reuse_fp16_shard, grad fp16 which wasn't be offloaded may be evicted to CPU
|
||||
if not p.colo_attr.offload_grad:
|
||||
colo_model_data_tensor_move_inline(p.colo_attr.grad_payload, torch.cuda.current_device())
|
||||
colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
|
||||
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
||||
# If we change p.grad directly
|
||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||
|
|
|
@ -114,5 +114,6 @@ class ZeroHook(BaseOpHook):
|
|||
def post_iter(self):
|
||||
if self._stateful_tensor_mgr:
|
||||
self.logger.info(
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB", ranks=[0])
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}",
|
||||
ranks=[0])
|
||||
self._stateful_tensor_mgr.finish_iter()
|
||||
|
|
|
@ -45,8 +45,8 @@ def run_stm():
|
|||
mem_collector = MemStatsCollector()
|
||||
tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector)
|
||||
stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy)
|
||||
for p in model.parameters():
|
||||
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
|
||||
stateful_tensors = [p.colo_attr.sharded_data_tensor for p in model.parameters()]
|
||||
stateful_tensor_mgr.register_stateful_tensor_list(stateful_tensors)
|
||||
|
||||
mem_collector.start_collection()
|
||||
# Compute order: 0 1 2 0 1
|
||||
|
@ -67,7 +67,7 @@ def run_stm():
|
|||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_model_data()
|
||||
mem_collector.finish_collection()
|
||||
stateful_tensor_mgr.reset()
|
||||
stateful_tensor_mgr.finish_iter()
|
||||
|
||||
# warmup done
|
||||
# only 2 params can be on CUDA
|
||||
|
|
Loading…
Reference in New Issue