diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py index 790a83767..16da81f23 100644 --- a/colossalai/amp/__init__.py +++ b/colossalai/amp/__init__.py @@ -10,6 +10,8 @@ from .torch_amp import convert_to_torch_amp from .apex_amp import convert_to_apex_amp from .naive_amp import convert_to_naive_amp +__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] + def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): """A helper function to wrap training components with Torch AMP modules. diff --git a/colossalai/engine/ophooks/__init__.py b/colossalai/engine/ophooks/__init__.py index 412df33c3..c2664b2af 100644 --- a/colossalai/engine/ophooks/__init__.py +++ b/colossalai/engine/ophooks/__init__.py @@ -1,119 +1,3 @@ -from typing import List, Callable, Optional +from .utils import register_ophooks_recursively, BaseOpHook -import torch - -from ._base_ophook import BaseOpHook -from ._memtracer_ophook import MemTracerOpHook -from ._shard_grad_ophook import ShardGradHook -from ._shard_param_ophook import ShardParamHook - -all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"] - - -# apply torch.autograd.Function that calls a backward_function to tensors in output -def _apply_to_tensors_only(module, functional, backward_function, outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, functional, backward_function, output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - return functional.apply(module, backward_function, outputs) - else: - return outputs - - -class PreBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - module.applied_pre_backward = False - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, module, pre_backward_function, output): - ctx.module = module - output = output.detach() - ctx.pre_backward_function = pre_backward_function - return output - - @staticmethod - def backward(ctx, *args): - """ - Args: - activation_grad of the next layer. - Returns: - grad of the input activation. - """ - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -def register_ophooks_recursively(module: torch.nn.Module, - ophook_list: List[BaseOpHook] = None, - name: str = "", - filter_fn: Optional[Callable] = None): - r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" - assert isinstance(module, torch.nn.Module) - - # Add hooks for submodules - for child_name, child in module.named_children(): - register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn) - - # Early return on modules with no parameters. - if len(list(module.parameters(recurse=False))) == 0: - return - - # return from flitered module - if filter_fn is not None and filter_fn(module): - return - - if ophook_list is not None: - for hook in ophook_list: - assert (isinstance(hook, BaseOpHook)) - - def _pre_forward_module_hook(submodule, *args): - for hook in ophook_list: - assert isinstance(submodule, torch.nn.Module) - hook.pre_fwd_exec(submodule, *args) - - def _post_forward_module_hook(submodule, *args): - for hook in ophook_list: - assert isinstance(submodule, torch.nn.Module) - hook.post_fwd_exec(submodule, *args) - - def _pre_backward_module_hook(submodule, inputs, output): - - def _run_before_backward_function(submodule): - for hook in ophook_list: - assert isinstance(submodule, torch.nn.Module) - hook.pre_bwd_exec(submodule, inputs, output) - - return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output) - - def _post_backward_module_hook(submodule, inputs): - - def _run_after_backward_function(submodule): - for hook in ophook_list: - assert isinstance(submodule, torch.nn.Module) - hook.post_bwd_exec(submodule, inputs) - - return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs) - - module.register_forward_pre_hook(_pre_forward_module_hook) - module.register_forward_hook(_post_forward_module_hook) - - module.register_forward_hook(_pre_backward_module_hook) - module.register_forward_pre_hook(_post_backward_module_hook) +__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"] diff --git a/colossalai/engine/ophooks/_base_ophook.py b/colossalai/engine/ophooks/_base_ophook.py deleted file mode 100644 index 24251141d..000000000 --- a/colossalai/engine/ophooks/_base_ophook.py +++ /dev/null @@ -1,30 +0,0 @@ -from abc import ABC, abstractmethod -import torch - - -class BaseOpHook(ABC): - """This class allows users to add customized operations - before and after the execution of a PyTorch submodule""" - - def __init__(self): - pass - - @abstractmethod - def pre_fwd_exec(self, module: torch.nn.Module, *args): - pass - - @abstractmethod - def post_fwd_exec(self, module: torch.nn.Module, *args): - pass - - @abstractmethod - def pre_bwd_exec(self, module: torch.nn.Module, input, output): - pass - - @abstractmethod - def post_bwd_exec(self, module: torch.nn.Module, input): - pass - - @abstractmethod - def post_iter(self): - pass diff --git a/colossalai/engine/ophooks/utils.py b/colossalai/engine/ophooks/utils.py new file mode 100644 index 000000000..26d485657 --- /dev/null +++ b/colossalai/engine/ophooks/utils.py @@ -0,0 +1,142 @@ +import torch +from typing import List, Callable, Optional + +from abc import ABC, abstractmethod +import torch + + +class BaseOpHook(ABC): + """This class allows users to add customized operations + before and after the execution of a PyTorch submodule""" + + def __init__(self): + pass + + @abstractmethod + def pre_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def post_fwd_exec(self, module: torch.nn.Module, *args): + pass + + @abstractmethod + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + pass + + @abstractmethod + def post_bwd_exec(self, module: torch.nn.Module, input): + pass + + @abstractmethod + def post_iter(self): + pass + + +# apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, functional, backward_function, output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +class PreBackwardFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + module.applied_pre_backward = False + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + output = output.detach() + ctx.pre_backward_function = pre_backward_function + return output + + @staticmethod + def backward(ctx, *args): + """ + Args: + activation_grad of the next layer. + Returns: + grad of the input activation. + """ + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +def register_ophooks_recursively(module: torch.nn.Module, + ophook_list: List[BaseOpHook] = None, + name: str = "", + filter_fn: Optional[Callable] = None): + r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" + assert isinstance(module, torch.nn.Module) + + # Add hooks for submodules + for child_name, child in module.named_children(): + register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn) + + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: + return + + # return from flitered module + if filter_fn is not None and filter_fn(module): + return + + if ophook_list is not None: + for hook in ophook_list: + assert (isinstance(hook, BaseOpHook)) + + def _pre_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_fwd_exec(submodule, *args) + + def _post_forward_module_hook(submodule, *args): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_fwd_exec(submodule, *args) + + def _pre_backward_module_hook(submodule, inputs, output): + + def _run_before_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.pre_bwd_exec(submodule, inputs, output) + + return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output) + + def _post_backward_module_hook(submodule, inputs): + + def _run_after_backward_function(submodule): + for hook in ophook_list: + assert isinstance(submodule, torch.nn.Module) + hook.post_bwd_exec(submodule, inputs) + + return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs) + + module.register_forward_pre_hook(_pre_forward_module_hook) + module.register_forward_hook(_post_forward_module_hook) + + module.register_forward_hook(_pre_backward_module_hook) + module.register_forward_pre_hook(_post_backward_module_hook) diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/shard_utils/__init__.py index 9a7917c63..5e5d63a7e 100644 --- a/colossalai/zero/shard_utils/__init__.py +++ b/colossalai/zero/shard_utils/__init__.py @@ -1,6 +1,5 @@ from .base_shard_strategy import BaseShardStrategy from .bucket_tensor_shard_strategy import BucketTensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy -from .stateful_tensor_mgr import StatefulTensorMgr -__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr'] +__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 761143cf3..8fba5f73a 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -3,7 +3,7 @@ from typing import List, Optional import torch import torch.distributed as dist from colossalai.utils import get_current_device -from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 2be5bac43..a324be8c5 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -8,9 +8,8 @@ import torch.nn as nn from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.engine.ophooks import register_ophooks_recursively -from colossalai.engine.ophooks.zero_hook import ZeroHook +from colossalai.zero.utils import ZeroHook from colossalai.engine.paramhooks import BaseParamHookMgr -from colossalai.engine.gradient_handler.utils import bucket_allreduce from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device, disposable from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector @@ -18,12 +17,12 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu +from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_param.tensorful_state import TensorState from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/sharded_optim/_utils.py index 18dca5231..49cf21969 100644 --- a/colossalai/zero/sharded_optim/_utils.py +++ b/colossalai/zero/sharded_optim/_utils.py @@ -8,22 +8,6 @@ from colossalai.utils import is_model_parallel_parameter import torch.distributed as dist -def move_tensor(input_, device): - assert device in ['cpu', 'gpu'] - - if isinstance(input_, (list, tuple)): - for tensor in input_: - tensor.data = tensor.data.cpu( - ) if device == 'cpu' else tensor.data.cuda() - elif torch.is_tensor(input_): - input_.data = input_.data.cpu( - ) if device == 'cpu' else tensor.data.cuda() - else: - raise TypeError( - f"Expected argument 'input_' to be torch.Tensor, list or tuple, but got {type(input_)} " - ) - - def flatten(input_): return _flatten_dense_tensors(input_) @@ -51,8 +35,7 @@ def shuffle_by_round_robin(tensor_list, num_partitions): partition_to_go = tensor_idx % num_partitions if partition_to_go not in partitions: partitions[partition_to_go] = [] - partitions[partition_to_go].append(dict(tensor=tensor, - index=tensor_idx)) + partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx)) partitions_count = len(partitions) new_tensor_list = [] @@ -73,9 +56,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size): padding = calculate_padding(num_elements, unit_size=unit_size) if padding > 0: - pad_tensor = torch.zeros(padding, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) + pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype) padded_tensor_list = tensor_list + [pad_tensor] else: padded_tensor_list = tensor_list @@ -86,6 +67,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size): def is_nccl_aligned(tensor): return tensor.data_ptr() % 4 == 0 + def get_grad_accumulate_object(tensor): """ Return the AccumulateGrad of the input tensor @@ -108,10 +90,7 @@ def get_grad_accumulate_object(tensor): def split_half_float_double(tensor_list): - dtypes = [ - "torch.cuda.HalfTensor", "torch.cuda.FloatTensor", - "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor" - ] + dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] buckets = [] for i, dtype in enumerate(dtypes): bucket = [t for t in tensor_list if t.type() == dtype] @@ -120,10 +99,7 @@ def split_half_float_double(tensor_list): return buckets -def reduce_tensor(tensor, - dtype, - dst_rank=None, - parallel_mode=ParallelMode.DATA): +def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA): """ Reduce the tensor in the data parallel process group @@ -165,6 +141,7 @@ def reduce_tensor(tensor, tensor.copy_(tensor_to_reduce) return tensor + def has_inf_or_nan(tensor): try: # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if @@ -181,8 +158,7 @@ def has_inf_or_nan(tensor): raise return True else: - if tensor_sum == float('inf') or tensor_sum == -float( - 'inf') or tensor_sum != tensor_sum: + if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: return True return False @@ -201,11 +177,7 @@ def calculate_global_norm_from_list(norm_list): return math.sqrt(total_norm) -def compute_norm(gradients, - params, - dp_group, - mp_group, - norm_type=2): +def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): """Clips gradient norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that @@ -229,14 +201,11 @@ def compute_norm(gradients, if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, - group=dp_group) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group) # Take max across all GPUs. if mp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.MAX) + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: total_norm = 0.0 @@ -248,21 +217,17 @@ def compute_norm(gradients, if is_model_parallel_parameter(p) or mp_rank == 0: param_norm = g.data.double().norm(2) total_norm += param_norm.item()**2 - + # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=dp_group) - + torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) + if mp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) + dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM) total_norm = total_norm_cuda[0].item()**(1. / norm_type) - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: + if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: total_norm = -1 return total_norm diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index d4b0a8c46..88464b0e1 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -12,8 +12,8 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER -from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone, - colo_tensor_mem_usage) +from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone, + colo_tensor_mem_usage) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_optim._utils import has_inf_or_nan diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py index 5642a504a..f6f46db8e 100644 --- a/colossalai/zero/sharded_param/__init__.py +++ b/colossalai/zero/sharded_param/__init__.py @@ -1,4 +1,11 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 +from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline, + colo_model_data_move_to_cpu, colo_model_tensor_clone, + colo_tensor_mem_usage) +from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor -__all__ = ['ShardedTensor', 'ShardedParamV2'] +__all__ = [ + 'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline', + 'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor' +] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index dff933a83..51c3d8556 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -1,7 +1,7 @@ import torch from colossalai.zero.sharded_param import ShardedTensor from typing import Optional, Tuple -from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage +from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage from .tensorful_state import StatefulTensor, TensorState from typing import List diff --git a/colossalai/zero/shard_utils/tensor_utils.py b/colossalai/zero/sharded_param/tensor_utils.py similarity index 100% rename from colossalai/zero/shard_utils/tensor_utils.py rename to colossalai/zero/sharded_param/tensor_utils.py diff --git a/colossalai/zero/utils/__init__.py b/colossalai/zero/utils/__init__.py new file mode 100644 index 000000000..2153ebe34 --- /dev/null +++ b/colossalai/zero/utils/__init__.py @@ -0,0 +1,4 @@ +from .stateful_tensor_mgr import StatefulTensorMgr +from .zero_hook import ZeroHook + +__all__ = ['StatefulTensorMgr', 'ZeroHook'] \ No newline at end of file diff --git a/colossalai/zero/shard_utils/stateful_tensor_mgr.py b/colossalai/zero/utils/stateful_tensor_mgr.py similarity index 97% rename from colossalai/zero/shard_utils/stateful_tensor_mgr.py rename to colossalai/zero/utils/stateful_tensor_mgr.py index 817a383d8..1674775a2 100644 --- a/colossalai/zero/shard_utils/stateful_tensor_mgr.py +++ b/colossalai/zero/utils/stateful_tensor_mgr.py @@ -4,7 +4,7 @@ import types from colossalai.utils.cuda import get_current_device from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState -from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from typing import Dict, List diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/zero/utils/zero_hook.py similarity index 96% rename from colossalai/engine/ophooks/zero_hook.py rename to colossalai/zero/utils/zero_hook.py index 4dfe924dd..14e502530 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -3,15 +3,16 @@ from typing import Optional import torch import torch.distributed as dist from colossalai.registry import OPHOOKS + from colossalai.utils import get_current_device from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector + from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_param.tensorful_state import TensorState -from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline -from ._base_ophook import BaseOpHook - -from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.engine.ophooks import BaseOpHook @OPHOOKS.register_module diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index 6580ae37f..d588f42d9 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,4 +1,4 @@ -from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.utils import free_port from colossalai.testing import rerun_on_exception from colossalai.zero.sharded_param import ShardedTensor diff --git a/tests/test_utils/test_tensor_move.py b/tests/test_utils/test_tensor_move.py index 62874d652..cf1677d4b 100644 --- a/tests/test_utils/test_tensor_move.py +++ b/tests/test_utils/test_tensor_move.py @@ -1,11 +1,12 @@ import pytest +import colossalai from colossalai.utils.cuda import get_current_device -from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone +from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move, + colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, + colo_model_tensor_clone) from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity from colossalai.utils import free_port -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor -import colossalai import torch diff --git a/tests/test_zero_data_parallel/test_found_inf.py b/tests/test_zero_data_parallel/test_found_inf.py index 22c3a80fd..6bdd667e4 100644 --- a/tests/test_zero_data_parallel/test_found_inf.py +++ b/tests/test_zero_data_parallel/test_found_inf.py @@ -30,10 +30,9 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - with ZeroInitContext( - target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(get_current_device()), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): zero_model = model_builder(checkpoint=True) zero_model = ShardedModelV2( zero_model, diff --git a/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py b/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py index af8165de2..ed76d27a7 100644 --- a/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py +++ b/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py @@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction -from colossalai.zero.shard_utils import StatefulTensorMgr +from colossalai.zero.utils import StatefulTensorMgr from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.utils import free_port