mirror of https://github.com/hpcaitech/ColossalAI
[refactor] zero directory (#724)
parent
20ab1f5520
commit
4d90a7b513
|
@ -10,6 +10,8 @@ from .torch_amp import convert_to_torch_amp
|
||||||
from .apex_amp import convert_to_apex_amp
|
from .apex_amp import convert_to_apex_amp
|
||||||
from .naive_amp import convert_to_naive_amp
|
from .naive_amp import convert_to_naive_amp
|
||||||
|
|
||||||
|
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
|
||||||
|
|
||||||
|
|
||||||
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
|
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
|
||||||
"""A helper function to wrap training components with Torch AMP modules.
|
"""A helper function to wrap training components with Torch AMP modules.
|
||||||
|
|
|
@ -1,119 +1,3 @@
|
||||||
from typing import List, Callable, Optional
|
from .utils import register_ophooks_recursively, BaseOpHook
|
||||||
|
|
||||||
import torch
|
__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
|
||||||
|
|
||||||
from ._base_ophook import BaseOpHook
|
|
||||||
from ._memtracer_ophook import MemTracerOpHook
|
|
||||||
from ._shard_grad_ophook import ShardGradHook
|
|
||||||
from ._shard_param_ophook import ShardParamHook
|
|
||||||
|
|
||||||
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
|
|
||||||
|
|
||||||
|
|
||||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
|
||||||
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
|
||||||
if type(outputs) is tuple:
|
|
||||||
touched_outputs = []
|
|
||||||
for output in outputs:
|
|
||||||
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
|
||||||
touched_outputs.append(touched_output)
|
|
||||||
return tuple(touched_outputs)
|
|
||||||
elif type(outputs) is torch.Tensor:
|
|
||||||
return functional.apply(module, backward_function, outputs)
|
|
||||||
else:
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class PreBackwardFunction(torch.autograd.Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, module, pre_backward_function, outputs):
|
|
||||||
ctx.module = module
|
|
||||||
ctx.pre_backward_function = pre_backward_function
|
|
||||||
module.applied_pre_backward = False
|
|
||||||
outputs = outputs.detach()
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, *args):
|
|
||||||
ctx.pre_backward_function(ctx.module)
|
|
||||||
return (None, None) + args
|
|
||||||
|
|
||||||
|
|
||||||
class PostBackwardFunction(torch.autograd.Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, module, pre_backward_function, output):
|
|
||||||
ctx.module = module
|
|
||||||
output = output.detach()
|
|
||||||
ctx.pre_backward_function = pre_backward_function
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, *args):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
activation_grad of the next layer.
|
|
||||||
Returns:
|
|
||||||
grad of the input activation.
|
|
||||||
"""
|
|
||||||
ctx.pre_backward_function(ctx.module)
|
|
||||||
return (None, None) + args
|
|
||||||
|
|
||||||
|
|
||||||
def register_ophooks_recursively(module: torch.nn.Module,
|
|
||||||
ophook_list: List[BaseOpHook] = None,
|
|
||||||
name: str = "",
|
|
||||||
filter_fn: Optional[Callable] = None):
|
|
||||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
|
||||||
assert isinstance(module, torch.nn.Module)
|
|
||||||
|
|
||||||
# Add hooks for submodules
|
|
||||||
for child_name, child in module.named_children():
|
|
||||||
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
|
||||||
|
|
||||||
# Early return on modules with no parameters.
|
|
||||||
if len(list(module.parameters(recurse=False))) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# return from flitered module
|
|
||||||
if filter_fn is not None and filter_fn(module):
|
|
||||||
return
|
|
||||||
|
|
||||||
if ophook_list is not None:
|
|
||||||
for hook in ophook_list:
|
|
||||||
assert (isinstance(hook, BaseOpHook))
|
|
||||||
|
|
||||||
def _pre_forward_module_hook(submodule, *args):
|
|
||||||
for hook in ophook_list:
|
|
||||||
assert isinstance(submodule, torch.nn.Module)
|
|
||||||
hook.pre_fwd_exec(submodule, *args)
|
|
||||||
|
|
||||||
def _post_forward_module_hook(submodule, *args):
|
|
||||||
for hook in ophook_list:
|
|
||||||
assert isinstance(submodule, torch.nn.Module)
|
|
||||||
hook.post_fwd_exec(submodule, *args)
|
|
||||||
|
|
||||||
def _pre_backward_module_hook(submodule, inputs, output):
|
|
||||||
|
|
||||||
def _run_before_backward_function(submodule):
|
|
||||||
for hook in ophook_list:
|
|
||||||
assert isinstance(submodule, torch.nn.Module)
|
|
||||||
hook.pre_bwd_exec(submodule, inputs, output)
|
|
||||||
|
|
||||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
|
||||||
|
|
||||||
def _post_backward_module_hook(submodule, inputs):
|
|
||||||
|
|
||||||
def _run_after_backward_function(submodule):
|
|
||||||
for hook in ophook_list:
|
|
||||||
assert isinstance(submodule, torch.nn.Module)
|
|
||||||
hook.post_bwd_exec(submodule, inputs)
|
|
||||||
|
|
||||||
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
|
||||||
|
|
||||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
|
||||||
module.register_forward_hook(_post_forward_module_hook)
|
|
||||||
|
|
||||||
module.register_forward_hook(_pre_backward_module_hook)
|
|
||||||
module.register_forward_pre_hook(_post_backward_module_hook)
|
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class BaseOpHook(ABC):
|
|
||||||
"""This class allows users to add customized operations
|
|
||||||
before and after the execution of a PyTorch submodule"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def post_iter(self):
|
|
||||||
pass
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
import torch
|
||||||
|
from typing import List, Callable, Optional
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOpHook(ABC):
|
||||||
|
"""This class allows users to add customized operations
|
||||||
|
before and after the execution of a PyTorch submodule"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def post_iter(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||||
|
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||||
|
if type(outputs) is tuple:
|
||||||
|
touched_outputs = []
|
||||||
|
for output in outputs:
|
||||||
|
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
||||||
|
touched_outputs.append(touched_output)
|
||||||
|
return tuple(touched_outputs)
|
||||||
|
elif type(outputs) is torch.Tensor:
|
||||||
|
return functional.apply(module, backward_function, outputs)
|
||||||
|
else:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class PreBackwardFunction(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, module, pre_backward_function, outputs):
|
||||||
|
ctx.module = module
|
||||||
|
ctx.pre_backward_function = pre_backward_function
|
||||||
|
module.applied_pre_backward = False
|
||||||
|
outputs = outputs.detach()
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *args):
|
||||||
|
ctx.pre_backward_function(ctx.module)
|
||||||
|
return (None, None) + args
|
||||||
|
|
||||||
|
|
||||||
|
class PostBackwardFunction(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, module, pre_backward_function, output):
|
||||||
|
ctx.module = module
|
||||||
|
output = output.detach()
|
||||||
|
ctx.pre_backward_function = pre_backward_function
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *args):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
activation_grad of the next layer.
|
||||||
|
Returns:
|
||||||
|
grad of the input activation.
|
||||||
|
"""
|
||||||
|
ctx.pre_backward_function(ctx.module)
|
||||||
|
return (None, None) + args
|
||||||
|
|
||||||
|
|
||||||
|
def register_ophooks_recursively(module: torch.nn.Module,
|
||||||
|
ophook_list: List[BaseOpHook] = None,
|
||||||
|
name: str = "",
|
||||||
|
filter_fn: Optional[Callable] = None):
|
||||||
|
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||||
|
assert isinstance(module, torch.nn.Module)
|
||||||
|
|
||||||
|
# Add hooks for submodules
|
||||||
|
for child_name, child in module.named_children():
|
||||||
|
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
||||||
|
|
||||||
|
# Early return on modules with no parameters.
|
||||||
|
if len(list(module.parameters(recurse=False))) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# return from flitered module
|
||||||
|
if filter_fn is not None and filter_fn(module):
|
||||||
|
return
|
||||||
|
|
||||||
|
if ophook_list is not None:
|
||||||
|
for hook in ophook_list:
|
||||||
|
assert (isinstance(hook, BaseOpHook))
|
||||||
|
|
||||||
|
def _pre_forward_module_hook(submodule, *args):
|
||||||
|
for hook in ophook_list:
|
||||||
|
assert isinstance(submodule, torch.nn.Module)
|
||||||
|
hook.pre_fwd_exec(submodule, *args)
|
||||||
|
|
||||||
|
def _post_forward_module_hook(submodule, *args):
|
||||||
|
for hook in ophook_list:
|
||||||
|
assert isinstance(submodule, torch.nn.Module)
|
||||||
|
hook.post_fwd_exec(submodule, *args)
|
||||||
|
|
||||||
|
def _pre_backward_module_hook(submodule, inputs, output):
|
||||||
|
|
||||||
|
def _run_before_backward_function(submodule):
|
||||||
|
for hook in ophook_list:
|
||||||
|
assert isinstance(submodule, torch.nn.Module)
|
||||||
|
hook.pre_bwd_exec(submodule, inputs, output)
|
||||||
|
|
||||||
|
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||||
|
|
||||||
|
def _post_backward_module_hook(submodule, inputs):
|
||||||
|
|
||||||
|
def _run_after_backward_function(submodule):
|
||||||
|
for hook in ophook_list:
|
||||||
|
assert isinstance(submodule, torch.nn.Module)
|
||||||
|
hook.post_bwd_exec(submodule, inputs)
|
||||||
|
|
||||||
|
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||||
|
module.register_forward_hook(_post_forward_module_hook)
|
||||||
|
|
||||||
|
module.register_forward_hook(_pre_backward_module_hook)
|
||||||
|
module.register_forward_pre_hook(_post_backward_module_hook)
|
|
@ -1,6 +1,5 @@
|
||||||
from .base_shard_strategy import BaseShardStrategy
|
from .base_shard_strategy import BaseShardStrategy
|
||||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||||
from .tensor_shard_strategy import TensorShardStrategy
|
from .tensor_shard_strategy import TensorShardStrategy
|
||||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
|
||||||
|
|
||||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr']
|
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
|
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.shard_utils.commons import get_shard
|
from colossalai.zero.shard_utils.commons import get_shard
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
|
|
@ -8,9 +8,8 @@ import torch.nn as nn
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine.ophooks import register_ophooks_recursively
|
from colossalai.engine.ophooks import register_ophooks_recursively
|
||||||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
from colossalai.zero.utils import ZeroHook
|
||||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||||
from colossalai.engine.gradient_handler.utils import bucket_allreduce
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device, disposable
|
from colossalai.utils import get_current_device, disposable
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
|
@ -18,12 +17,12 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
|
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
|
|
||||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
|
|
|
@ -8,22 +8,6 @@ from colossalai.utils import is_model_parallel_parameter
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
def move_tensor(input_, device):
|
|
||||||
assert device in ['cpu', 'gpu']
|
|
||||||
|
|
||||||
if isinstance(input_, (list, tuple)):
|
|
||||||
for tensor in input_:
|
|
||||||
tensor.data = tensor.data.cpu(
|
|
||||||
) if device == 'cpu' else tensor.data.cuda()
|
|
||||||
elif torch.is_tensor(input_):
|
|
||||||
input_.data = input_.data.cpu(
|
|
||||||
) if device == 'cpu' else tensor.data.cuda()
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"Expected argument 'input_' to be torch.Tensor, list or tuple, but got {type(input_)} "
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def flatten(input_):
|
def flatten(input_):
|
||||||
return _flatten_dense_tensors(input_)
|
return _flatten_dense_tensors(input_)
|
||||||
|
|
||||||
|
@ -51,8 +35,7 @@ def shuffle_by_round_robin(tensor_list, num_partitions):
|
||||||
partition_to_go = tensor_idx % num_partitions
|
partition_to_go = tensor_idx % num_partitions
|
||||||
if partition_to_go not in partitions:
|
if partition_to_go not in partitions:
|
||||||
partitions[partition_to_go] = []
|
partitions[partition_to_go] = []
|
||||||
partitions[partition_to_go].append(dict(tensor=tensor,
|
partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx))
|
||||||
index=tensor_idx))
|
|
||||||
|
|
||||||
partitions_count = len(partitions)
|
partitions_count = len(partitions)
|
||||||
new_tensor_list = []
|
new_tensor_list = []
|
||||||
|
@ -73,9 +56,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
|
||||||
padding = calculate_padding(num_elements, unit_size=unit_size)
|
padding = calculate_padding(num_elements, unit_size=unit_size)
|
||||||
|
|
||||||
if padding > 0:
|
if padding > 0:
|
||||||
pad_tensor = torch.zeros(padding,
|
pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
|
||||||
device=tensor_list[0].device,
|
|
||||||
dtype=tensor_list[0].dtype)
|
|
||||||
padded_tensor_list = tensor_list + [pad_tensor]
|
padded_tensor_list = tensor_list + [pad_tensor]
|
||||||
else:
|
else:
|
||||||
padded_tensor_list = tensor_list
|
padded_tensor_list = tensor_list
|
||||||
|
@ -86,6 +67,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
|
||||||
def is_nccl_aligned(tensor):
|
def is_nccl_aligned(tensor):
|
||||||
return tensor.data_ptr() % 4 == 0
|
return tensor.data_ptr() % 4 == 0
|
||||||
|
|
||||||
|
|
||||||
def get_grad_accumulate_object(tensor):
|
def get_grad_accumulate_object(tensor):
|
||||||
"""
|
"""
|
||||||
Return the AccumulateGrad of the input tensor
|
Return the AccumulateGrad of the input tensor
|
||||||
|
@ -108,10 +90,7 @@ def get_grad_accumulate_object(tensor):
|
||||||
|
|
||||||
|
|
||||||
def split_half_float_double(tensor_list):
|
def split_half_float_double(tensor_list):
|
||||||
dtypes = [
|
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
||||||
"torch.cuda.HalfTensor", "torch.cuda.FloatTensor",
|
|
||||||
"torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"
|
|
||||||
]
|
|
||||||
buckets = []
|
buckets = []
|
||||||
for i, dtype in enumerate(dtypes):
|
for i, dtype in enumerate(dtypes):
|
||||||
bucket = [t for t in tensor_list if t.type() == dtype]
|
bucket = [t for t in tensor_list if t.type() == dtype]
|
||||||
|
@ -120,10 +99,7 @@ def split_half_float_double(tensor_list):
|
||||||
return buckets
|
return buckets
|
||||||
|
|
||||||
|
|
||||||
def reduce_tensor(tensor,
|
def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA):
|
||||||
dtype,
|
|
||||||
dst_rank=None,
|
|
||||||
parallel_mode=ParallelMode.DATA):
|
|
||||||
"""
|
"""
|
||||||
Reduce the tensor in the data parallel process group
|
Reduce the tensor in the data parallel process group
|
||||||
|
|
||||||
|
@ -165,6 +141,7 @@ def reduce_tensor(tensor,
|
||||||
tensor.copy_(tensor_to_reduce)
|
tensor.copy_(tensor_to_reduce)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def has_inf_or_nan(tensor):
|
def has_inf_or_nan(tensor):
|
||||||
try:
|
try:
|
||||||
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||||
|
@ -181,8 +158,7 @@ def has_inf_or_nan(tensor):
|
||||||
raise
|
raise
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
if tensor_sum == float('inf') or tensor_sum == -float(
|
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
|
||||||
'inf') or tensor_sum != tensor_sum:
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -201,11 +177,7 @@ def calculate_global_norm_from_list(norm_list):
|
||||||
return math.sqrt(total_norm)
|
return math.sqrt(total_norm)
|
||||||
|
|
||||||
|
|
||||||
def compute_norm(gradients,
|
def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
||||||
params,
|
|
||||||
dp_group,
|
|
||||||
mp_group,
|
|
||||||
norm_type=2):
|
|
||||||
"""Clips gradient norm of an iterable of parameters.
|
"""Clips gradient norm of an iterable of parameters.
|
||||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||||
added functionality to handle model parallel parameters. Note that
|
added functionality to handle model parallel parameters. Note that
|
||||||
|
@ -229,14 +201,11 @@ def compute_norm(gradients,
|
||||||
if norm_type == inf:
|
if norm_type == inf:
|
||||||
total_norm = max(g.data.abs().max() for g in gradients)
|
total_norm = max(g.data.abs().max() for g in gradients)
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
dist.all_reduce(total_norm_cuda,
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
||||||
op=torch.distributed.ReduceOp.MAX,
|
|
||||||
group=dp_group)
|
|
||||||
|
|
||||||
# Take max across all GPUs.
|
# Take max across all GPUs.
|
||||||
if mp_group is not None:
|
if mp_group is not None:
|
||||||
dist.all_reduce(tensor=total_norm_cuda,
|
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
||||||
op=torch.distributed.ReduceOp.MAX)
|
|
||||||
total_norm = total_norm_cuda[0].item()
|
total_norm = total_norm_cuda[0].item()
|
||||||
else:
|
else:
|
||||||
total_norm = 0.0
|
total_norm = 0.0
|
||||||
|
@ -251,18 +220,14 @@ def compute_norm(gradients,
|
||||||
|
|
||||||
# Sum across all model parallel GPUs.
|
# Sum across all model parallel GPUs.
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
torch.distributed.all_reduce(total_norm_cuda,
|
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
||||||
op=torch.distributed.ReduceOp.SUM,
|
|
||||||
group=dp_group)
|
|
||||||
|
|
||||||
if mp_group is not None:
|
if mp_group is not None:
|
||||||
dist.all_reduce(tensor=total_norm_cuda,
|
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
|
||||||
op=torch.distributed.ReduceOp.SUM)
|
|
||||||
|
|
||||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||||
|
|
||||||
if total_norm == float(
|
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||||||
'inf') or total_norm == -float('inf') or total_norm != total_norm:
|
|
||||||
total_norm = -1
|
total_norm = -1
|
||||||
|
|
||||||
return total_norm
|
return total_norm
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||||
colo_tensor_mem_usage)
|
colo_tensor_mem_usage)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||||
|
|
|
@ -1,4 +1,11 @@
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
|
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
|
||||||
|
colo_model_data_move_to_cpu, colo_model_tensor_clone,
|
||||||
|
colo_tensor_mem_usage)
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
|
||||||
|
|
||||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
__all__ = [
|
||||||
|
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
|
||||||
|
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
|
||||||
|
]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage
|
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
|
||||||
from .tensorful_state import StatefulTensor, TensorState
|
from .tensorful_state import StatefulTensor, TensorState
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||||
|
from .zero_hook import ZeroHook
|
||||||
|
|
||||||
|
__all__ = ['StatefulTensorMgr', 'ZeroHook']
|
|
@ -4,7 +4,7 @@ import types
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
|
@ -3,15 +3,16 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.registry import OPHOOKS
|
from colossalai.registry import OPHOOKS
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
|
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
|
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||||
|
|
||||||
from ._base_ophook import BaseOpHook
|
from colossalai.engine.ophooks import BaseOpHook
|
||||||
|
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
|
|
||||||
|
|
||||||
|
|
||||||
@OPHOOKS.register_module
|
@OPHOOKS.register_module
|
|
@ -1,4 +1,4 @@
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_on_exception
|
from colossalai.testing import rerun_on_exception
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone
|
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
|
||||||
|
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||||
|
colo_model_tensor_clone)
|
||||||
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
|
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
|
||||||
import colossalai
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
|
@ -30,8 +30,7 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
with ZeroInitContext(
|
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
|
||||||
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(get_current_device()),
|
|
||||||
shard_strategy=shard_strategy,
|
shard_strategy=shard_strategy,
|
||||||
shard_param=True):
|
shard_param=True):
|
||||||
zero_model = model_builder(checkpoint=True)
|
zero_model = model_builder(checkpoint=True)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||||
from colossalai.zero.shard_utils import StatefulTensorMgr
|
from colossalai.zero.utils import StatefulTensorMgr
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
|
Loading…
Reference in New Issue