mirror of https://github.com/hpcaitech/ColossalAI
[gemini] gemini mgr supports "cpu" placement policy (#1118)
* update gemini mgr * update chunk * add docstr * polish placement policy * update test chunk * update test zero * polish unit test * remove useless unit testpull/1120/head
parent
f99f56dff4
commit
7d14b473f0
|
@ -1,4 +1,4 @@
|
||||||
import functools
|
import torch
|
||||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from time import time
|
from time import time
|
||||||
|
@ -15,8 +15,6 @@ class GeminiManager:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
||||||
# TODO: remove assert
|
|
||||||
assert placement_policy == 'cuda', 'placement_policy can only be "cuda" now'
|
|
||||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
||||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
|
@ -111,3 +109,7 @@ class GeminiManager:
|
||||||
@property
|
@property
|
||||||
def is_cuda_margin_mem_avail(self) -> bool:
|
def is_cuda_margin_mem_avail(self) -> bool:
|
||||||
return self._placement_policy.need_mem_stats
|
return self._placement_policy.need_mem_stats
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_device(policy_name: str) -> torch.device:
|
||||||
|
return PlacementPolicyFactory.get_default_device(policy_name)
|
||||||
|
|
|
@ -34,10 +34,11 @@ class CPUPlacementPolicy(PlacementPolicy):
|
||||||
|
|
||||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
|
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
|
||||||
volume = 0
|
volume = 0
|
||||||
|
start = time()
|
||||||
for chunk in can_evict_chunks:
|
for chunk in can_evict_chunks:
|
||||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
|
||||||
volume += chunk.mem
|
volume += chunk.mem
|
||||||
return volume, 0
|
return volume, time() - start
|
||||||
|
|
||||||
|
|
||||||
class CUDAPlacementPolicy(PlacementPolicy):
|
class CUDAPlacementPolicy(PlacementPolicy):
|
||||||
|
@ -115,7 +116,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||||
break
|
break
|
||||||
freed_cuda_model_data += chunk.mem
|
freed_cuda_model_data += chunk.mem
|
||||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
|
||||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||||
|
|
|
@ -100,6 +100,8 @@ class ColoDDPV2(ColoDDP):
|
||||||
self.fp32_params = []
|
self.fp32_params = []
|
||||||
self.overflow_counter = 0
|
self.overflow_counter = 0
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||||
|
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
|
||||||
|
self.chunk_manager.create_group('fp32_param')
|
||||||
# TODO: get param order and filter unused params
|
# TODO: get param order and filter unused params
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
assert p.dtype == torch.half
|
assert p.dtype == torch.half
|
||||||
|
|
|
@ -36,8 +36,21 @@ class ChunkFullError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Chunk:
|
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||||
|
return tensor.storage().size() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def free_storage(tensor: torch.Tensor) -> None:
|
||||||
|
if not is_storage_empty(tensor):
|
||||||
|
tensor.storage().resize_(0)
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||||
|
if is_storage_empty(tensor):
|
||||||
|
tensor.storage().resize_(tensor.numel())
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk:
|
||||||
"""
|
"""
|
||||||
A chunk is a contiguous memory space which contains multiple tensors.
|
A chunk is a contiguous memory space which contains multiple tensors.
|
||||||
|
|
||||||
|
@ -46,26 +59,37 @@ class Chunk:
|
||||||
src_rank (int): the process which owns the chunk
|
src_rank (int): the process which owns the chunk
|
||||||
dtype (torch.dtype): the data type of the chunk
|
dtype (torch.dtype): the data type of the chunk
|
||||||
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
|
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
|
||||||
|
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
src_rank: int,
|
src_rank: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
init_device: Optional[torch.device] = None) -> None:
|
init_device: Optional[torch.device] = None,
|
||||||
|
force_data_on_cuda: bool = False) -> None:
|
||||||
self.size = chunk_size
|
self.size = chunk_size
|
||||||
self.utilized_size = 0
|
self.utilized_size = 0
|
||||||
self.src_rank = src_rank
|
self.src_rank = src_rank
|
||||||
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
||||||
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
|
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = init_device or get_current_device()
|
device = init_device or get_current_device()
|
||||||
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
|
if force_data_on_cuda:
|
||||||
|
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
|
||||||
|
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
free_storage(self._cpu_data)
|
||||||
|
else:
|
||||||
|
free_storage(self.data)
|
||||||
|
else:
|
||||||
|
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
|
||||||
|
self._cpu_data = None
|
||||||
|
|
||||||
# we only keep the chunk in full in the process by which the tensor is owned
|
# we only keep the chunk in full in the process by which the tensor is owned
|
||||||
if not self.is_src_rank:
|
if not self.is_src_rank:
|
||||||
self.data.storage().resize_(0)
|
free_storage(self._payload)
|
||||||
|
|
||||||
# each tensor is associated with a TensorInfo to track meta info
|
# each tensor is associated with a TensorInfo to track meta info
|
||||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||||
self.mem = self.size * self.data.element_size()
|
self.mem = self.size * self.data.element_size()
|
||||||
|
@ -83,16 +107,16 @@ class Chunk:
|
||||||
# raise exception when the chunk size is exceeded
|
# raise exception when the chunk size is exceeded
|
||||||
if new_utilized_size > self.size:
|
if new_utilized_size > self.size:
|
||||||
raise ChunkFullError
|
raise ChunkFullError
|
||||||
|
|
||||||
# set tensor state
|
# set tensor state
|
||||||
tensor_state = TensorState.FREE
|
tensor_state = TensorState.FREE
|
||||||
|
|
||||||
# if the process owns the rank, then copy the tensor to its chunk buffer
|
# if the process owns the rank, then copy the tensor to its chunk buffer
|
||||||
# otherwise set its storage size to 0 to reduce memory consumption
|
# otherwise set its storage size to 0 to reduce memory consumption
|
||||||
if self.is_src_rank:
|
if self.is_src_rank:
|
||||||
self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
|
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
|
||||||
tensor_state = TensorState.HOLD
|
tensor_state = TensorState.HOLD
|
||||||
tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape)
|
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||||
else:
|
else:
|
||||||
tensor.storage().resize_(0)
|
tensor.storage().resize_(0)
|
||||||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
||||||
|
@ -103,12 +127,12 @@ class Chunk:
|
||||||
Release the memory space on processes which do not own the chunk.
|
Release the memory space on processes which do not own the chunk.
|
||||||
"""
|
"""
|
||||||
if not self.is_src_rank:
|
if not self.is_src_rank:
|
||||||
self.data.storage().resize_(0)
|
free_storage(self._payload)
|
||||||
self._update_tensors_state(TensorState.FREE)
|
self._update_tensors_state(TensorState.FREE)
|
||||||
|
|
||||||
def _update_tensors_ptr(self) -> None:
|
def _update_tensors_ptr(self) -> None:
|
||||||
for tensor, tensor_info in self.tensors_info.items():
|
for tensor, tensor_info in self.tensors_info.items():
|
||||||
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||||
for tensor_info in self.tensors_info.values():
|
for tensor_info in self.tensors_info.values():
|
||||||
|
@ -122,8 +146,8 @@ class Chunk:
|
||||||
# recover the chunk on non-owner processes
|
# recover the chunk on non-owner processes
|
||||||
# and broadcast the chunk from the source to all processes
|
# and broadcast the chunk from the source to all processes
|
||||||
if not self.is_src_rank:
|
if not self.is_src_rank:
|
||||||
self.data.storage().resize_(self.size)
|
alloc_storage(self._payload)
|
||||||
self.data.data = self.data.to(get_current_device())
|
self.move_device(get_current_device(), update_ptr=False)
|
||||||
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
||||||
# update tensor meta info
|
# update tensor meta info
|
||||||
|
@ -131,15 +155,32 @@ class Chunk:
|
||||||
if not self.is_src_rank:
|
if not self.is_src_rank:
|
||||||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
||||||
|
|
||||||
def move_device(self, device: torch.device) -> None:
|
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
Move the chunk to a target device.
|
Move the chunk to a target device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (torch.device): the target device for data movement.
|
device (torch.device): the target device for data movement.
|
||||||
"""
|
"""
|
||||||
self.data.data = self.data.to(device)
|
if self._payload.device == device:
|
||||||
self._update_tensors_ptr()
|
return
|
||||||
|
if self._cpu_data is None:
|
||||||
|
self.data.data = self.data.to(device)
|
||||||
|
else:
|
||||||
|
if device.type == 'cuda':
|
||||||
|
# cpu -> cuda
|
||||||
|
src = self._cpu_data
|
||||||
|
dest = self.data
|
||||||
|
else:
|
||||||
|
# cuda -> cpu
|
||||||
|
src = self.data
|
||||||
|
dest = self._cpu_data
|
||||||
|
alloc_storage(dest)
|
||||||
|
dest.copy_(src)
|
||||||
|
free_storage(src)
|
||||||
|
|
||||||
|
if update_ptr:
|
||||||
|
self._update_tensors_ptr()
|
||||||
|
|
||||||
def reduce(self, is_all_reduce: bool = False) -> None:
|
def reduce(self, is_all_reduce: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -148,7 +189,7 @@ class Chunk:
|
||||||
Args:
|
Args:
|
||||||
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
|
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
|
||||||
"""
|
"""
|
||||||
self.data.data = self.data.to(get_current_device())
|
self.move_device(get_current_device(), update_ptr=False)
|
||||||
if is_all_reduce:
|
if is_all_reduce:
|
||||||
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
||||||
else:
|
else:
|
||||||
|
@ -187,8 +228,8 @@ class Chunk:
|
||||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||||
"""
|
"""
|
||||||
tensor_info = self.tensors_info[tensor]
|
tensor_info = self.tensors_info[tensor]
|
||||||
self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
|
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
|
||||||
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_release(self) -> bool:
|
def can_release(self) -> bool:
|
||||||
|
@ -225,7 +266,7 @@ class Chunk:
|
||||||
"""
|
"""
|
||||||
Check whether the chunk is empty.
|
Check whether the chunk is empty.
|
||||||
"""
|
"""
|
||||||
return self.data.storage().size() == 0
|
return is_storage_empty(self._payload)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
||||||
|
@ -235,8 +276,8 @@ class Chunk:
|
||||||
"""
|
"""
|
||||||
Check if the chunk has inf or nan values.
|
Check if the chunk has inf or nan values.
|
||||||
"""
|
"""
|
||||||
return torch.isinf(self.data[:self.utilized_size]).any().item() or \
|
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
|
||||||
torch.isnan(self.data[:self.utilized_size]).any().item()
|
torch.isnan(self._payload[:self.utilized_size]).any().item()
|
||||||
|
|
||||||
def copy_(self, dest_chunk: 'Chunk'):
|
def copy_(self, dest_chunk: 'Chunk'):
|
||||||
"""
|
"""
|
||||||
|
@ -246,7 +287,7 @@ class Chunk:
|
||||||
assert not dest_chunk.is_empty
|
assert not dest_chunk.is_empty
|
||||||
assert self.size == dest_chunk.size
|
assert self.size == dest_chunk.size
|
||||||
assert self.utilized_size == dest_chunk.utilized_size
|
assert self.utilized_size == dest_chunk.utilized_size
|
||||||
self.data.copy_(dest_chunk.data)
|
self._payload.copy_(dest_chunk._payload)
|
||||||
self._update_tensors_ptr()
|
self._update_tensors_ptr()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -254,7 +295,7 @@ class Chunk:
|
||||||
"""
|
"""
|
||||||
Get the device type of the chunk.
|
Get the device type of the chunk.
|
||||||
"""
|
"""
|
||||||
return self.data.device.type
|
return self._payload.device.type
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(id(self))
|
return hash(id(self))
|
||||||
|
@ -265,6 +306,12 @@ class Chunk:
|
||||||
def get_tensors(self) -> List[torch.Tensor]:
|
def get_tensors(self) -> List[torch.Tensor]:
|
||||||
return list(self.tensors_info.keys())
|
return list(self.tensors_info.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _payload(self) -> torch.Tensor:
|
||||||
|
if self._cpu_data is None or is_storage_empty(self._cpu_data):
|
||||||
|
return self.data
|
||||||
|
return self._cpu_data
|
||||||
|
|
||||||
|
|
||||||
class ChunkManager:
|
class ChunkManager:
|
||||||
"""
|
"""
|
||||||
|
@ -285,6 +332,7 @@ class ChunkManager:
|
||||||
self.enable_distributed_storage = enable_distributed_storage
|
self.enable_distributed_storage = enable_distributed_storage
|
||||||
self.device = init_device or get_current_device()
|
self.device = init_device or get_current_device()
|
||||||
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
|
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
|
||||||
|
self.groups_force_data_on_cuda: Dict[str, bool] = {}
|
||||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
|
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
|
||||||
self.accessed_chunks: Set[Chunk] = set()
|
self.accessed_chunks: Set[Chunk] = set()
|
||||||
self.lazy_release_tensors: List[torch.Tensor] = []
|
self.lazy_release_tensors: List[torch.Tensor] = []
|
||||||
|
@ -292,6 +340,17 @@ class ChunkManager:
|
||||||
self.rank_load: Dict[str, torch.Tensor] = {}
|
self.rank_load: Dict[str, torch.Tensor] = {}
|
||||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||||
|
|
||||||
|
def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None:
|
||||||
|
"""Create a chunk group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_name (str): group name
|
||||||
|
force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False.
|
||||||
|
"""
|
||||||
|
assert group_name not in self.chunk_groups
|
||||||
|
self.chunk_groups[group_name] = deque()
|
||||||
|
self.groups_force_data_on_cuda[group_name] = force_data_on_cuda
|
||||||
|
|
||||||
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
|
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Append a tensor to a chunk.
|
Append a tensor to a chunk.
|
||||||
|
@ -304,19 +363,20 @@ class ChunkManager:
|
||||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
||||||
if group_name not in self.chunk_groups:
|
|
||||||
self.chunk_groups[group_name] = deque()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# append the tensor to the last chunk
|
# append the tensor to the last chunk
|
||||||
self.chunk_groups[group_name][-1].append(tensor)
|
self.chunk_groups[group_name][-1].append(tensor)
|
||||||
except (IndexError, ChunkFullError):
|
except (IndexError, ChunkFullError):
|
||||||
# the except statement will be triggered when there is no chunk or
|
# the except statement will be triggered when there is no chunk or
|
||||||
# the last chunk in the chunk group is full
|
# the last chunk in the chunk group is full
|
||||||
# this will create a new chunk and allocate this chunk to its corresponding process
|
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||||
chunk_size = self.chunk_size or tensor.numel()
|
chunk_size = self.chunk_size or tensor.numel()
|
||||||
src_rank = self._get_next_src_rank(group_name)
|
src_rank = self._get_next_src_rank(group_name)
|
||||||
chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device)
|
chunk = Chunk(chunk_size,
|
||||||
|
src_rank,
|
||||||
|
tensor.dtype,
|
||||||
|
self.device,
|
||||||
|
force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
|
||||||
|
|
||||||
if self.enable_distributed_storage and self.chunk_size is None:
|
if self.enable_distributed_storage and self.chunk_size is None:
|
||||||
self.rank_load[group_name][src_rank] += chunk_size
|
self.rank_load[group_name][src_rank] += chunk_size
|
||||||
|
@ -387,7 +447,7 @@ class ChunkManager:
|
||||||
# update the memory consumption after releasing
|
# update the memory consumption after releasing
|
||||||
self.total_mem[chunk.device_type] -= chunk.mem
|
self.total_mem[chunk.device_type] -= chunk.mem
|
||||||
|
|
||||||
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
|
def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
Move the chunk to the target device.
|
Move the chunk to the target device.
|
||||||
|
|
||||||
|
@ -399,7 +459,7 @@ class ChunkManager:
|
||||||
return
|
return
|
||||||
if chunk.can_move_device and not chunk.is_empty:
|
if chunk.can_move_device and not chunk.is_empty:
|
||||||
self.total_mem[chunk.device_type] -= chunk.mem
|
self.total_mem[chunk.device_type] -= chunk.mem
|
||||||
chunk.move_device(device)
|
chunk.move_device(device, update_ptr=update_ptr)
|
||||||
self.total_mem[chunk.device_type] += chunk.mem
|
self.total_mem[chunk.device_type] += chunk.mem
|
||||||
|
|
||||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||||
|
|
|
@ -44,6 +44,7 @@ def run_chunk_zero(use_chunk, use_zero):
|
||||||
params = [torch.rand(8, 8) for _ in range(3)]
|
params = [torch.rand(8, 8) for _ in range(3)]
|
||||||
chunk_size = 128 if use_chunk else None
|
chunk_size = 128 if use_chunk else None
|
||||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||||
|
chunk_manager.create_group('param')
|
||||||
assert chunk_manager.total_mem['cpu'] == 0
|
assert chunk_manager.total_mem['cpu'] == 0
|
||||||
assert chunk_manager.total_mem['cuda'] == 0
|
assert chunk_manager.total_mem['cuda'] == 0
|
||||||
for p in params:
|
for p in params:
|
||||||
|
|
|
@ -1,82 +0,0 @@
|
||||||
import pytest
|
|
||||||
import colossalai
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
||||||
from colossalai.tensor import ChunkManager
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from functools import partial
|
|
||||||
from _utils import tensor_equal, set_seed
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from colossalai.nn.parallel import ColoDDPV2
|
|
||||||
from colossalai.testing import parameterize
|
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
|
||||||
|
|
||||||
|
|
||||||
def check_param_equal(model, torch_model):
|
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
|
||||||
if p.storage().size() > 0:
|
|
||||||
assert tensor_equal(torch_p, p.float()), f'{torch_p} vs {p}'
|
|
||||||
|
|
||||||
|
|
||||||
def check_grad_equal(model, torch_model):
|
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
|
||||||
if p.grad is not None:
|
|
||||||
assert tensor_equal(torch_p.grad, p.grad.float())
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize('use_chunk', [False, True])
|
|
||||||
@parameterize('use_zero', [False, True])
|
|
||||||
def run_gpt(use_chunk, use_zero):
|
|
||||||
set_seed(42)
|
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
|
||||||
|
|
||||||
with ColoInitContext(device=get_current_device()):
|
|
||||||
model = model_builder(checkpoint=True)
|
|
||||||
model = model.cuda()
|
|
||||||
torch_model = model_builder().cuda()
|
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
|
||||||
torch_p.data.copy_(p)
|
|
||||||
model = model.half()
|
|
||||||
chunk_size = 38 * 1024**2 if use_chunk else None
|
|
||||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
|
||||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
|
||||||
model = ColoDDPV2(model, gemini_manager)
|
|
||||||
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
|
|
||||||
print(chunk_manager)
|
|
||||||
check_param_equal(model, torch_model)
|
|
||||||
model.train()
|
|
||||||
torch_model.train()
|
|
||||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
|
||||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
|
||||||
logits = model(input_ids, attn_mask)
|
|
||||||
torch_logits = torch_model(input_ids, attn_mask)
|
|
||||||
assert tensor_equal(torch_logits, logits.float())
|
|
||||||
loss = criterion(logits, input_ids)
|
|
||||||
torch_loss = criterion(torch_logits, input_ids)
|
|
||||||
model.backward(loss)
|
|
||||||
torch_loss.backward()
|
|
||||||
check_grad_equal(model, torch_model)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
run_gpt()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_gpt(world_size):
|
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_gpt(4)
|
|
|
@ -25,22 +25,28 @@ def check_param_equal(model, torch_model):
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
if p.storage().size() > 0:
|
if p.storage().size() > 0:
|
||||||
assert p.dtype == torch.half
|
assert p.dtype == torch.half
|
||||||
assert tensor_equal(torch_p, p), f'{torch_p} vs {p}'
|
assert tensor_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}'
|
||||||
|
|
||||||
|
|
||||||
def run_step(model, criterion, optimizer, input_ids, attn_mask):
|
def check_grad_equal(model, torch_model):
|
||||||
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
|
if p.grad is not None:
|
||||||
|
assert tensor_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
logits = model(input_ids, attn_mask)
|
logits = model(input_ids, attn_mask)
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
loss = criterion(logits, input_ids)
|
loss = criterion(logits, input_ids)
|
||||||
optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
optimizer.step()
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@parameterize('use_chunk', [False, True])
|
@parameterize('use_chunk', [False, True])
|
||||||
@parameterize('use_zero', [False, True])
|
@parameterize('use_zero', [False, True])
|
||||||
def run_gpt(use_chunk, use_zero):
|
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||||
|
def run_gpt(use_chunk, use_zero, placement_policy):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
@ -52,9 +58,11 @@ def run_gpt(use_chunk, use_zero):
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
torch_p.data.copy_(p)
|
torch_p.data.copy_(p)
|
||||||
|
|
||||||
chunk_size = 38 * 1024**2 if use_chunk else None
|
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
||||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
chunk_manager = ChunkManager(chunk_size,
|
||||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
enable_distributed_storage=use_zero,
|
||||||
|
init_device=GeminiManager.get_default_device(placement_policy))
|
||||||
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
model = ColoDDPV2(model, gemini_manager)
|
model = ColoDDPV2(model, gemini_manager)
|
||||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||||
|
@ -64,7 +72,7 @@ def run_gpt(use_chunk, use_zero):
|
||||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||||
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
|
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
||||||
# print(chunk_manager)
|
print(chunk_manager)
|
||||||
check_param_equal(model, torch_model)
|
check_param_equal(model, torch_model)
|
||||||
model.train()
|
model.train()
|
||||||
torch_model.train()
|
torch_model.train()
|
||||||
|
@ -72,9 +80,12 @@ def run_gpt(use_chunk, use_zero):
|
||||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
break
|
break
|
||||||
logits = run_step(model, criterion, optim, input_ids, attn_mask)
|
logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask)
|
||||||
torch_logits = run_step(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||||
assert tensor_equal(logits, torch_logits)
|
assert tensor_equal(logits, torch_logits)
|
||||||
|
check_grad_equal(model, torch_model)
|
||||||
|
optim.step()
|
||||||
|
torch_optim.step()
|
||||||
check_param_equal(model, torch_model)
|
check_param_equal(model, torch_model)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue