[refactor] zero directory (#724)

pull/729/head
Jiarui Fang 2022-04-11 23:13:02 +08:00 committed by GitHub
parent 20ab1f5520
commit 4d90a7b513
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 196 additions and 223 deletions

View File

@ -10,6 +10,8 @@ from .torch_amp import convert_to_torch_amp
from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
"""A helper function to wrap training components with Torch AMP modules.

View File

@ -1,119 +1,3 @@
from typing import List, Callable, Optional
from .utils import register_ophooks_recursively, BaseOpHook
import torch
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
from ._shard_grad_ophook import ShardGradHook
from ._shard_param_ophook import ShardParamHook
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
# Add hooks for submodules
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
# return from flitered module
if filter_fn is not None and filter_fn(module):
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)
__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]

View File

@ -1,30 +0,0 @@
from abc import ABC, abstractmethod
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass

View File

@ -0,0 +1,142 @@
import torch
from typing import List, Callable, Optional
from abc import ABC, abstractmethod
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
# Add hooks for submodules
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
# return from flitered module
if filter_fn is not None and filter_fn(module):
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)

View File

@ -1,6 +1,5 @@
from .base_shard_strategy import BaseShardStrategy
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy
from .stateful_tensor_mgr import StatefulTensorMgr
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr']
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor

View File

@ -8,9 +8,8 @@ import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.zero.utils import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.engine.gradient_handler.utils import bucket_allreduce
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
@ -18,12 +17,12 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)

View File

@ -8,22 +8,6 @@ from colossalai.utils import is_model_parallel_parameter
import torch.distributed as dist
def move_tensor(input_, device):
assert device in ['cpu', 'gpu']
if isinstance(input_, (list, tuple)):
for tensor in input_:
tensor.data = tensor.data.cpu(
) if device == 'cpu' else tensor.data.cuda()
elif torch.is_tensor(input_):
input_.data = input_.data.cpu(
) if device == 'cpu' else tensor.data.cuda()
else:
raise TypeError(
f"Expected argument 'input_' to be torch.Tensor, list or tuple, but got {type(input_)} "
)
def flatten(input_):
return _flatten_dense_tensors(input_)
@ -51,8 +35,7 @@ def shuffle_by_round_robin(tensor_list, num_partitions):
partition_to_go = tensor_idx % num_partitions
if partition_to_go not in partitions:
partitions[partition_to_go] = []
partitions[partition_to_go].append(dict(tensor=tensor,
index=tensor_idx))
partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx))
partitions_count = len(partitions)
new_tensor_list = []
@ -73,9 +56,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
padding = calculate_padding(num_elements, unit_size=unit_size)
if padding > 0:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
else:
padded_tensor_list = tensor_list
@ -86,6 +67,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
def is_nccl_aligned(tensor):
return tensor.data_ptr() % 4 == 0
def get_grad_accumulate_object(tensor):
"""
Return the AccumulateGrad of the input tensor
@ -108,10 +90,7 @@ def get_grad_accumulate_object(tensor):
def split_half_float_double(tensor_list):
dtypes = [
"torch.cuda.HalfTensor", "torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"
]
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensor_list if t.type() == dtype]
@ -120,10 +99,7 @@ def split_half_float_double(tensor_list):
return buckets
def reduce_tensor(tensor,
dtype,
dst_rank=None,
parallel_mode=ParallelMode.DATA):
def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA):
"""
Reduce the tensor in the data parallel process group
@ -165,6 +141,7 @@ def reduce_tensor(tensor,
tensor.copy_(tensor_to_reduce)
return tensor
def has_inf_or_nan(tensor):
try:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
@ -181,8 +158,7 @@ def has_inf_or_nan(tensor):
raise
return True
else:
if tensor_sum == float('inf') or tensor_sum == -float(
'inf') or tensor_sum != tensor_sum:
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
return True
return False
@ -201,11 +177,7 @@ def calculate_global_norm_from_list(norm_list):
return math.sqrt(total_norm)
def compute_norm(gradients,
params,
dp_group,
mp_group,
norm_type=2):
def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
@ -229,14 +201,11 @@ def compute_norm(gradients,
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=dp_group)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
# Take max across all GPUs.
if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.MAX)
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
@ -248,21 +217,17 @@ def compute_norm(gradients,
if is_model_parallel_parameter(p) or mp_rank == 0:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=dp_group)
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm

View File

@ -12,8 +12,8 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan

View File

@ -1,4 +1,11 @@
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
colo_model_data_move_to_cpu, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
__all__ = ['ShardedTensor', 'ShardedParamV2']
__all__ = [
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
]

View File

@ -1,7 +1,7 @@
import torch
from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState
from typing import List

View File

@ -0,0 +1,4 @@
from .stateful_tensor_mgr import StatefulTensorMgr
from .zero_hook import ZeroHook
__all__ = ['StatefulTensorMgr', 'ZeroHook']

View File

@ -4,7 +4,7 @@ import types
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Dict, List

View File

@ -3,15 +3,16 @@ from typing import Optional
import torch
import torch.distributed as dist
from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from ._base_ophook import BaseOpHook
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.engine.ophooks import BaseOpHook
@OPHOOKS.register_module

View File

@ -1,4 +1,4 @@
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.zero.sharded_param import ShardedTensor

View File

@ -1,11 +1,12 @@
import pytest
import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
from colossalai.utils import free_port
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
import colossalai
import torch

View File

@ -30,10 +30,9 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(get_current_device()),
shard_strategy=shard_strategy,
shard_param=True):
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(
zero_model,

View File

@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.shard_utils import StatefulTensorMgr
from colossalai.zero.utils import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port