diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py index b3ea38935..746b3e02a 100644 --- a/colossalai/gemini/__init__.py +++ b/colossalai/gemini/__init__.py @@ -1,5 +1,10 @@ +from .chunk import TensorInfo, Chunk, TensorState +from .chunk_mgr import ChunkManager from .stateful_tensor_mgr import StatefulTensorMgr from .tensor_placement_policy import TensorPlacementPolicyFactory from .gemini_mgr import GeminiManager -__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager'] +__all__ = [ + 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk', + 'TensorState' +] diff --git a/colossalai/gemini/chunk.py b/colossalai/gemini/chunk.py new file mode 100644 index 000000000..a5a7ae027 --- /dev/null +++ b/colossalai/gemini/chunk.py @@ -0,0 +1,315 @@ +import torch +import torch.distributed as dist +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Dict, List + +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.utils import get_current_device + + +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, + TensorState.HOLD)) + + +@dataclass +class TensorInfo: + state: TensorState + offset: int + end: int + + +class ChunkFullError(Exception): + pass + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +class Chunk: + """ + A chunk is a contiguous memory space which contains multiple tensors. + + Args: + chunk_size (int): the number of elements in a chunk + src_rank (int): the process which owns the chunk + dtype (torch.dtype): the data type of the chunk + init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU. + force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False. + """ + + def __init__(self, + chunk_size: int, + src_rank: int, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + force_data_on_cuda: bool = False) -> None: + self.size = chunk_size + self.utilized_size = 0 + self.src_rank = src_rank + self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank + self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank] + self.dtype = dtype + device = init_device or get_current_device() + if force_data_on_cuda: + self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device()) + self._cpu_data = torch.empty(chunk_size, dtype=dtype) + if device.type == 'cuda': + free_storage(self._cpu_data) + else: + free_storage(self.data) + else: + self.data = torch.empty(chunk_size, dtype=dtype, device=device) + self._cpu_data = None + + # we only keep the chunk in full in the process by which the tensor is owned + if not self.is_src_rank: + free_storage(self._payload) + + # each tensor is associated with a TensorInfo to track meta info + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + self.mem = self.size * self.data.element_size() + + def append(self, tensor: torch.Tensor) -> None: + """ + Add a tensor to the chunk. + + Args: + tensor (torch.Tensor): a tensor to be added to the chunk + """ + assert tensor.dtype == self.dtype + new_utilized_size = self.utilized_size + tensor.numel() + + # raise exception when the chunk size is exceeded + if new_utilized_size > self.size: + raise ChunkFullError + + # set tensor state + tensor_state = TensorState.FREE + + # if the process owns the rank, then copy the tensor to its chunk buffer + # otherwise set its storage size to 0 to reduce memory consumption + if self.is_src_rank: + self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten()) + tensor_state = TensorState.HOLD + assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" + tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape) + else: + tensor.storage().resize_(0) + self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) + self.utilized_size = new_utilized_size + + def release(self) -> None: + """ + Release the memory space on processes which do not own the chunk. + """ + if not self.is_src_rank: + free_storage(self._payload) + self._update_tensors_state(TensorState.FREE) + + def _update_tensors_ptr(self) -> None: + assert type(self._payload) == torch.Tensor + for tensor, tensor_info in self.tensors_info.items(): + tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + tensor_info.state = next_state + + def access(self) -> None: + """ + Broadcast the chunk to synchronize the tensors across data parallel processes. + """ + # recover the chunk on non-owner processes + # and broadcast the chunk from the source to all processes + if not self.is_src_rank: + alloc_storage(self._payload) + self.move_device(get_current_device(), update_ptr=False) + dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) + + # update tensor meta info + self._update_tensors_ptr() + if not self.is_src_rank: + self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) + + def move_device(self, device: torch.device, update_ptr: bool = True) -> None: + """ + Move the chunk to a target device. + + Args: + device (torch.device): the target device for data movement. + """ + if self._payload.device == device: + return + if self._cpu_data is None: + self.data.data = self.data.to(device) + else: + if device.type == 'cuda': + # cpu -> cuda + src = self._cpu_data + dest = self.data + else: + # cuda -> cpu + src = self.data + dest = self._cpu_data + alloc_storage(dest) + dest.copy_(src) + free_storage(src) + + if update_ptr: + self._update_tensors_ptr() + + def reduce(self, is_all_reduce: bool = False) -> None: + """ + Reduce or all-reduce the chunk. + + Args: + is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false. + """ + self.move_device(get_current_device(), update_ptr=False) + if is_all_reduce: + dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) + else: + dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) + self._update_tensors_ptr() + self._update_tensors_state(TensorState.HOLD) + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + """ + Make a transition of the tensor into the next state. + + Args: + tensor (torch.Tensor): a torch Tensor object. + tensor_state (TensorState): the target state for transition. + """ + assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE' + # As the gradient hook can be triggered either before or after post-backward + # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce + # or compute -> ready_for_reduce -> hold_after_bwd + # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd + # this function only apply valid state transformation + # invalid calls will be ignored and nothing changes + if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + # print( + # f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' + # ) + return + self.tensors_info[tensor].state = tensor_state + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the chunk + """ + tensor_info = self.tensors_info[tensor] + self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) + tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) + + @property + def can_release(self) -> bool: + """ + Check whether the chunk can be released. + """ + for tensor_info in self.tensors_info.values(): + if tensor_info.state != TensorState.HOLD: + return False + return True + + @property + def can_move_device(self) -> bool: + """ + Check whether the chunk can be moved across devices. + """ + for tensor_info in self.tensors_info.values(): + if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE): + return False + return True + + @property + def can_reduce(self) -> bool: + """ + Check whether the chunk can be reduced. + """ + for tensor_info in self.tensors_info.values(): + if tensor_info.state != TensorState.READY_FOR_REDUCE: + return False + return True + + @property + def is_empty(self) -> bool: + """ + Check whether the chunk is empty. + """ + return is_storage_empty(self._payload) + + def __repr__(self) -> str: + return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}' + + @property + def has_inf_or_nan(self) -> bool: + """ + Check if the chunk has inf or nan values. + """ + return torch.isinf(self._payload[:self.utilized_size]).any().item() or \ + torch.isnan(self._payload[:self.utilized_size]).any().item() + + def copy_(self, dest_chunk: 'Chunk'): + """ + Copy the data of this chunk to a destination chunk. + """ + assert not self.is_empty + assert not dest_chunk.is_empty + assert self.size == dest_chunk.size + assert self.utilized_size == dest_chunk.utilized_size + self._payload.copy_(dest_chunk._payload) + self._update_tensors_ptr() + + @property + def device_type(self) -> str: + """ + Get the device type of the chunk. + """ + return self._payload.device.type + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, __o: object) -> bool: + return self is __o + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) + + @property + def _payload(self) -> torch.Tensor: + if self._cpu_data is None or is_storage_empty(self._cpu_data): + return self.data + return self._cpu_data diff --git a/colossalai/tensor/chunk.py b/colossalai/gemini/chunk_mgr.py similarity index 53% rename from colossalai/tensor/chunk.py rename to colossalai/gemini/chunk_mgr.py index b66612088..6d6c8b47f 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/gemini/chunk_mgr.py @@ -1,318 +1,11 @@ import torch -import torch.distributed as dist -from dataclasses import dataclass -from enum import Enum from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from collections import deque -from colossalai.core import global_context as gpc + from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.utils import get_current_device - - -class TensorState(Enum): - FREE = 0 - COMPUTE = 1 - HOLD = 2 - HOLD_AFTER_BWD = 3 - READY_FOR_REDUCE = 4 - - -STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), - (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), - (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, - TensorState.HOLD)) - - -@dataclass -class TensorInfo: - state: TensorState - offset: int - end: int - - -class ChunkFullError(Exception): - pass - - -def is_storage_empty(tensor: torch.Tensor) -> bool: - return tensor.storage().size() == 0 - - -def free_storage(tensor: torch.Tensor) -> None: - if not is_storage_empty(tensor): - tensor.storage().resize_(0) - - -def alloc_storage(tensor: torch.Tensor) -> None: - if is_storage_empty(tensor): - tensor.storage().resize_(tensor.numel()) - - -class Chunk: - """ - A chunk is a contiguous memory space which contains multiple tensors. - - Args: - chunk_size (int): the number of elements in a chunk - src_rank (int): the process which owns the chunk - dtype (torch.dtype): the data type of the chunk - init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU. - force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False. - """ - - def __init__(self, - chunk_size: int, - src_rank: int, - dtype: torch.dtype, - init_device: Optional[torch.device] = None, - force_data_on_cuda: bool = False) -> None: - self.size = chunk_size - self.utilized_size = 0 - self.src_rank = src_rank - self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank - self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank] - self.dtype = dtype - device = init_device or get_current_device() - if force_data_on_cuda: - self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device()) - self._cpu_data = torch.empty(chunk_size, dtype=dtype) - if device.type == 'cuda': - free_storage(self._cpu_data) - else: - free_storage(self.data) - else: - self.data = torch.empty(chunk_size, dtype=dtype, device=device) - self._cpu_data = None - - # we only keep the chunk in full in the process by which the tensor is owned - if not self.is_src_rank: - free_storage(self._payload) - - # each tensor is associated with a TensorInfo to track meta info - self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} - self.mem = self.size * self.data.element_size() - - def append(self, tensor: torch.Tensor) -> None: - """ - Add a tensor to the chunk. - - Args: - tensor (torch.Tensor): a tensor to be added to the chunk - """ - assert tensor.dtype == self.dtype - new_utilized_size = self.utilized_size + tensor.numel() - - # raise exception when the chunk size is exceeded - if new_utilized_size > self.size: - raise ChunkFullError - - # set tensor state - tensor_state = TensorState.FREE - - # if the process owns the rank, then copy the tensor to its chunk buffer - # otherwise set its storage size to 0 to reduce memory consumption - if self.is_src_rank: - self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten()) - tensor_state = TensorState.HOLD - assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" - tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape) - else: - tensor.storage().resize_(0) - self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) - self.utilized_size = new_utilized_size - - def release(self) -> None: - """ - Release the memory space on processes which do not own the chunk. - """ - if not self.is_src_rank: - free_storage(self._payload) - self._update_tensors_state(TensorState.FREE) - - def _update_tensors_ptr(self) -> None: - assert type(self._payload) == torch.Tensor - for tensor, tensor_info in self.tensors_info.items(): - tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) - - def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): - for tensor_info in self.tensors_info.values(): - if prev_state is None or tensor_info.state == prev_state: - tensor_info.state = next_state - - def access(self) -> None: - """ - Broadcast the chunk to synchronize the tensors across data parallel processes. - """ - # recover the chunk on non-owner processes - # and broadcast the chunk from the source to all processes - if not self.is_src_rank: - alloc_storage(self._payload) - self.move_device(get_current_device(), update_ptr=False) - dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) - - # update tensor meta info - self._update_tensors_ptr() - if not self.is_src_rank: - self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) - - def move_device(self, device: torch.device, update_ptr: bool = True) -> None: - """ - Move the chunk to a target device. - - Args: - device (torch.device): the target device for data movement. - """ - if self._payload.device == device: - return - if self._cpu_data is None: - self.data.data = self.data.to(device) - else: - if device.type == 'cuda': - # cpu -> cuda - src = self._cpu_data - dest = self.data - else: - # cuda -> cpu - src = self.data - dest = self._cpu_data - alloc_storage(dest) - dest.copy_(src) - free_storage(src) - - if update_ptr: - self._update_tensors_ptr() - - def reduce(self, is_all_reduce: bool = False) -> None: - """ - Reduce or all-reduce the chunk. - - Args: - is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false. - """ - self.move_device(get_current_device(), update_ptr=False) - if is_all_reduce: - dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) - else: - dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) - self._update_tensors_ptr() - self._update_tensors_state(TensorState.HOLD) - - def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: - """ - Make a transition of the tensor into the next state. - - Args: - tensor (torch.Tensor): a torch Tensor object. - tensor_state (TensorState): the target state for transition. - """ - assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE' - # As the gradient hook can be triggered either before or after post-backward - # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce - # or compute -> ready_for_reduce -> hold_after_bwd - # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd - # this function only apply valid state transformation - # invalid calls will be ignored and nothing changes - if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: - # print( - # f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' - # ) - return - self.tensors_info[tensor].state = tensor_state - - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: - """ - Copy data slice to the memory space indexed by the input tensor in the chunk. - - Args: - tensor (torch.Tensor): the tensor used to retrive meta information - data_slice (torch.Tensor): the tensor to be copied to the chunk - """ - tensor_info = self.tensors_info[tensor] - self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) - tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) - - @property - def can_release(self) -> bool: - """ - Check whether the chunk can be released. - """ - for tensor_info in self.tensors_info.values(): - if tensor_info.state != TensorState.HOLD: - return False - return True - - @property - def can_move_device(self) -> bool: - """ - Check whether the chunk can be moved across devices. - """ - for tensor_info in self.tensors_info.values(): - if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE): - return False - return True - - @property - def can_reduce(self) -> bool: - """ - Check whether the chunk can be reduced. - """ - for tensor_info in self.tensors_info.values(): - if tensor_info.state != TensorState.READY_FOR_REDUCE: - return False - return True - - @property - def is_empty(self) -> bool: - """ - Check whether the chunk is empty. - """ - return is_storage_empty(self._payload) - - def __repr__(self) -> str: - return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}' - - @property - def has_inf_or_nan(self) -> bool: - """ - Check if the chunk has inf or nan values. - """ - return torch.isinf(self._payload[:self.utilized_size]).any().item() or \ - torch.isnan(self._payload[:self.utilized_size]).any().item() - - def copy_(self, dest_chunk: 'Chunk'): - """ - Copy the data of this chunk to a destination chunk. - """ - assert not self.is_empty - assert not dest_chunk.is_empty - assert self.size == dest_chunk.size - assert self.utilized_size == dest_chunk.utilized_size - self._payload.copy_(dest_chunk._payload) - self._update_tensors_ptr() - - @property - def device_type(self) -> str: - """ - Get the device type of the chunk. - """ - return self._payload.device.type - - def __hash__(self) -> int: - return hash(id(self)) - - def __eq__(self, __o: object) -> bool: - return self is __o - - def get_tensors(self) -> List[torch.Tensor]: - return list(self.tensors_info.keys()) - - @property - def _payload(self) -> torch.Tensor: - if self._cpu_data is None or is_storage_empty(self._cpu_data): - return self.data - return self._cpu_data +from .chunk import Chunk, ChunkFullError, TensorState class ChunkManager: diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 42ae598db..fc35f4c33 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -3,8 +3,8 @@ import functools from .memory_tracer.memstats_collector import MemStatsCollectorV2 from typing import List, Optional, Tuple from time import time -from colossalai.tensor.chunk import Chunk, ChunkManager -from .placement_policy import PlacementPolicy, PlacementPolicyFactory +from colossalai.gemini import Chunk, ChunkManager +from .placement_policy import PlacementPolicyFactory class GeminiManager: diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index e0ded421f..1ff88bd3f 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity from colossalai.utils import get_current_device from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.tensor import ChunkManager +from colossalai.gemini import ChunkManager import torch import time diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index 7e8a0fc61..5ae1dfaa1 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2 from typing import Type import functools -from colossalai.tensor.chunk import Chunk, ChunkManager +from colossalai.gemini import Chunk, ChunkManager class PlacementPolicy(ABC): diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index fb98b62f3..dac68bb51 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -5,7 +5,7 @@ from colossalai.core import global_context as gpc from colossalai.context import ParallelMode from functools import partial from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 -from colossalai.tensor.chunk import TensorState, Chunk +from colossalai.gemini.chunk import TensorState, Chunk from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.gemini.gemini_mgr import GeminiManager from typing import Dict, Iterable, List, Optional diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index b71db453d..591848e42 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -5,7 +5,6 @@ from .colo_parameter import ColoParameter from .utils import convert_parameter, named_params_with_colotensor from .dist_spec_mgr import DistSpecManager from .param_op_hook import ParamOpHook, ParamOpHookManager -from .chunk import ChunkManager, TensorState from . import distspec from .process_group import ProcessGroup diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/zero_hook_v2.py index 7898adcd3..af0187b4f 100644 --- a/colossalai/zero/utils/zero_hook_v2.py +++ b/colossalai/zero/utils/zero_hook_v2.py @@ -1,6 +1,6 @@ import torch from colossalai.tensor.param_op_hook import ParamOpHook -from colossalai.tensor.chunk import ChunkManager, TensorState +from colossalai.gemini import TensorState from enum import Enum from typing import List from contextlib import contextmanager diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index dba47b052..d9f4ff48a 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ChunkManager +from colossalai.gemini import ChunkManager from functools import partial from colossalai.nn.parallel import ColoDDP, ZeroDDP from colossalai.gemini.gemini_mgr import GeminiManager diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index 782ff673a..9201ff661 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ChunkManager +from colossalai.gemini import ChunkManager from functools import partial from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.nn.parallel import ZeroDDP, ColoDDP diff --git a/tests/test_ddp/test_reducer.py b/tests/test_ddp/test_reducer.py index 64faffd3c..5b302d99f 100644 --- a/tests/test_ddp/test_reducer.py +++ b/tests/test_ddp/test_reducer.py @@ -5,14 +5,7 @@ import torch.multiprocessing as mp from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ChunkManager from functools import partial -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ZeroDDP, ColoDDP -from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Callable -from collections import OrderedDict from colossalai.nn.parallel.reducer import Reducer import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py index 243c03941..0f5d75c82 100644 --- a/tests/test_tensor/test_chunk.py +++ b/tests/test_tensor/test_chunk.py @@ -4,7 +4,7 @@ import pytest import torch.multiprocessing as mp from typing import List from functools import partial -from colossalai.tensor import ChunkManager +from colossalai.gemini import ChunkManager from colossalai.testing import rerun_if_address_is_in_use, parameterize from colossalai.utils import free_port from colossalai.core import global_context as gpc diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_zero_optim.py index bd756fefd..08a2a4bfb 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_zero_optim.py @@ -7,7 +7,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ChunkManager +from colossalai.gemini import ChunkManager from colossalai.core import global_context as gpc from functools import partial from _utils import tensor_equal, set_seed, tensor_shard_equal diff --git a/tests/test_zero/test_zero_optim_state_dict.py b/tests/test_zero/test_zero_optim_state_dict.py index 258b32a8e..e7b5a64fb 100644 --- a/tests/test_zero/test_zero_optim_state_dict.py +++ b/tests/test_zero/test_zero_optim_state_dict.py @@ -7,13 +7,12 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ChunkManager from colossalai.core import global_context as gpc from functools import partial from tests.test_tensor._utils import set_seed from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.nn.parallel.data_parallel import ZeroDDP -from colossalai.gemini import GeminiManager +from colossalai.gemini import ChunkManager, GeminiManager from colossalai.testing import parameterize from colossalai.nn.optimizer import HybridAdam from colossalai.zero import ZeroOptimizer