mirror of https://github.com/hpcaitech/ColossalAI
Feature/zero (#279)
* add zero1 (#209) * add zero1 * add test zero1 * update zero stage 1 develop (#212) * Implement naive zero3 (#240) * naive zero3 works well * add zero3 param manager * add TODOs in comments * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * fix bugs of hook and add unit tests (#252) * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * polish code and add state dict hook * fix bug * update unit test * refactor reconstructed zero code * clip_grad support zero3 and add unit test * add unit test for Zero3ParameterManager * [WIP] initialize the shard param class * [WIP] Yet another sharded model implementation (#274) * [WIP] initialize the shard param class * [WIP] Yes another implementation of shardModel. Using a better hook method. * torch.concat -> torch.cat * fix test_zero_level_1.py::test_zero_level_1 unitest * remove deepspeed implementation and refactor for the reconstructed zero module * polish zero dp unittests Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com>pull/394/head
parent
08eccfe681
commit
5a560a060a
|
@ -13,4 +13,4 @@ class ZeROGradientHandler(BaseGradientHandler):
|
|||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
self._optimizer.allreduce_gradients()
|
||||
self._optimizer.sync_grad()
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from ._base_ophook import BaseOpHook
|
||||
from ._memtracer_ophook import MemTracerOpHook
|
||||
from ._shard_param_ophook import ShardParamHook
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
|
||||
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook"]
|
||||
|
||||
|
||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||
|
|
|
@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.logging import get_dist_logger
|
||||
from time import sleep, time
|
||||
import psutil
|
||||
import pickle
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
import torch
|
||||
from . import BaseOpHook
|
||||
from colossalai.registry import OPHOOKS
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ShardParamHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param before and afther FWD and BWD operator executing.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def niter(self):
|
||||
return self._niter
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
||||
|
|
@ -12,8 +12,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.zero import ShardedOptimizer, ShardedModel
|
||||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
@ -91,9 +90,10 @@ class PipelineSchedule(BaseSchedule):
|
|||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def pre_processing(self, engine):
|
||||
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
if isinstance(engine.optimizer, ShardedOptimizer) or isinstance(engine.model, ShardedModel):
|
||||
raise TypeError(
|
||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||
"Pipeline schedule is currently not compatible with ZeRO"
|
||||
)
|
||||
model = engine.model
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
|
|
|
@ -2,30 +2,31 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
import os
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
import pprint
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union, Optional, Tuple, List, Dict
|
||||
|
||||
from colossalai.amp import convert_to_amp, AMP_TYPE
|
||||
from colossalai.context import Config, ParallelMode, ConfigException
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||
sync_model_param, is_using_ddp, is_using_pp, is_using_sequence)
|
||||
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.amp import AMP_TYPE, convert_to_amp
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||
is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
@ -332,8 +333,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
# 1. if optimizer is ZERO, then use zero grad handler
|
||||
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
if isinstance(optimizer, ShardedOptimizer):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
|
@ -348,7 +348,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_sequence():
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()])
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
device_ids=[torch.cuda.current_device()])
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
||||
|
@ -393,7 +394,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||
|
||||
# check if optimizer is ColossalaiOptimizer
|
||||
if not isinstance(optimizer, (ColossalaiOptimizer, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)):
|
||||
optimizer = ColossalaiOptimizer(optim=optimizer)
|
||||
|
||||
# gradient accumulation
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from .activation_checkpoint import checkpoint
|
||||
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
|
||||
is_using_ddp, is_using_pp, is_using_sequence, model_branch_context, multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param)
|
||||
|
||||
from .common import (clip_grad_norm_fp32, conditional_context,
|
||||
copy_tensor_parallel_attributes, count_zeros_fp32,
|
||||
free_port, is_dp_rank_0, is_model_parallel_parameter,
|
||||
is_moe_parallel_parameter, is_no_pp_or_last_stage,
|
||||
is_tp_rank_0, is_using_ddp, is_using_pp,
|
||||
is_using_sequence, multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate, print_rank_0,
|
||||
switch_virtual_pipeline_parallel_rank, sync_model_param)
|
||||
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .gradient_accumulation import accumulate_gradient
|
||||
|
@ -12,9 +16,9 @@ from .timer import MultiTimer, Timer
|
|||
|
||||
__all__ = [
|
||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'model_branch_context',
|
||||
'conditional_context', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32',
|
||||
'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize',
|
||||
'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier',
|
||||
'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
|
||||
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
|
||||
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter'
|
||||
]
|
||||
|
|
|
@ -2,9 +2,12 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
import random
|
||||
import socket
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
|
@ -14,7 +17,8 @@ except:
|
|||
from contextlib import contextmanager
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS,
|
||||
TENSOR_PARALLEL_ATTRIBUTES)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import moe_env
|
||||
|
@ -134,6 +138,10 @@ def _calc_lp(grads, norm_type):
|
|||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != 'cuda':
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
# ======== Gradient Clipping =========
|
||||
|
||||
|
@ -163,17 +171,27 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
params = []
|
||||
params: List[Parameter] = []
|
||||
has_zero_shared_param: bool = False
|
||||
for param in parameters:
|
||||
if param.grad is not None:
|
||||
# Make sure the grads are in fp32
|
||||
assert param.grad.type() == 'torch.cuda.FloatTensor', \
|
||||
f'expected gradient to be dtype torch.cuda.FloatTensor, but got {param.grad.type()}'
|
||||
assert param.grad.dtype == torch.float, \
|
||||
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
|
||||
if hasattr(param, 'zero_is_sharded'):
|
||||
has_zero_shared_param = True
|
||||
params.append(param)
|
||||
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Parameters can be on CPU or CUDA
|
||||
# If parameters are on CPU, disable CUDA kernerls
|
||||
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(p.grad.data.abs().max() for p in params)
|
||||
|
@ -184,28 +202,49 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
async_op=False)
|
||||
if has_zero_shared_param:
|
||||
dist.all_reduce(total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.DATA),
|
||||
async_op=False)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
no_tensor_parallel_grads = []
|
||||
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
|
||||
zero_sharded_grads = []
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
|
||||
tensor_parallel_grads.append(p.grad.data / reductor)
|
||||
elif is_moe_parallel_parameter(p):
|
||||
moe_parallel_grads.append(p.grad.data)
|
||||
elif hasattr(p, 'zero_is_sharded'):
|
||||
zero_sharded_grads.append(p.grad.data)
|
||||
else:
|
||||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
|
||||
if norm_type == 2.0:
|
||||
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = _calc_l2_norm(
|
||||
tensor_parallel_grads) ** norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(
|
||||
no_tensor_parallel_grads) ** norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(
|
||||
moe_parallel_grads) ** norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
|
||||
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
|
||||
zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
|
||||
|
||||
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
|
||||
if not enable_cuda_kernels:
|
||||
tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
|
||||
no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
|
||||
moe_parallel_norm = _move_norm_to_cuda(moe_parallel_norm)
|
||||
zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
|
||||
dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
|
@ -213,20 +252,32 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
if len(moe_parallel_grads) > 0:
|
||||
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
||||
no_tensor_parallel_norm += moe_parallel_norm
|
||||
# Sum across all zero sharded GPUs
|
||||
if len(zero_sharded_grads) > 0:
|
||||
dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))
|
||||
no_tensor_parallel_norm += zero_sharded_norm
|
||||
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm**(1.0 / norm_type)
|
||||
if type(total_norm) == 'torch.cuda.FloatTensor':
|
||||
dist.all_reduce(total_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
clip_coeff = max_norm / (total_norm + 1.0e-6)
|
||||
if clip_coeff < 1.0:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
||||
|
||||
if enable_cuda_kernels:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale,
|
||||
dummy_overflow_buf,
|
||||
[grads, grads],
|
||||
clip_coeff)
|
||||
else:
|
||||
for p in params:
|
||||
p.grad.detach().mul_(clip_coeff)
|
||||
return total_norm
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
from distutils.command.config import config
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
|
||||
from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch.optim import Optimizer
|
||||
from .sharded_model import ShardedModel
|
||||
from .sharded_optim import ShardedOptimizer
|
||||
|
||||
|
||||
def convert_to_zero(model: nn.Module,
|
||||
|
@ -29,82 +28,14 @@ def convert_to_zero(model: nn.Module,
|
|||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
import deepspeed
|
||||
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
|
||||
if level == 2:
|
||||
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
|
||||
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
|
||||
if level in [1, 2]:
|
||||
if level == 2:
|
||||
assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True'
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
optimizer = ShardedOptimizer(model.parameters(), *zero_config)
|
||||
else:
|
||||
optimizer = ZeroRedundancyOptimizer_Level_3(init_optimizer=optimizer, module=model, **zero_config)
|
||||
model = ShardedModel(module=model, **zero_config)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def zero3_model_context(dtype=torch.half):
|
||||
"""A context to enable massive model construction for training with
|
||||
ZeRO-3. Models are automatically partitioned (or, sharded) across the
|
||||
system and converted to half precision. Note that the config of ZeRO-3 will be loaded automatically from `gpc.config`.
|
||||
|
||||
Args:
|
||||
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
|
||||
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``torch.half``
|
||||
|
||||
This context accelerates model initialization and enables models that
|
||||
are too large to allocate in their entirety in CPU memory. It has the
|
||||
following effects:
|
||||
|
||||
#. allocates tensors to either GPU or CPU memory or NVMe
|
||||
#. converts floating point tensors to half precision
|
||||
#. immediately partitions tensors among the group of data-parallel devices
|
||||
#. (*optional*) replaces ``torch.nn.functional.linear`` with a more
|
||||
memory-efficient implementation
|
||||
|
||||
These modifications allow for models that exceed the size of local CPU/GPU
|
||||
memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
|
||||
or GPU memory or NVMe) across all nodes. Consider initializing a model with one
|
||||
trillion parameters, whose weights occupy two terabytes (TB) in half
|
||||
precision. The initial CPU allocation in full precision requires 4TB of
|
||||
memory *per process*, and so a system with 8 GPUs per node would need 32TB of
|
||||
CPU memory due to data-parallel redundancies. Instead, by immediately
|
||||
partitioning tensors we remove the redundancies. The result is that
|
||||
regardless of the number of GPUs, we still only require the original 4TB. This
|
||||
allows for a linear increase in model size with the aggregate system memory.
|
||||
For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
|
||||
parameter model with 4 nodes and 32 GPUs.
|
||||
|
||||
Important: If the fp16 weights of the model can't fit onto a single GPU memory
|
||||
this feature must be used.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
#. Allocate a model and partition it among all processes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with zero3_model_context():
|
||||
model = MyLargeModel()
|
||||
|
||||
"""
|
||||
assert dtype == torch.half or dtype == torch.float, f'Invalid dtype, except torch.half or torch.float, got {dtype}'
|
||||
import deepspeed
|
||||
ds_config = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"zero_optimization": {
|
||||
"offload_param": getattr(gpc.config.zero, 'offload_param_config', None),
|
||||
"offload_optimizer": getattr(gpc.config.zero, 'offload_optimizer_config'),
|
||||
},
|
||||
"aio": getattr(gpc.config.zero, 'aio_config', None)
|
||||
}
|
||||
remote_device = getattr(ds_config['zero_optimization']['offload_param'], 'device', None)
|
||||
pin_memory = getattr(ds_config['zero_optimization']['offload_param'], 'pin_memory', False)
|
||||
return deepspeed.zero.Init(data_parallel_group=gpc.get_group(ParallelMode.DATA),
|
||||
remote_device=remote_device,
|
||||
config_dict_or_path=ds_config,
|
||||
pin_memory=pin_memory,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero', 'ZeroRedundancyOptimizer_Level_2',
|
||||
'ZeroRedundancyOptimizer_Level_3', 'zero3_model_context']
|
||||
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
|
||||
|
|
|
@ -1,169 +0,0 @@
|
|||
# Copyright 2019 The Microsoft DeepSpeed Team
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Taken and modified for DeepSpeed from:
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
|
||||
# Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
|
||||
|
||||
|
||||
INITIAL_LOSS_SCALE = 'init_scale'
|
||||
SCALE_WINDOW = 'scale_window'
|
||||
DELAYED_SHIFT = 'delayed_shift'
|
||||
MIN_LOSS_SCALE = 'min_scale'
|
||||
|
||||
|
||||
# item() is a recent addition, so this helps with backward compatibility.
|
||||
def to_python_float(t):
|
||||
if hasattr(t, 'item'):
|
||||
return t.item()
|
||||
return t[0]
|
||||
|
||||
|
||||
class LossScalerBase:
|
||||
"""LossScalarBase
|
||||
Base class for a loss scaler
|
||||
"""
|
||||
|
||||
def __init__(self, cur_scale):
|
||||
self.cur_scale = cur_scale
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.cur_scale
|
||||
|
||||
def scale_gradient(self, module, grad_in, grad_out):
|
||||
return tuple(self.loss_scale * g for g in grad_in)
|
||||
|
||||
def update_scale(self, overflow):
|
||||
pass
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
scaled_loss = loss * self.loss_scale
|
||||
scaled_loss.backward(retain_graph=retain_graph)
|
||||
|
||||
|
||||
class LossScaler(LossScalerBase):
|
||||
"""
|
||||
Class that manages a static loss scale. This class is intended to interact with
|
||||
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
|
||||
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
|
||||
:class:`FP16_Optimizer`'s constructor.
|
||||
Args:
|
||||
scale (float, optional, default=1.0): The loss scale.
|
||||
"""
|
||||
|
||||
def __init__(self, scale=1):
|
||||
super(LossScaler, self).__init__(scale)
|
||||
|
||||
# `params` is a list / generator of torch.Variable
|
||||
def has_overflow(self, params):
|
||||
return False
|
||||
|
||||
# `x` is a torch.Tensor
|
||||
def _has_inf_or_nan(x):
|
||||
return False
|
||||
|
||||
|
||||
class DynamicLossScaler(LossScalerBase):
|
||||
"""
|
||||
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
|
||||
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
|
||||
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
|
||||
operates, because the default options can be changed using the
|
||||
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
|
||||
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
|
||||
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
|
||||
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
|
||||
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
|
||||
occurred.
|
||||
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
|
||||
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
|
||||
If a certain number of iterations occur without overflowing gradients detected,
|
||||
:class:`DynamicLossScaler` increases the loss scale once more.
|
||||
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
|
||||
always using the highest loss scale possible without incurring overflow.
|
||||
Args:
|
||||
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
|
||||
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is
|
||||
encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive
|
||||
iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
|
||||
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before
|
||||
increasing the loss scale.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_scale=2 ** 32,
|
||||
scale_factor=2.,
|
||||
scale_window=1000,
|
||||
min_scale=1,
|
||||
delayed_shift=1,
|
||||
consecutive_hysteresis=False):
|
||||
super(DynamicLossScaler, self).__init__(init_scale)
|
||||
self.cur_iter = 0
|
||||
self.last_overflow_iter = -1
|
||||
self.scale_factor = scale_factor
|
||||
self.scale_window = scale_window
|
||||
self.min_scale = min_scale
|
||||
self.delayed_shift = delayed_shift
|
||||
self.cur_hysteresis = delayed_shift
|
||||
self.consecutive_hysteresis = consecutive_hysteresis
|
||||
|
||||
# `params` is a list / generator of torch.Variable
|
||||
def has_overflow_serial(self, params):
|
||||
for p in params:
|
||||
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# `x` is a torch.Tensor
|
||||
@staticmethod
|
||||
def _has_inf_or_nan(x):
|
||||
try:
|
||||
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||
# Pytorch's .sum() creates a one-element tensor of the same type as x
|
||||
# (which is true for some recent version of pytorch).
|
||||
cpu_sum = float(x.float().sum())
|
||||
# More efficient version that can be used if .sum() returns a Python scalar
|
||||
# cpu_sum = float(x.sum())
|
||||
except RuntimeError as instance:
|
||||
# We want to check if inst is actually an overflow exception.
|
||||
# RuntimeError could come from a different error.
|
||||
# If so, we still want the exception to propagate.
|
||||
if "value cannot be converted" not in instance.args[0]:
|
||||
raise
|
||||
return True
|
||||
else:
|
||||
if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
|
||||
return True
|
||||
return False
|
||||
|
||||
# `overflow` is boolean indicating whether the gradient overflowed
|
||||
def update_scale(self, overflow):
|
||||
if overflow:
|
||||
# self.cur_scale /= self.scale_factor
|
||||
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
|
||||
self.cur_scale = max(
|
||||
self.cur_scale / self.scale_factor, self.min_scale)
|
||||
else:
|
||||
self.cur_hysteresis -= 1
|
||||
self.last_overflow_iter = self.cur_iter
|
||||
else:
|
||||
if self.consecutive_hysteresis:
|
||||
self.cur_hysteresis = self.delayed_shift
|
||||
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
|
||||
if not self.consecutive_hysteresis:
|
||||
self.cur_hysteresis = self.delayed_shift
|
||||
self.cur_scale *= self.scale_factor
|
||||
self.cur_iter += 1
|
|
@ -0,0 +1,3 @@
|
|||
from .shard_param import ShardParam
|
||||
|
||||
__all__ = ['ShardParam']
|
|
@ -0,0 +1,63 @@
|
|||
from enum import Enum
|
||||
from optparse import Option
|
||||
import torch
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.distributed as dist
|
||||
|
||||
class TensorType(Enum):
|
||||
GRAD = 1
|
||||
DATA = 2
|
||||
|
||||
class ShardParam(object):
|
||||
r"""
|
||||
A wrapper to torch.nn.Parameter. Shard a param
|
||||
on different processes.
|
||||
"""
|
||||
def __init__(self,
|
||||
param: torch.nn.Parameter,
|
||||
tensor_type: TensorType = TensorType.DATA,
|
||||
process_group = None,
|
||||
) -> None:
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
|
||||
self._payload_numel = None
|
||||
self._origin_shape = param.shape
|
||||
self._origin_numel = param.numel()
|
||||
self.is_shared = False
|
||||
|
||||
def payload(self, target_device : torch.device):
|
||||
return self._param_payload.to(target_device)
|
||||
|
||||
def shard(self):
|
||||
r"""
|
||||
Distributed the payload of param to all processes.
|
||||
"""
|
||||
if self.is_shared:
|
||||
return
|
||||
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
|
||||
self.is_shared = True
|
||||
|
||||
def gather(self):
|
||||
r"""
|
||||
Collect the payload of param from different processes to process of local rank.
|
||||
"""
|
||||
if not self.is_shared:
|
||||
return
|
||||
|
||||
buffer_list = []
|
||||
payload_numel = self._param_payload.numel()
|
||||
for i in range(self.world_size):
|
||||
if i == self.local_rank:
|
||||
buffer_list.append(self._param_payload.cuda())
|
||||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel).cuda())
|
||||
|
||||
torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group, async_op=False)
|
||||
print(buffer_list)
|
||||
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
||||
self.is_shared = False
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from .sharded_model import ShardedModel
|
||||
from .sharded_model_v2 import ShardedModelV2
|
||||
|
||||
__all__ = ['ShardedModel', 'ShardedModelV2']
|
|
@ -0,0 +1,124 @@
|
|||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_gradient_predivide_factor(world_size: int) -> float:
|
||||
factor: int = 1
|
||||
while world_size % factor == 0 and world_size / factor > factor:
|
||||
factor *= 2
|
||||
return float(factor)
|
||||
|
||||
|
||||
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""Return the local shard of a full tensor."""
|
||||
# Shard using torch.chunk to match all-gather/reduce-scatter.
|
||||
chunks = list(torch.flatten(tensor).chunk(world_size))
|
||||
while len(chunks) < world_size:
|
||||
chunks.append(chunks[0].new_empty(0))
|
||||
|
||||
# Determine number of padding elements.
|
||||
num_to_pad = chunks[0].numel() - chunks[rank].numel()
|
||||
assert num_to_pad >= 0, num_to_pad
|
||||
|
||||
shard = chunks[rank].clone()
|
||||
if num_to_pad > 0:
|
||||
shard = F.pad(shard, [0, num_to_pad])
|
||||
return shard, num_to_pad
|
||||
|
||||
|
||||
def free_storage(data: torch.Tensor) -> None:
|
||||
"""Free underlying storage of a Tensor."""
|
||||
if data.storage().size() > 0:
|
||||
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
||||
# is the sole occupant of the Storage.
|
||||
assert data.storage_offset() == 0
|
||||
data.storage().resize_(0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
|
||||
"""Allocate storage for a tensor."""
|
||||
if data.storage().size() == size.numel(): # no need to reallocate
|
||||
return
|
||||
assert data.storage().size() == 0
|
||||
data.storage().resize_(size.numel())
|
||||
|
||||
|
||||
def cast_trensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if tensor.dtype is torch.float32:
|
||||
out = tensor.half()
|
||||
if tensor.is_leaf:
|
||||
out.requires_grad = tensor.requires_grad
|
||||
return out
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_trensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if tensor.dtype is torch.float16:
|
||||
out = tensor.float()
|
||||
if tensor.is_leaf:
|
||||
out.requires_grad = tensor.requires_grad
|
||||
return out
|
||||
return tensor
|
||||
|
||||
|
||||
def apply_to_tensors(x: Any, fn: Callable):
|
||||
if torch.is_tensor(x):
|
||||
return fn(x)
|
||||
elif isinstance(x, list):
|
||||
return [apply_to_tensors(t, fn) for t in x]
|
||||
elif isinstance(x, tuple):
|
||||
return tuple(apply_to_tensors(t, fn) for t in x)
|
||||
elif isinstance(x, dict):
|
||||
return {key: apply_to_tensors(val, fn) for key, val in x.items()}
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def cast_float_arguments(fn: Callable, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
|
||||
return apply_to_tensors(args, fn), apply_to_tensors(kwargs, fn)
|
||||
|
||||
|
||||
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
|
||||
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
|
||||
chunks = list(torch.flatten(tensor).chunk(num_chunks))
|
||||
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
|
||||
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
|
||||
if num_pad_for_partial_chunk > 0:
|
||||
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
|
||||
if len(chunks) < num_chunks:
|
||||
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
|
||||
return chunks
|
||||
|
||||
|
||||
def assert_in_engine(cond: Any, s: Any) -> None:
|
||||
"""Used in backward context to make sure error is printed."""
|
||||
if not cond:
|
||||
print(s)
|
||||
raise AssertionError
|
||||
|
||||
|
||||
def replace_state_dict_prefix(
|
||||
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], old_prefix: str, new_prefix: str
|
||||
) -> None:
|
||||
"""
|
||||
Replace all keys that match a given old_prefix with a new_prefix (in-place).
|
||||
|
||||
Usage::
|
||||
|
||||
state_dict = {"layer.xyz": torch.tensor(1)}
|
||||
replace_state_dict_prefix(state_dict, "layer.", "module.layer.")
|
||||
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
|
||||
"""
|
||||
if old_prefix == new_prefix:
|
||||
raise ValueError("old_prefix and new_prefix must be distinct")
|
||||
for key in list(state_dict.keys()):
|
||||
if not key.startswith(old_prefix):
|
||||
continue
|
||||
new_key = new_prefix + key[len(old_prefix):]
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
|
@ -0,0 +1,385 @@
|
|||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._zero3_utils import alloc_storage, free_storage, get_shard
|
||||
|
||||
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
|
||||
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
|
||||
enable_nccl_base_collectives = False
|
||||
else:
|
||||
enable_nccl_base_collectives = True
|
||||
|
||||
# TODO: add flatten params
|
||||
|
||||
|
||||
class Zero3ParameterManager:
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup],
|
||||
mixed_precision: bool = False,
|
||||
flatten_parameters: bool = True,
|
||||
compute_dtype: Optional[torch.dtype] = None,
|
||||
compute_device: Optional[torch.device] = None,
|
||||
offload_config: Optional[dict] = None
|
||||
) -> None:
|
||||
"""Manage parameter shards. We manage several attributes on each Parameter instance:
|
||||
``zero_is_sharded``: ``True`` if the Parameter is sharded or ``False``
|
||||
if the Parameter is intentionally not sharded (in which case we
|
||||
will all-reduce grads for this param).
|
||||
``zero_orig_size``: the size of the original Parameter (before sharding)
|
||||
``zero_shard_padding``: the padding size. All paddings are right padding.
|
||||
``zero_fp32_shard``: a single shard of the parameters in full precision
|
||||
(typically FP32, but this is dependent on the dtype of the model
|
||||
as it's passed in by the user). This can be on CPU or GPU
|
||||
depending on the value of *``offload_config``*.
|
||||
``zero_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
|
||||
This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
|
||||
if params are offloaded to CPU.
|
||||
``zero_full_param_padded``: the full weight (padded to be evenly
|
||||
divisible by ``world_size``), used for computation in the
|
||||
forward and backward pass. This will be resized in place and
|
||||
only materialized (via all-gather) as needed.
|
||||
``zero_cpu_grad``: the gradient saved on CPU. It's set only when using CPU offload.
|
||||
|
||||
:param module: original module
|
||||
:type module: nn.Module
|
||||
:param process_group: typically data parallel process group, defaults to None
|
||||
:type process_group: Optional[ProcessGroup], optional
|
||||
:param mixed_precision: whether to use mixed precision mode, defaults to False
|
||||
:type mixed_precision: bool, optional
|
||||
:param flatten_parameters: whether to flatten parameters, useless now, defaults to True
|
||||
:type flatten_parameters: bool, optional
|
||||
:param compute_dtype: the dtype of parameters when computing, defaults to None
|
||||
:type compute_dtype: Optional[torch.dtype], optional
|
||||
:param compute_device: the device of parameters when computing, defaults to None
|
||||
:type compute_device: Optional[torch.device], optional
|
||||
:param offload_config: offload config, defaults to None
|
||||
:type offload_config: Optional[dict], optional
|
||||
"""
|
||||
self.process_group = process_group
|
||||
self.shard_idx = process_group.rank()
|
||||
self.num_shards = process_group.size()
|
||||
self.mixed_precision = mixed_precision
|
||||
self.compute_dtype = compute_dtype
|
||||
self.compute_device = compute_device
|
||||
self.offload_config = offload_config
|
||||
|
||||
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
|
||||
self.params: List[Parameter] = []
|
||||
for param in module.parameters():
|
||||
if not hasattr(param, 'zero_is_sharded'):
|
||||
self.params.append(param)
|
||||
|
||||
self._has_params = len(self.params) > 0
|
||||
self._has_sharded_params = False
|
||||
# Flag to indicate if the full params are gathered.
|
||||
self.has_full_params: bool = False
|
||||
|
||||
self._shard_params()
|
||||
# Maybe no need, reserve to prevent bugs
|
||||
# self.delete_fp32_shards()
|
||||
|
||||
self._streams: Dict[str, torch.cuda.Stream] = {}
|
||||
|
||||
def _shard_params(self) -> None:
|
||||
for p in self.params:
|
||||
assert not hasattr(p, "zero_is_sharded")
|
||||
assert p.is_floating_point()
|
||||
if self.mixed_precision:
|
||||
assert p.dtype == torch.float32
|
||||
|
||||
# If world_size is 1, then we all-reduce grads instead of sharding.
|
||||
p.zero_is_sharded = self.num_shards > 1
|
||||
p.zero_orig_size = p.data.size()
|
||||
|
||||
if not p.zero_is_sharded:
|
||||
p.zero_shard_padding = 0
|
||||
continue
|
||||
|
||||
# Replace p.data with the relevant shard.
|
||||
orig_data = p.data
|
||||
p.data, p.zero_shard_padding = get_shard(p.data, self.shard_idx, self.num_shards)
|
||||
free_storage(orig_data)
|
||||
|
||||
@torch.no_grad()
|
||||
def reset_param_attr(self, p: Parameter, training: bool) -> None:
|
||||
"""This should be called by ``ZeroRedundancyLevel3Model._lazy_init()``
|
||||
"""
|
||||
assert hasattr(p, 'zero_is_sharded') and hasattr(p, 'zero_orig_size')
|
||||
if hasattr(p, 'zero_fp32_shard'):
|
||||
return
|
||||
|
||||
# A single shard of the parameters in full precision.
|
||||
p.zero_fp32_shard = p.data
|
||||
|
||||
if self.mixed_precision:
|
||||
assert p.zero_fp32_shard.dtype == torch.float32
|
||||
|
||||
if self._cpu_offload:
|
||||
assert p.zero_fp32_shard.device == torch.device('cpu')
|
||||
# If we plan to keep the FP32 parameters on CPU, then pinning
|
||||
# memory allows us to later use non-blocking transfers when moving
|
||||
# the FP32 param shard to compute_device.
|
||||
p.zero_fp32_shard = p.zero_fp32_shard.pin_memory()
|
||||
p.data = p.zero_fp32_shard
|
||||
|
||||
if self.mixed_precision or self._cpu_offload:
|
||||
|
||||
# In mixed precision mode, we maintain a reduced precision
|
||||
# (typically FP16) parameter shard on compute_device for performing
|
||||
# the computation in the forward/backward pass. We resize the
|
||||
# storage to size 0 at init (here) and re-materialize (by copying
|
||||
# from _fp32_shard) as needed. If offloading params to CPU, the
|
||||
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
|
||||
p.zero_fp16_shard = torch.zeros_like(
|
||||
p.zero_fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
|
||||
free_storage(p.zero_fp16_shard)
|
||||
|
||||
if self.mixed_precision:
|
||||
assert p.zero_fp32_shard.dtype == torch.float32
|
||||
|
||||
if not self.mixed_precision and not self._cpu_offload:
|
||||
# use _fp32_shard if you are not in using mixed precision or
|
||||
# offloading params and grads to CPU.
|
||||
p.zero_fp16_shard = None
|
||||
|
||||
# We also maintain a full-sized parameter of type self.compute_dtype
|
||||
# (FP16 for mixed_precision or FP32 otherwise). We resize the
|
||||
# storage to size 0 at init (here) and only materialize as needed. The
|
||||
# storage may contain padding elements so that it is evenly divisible by
|
||||
# world_size, although these padding elements will be removed before the
|
||||
# relevant computation.
|
||||
if p.zero_is_sharded:
|
||||
p.zero_full_param_padded = torch.zeros(
|
||||
p.data.numel() * self.num_shards, device=self.compute_device, dtype=self.compute_dtype
|
||||
)
|
||||
free_storage(p.zero_full_param_padded)
|
||||
|
||||
if self._cpu_offload and training:
|
||||
p.zero_cpu_grad = torch.zeros_like(p.data, device='cpu').pin_memory()
|
||||
|
||||
def setup_streams(self, streams):
|
||||
self._streams = streams
|
||||
|
||||
@torch.no_grad()
|
||||
def rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
|
||||
"""
|
||||
Gather all shards of params.
|
||||
|
||||
Note, this is idempotent if full params are already gathered. Callers
|
||||
assume the idempotency. So please keep it that way.
|
||||
|
||||
Args:
|
||||
force_full_precision (bool, Optional): by default params will be gathered
|
||||
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
|
||||
``True``, in which case they will be gathered in full precision
|
||||
(e.g., FP32), possibly in fresh storage. The parameter that's being
|
||||
rebuilt will end up in full precision as well.
|
||||
|
||||
Returns:
|
||||
A list of tuples, where the first element is the full-sized param
|
||||
and the second element is a bool indicating if it's safe for the
|
||||
caller to free the full-sized param. This will be ``None`` if
|
||||
``force_full_precision=False`` and the full params are already gathered.
|
||||
"""
|
||||
# Store tensor and free flag
|
||||
output_tensors: List[Tuple[torch.Tensor, bool]] = []
|
||||
|
||||
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
Helper function to update p.data pointer.
|
||||
|
||||
Args:
|
||||
custom_output_tensor (torch.Tensor, Optional): if not None, this
|
||||
tensor contains the data we just gathered.
|
||||
"""
|
||||
if custom_output_tensor is not None:
|
||||
assert p.zero_is_sharded
|
||||
p.data = custom_output_tensor
|
||||
output_tensors.append((p.data, True))
|
||||
elif not p.zero_is_sharded:
|
||||
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
|
||||
assert p.zero_fp16_shard is not None
|
||||
p.data = p.zero_fp16_shard
|
||||
output_tensors.append((p.data, True))
|
||||
else:
|
||||
# Here p.data == p._fp32_shard, so it's not safe to free.
|
||||
output_tensors.append((p.data, False))
|
||||
else:
|
||||
p.data = p.zero_full_param_padded
|
||||
output_tensors.append((p.data, True))
|
||||
# Trim any padding and reshape to match original size.
|
||||
p.data = p.data[: p.zero_orig_size.numel()].view(p.zero_orig_size)
|
||||
|
||||
if self._has_sharded_params:
|
||||
# self.has_full_params flag can be out of sync if a shared param is
|
||||
# sharded by another ZeroRedundancyLevel3Model instance. An example is that in eval case
|
||||
# with reshard_after_forward=False but the sharing instance has
|
||||
# reshard_after_forward=True. Then, on the second forward, the
|
||||
# other instance can shard the shared param and but this instance
|
||||
# can mistakenly think the full param is already gathered from the
|
||||
# has_full_params flag.
|
||||
#
|
||||
# Therefore, we update the flag accordingly here.
|
||||
self.has_full_params = not any(p.zero_full_param_padded.storage().size() == 0 for p in self.params)
|
||||
|
||||
# Early exit if we already have full params and don't need full precision.
|
||||
if self.has_full_params and not force_full_precision:
|
||||
for p in self.params:
|
||||
update_p_data()
|
||||
return output_tensors
|
||||
|
||||
self.has_full_params = True
|
||||
|
||||
with torch.cuda.stream(self._streams["all_gather"]):
|
||||
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
|
||||
self.use_fp16_shards()
|
||||
|
||||
if self._cpu_offload and force_full_precision:
|
||||
# If the compute_dtype and storage dtype are the same,
|
||||
# use pinned memory. Otherwise move p.data to the compute
|
||||
# device.
|
||||
if self.params[0].dtype == self.compute_dtype:
|
||||
self.use_fp16_shards()
|
||||
else:
|
||||
for p in self.params:
|
||||
p.data = p.data.to(self.compute_device)
|
||||
|
||||
for p in self.params:
|
||||
if not p.zero_is_sharded: # e.g., when world_size == 1
|
||||
update_p_data()
|
||||
else:
|
||||
# Skip if already built. Only shared param can be rebuilt multiple times.
|
||||
# A corner case is p.zero_orig_size = (1,), which means the shape equality is
|
||||
# not a perfect check. But we assume we don't share a param with shape (1,).
|
||||
# if p.data.shape == p.zero_orig_size and hasattr(p, "zero_is_shared") and p.zero_is_shared:
|
||||
# continue
|
||||
# If self._cpu_offload and force_full_precision, we need to cast
|
||||
# the FP32 CPU param to CUDA for the all-gather.
|
||||
p_data = p.data.to(p.zero_full_param_padded.device, non_blocking=True)
|
||||
|
||||
p_size = p.zero_full_param_padded.size()
|
||||
assert p_size.numel() % self.num_shards == 0
|
||||
if self.mixed_precision and force_full_precision:
|
||||
# Allocate fresh tensor in full precision since we are in
|
||||
# mixed precision and full precision rebuild is asked.
|
||||
output_tensor = p_data.new_zeros(p_size)
|
||||
else:
|
||||
if p.zero_full_param_padded.storage().size() != p_size.numel():
|
||||
# Allocate based on full size from all shards.
|
||||
alloc_storage(p.zero_full_param_padded, size=p_size)
|
||||
output_tensor = p.zero_full_param_padded
|
||||
|
||||
# Fill output_tensor with (p.data for each shard in self.world_size)
|
||||
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
|
||||
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
|
||||
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
|
||||
else:
|
||||
chunks = list(output_tensor.chunk(self.num_shards))
|
||||
dist.all_gather(chunks, p_data, group=self.process_group)
|
||||
|
||||
# Set p.data = output_tensor (with padding trimmed)
|
||||
update_p_data(output_tensor)
|
||||
|
||||
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
|
||||
self.free_fp16_shards([p])
|
||||
|
||||
if self._cpu_offload and (self.params[0].dtype == self.compute_dtype):
|
||||
self.free_fp16_shards([p])
|
||||
|
||||
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
|
||||
return output_tensors
|
||||
|
||||
@torch.no_grad()
|
||||
def use_full_params(self) -> None:
|
||||
"""
|
||||
Switch p.data pointers to use the full params.
|
||||
|
||||
Note: this assumes full params are already gathered.
|
||||
|
||||
Note: this might be called after full_params is already in used. So please
|
||||
make sure it is idempotent in that case.
|
||||
"""
|
||||
assert self.has_full_params
|
||||
for p in self.params:
|
||||
if not p.zero_is_sharded:
|
||||
if self.mixed_precision or self._cpu_offload:
|
||||
assert p.zero_fp16_shard is not None
|
||||
assert p.zero_fp16_shard.storage().size() != 0
|
||||
p.data = p.zero_fp16_shard
|
||||
else:
|
||||
assert p.zero_full_param_padded.storage().size() != 0, f"{p.zero_orig_size} {id(self)}"
|
||||
p.data = p.zero_full_param_padded[: p.zero_orig_size.numel()].view(p.zero_orig_size)
|
||||
|
||||
@torch.no_grad()
|
||||
def use_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None:
|
||||
"""Cast FP32 param shard to FP16 for a list of params."""
|
||||
if params is None:
|
||||
params = self.params
|
||||
with torch.cuda.stream(self._streams["fp32_to_fp16"]):
|
||||
for p in params:
|
||||
assert p.zero_fp16_shard is not None
|
||||
alloc_storage(p.zero_fp16_shard, size=p.zero_fp32_shard.size())
|
||||
p.zero_fp16_shard.copy_(
|
||||
# If _cpu_offload is True, this will be non-blocking
|
||||
# because _fp32_shard is pinned, otherwise it's a no-op.
|
||||
p.zero_fp32_shard.to(p.zero_fp16_shard.device, non_blocking=True)
|
||||
)
|
||||
p.data = p.zero_fp16_shard
|
||||
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
|
||||
|
||||
@torch.no_grad()
|
||||
def use_fp32_shards(self, params: Optional[List[Parameter]] = None) -> None:
|
||||
"""Use FP32 shard for a list of params."""
|
||||
if params is None:
|
||||
params = self.params
|
||||
for p in params:
|
||||
p.data = p.zero_fp32_shard
|
||||
|
||||
@torch.no_grad()
|
||||
def free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
|
||||
"""Free up storage for full parameters."""
|
||||
if params is None:
|
||||
params = self.params
|
||||
self.has_full_params = False
|
||||
current_stream = torch.cuda.current_stream()
|
||||
for p in params:
|
||||
if not p.zero_is_sharded: # e.g., world_size == 1
|
||||
if self.mixed_precision or self._cpu_offload:
|
||||
self.free_fp16_shards([p])
|
||||
continue
|
||||
# Don't let PyTorch reuse this memory until all work in the current
|
||||
# stream is complete.
|
||||
p.zero_full_param_padded.record_stream(current_stream)
|
||||
# There may be external references to the Tensor Storage that we
|
||||
# can't modify, such as references that are created by
|
||||
# ctx.save_for_backward in the forward pass. Thus when we
|
||||
# unshard parameters, we should reuse the original Tensor
|
||||
# Storage object and unshard it in-place. For now, just resize
|
||||
# the Storage to 0 to save memory.
|
||||
free_storage(p.zero_full_param_padded)
|
||||
|
||||
@torch.no_grad()
|
||||
def free_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None:
|
||||
"""Free storage for FP16 shards for a list of params."""
|
||||
if params is None:
|
||||
params = self.params
|
||||
current_stream = torch.cuda.current_stream()
|
||||
for p in params:
|
||||
if p.zero_fp16_shard is not None:
|
||||
# zero_fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
|
||||
# free it until the work in the current stream completes.
|
||||
p.zero_fp16_shard.record_stream(current_stream)
|
||||
free_storage(p.zero_fp16_shard)
|
||||
|
||||
def delete_fp32_shards(self) -> None:
|
||||
for p in self.params:
|
||||
if hasattr(p, 'zero_fp32_shard'):
|
||||
del p.zero_fp32_shard # reset _init_param_attr
|
|
@ -0,0 +1,204 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
|
||||
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
|
||||
enable_nccl_base_collectives = False
|
||||
else:
|
||||
enable_nccl_base_collectives = True
|
||||
|
||||
|
||||
class Bucket:
|
||||
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||||
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
|
||||
self.group = group
|
||||
self.offset = 0
|
||||
self.callbacks: List[Callable] = []
|
||||
self.output_shard = torch.zeros_like(self.buffer[0])
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush content of the bucket."""
|
||||
if self.offset == 0:
|
||||
assert len(self.callbacks) == 0
|
||||
return
|
||||
# reduce-scatter bucket
|
||||
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||||
dist._reduce_scatter_base(
|
||||
self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
|
||||
)
|
||||
else:
|
||||
dist.reduce_scatter(
|
||||
self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
|
||||
)
|
||||
# execute post-reduction callbacks
|
||||
for callback_fn in self.callbacks:
|
||||
callback_fn()
|
||||
# reuse input bucket but allocate a fresh output shard
|
||||
self.buffer[:, : self.offset].zero_()
|
||||
self.offset = 0
|
||||
self.callbacks.clear()
|
||||
self.output_shard = torch.zeros_like(self.buffer[0])
|
||||
|
||||
def alloc(self) -> None:
|
||||
"""Setup the buffers if they are not allocated.
|
||||
|
||||
Using ``setup`` and ``teardown``, we can ensure that the bucket
|
||||
buffers are only allocated during the backward pass, hence saving more
|
||||
memory to other parts of the training process, such as the forward pass
|
||||
for activation memory.
|
||||
"""
|
||||
for tensor in [self.buffer, self.output_shard]:
|
||||
if tensor.storage().size() == 0:
|
||||
tensor.storage().resize_(tensor.size().numel())
|
||||
|
||||
def free(self) -> None:
|
||||
"""Tear down the bucket by freeing the memory"""
|
||||
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
|
||||
for tensor in [self.buffer, self.output_shard]:
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
def append(self, tensor_list: List[Tensor], callback_fn: Callable):
|
||||
# copy data from input_list into bucket
|
||||
tensor_size = tensor_list[0].numel()
|
||||
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
|
||||
offset = self.offset
|
||||
self.buffer[:, offset: offset + tensor_size].copy_(stacked_input)
|
||||
self.offset += tensor_size
|
||||
|
||||
# callback will be given the reduced result
|
||||
if callback_fn is not None:
|
||||
result_view = self.output_shard[offset: offset + tensor_size].view_as(tensor_list[0])
|
||||
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||||
|
||||
|
||||
class ReduceScatterBucketer:
|
||||
"""
|
||||
Helper for bucketing multiple reduce-scatter operations on small tensors
|
||||
into larger reduce-scatter ops to improve communication efficiency.
|
||||
|
||||
Usage::
|
||||
|
||||
bucketer = ReduceScatterBucketer()
|
||||
bucketer.reduce_scatter_async(
|
||||
small_tensors, callback_fn=lambda result: print("small")
|
||||
)
|
||||
bucketer.reduce_scatter_async(
|
||||
big_tensors, callback_fn=lambda result: print("big")
|
||||
)
|
||||
bucketer.reduce_scatter_async(
|
||||
more_small_tensors, callback_fn=lambda result: print("small2")
|
||||
)
|
||||
bucketer.flush() # callbacks only guaranteed to be called after flush()
|
||||
# Example output (note that it is out of order, due to bucketing):
|
||||
# big
|
||||
# small
|
||||
# small2
|
||||
|
||||
Args:
|
||||
bucket_size_mb (int, Optional): bucket size for communicating. Buckets
|
||||
are sub-divided based on world_size. Values <= 0 disable bucketing.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket_size_mb: int = 25):
|
||||
self.bucket_size_mb = bucket_size_mb
|
||||
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def reduce_scatter_async(
|
||||
self,
|
||||
input_list: List[Tensor],
|
||||
group: ProcessGroup,
|
||||
callback_fn: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Reduce-scatter a list of tensors asynchronously, so smaller reductions
|
||||
can be bucketed together. The given callback (``callback_fn``) will be
|
||||
called with the reduced result at some later time. Call ``flush()`` to
|
||||
force all queued ops and callbacks to be executed.
|
||||
|
||||
Note that large inputs will be reduced immediately, and this function
|
||||
may also flush the relevant bucket to make room for ``input_list``.
|
||||
|
||||
Args:
|
||||
input_list (List[Tensor]): list of tensors to reduce-scatter. List
|
||||
should contain ``group.size()`` tensors and each tensor should
|
||||
have identical shape, dtype and device.
|
||||
group (ProcessGroup): process group for reduction
|
||||
callback_fn (Callable, Optional): callback function to call after
|
||||
the reduction executes. Function will be called with a single
|
||||
argument corresponding to the reduced result.
|
||||
"""
|
||||
world_size = group.size()
|
||||
|
||||
assert (
|
||||
len(input_list) == world_size
|
||||
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
|
||||
|
||||
first_input = input_list[0]
|
||||
first_input_size = first_input.numel()
|
||||
|
||||
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
|
||||
if first_input_size > bucket_shard_size:
|
||||
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
|
||||
# input is too big to fit in the bucket, reduce-scatter directly
|
||||
output = torch.zeros_like(input_list[0])
|
||||
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||||
input_flattened = torch.cat(input_list)
|
||||
dist._reduce_scatter_base(output, input_flattened, group=group)
|
||||
else:
|
||||
# fallback
|
||||
dist.reduce_scatter(output, input_list, group=group)
|
||||
if callback_fn is not None:
|
||||
callback_fn(output)
|
||||
return
|
||||
|
||||
bucket = self._get_bucket(first_input, group)
|
||||
if first_input_size > bucket.buffer.size(1) - bucket.offset:
|
||||
# not enough space remaining in bucket, flush it now
|
||||
bucket.flush()
|
||||
bucket.append(input_list, callback_fn)
|
||||
|
||||
@torch.no_grad()
|
||||
def flush(self) -> None:
|
||||
"""Reduce-scatter any partial buckets."""
|
||||
for bucket in self.buckets.values():
|
||||
bucket.flush()
|
||||
|
||||
@torch.no_grad()
|
||||
def free(self) -> None:
|
||||
"""Free buffers from all buckets."""
|
||||
for bucket in self.buckets.values():
|
||||
bucket.free()
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
|
||||
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||
return 0
|
||||
MB = 1024 * 1024
|
||||
bucket_size = self.bucket_size_mb * MB / element_size
|
||||
return int(bucket_size // num_shards)
|
||||
|
||||
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||||
# TODO (Min): the `group` used here in the key is the object hash, not the content
|
||||
# hash. That means if FSDP instances are initialized with different process groups,
|
||||
# even when the group members are in fact the same, we end up creating different
|
||||
# buckets here.
|
||||
key = (tensor.dtype, tensor.device, group)
|
||||
if key not in self.buckets:
|
||||
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
|
||||
world_size = group.size()
|
||||
shard_size = self._get_shard_size(tensor.element_size(), world_size)
|
||||
self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group)
|
||||
self.buckets[key].alloc()
|
||||
return self.buckets[key]
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,63 @@
|
|||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from enum import Enum, auto
|
||||
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
|
||||
Set, Union)
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook, ShardParamHook
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None
|
||||
):
|
||||
r"""
|
||||
A demo to reconfigure zero1 shared_model.
|
||||
Currently do not consider the Optimizer States.
|
||||
"""
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.rank = dist.get_rank(self.process_group)
|
||||
|
||||
# The module has to be placed on GPU
|
||||
self.module = module.cuda()
|
||||
|
||||
# Shard the parameters at first
|
||||
for _, param in self.module.named_parameters():
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr.shard()
|
||||
|
||||
# Register hooks
|
||||
register_ophooks_recursively(self.module, [ShardParamHook()])
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
outputs = self.module(*args, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
def backward(self, loss):
|
||||
if self.loss_scaler:
|
||||
self.loss_scaler.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .sharded_optim import ShardedOptimizer
|
||||
|
||||
__all__ = ['ShardedOptimizer']
|
|
@ -0,0 +1,288 @@
|
|||
import math
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.utils import is_model_parallel_parameter
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def move_tensor(input_, device):
|
||||
assert device in ['cpu', 'gpu']
|
||||
|
||||
if isinstance(input_, (list, tuple)):
|
||||
for tensor in input_:
|
||||
tensor.data = tensor.data.cpu(
|
||||
) if device == 'cpu' else tensor.data.cuda()
|
||||
elif torch.is_tensor(input_):
|
||||
input_.data = input_.data.cpu(
|
||||
) if device == 'cpu' else tensor.data.cuda()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected argument 'input_' to be torch.Tensor, list or tuple, but got {type(input_)} "
|
||||
)
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
return _flatten_dense_tensors(input_)
|
||||
|
||||
|
||||
def unflatten(flat, tensors):
|
||||
return _unflatten_dense_tensors(flat, tensors)
|
||||
|
||||
|
||||
def count_numel(tensor_list):
|
||||
res = 0
|
||||
for tensor in tensor_list:
|
||||
res += tensor.numel()
|
||||
return res
|
||||
|
||||
|
||||
def calculate_padding(numel, unit_size):
|
||||
remainder = numel % unit_size
|
||||
return unit_size - remainder if remainder else remainder
|
||||
|
||||
|
||||
def shuffle_by_round_robin(tensor_list, num_partitions):
|
||||
partitions = dict()
|
||||
|
||||
for tensor_idx, tensor in enumerate(tensor_list):
|
||||
partition_to_go = tensor_idx % num_partitions
|
||||
if partition_to_go not in partitions:
|
||||
partitions[partition_to_go] = []
|
||||
partitions[partition_to_go].append(dict(tensor=tensor,
|
||||
index=tensor_idx))
|
||||
|
||||
partitions_count = len(partitions)
|
||||
new_tensor_list = []
|
||||
tensor_index_mapping = dict()
|
||||
|
||||
for partition_id in range(partitions_count):
|
||||
partition_tensors = partitions[partition_id]
|
||||
for item in partition_tensors:
|
||||
tensor_index_mapping[item['index']] = len(new_tensor_list)
|
||||
new_tensor_list.append(item['tensor'])
|
||||
|
||||
return new_tensor_list, tensor_index_mapping
|
||||
|
||||
|
||||
# create a flat tensor aligned at the alignment boundary
|
||||
def flatten_dense_tensors_with_padding(tensor_list, unit_size):
|
||||
num_elements = count_numel(tensor_list)
|
||||
padding = calculate_padding(num_elements, unit_size=unit_size)
|
||||
|
||||
if padding > 0:
|
||||
pad_tensor = torch.zeros(padding,
|
||||
device=tensor_list[0].device,
|
||||
dtype=tensor_list[0].dtype)
|
||||
padded_tensor_list = tensor_list + [pad_tensor]
|
||||
else:
|
||||
padded_tensor_list = tensor_list
|
||||
|
||||
return flatten(padded_tensor_list)
|
||||
|
||||
|
||||
def is_nccl_aligned(tensor):
|
||||
return tensor.data_ptr() % 4 == 0
|
||||
|
||||
def get_grad_accumulate_object(tensor):
|
||||
"""
|
||||
Return the AccumulateGrad of the input tensor
|
||||
"""
|
||||
|
||||
# grad_fn reference:
|
||||
# https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463
|
||||
# expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand
|
||||
#
|
||||
# `next_functions` will return the backward graph where
|
||||
# the first element is the AccumulateGrad of the leaf nodes.
|
||||
# we want to get the AccumulateGrad of the input tensor instead of the leaf
|
||||
# node in the whole computation graph.
|
||||
# Therefore, we call expand_as to create a dummy graph
|
||||
# where tensor_tmp and tensor indeed point to the same object.
|
||||
# You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr())
|
||||
tensor_tmp = tensor.expand_as(tensor)
|
||||
grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0]
|
||||
return grad_acc_obj
|
||||
|
||||
|
||||
def split_half_float_double(tensor_list):
|
||||
dtypes = [
|
||||
"torch.cuda.HalfTensor", "torch.cuda.FloatTensor",
|
||||
"torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"
|
||||
]
|
||||
buckets = []
|
||||
for i, dtype in enumerate(dtypes):
|
||||
bucket = [t for t in tensor_list if t.type() == dtype]
|
||||
if bucket:
|
||||
buckets.append(bucket)
|
||||
return buckets
|
||||
|
||||
|
||||
def reduce_tensor(tensor,
|
||||
dtype,
|
||||
dst_rank=None,
|
||||
parallel_mode=ParallelMode.DATA):
|
||||
"""
|
||||
Reduce the tensor in the data parallel process group
|
||||
|
||||
:param tensor: A tensor object to reduce/all-reduce
|
||||
:param dtype: The data type used in communication
|
||||
:param dst_rank: The source rank for reduce. If dst_rank is None,
|
||||
all-reduce will be used instead of reduce. Default is None.
|
||||
|
||||
:type tensor: torch.Tensor
|
||||
:type dtype: torch.dtype
|
||||
:type dst_rank: int, optional
|
||||
"""
|
||||
|
||||
# cast the data to specified dtype for reduce/all-reduce
|
||||
if tensor.dtype != dtype:
|
||||
tensor_to_reduce = tensor.to(dtype)
|
||||
else:
|
||||
tensor_to_reduce = tensor
|
||||
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
group = gpc.get_group(parallel_mode)
|
||||
tensor_to_reduce.div_(world_size)
|
||||
|
||||
# if rank is None, all reduce will be used
|
||||
# else, reduce is used
|
||||
use_all_reduce = dst_rank is None
|
||||
|
||||
if use_all_reduce:
|
||||
dist.all_reduce(tensor_to_reduce, group=group)
|
||||
else:
|
||||
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
|
||||
global_rank = ranks_in_group[dst_rank]
|
||||
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
|
||||
|
||||
# recover the original dtype
|
||||
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
||||
local_rank = gpc.get_local_rank(parallel_mode)
|
||||
if use_all_reduce or dst_rank == local_rank:
|
||||
tensor.copy_(tensor_to_reduce)
|
||||
return tensor
|
||||
|
||||
def has_inf_or_nan(tensor):
|
||||
try:
|
||||
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
||||
# (which is true for some recent version of pytorch).
|
||||
tensor_sum = float(tensor.float().sum())
|
||||
# More efficient version that can be used if .sum() returns a Python scalar
|
||||
# tensor_sum = float(tensor.sum())
|
||||
except RuntimeError as instance:
|
||||
# We want to check if inst is actually an overflow exception.
|
||||
# RuntimeError could come from a different error.
|
||||
# If so, we still want the exception to propagate.
|
||||
if "value cannot be converted" not in instance.args[0]:
|
||||
raise
|
||||
return True
|
||||
else:
|
||||
if tensor_sum == float('inf') or tensor_sum == -float(
|
||||
'inf') or tensor_sum != tensor_sum:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def release_param_grad(tensor_list):
|
||||
for tensor in tensor_list:
|
||||
tensor.grad = None
|
||||
|
||||
|
||||
def calculate_global_norm_from_list(norm_list):
|
||||
""" Compute total from a list of norms
|
||||
"""
|
||||
total_norm = 0.0
|
||||
for norm in norm_list:
|
||||
total_norm += norm**2.0
|
||||
return math.sqrt(total_norm)
|
||||
|
||||
|
||||
def compute_norm(gradients,
|
||||
params,
|
||||
dp_group,
|
||||
mp_group,
|
||||
norm_type=2):
|
||||
"""Clips gradient norm of an iterable of parameters.
|
||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||
added functionality to handle model parallel parameters. Note that
|
||||
the gradients are modified in place.
|
||||
Arguments:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
Returns:
|
||||
Total norm of the parameters (viewed as a single vector).
|
||||
"""
|
||||
|
||||
if mp_group is None:
|
||||
mp_rank = 0
|
||||
else:
|
||||
mp_rank = dist.get_rank(mp_group)
|
||||
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
dist.all_reduce(total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=dp_group)
|
||||
|
||||
# Take max across all GPUs.
|
||||
if mp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
total_norm = 0.0
|
||||
# if dist.get_rank() == 0:
|
||||
# logger.info(f"Total Norm beginning {total_norm}")
|
||||
|
||||
for g, p in zip(gradients, params):
|
||||
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
|
||||
if is_model_parallel_parameter(p) or mp_rank == 0:
|
||||
param_norm = g.data.double().norm(2)
|
||||
total_norm += param_norm.item()**2
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
torch.distributed.all_reduce(total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=dp_group)
|
||||
|
||||
if mp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||
|
||||
if total_norm == float(
|
||||
'inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def sync_param(flat_tensor, tensor_list):
|
||||
"""
|
||||
Synchronize the flattened tensor and unflattened tensor list. When
|
||||
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
|
||||
a new tensor is created. Thus, the flat tensor and original tensor list do not
|
||||
share the same memory space. This function will update the tensor list so that
|
||||
they point to the same value.
|
||||
|
||||
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit
|
||||
:param tensor_list: A list of tensors corresponding to the flattened tensor
|
||||
:type flat_tensor: torch.Tensor
|
||||
:type tensor_list: List[torch.Tensor]
|
||||
"""
|
||||
updated_params = unflatten(flat_tensor, tensor_list)
|
||||
|
||||
# update the tensor data
|
||||
for p, q in zip(tensor_list, updated_params):
|
||||
p.data = q.data
|
|
@ -0,0 +1,6 @@
|
|||
from .gradient_store import GradientStore
|
||||
from .parameter_store import ParameterStore
|
||||
from .bucket_store import BucketStore
|
||||
from .tensor_bucket import TensorBucket
|
||||
|
||||
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
|
|
@ -0,0 +1,17 @@
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
|
||||
class BaseStore:
|
||||
|
||||
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
return self._local_rank
|
|
@ -0,0 +1,43 @@
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from .base_store import BaseStore
|
||||
|
||||
class BucketStore(BaseStore):
|
||||
|
||||
def __init__(self, dp_parallel_mode):
|
||||
super().__init__(dp_parallel_mode)
|
||||
self._grads = dict()
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
|
||||
self.reset()
|
||||
|
||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||
return self._num_elements_in_bucket[reduce_rank]
|
||||
|
||||
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
||||
self._num_elements_in_bucket[reduce_rank] += num_elements
|
||||
|
||||
def add_grad(self, tensor, reduce_rank: int = None):
|
||||
self._grads[reduce_rank].append(tensor)
|
||||
|
||||
def add_param(self, tensor, reduce_rank: int = None):
|
||||
self._params[reduce_rank].append(tensor)
|
||||
|
||||
def reset(self):
|
||||
keys = [None] + list(range(self._world_size))
|
||||
self._grads = {rank: [] for rank in keys}
|
||||
self._params = {rank: [] for rank in keys}
|
||||
self._num_elements_in_bucket = {rank: 0 for rank in keys}
|
||||
|
||||
def reset_by_rank(self, reduce_rank=None):
|
||||
self._grads[reduce_rank] = []
|
||||
self._params[reduce_rank] = []
|
||||
self._num_elements_in_bucket[reduce_rank] = 0
|
||||
|
||||
|
||||
def get_grad(self, reduce_rank: int = None):
|
||||
return self._grads[reduce_rank]
|
||||
|
||||
def get_param(self, reduce_rank: int = None):
|
||||
return self._params[reduce_rank]
|
|
@ -0,0 +1,66 @@
|
|||
from typing import List
|
||||
from torch import Tensor
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
# bookkeeping data structures
|
||||
self._averaged_gradients = dict()
|
||||
|
||||
# for backward reduction hooks
|
||||
self._grad_acc_objs = []
|
||||
|
||||
def add_accumulate_grad_object(self, obj):
|
||||
"""
|
||||
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
||||
be attached successfully.
|
||||
|
||||
:param obj: An object of :class:`AccumulateGrad` class
|
||||
:type obj: :class:`AccumulateGrad`
|
||||
"""
|
||||
|
||||
self._grad_acc_objs.append(obj)
|
||||
|
||||
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
||||
"""
|
||||
Return average gradients of a parameter group
|
||||
|
||||
:param group_id: The index of parameter group
|
||||
:type group_id: int
|
||||
|
||||
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
|
||||
:rtype: List[torch.Tensor]
|
||||
"""
|
||||
|
||||
return self._averaged_gradients[group_id]
|
||||
|
||||
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Append an average gradient to the list of averaged gradients of a parameter group
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
"""
|
||||
|
||||
if group_id in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id].append(tensor)
|
||||
else:
|
||||
self._averaged_gradients[group_id] = [tensor]
|
||||
|
||||
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:type group_id: int
|
||||
"""
|
||||
|
||||
self._averaged_gradients[group_id] = []
|
||||
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
from .base_store import BaseStore
|
||||
from torch import Tensor
|
||||
from typing import List
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
|
||||
def __init__(self, dp_paralle_mode):
|
||||
super().__init__(dp_paralle_mode)
|
||||
# param partitioning data structures
|
||||
self._fp16_param_to_rank = dict()
|
||||
self._rank_groupid_to_fp16_param_list = dict()
|
||||
self._rank_group_id_to_flat_fp16_param = dict()
|
||||
|
||||
# param reduction data structures
|
||||
self._is_param_reduced = dict()
|
||||
self._reduced_param = []
|
||||
|
||||
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
||||
"""
|
||||
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
:param rank: The rank of which the process is responsible for updating the parameter
|
||||
:type rank: int
|
||||
"""
|
||||
|
||||
self._fp16_param_to_rank[tensor] = rank
|
||||
|
||||
def get_param_rank(self, tensor: Tensor) -> int:
|
||||
"""
|
||||
Gives the rank which the parameter belongs to
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
"""
|
||||
return self._fp16_param_to_rank[tensor]
|
||||
|
||||
def belongs_to_current_rank(self, tensor) -> bool:
|
||||
"""
|
||||
Check whether a parameter is supposed to be updated by the process of the current rank
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
tensor_rank = self._fp16_param_to_rank[tensor]
|
||||
return tensor_rank == self._local_rank
|
||||
|
||||
def add_fp16_param_list_by_rank_group(self, rank, group_id,
|
||||
tensor_list) -> None:
|
||||
if rank not in self._rank_groupid_to_fp16_param_list:
|
||||
self._rank_groupid_to_fp16_param_list[rank] = dict()
|
||||
|
||||
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
|
||||
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
|
||||
|
||||
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(
|
||||
tensor_list)
|
||||
|
||||
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
|
||||
return self._rank_groupid_to_fp16_param_list[rank][group_id]
|
||||
|
||||
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
|
||||
if rank not in self._rank_group_id_to_flat_fp16_param:
|
||||
self._rank_group_id_to_flat_fp16_param[rank] = dict()
|
||||
|
||||
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
|
||||
|
||||
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
|
||||
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
|
||||
|
||||
def is_param_reduced(self, tensor):
|
||||
return self._is_param_reduced[tensor]
|
||||
|
||||
def set_param_reduction_state(self, tensor, state):
|
||||
self._is_param_reduced[tensor] = state
|
||||
|
||||
def get_param_reduction_states(self):
|
||||
return self._is_param_reduced
|
||||
|
||||
def reset_previous_reduced_params(self):
|
||||
self._reduced_param = []
|
||||
|
||||
def add_previous_reduced_param(self, tensor):
|
||||
self._reduced_param.append(tensor)
|
||||
|
||||
def clear_grads_of_previous_reduced_params(self):
|
||||
if len(self._reduced_param) > 0:
|
||||
for param in self._reduced_param:
|
||||
param.grad = None
|
||||
self.reset_previous_reduced_params()
|
|
@ -0,0 +1,54 @@
|
|||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
|
||||
def __init__(self, size):
|
||||
self._max_size = size
|
||||
self._current_size = 0
|
||||
self._bucket = []
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
return self._max_size
|
||||
|
||||
@property
|
||||
def current_size(self):
|
||||
return self._current_size
|
||||
|
||||
def is_full_or_oversized(self):
|
||||
return self._current_size >= self._max_size
|
||||
|
||||
def is_empty(self):
|
||||
return len(self._bucket) == 0
|
||||
|
||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" \
|
||||
+ f"by tensor (size {tensor_size})"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._bucket.append(tensor)
|
||||
self._current_size += tensor_size
|
||||
|
||||
def will_exceed_max_size(self, tensor_size):
|
||||
expected_size = self._current_size + tensor_size
|
||||
return expected_size > self._max_size
|
||||
|
||||
def get_bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def empty(self):
|
||||
self._bucket = []
|
||||
self._size = 0
|
||||
|
||||
def flatten(self):
|
||||
return _flatten_dense_tensors(self._bucket)
|
||||
|
||||
def unflatten_and_copy(self, flat_tensor):
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(
|
||||
flat_tensor, self._bucket)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
|
@ -0,0 +1,568 @@
|
|||
from itertools import groupby
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.logging import get_dist_logger
|
||||
from torch.optim import Optimizer
|
||||
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
|
||||
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
|
||||
from functools import partial
|
||||
|
||||
|
||||
class ShardedOptimizer(ColossalaiOptimizer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
|
||||
# grad scaler config
|
||||
initial_scale=2**32,
|
||||
min_scale=1,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=1000,
|
||||
hysteresis=2,
|
||||
max_scale: int = 2**32,
|
||||
|
||||
# grad clipping
|
||||
clip_grad_norm=2.0,
|
||||
verbose=False,
|
||||
|
||||
# communication
|
||||
reduce_bucket_size=500000000,
|
||||
communication_dtype=torch.float16,
|
||||
overlap_communication=False,
|
||||
|
||||
# stage 2
|
||||
partition_grad=False,
|
||||
|
||||
dp_parallel_mode=ParallelMode.DATA,
|
||||
mp_parallel_mode=ParallelMode.MODEL,
|
||||
|
||||
# cpu offload
|
||||
cpu_offload=False):
|
||||
|
||||
# TODO: add support for
|
||||
# 1. fp16 master weights
|
||||
# 2. contiguous gradients
|
||||
# 3. cpu offload
|
||||
# 4. support when some parameters requires_grad = False
|
||||
|
||||
self._optimizer = optimizer
|
||||
self._dtype = self._optimizer.param_groups[0]['params'][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
self._verbose = verbose
|
||||
|
||||
# stage 2
|
||||
self._partition_grads = partition_grad
|
||||
|
||||
# cpu_offload
|
||||
self._cpu_offload = cpu_offload
|
||||
|
||||
# get process groups
|
||||
self._dp_parallel_mode = dp_parallel_mode
|
||||
self._mp_parallel_mode = mp_parallel_mode
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
|
||||
self._dp_group = gpc.get_group(dp_parallel_mode)
|
||||
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
|
||||
self._mp_group = gpc.get_group(mp_parallel_mode)
|
||||
else:
|
||||
self._mp_group = None
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
verbose=verbose)
|
||||
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
||||
# check argument conflict
|
||||
self._sanity_checks()
|
||||
|
||||
# ParameterStore will manage the tensor buffers used for zero
|
||||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(self._dp_parallel_mode)
|
||||
self._grad_store = GradientStore(self._dp_parallel_mode)
|
||||
self._bucket_store = BucketStore(self._dp_parallel_mode)
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
for group_id, param_group in enumerate(self._optimizer.param_groups):
|
||||
params = param_group['params']
|
||||
|
||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||
self._fp16_param_groups[group_id] = params
|
||||
|
||||
# assign parameters to ranks
|
||||
# the params in the list are sorted
|
||||
params_per_rank = self._partition_param_list(params)
|
||||
|
||||
# store the mapping between param to rank
|
||||
# each param should belong to only one rank
|
||||
for rank, params in enumerate(params_per_rank):
|
||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
||||
for param in params:
|
||||
self._param_store.set_param_to_rank(param, rank)
|
||||
|
||||
# move to cpu to make room to create the flat tensor
|
||||
move_tensor(params, device='cpu')
|
||||
|
||||
# flatten the reordered tensors
|
||||
for rank in range(self._world_size):
|
||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||
flat_tensor = flatten(tensor_list)
|
||||
flat_tensor = flat_tensor.cuda()
|
||||
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
||||
|
||||
# sync parameters
|
||||
for rank in range(self._world_size):
|
||||
flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id)
|
||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
|
||||
|
||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach()
|
||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
||||
fp32_flat_current_rank.requires_grad = True
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
||||
|
||||
# need to replace the params in the `params` field in the optimizer
|
||||
# so that when the optimizer calls step(), it only updates the tensors
|
||||
# managed by this data parallel rank
|
||||
param_group['params'] = [fp32_flat_current_rank]
|
||||
|
||||
# set reduction state
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
self._param_store.set_param_reduction_state(param, False)
|
||||
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# or stage 2 is used
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
if self._overlap_communication or self._partition_grads:
|
||||
self._attach_reduction_hook()
|
||||
|
||||
self._initialize_optimizer_states()
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale
|
||||
|
||||
@property
|
||||
def num_param_groups(self):
|
||||
return len(self._fp16_param_groups)
|
||||
|
||||
def _partition_param_list(self, param_list):
|
||||
params_per_rank = [[] for _ in range(self._world_size)]
|
||||
numel_per_rank = [0 for _ in range(self._world_size)]
|
||||
|
||||
# partititon the parameters in a greedy fashion
|
||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
||||
for param in sorted_params:
|
||||
# allocate this parameter to the rank with
|
||||
# the smallest numel for load balancing purpose
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
|
||||
if self._verbose:
|
||||
self._logger.info(f'Number of elements on ranks: {numel_per_rank}',
|
||||
ranks=[0],
|
||||
parallel_mode=self._dp_parallel_mode)
|
||||
return params_per_rank
|
||||
|
||||
def _initialize_optimizer_states(self):
|
||||
# create a dummy zero tensor which has the same shape as that of the param
|
||||
# set this dummpy zero tensor as grad
|
||||
for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)):
|
||||
fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp32_partition_grad = torch.zeros_like(fp32_partition_param)
|
||||
fp32_partition_param.grad = fp32_partition_grad
|
||||
|
||||
# update the parameter with zero gradients for initialization of optimizer states
|
||||
self._optimizer.step()
|
||||
|
||||
# remove the grad of the paramter to save memory
|
||||
for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
|
||||
fp32_flat_tensor.grad = None
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available(), 'CUDA is required'
|
||||
assert self._dtype == torch.float16, \
|
||||
f'Parameters are expected to be of type torch.float16, but got {self._dtype}'
|
||||
|
||||
###########################################################
|
||||
# Backward Reduction Hook
|
||||
###########################################################
|
||||
|
||||
def _attach_reduction_hook(self):
|
||||
# we iterate over the fp16 params
|
||||
# on each param, we register a hook to its AccumulateGrad object
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.requires_grad:
|
||||
# determines the reduction destionation rank
|
||||
# this is only valid for stage 2
|
||||
# dst_rank = None means using all-reduce
|
||||
# else using reduce
|
||||
if self._partition_grads:
|
||||
reduce_rank = self._param_store.get_param_rank(param)
|
||||
else:
|
||||
reduce_rank = None
|
||||
|
||||
def _define_and_attach(param, reduce_rank):
|
||||
# get the AccumulateGrad object of the param itself
|
||||
accum_grad_obj = get_grad_accumulate_object(param)
|
||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
|
||||
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank)
|
||||
|
||||
# define hook
|
||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args):
|
||||
reduction_func()
|
||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||
|
||||
_define_and_attach(param, reduce_rank)
|
||||
|
||||
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_in_bucket(reduce_rank)
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
||||
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
|
||||
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
|
||||
def _reduce_grads_in_bucket(self, reduce_rank=None):
|
||||
# reduce grads
|
||||
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
||||
|
||||
# use communication stream if overlapping
|
||||
# communication with computation
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
|
||||
'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
|
||||
# if partition grads = True
|
||||
# we do not keep the gradient after reduction
|
||||
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
|
||||
if self._overlap_communication:
|
||||
# we need to keep this gradient for now as reduction may
|
||||
# be completed yet since it is using a different cuda stream
|
||||
self._param_store.add_previous_reduced_param(param)
|
||||
else:
|
||||
param.grad = None
|
||||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
|
||||
|
||||
##############################
|
||||
# Reduction Utility Function #
|
||||
##############################
|
||||
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(tensor=flat,
|
||||
dtype=self._communication_dtype,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=self._dp_parallel_mode)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
|
||||
################################
|
||||
# torch.optim.Optimizer methods
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=True):
|
||||
loss = self.loss_scale * loss
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""
|
||||
Set parameter gradients to zero. If set_to_none = True, gradient
|
||||
will be set to None to save memory.
|
||||
|
||||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||
:type set_to_none: bool
|
||||
"""
|
||||
for group_id, param_group in self._fp16_param_groups.items():
|
||||
for param in param_group:
|
||||
if set_to_none:
|
||||
param.grad = None
|
||||
else:
|
||||
if param.grad is not None:
|
||||
param.grad.detach()
|
||||
param.grad.zero_()
|
||||
|
||||
####################
|
||||
# Update Parameter #
|
||||
####################
|
||||
|
||||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
# check for overflow
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return
|
||||
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
norm_groups = []
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
|
||||
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
||||
rank=self._local_rank),
|
||||
dp_group=self._dp_group,
|
||||
mp_group=self._mp_group)
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
# create flat gradient for the flat fp32 params
|
||||
fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
||||
|
||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||
|
||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||
assert param_shape == flat_fp32_avg_grads.shape, \
|
||||
f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}'
|
||||
|
||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
|
||||
|
||||
# update the parameters
|
||||
self._optimizer.step()
|
||||
# release the fp32 grad
|
||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||
|
||||
# update fp16 partition updated by the current rank
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
|
||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device)
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
# broadcast the updated model weights
|
||||
handles = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._world_size):
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
##################
|
||||
# FP16 Utilities #
|
||||
##################
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(0.0)
|
||||
|
||||
# check for overflow
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
|
||||
# all-reduce across dp group
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
if self._mp_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
|
||||
|
||||
if self._found_overflow.item() > 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = self.loss_scale
|
||||
|
||||
if self._clip_grad_norm > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1. / combined_scale)
|
||||
|
||||
############################
|
||||
# Gradient Synchronization #
|
||||
############################
|
||||
|
||||
def sync_grad(self):
|
||||
if not self._partition_grads:
|
||||
self._reduce_grad_stage1()
|
||||
else:
|
||||
# TODO: support async comm in reduce
|
||||
self._reduce_grad_stage2()
|
||||
|
||||
# update param already reduced flag
|
||||
reduction_states = self._param_store.get_param_reduction_states()
|
||||
for tensor, state in reduction_states.items():
|
||||
reduction_states[tensor] = False
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# accumulate gradient
|
||||
avg_gradients = self._grad_store._averaged_gradients
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
|
||||
|
||||
if group_id not in avg_gradients:
|
||||
avg_gradients[group_id] = []
|
||||
|
||||
param_idx = 0
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
if len(avg_gradients[group_id]) == param_idx:
|
||||
avg_gradients[group_id].append(param.grad)
|
||||
else:
|
||||
avg_gradients[group_id][param_idx].add_(param.grad)
|
||||
param_idx += 1
|
||||
|
||||
# the gradients needed are stored in the avg_gradients buffer
|
||||
# thus, can clear this
|
||||
self.zero_grad()
|
||||
|
||||
def _reduce_grad_stage1(self):
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
# we need to manually reduce these gradients
|
||||
if not self._overlap_communication:
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
self._reduce_and_remove_grads_by_bucket(param)
|
||||
|
||||
# we need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
self._reduce_grads_in_bucket()
|
||||
|
||||
def _reduce_grad_stage2(self):
|
||||
# when partition_grads is True, reduction hooks
|
||||
# are attached in the __init__ function, so we
|
||||
# only need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
for reduce_rank in range(self._world_size):
|
||||
self._reduce_grads_in_bucket(reduce_rank)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,187 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import operator as op
|
||||
from functools import partial, reduce
|
||||
from typing import List
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
|
||||
class Enumerator:
|
||||
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
|
||||
self.arg_names = arg_names
|
||||
self.enums = Enumerator.all_enumerate(arg_values)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.enums)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {name: self.enums[idx][i] for i, name in enumerate(self.arg_names)}
|
||||
|
||||
@staticmethod
|
||||
def all_enumerate(args: List[tuple]):
|
||||
num_states = reduce(op.mul, map(lambda xs: len(xs), args))
|
||||
idxs = [0] * len(args)
|
||||
states = []
|
||||
for _ in range(num_states):
|
||||
states.append(tuple(args[j][idx] for j, idx in enumerate(idxs)))
|
||||
if len(states) == num_states:
|
||||
break
|
||||
i = 0
|
||||
while idxs[i] + 1 == len(args[i]):
|
||||
idxs[i] = 0
|
||||
i += 1
|
||||
idxs[i] += 1
|
||||
return states
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
clip_grad(model, norm_type)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def clip_grad(model, norm_type):
|
||||
if isinstance(model, DDP):
|
||||
clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
|
||||
else:
|
||||
clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
if zero_p.zero_shard_padding > 0:
|
||||
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_shard_padding = zero_p.zero_shard_padding
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_shard_padding > 0:
|
||||
zero_p = zero_p[:-zero_shard_padding]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
ddp_model = DDP(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type)
|
||||
check_grads(ddp_model, zero_model)
|
||||
check_params(ddp_model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type)
|
||||
check_grads(ddp_model, zero_model, loose=True)
|
||||
check_params(ddp_model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload', 'norm_type']
|
||||
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
|
||||
arg_enumerator = Enumerator(args, arg_values)
|
||||
|
||||
for kwargs in arg_enumerator:
|
||||
if dist.get_rank() == 0:
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
check_config()
|
||||
|
||||
|
||||
@ pytest.mark.dist
|
||||
def test_zero_clip_grad():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_clip_grad()
|
|
@ -0,0 +1,82 @@
|
|||
from functools import partial
|
||||
from operator import imod
|
||||
from colossalai.utils import checkpoint
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
|
||||
LOGGER = get_dist_logger()
|
||||
|
||||
CONFIG = dict(
|
||||
fp16=dict(
|
||||
mode=None,
|
||||
),
|
||||
zero=dict(
|
||||
level=3,
|
||||
verbose=False,
|
||||
offload_optimizer_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
fast_init=False
|
||||
),
|
||||
offload_param_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9
|
||||
)
|
||||
),
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
)
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
assert p.grad.dtype == zero_grad.dtype
|
||||
assert allclose(p.grad, zero_grad, loose=loose)
|
||||
LOGGER.info(torch.sum(p.grad-zero_grad))
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from operator import mod
|
||||
from pyexpat import model
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
|
||||
|
||||
|
||||
def run_fwd_bwd(model, x, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
model = Net(checkpoint=True).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
for _ in range(2):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_fwd_bwd(zero_model, x, False)
|
||||
run_fwd_bwd(model, x, False)
|
||||
check_grads(model, zero_model)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_shard_model_v2():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_model_v2()
|
|
@ -0,0 +1,50 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from asyncio.log import logger
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||
from tests.test_zero_data_parallel.common import Net, CONFIG
|
||||
|
||||
def run_shard_param_check(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
logger = get_dist_logger()
|
||||
model = Net()
|
||||
|
||||
# add an attribute as ca_attr to hijack the access to param.data
|
||||
for _, param in model.named_parameters():
|
||||
numel_ref = (param.numel() + world_size - 1) // world_size
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr.shard()
|
||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||
logger.info(f'shard {param_data.shape} {param_data}', ranks = [1])
|
||||
assert(numel_ref == param_data.numel())
|
||||
|
||||
for _, param in model.named_parameters():
|
||||
param.ca_attr.gather()
|
||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||
logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
|
||||
|
||||
disable_existing_loggers([logger])
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_run_shard_shape():
|
||||
world_size = 2
|
||||
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_run_shard_shape()
|
|
@ -0,0 +1,119 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from common import Net, check_grads, check_params, check_params
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def decode_booleans(intval, bits):
|
||||
res = []
|
||||
for bit in range(bits):
|
||||
mask = 1 << bit
|
||||
res.append((intval & mask) == mask)
|
||||
return res
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(model, optimizer, x, enable_autocast=fp16)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
|
||||
check_grads(model, zero_model)
|
||||
check_params(model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(model, optimizer, x, enable_autocast=False)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
|
||||
check_grads(model, zero_model, loose=True)
|
||||
check_params(model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload']
|
||||
|
||||
def pack_args(i):
|
||||
booleans = decode_booleans(i, len(args))
|
||||
return {arg: booleans[idx] for idx, arg in enumerate(args)}
|
||||
|
||||
for j in range(2 ** len(args)):
|
||||
kwargs = pack_args(j)
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 1
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_3()
|
|
@ -0,0 +1,123 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from common import Net, allclose
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
if zero_p.zero_shard_padding > 0:
|
||||
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_shard_padding = zero_p.zero_shard_padding
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_shard_padding > 0:
|
||||
zero_p = zero_p[:-zero_shard_padding]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def decode_booleans(intval, bits):
|
||||
res = []
|
||||
for bit in range(bits):
|
||||
mask = 1 << bit
|
||||
res.append((intval & mask) == mask)
|
||||
return res
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
ddp_model = DDP(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=fp16)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
|
||||
check_grads_padding(ddp_model, zero_model)
|
||||
check_params_padding(ddp_model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=False)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
|
||||
check_grads_padding(ddp_model, zero_model, loose=True)
|
||||
check_params_padding(ddp_model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload']
|
||||
|
||||
def pack_args(i):
|
||||
booleans = decode_booleans(i, len(args))
|
||||
return {arg: booleans[idx] for idx, arg in enumerate(args)}
|
||||
|
||||
for j in range(2 ** len(args)):
|
||||
kwargs = pack_args(j)
|
||||
if dist.get_rank() == 0:
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_3()
|
|
@ -1,102 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
BATCH_SIZE = 16
|
||||
IMG_SIZE = 224
|
||||
|
||||
CONFIG = dict(
|
||||
fp16=dict(
|
||||
mode=None,
|
||||
),
|
||||
zero=dict(
|
||||
level=2,
|
||||
cpu_offload=True,
|
||||
verbose=False,
|
||||
),
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloader# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer and loss
|
||||
# optimizer = build_optimizer(global_context.config.optimizer, model)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
# train
|
||||
model.train()
|
||||
for idx, (data, label) in enumerate(train_dataloader):
|
||||
engine.zero_grad()
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_2():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_2()
|
|
@ -1,114 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
BATCH_SIZE = 16
|
||||
IMG_SIZE = 224
|
||||
|
||||
CONFIG = dict(
|
||||
fp16=dict(
|
||||
mode=None,
|
||||
),
|
||||
zero=dict(
|
||||
level=3,
|
||||
verbose=False,
|
||||
offload_optimizer_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
fast_init=False
|
||||
),
|
||||
offload_param_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9
|
||||
)
|
||||
),
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloader# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer and loss
|
||||
# optimizer = build_optimizer(global_context.config.optimizer, model)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
# train
|
||||
model.train()
|
||||
for idx, (data, label) in enumerate(train_dataloader):
|
||||
engine.zero_grad()
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_3()
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.utils import free_port
|
||||
from common import CONFIG
|
||||
|
||||
def run_shard_shape_check(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
model = torch.nn.Linear(2, 4 * world_size)
|
||||
gpc.init_parallel_groups()
|
||||
Zero3ParameterManager(module=model, process_group=gpc.get_group(ParallelMode.DATA), offload_config=CONFIG.get('offload_param_config'))
|
||||
|
||||
assert(model.weight.numel() == 4 * 2)
|
||||
assert(model.bias.numel() == 4)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_run_shard_shape():
|
||||
world_size = 2
|
||||
run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_run_shard_shape()
|
|
@ -88,6 +88,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
|
||||
def test_2d_vit_zero_level_2():
|
||||
world_size = 8
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port())
|
||||
|
|
|
@ -88,7 +88,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
||||
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
|
||||
def test_3d_vit_zero_level_3():
|
||||
world_size = 8
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port())
|
||||
|
|
Loading…
Reference in New Issue