mirror of https://github.com/hpcaitech/ColossalAI
[refactor] zero directory (#724)
parent
20ab1f5520
commit
4d90a7b513
|
@ -10,6 +10,8 @@ from .torch_amp import convert_to_torch_amp
|
|||
from .apex_amp import convert_to_apex_amp
|
||||
from .naive_amp import convert_to_naive_amp
|
||||
|
||||
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
|
||||
|
||||
|
||||
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
|
||||
"""A helper function to wrap training components with Torch AMP modules.
|
||||
|
|
|
@ -1,119 +1,3 @@
|
|||
from typing import List, Callable, Optional
|
||||
from .utils import register_ophooks_recursively, BaseOpHook
|
||||
|
||||
import torch
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from ._memtracer_ophook import MemTracerOpHook
|
||||
from ._shard_grad_ophook import ShardGradHook
|
||||
from ._shard_param_ophook import ShardParamHook
|
||||
|
||||
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
|
||||
|
||||
|
||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
if type(outputs) is tuple:
|
||||
touched_outputs = []
|
||||
for output in outputs:
|
||||
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
||||
touched_outputs.append(touched_output)
|
||||
return tuple(touched_outputs)
|
||||
elif type(outputs) is torch.Tensor:
|
||||
return functional.apply(module, backward_function, outputs)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
module.applied_pre_backward = False
|
||||
outputs = outputs.detach()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
output = output.detach()
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
"""
|
||||
Args:
|
||||
activation_grad of the next layer.
|
||||
Returns:
|
||||
grad of the input activation.
|
||||
"""
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook] = None,
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
# return from flitered module
|
||||
if filter_fn is not None and filter_fn(module):
|
||||
return
|
||||
|
||||
if ophook_list is not None:
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
||||
def _pre_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_fwd_exec(submodule, *args)
|
||||
|
||||
def _post_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_bwd_exec(submodule, inputs, output)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_bwd_exec(submodule, inputs)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
||||
|
||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||
module.register_forward_hook(_post_forward_module_hook)
|
||||
|
||||
module.register_forward_hook(_pre_backward_module_hook)
|
||||
module.register_forward_pre_hook(_post_backward_module_hook)
|
||||
__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
|
||||
class BaseOpHook(ABC):
|
||||
"""This class allows users to add customized operations
|
||||
before and after the execution of a PyTorch submodule"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_iter(self):
|
||||
pass
|
|
@ -0,0 +1,142 @@
|
|||
import torch
|
||||
from typing import List, Callable, Optional
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
|
||||
class BaseOpHook(ABC):
|
||||
"""This class allows users to add customized operations
|
||||
before and after the execution of a PyTorch submodule"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_iter(self):
|
||||
pass
|
||||
|
||||
|
||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
if type(outputs) is tuple:
|
||||
touched_outputs = []
|
||||
for output in outputs:
|
||||
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
||||
touched_outputs.append(touched_output)
|
||||
return tuple(touched_outputs)
|
||||
elif type(outputs) is torch.Tensor:
|
||||
return functional.apply(module, backward_function, outputs)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
module.applied_pre_backward = False
|
||||
outputs = outputs.detach()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
output = output.detach()
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
"""
|
||||
Args:
|
||||
activation_grad of the next layer.
|
||||
Returns:
|
||||
grad of the input activation.
|
||||
"""
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook] = None,
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
# return from flitered module
|
||||
if filter_fn is not None and filter_fn(module):
|
||||
return
|
||||
|
||||
if ophook_list is not None:
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
||||
def _pre_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_fwd_exec(submodule, *args)
|
||||
|
||||
def _post_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_bwd_exec(submodule, inputs, output)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_bwd_exec(submodule, inputs)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
||||
|
||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||
module.register_forward_hook(_post_forward_module_hook)
|
||||
|
||||
module.register_forward_hook(_pre_backward_module_hook)
|
||||
module.register_forward_pre_hook(_post_backward_module_hook)
|
|
@ -1,6 +1,5 @@
|
|||
from .base_shard_strategy import BaseShardStrategy
|
||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr']
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.commons import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
|
|
@ -8,9 +8,8 @@ import torch.nn as nn
|
|||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively
|
||||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.zero.utils import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.engine.gradient_handler.utils import bucket_allreduce
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
|
@ -18,12 +17,12 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
|
|||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
|
|
|
@ -8,22 +8,6 @@ from colossalai.utils import is_model_parallel_parameter
|
|||
import torch.distributed as dist
|
||||
|
||||
|
||||
def move_tensor(input_, device):
|
||||
assert device in ['cpu', 'gpu']
|
||||
|
||||
if isinstance(input_, (list, tuple)):
|
||||
for tensor in input_:
|
||||
tensor.data = tensor.data.cpu(
|
||||
) if device == 'cpu' else tensor.data.cuda()
|
||||
elif torch.is_tensor(input_):
|
||||
input_.data = input_.data.cpu(
|
||||
) if device == 'cpu' else tensor.data.cuda()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected argument 'input_' to be torch.Tensor, list or tuple, but got {type(input_)} "
|
||||
)
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
return _flatten_dense_tensors(input_)
|
||||
|
||||
|
@ -51,8 +35,7 @@ def shuffle_by_round_robin(tensor_list, num_partitions):
|
|||
partition_to_go = tensor_idx % num_partitions
|
||||
if partition_to_go not in partitions:
|
||||
partitions[partition_to_go] = []
|
||||
partitions[partition_to_go].append(dict(tensor=tensor,
|
||||
index=tensor_idx))
|
||||
partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx))
|
||||
|
||||
partitions_count = len(partitions)
|
||||
new_tensor_list = []
|
||||
|
@ -73,9 +56,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
|
|||
padding = calculate_padding(num_elements, unit_size=unit_size)
|
||||
|
||||
if padding > 0:
|
||||
pad_tensor = torch.zeros(padding,
|
||||
device=tensor_list[0].device,
|
||||
dtype=tensor_list[0].dtype)
|
||||
pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
|
||||
padded_tensor_list = tensor_list + [pad_tensor]
|
||||
else:
|
||||
padded_tensor_list = tensor_list
|
||||
|
@ -86,6 +67,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
|
|||
def is_nccl_aligned(tensor):
|
||||
return tensor.data_ptr() % 4 == 0
|
||||
|
||||
|
||||
def get_grad_accumulate_object(tensor):
|
||||
"""
|
||||
Return the AccumulateGrad of the input tensor
|
||||
|
@ -108,10 +90,7 @@ def get_grad_accumulate_object(tensor):
|
|||
|
||||
|
||||
def split_half_float_double(tensor_list):
|
||||
dtypes = [
|
||||
"torch.cuda.HalfTensor", "torch.cuda.FloatTensor",
|
||||
"torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"
|
||||
]
|
||||
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
||||
buckets = []
|
||||
for i, dtype in enumerate(dtypes):
|
||||
bucket = [t for t in tensor_list if t.type() == dtype]
|
||||
|
@ -120,10 +99,7 @@ def split_half_float_double(tensor_list):
|
|||
return buckets
|
||||
|
||||
|
||||
def reduce_tensor(tensor,
|
||||
dtype,
|
||||
dst_rank=None,
|
||||
parallel_mode=ParallelMode.DATA):
|
||||
def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA):
|
||||
"""
|
||||
Reduce the tensor in the data parallel process group
|
||||
|
||||
|
@ -165,6 +141,7 @@ def reduce_tensor(tensor,
|
|||
tensor.copy_(tensor_to_reduce)
|
||||
return tensor
|
||||
|
||||
|
||||
def has_inf_or_nan(tensor):
|
||||
try:
|
||||
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||
|
@ -181,8 +158,7 @@ def has_inf_or_nan(tensor):
|
|||
raise
|
||||
return True
|
||||
else:
|
||||
if tensor_sum == float('inf') or tensor_sum == -float(
|
||||
'inf') or tensor_sum != tensor_sum:
|
||||
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -201,11 +177,7 @@ def calculate_global_norm_from_list(norm_list):
|
|||
return math.sqrt(total_norm)
|
||||
|
||||
|
||||
def compute_norm(gradients,
|
||||
params,
|
||||
dp_group,
|
||||
mp_group,
|
||||
norm_type=2):
|
||||
def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
||||
"""Clips gradient norm of an iterable of parameters.
|
||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||
added functionality to handle model parallel parameters. Note that
|
||||
|
@ -229,14 +201,11 @@ def compute_norm(gradients,
|
|||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
dist.all_reduce(total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=dp_group)
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
||||
|
||||
# Take max across all GPUs.
|
||||
if mp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX)
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
total_norm = 0.0
|
||||
|
@ -248,21 +217,17 @@ def compute_norm(gradients,
|
|||
if is_model_parallel_parameter(p) or mp_rank == 0:
|
||||
param_norm = g.data.double().norm(2)
|
||||
total_norm += param_norm.item()**2
|
||||
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
torch.distributed.all_reduce(total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=dp_group)
|
||||
|
||||
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
||||
|
||||
if mp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM)
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||
|
||||
if total_norm == float(
|
||||
'inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||||
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
|
|
@ -12,8 +12,8 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
|
|
|
@ -1,4 +1,11 @@
|
|||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
|
||||
colo_model_data_move_to_cpu, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
|
||||
|
||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
||||
__all__ = [
|
||||
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
|
||||
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
|
||||
]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Optional, Tuple
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
|
||||
from .tensorful_state import StatefulTensor, TensorState
|
||||
from typing import List
|
||||
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .zero_hook import ZeroHook
|
||||
|
||||
__all__ = ['StatefulTensorMgr', 'ZeroHook']
|
|
@ -4,7 +4,7 @@ import types
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from typing import Dict, List
|
|
@ -3,15 +3,16 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.registry import OPHOOKS
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
|
@ -1,4 +1,4 @@
|
|||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone
|
||||
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
|
||||
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||
colo_model_tensor_clone)
|
||||
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
import colossalai
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
@ -30,10 +30,9 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
|
|||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(
|
||||
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(get_current_device()),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
|
|
|
@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.shard_utils import StatefulTensorMgr
|
||||
from colossalai.zero.utils import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.utils import free_port
|
||||
|
|
Loading…
Reference in New Issue