mirror of https://github.com/hpcaitech/ColossalAI
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples cipull/4750/head
parent
32e7f99416
commit
b5f9e37c70
@ -1,54 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.modules.loss import _Loss
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
from colossalai.context import Config
|
|
||||||
|
|
||||||
from .amp_type import AMP_TYPE
|
|
||||||
from .apex_amp import convert_to_apex_amp
|
|
||||||
from .naive_amp import convert_to_naive_amp
|
|
||||||
from .torch_amp import convert_to_torch_amp
|
|
||||||
|
|
||||||
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
|
|
||||||
"""A helper function to wrap training components with Torch AMP modules.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
param model (:class:`torch.nn.Module`): your model object.
|
|
||||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
|
|
||||||
criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
|
|
||||||
mode (:class:`colossalai.amp.AMP_TYPE`): amp mode.
|
|
||||||
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (model, optimizer, criterion).
|
|
||||||
|
|
||||||
Note:
|
|
||||||
``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
|
|
||||||
for more details about ``amp_config``.
|
|
||||||
For ``apex_amp``, please check
|
|
||||||
`apex_amp config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
|
|
||||||
For ``naive_amp``, please check
|
|
||||||
`naive_amp config <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/amp/naive_amp/_fp16_optimizer.py#L42>`_.
|
|
||||||
For ``torch_amp``, please check
|
|
||||||
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
|
|
||||||
"""
|
|
||||||
assert isinstance(mode, AMP_TYPE), \
|
|
||||||
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
|
|
||||||
|
|
||||||
if amp_config is None:
|
|
||||||
amp_config = Config()
|
|
||||||
|
|
||||||
if mode == AMP_TYPE.TORCH:
|
|
||||||
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
|
|
||||||
elif mode == AMP_TYPE.APEX:
|
|
||||||
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
|
|
||||||
elif mode == AMP_TYPE.NAIVE:
|
|
||||||
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
|
|
||||||
|
|
||||||
return model, optimizer, criterion
|
|
@ -1,60 +0,0 @@
|
|||||||
import inspect
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
from colossalai.utils import is_no_pp_or_last_stage
|
|
||||||
|
|
||||||
from ._fp16_optimizer import FP16Optimizer
|
|
||||||
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
|
|
||||||
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
|
||||||
"""A helper function to wrap training components with naive AMP modules. In this mode,
|
|
||||||
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
|
|
||||||
which is equivalent to Apex O3.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (:class:`torch.nn.Module`): your model object
|
|
||||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
|
|
||||||
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple: A tuple (model, optimizer)
|
|
||||||
|
|
||||||
The ``amp_config`` should contain parameters below::
|
|
||||||
|
|
||||||
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
|
|
||||||
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
|
|
||||||
Note that clipping is ignored if clip_grad == 0.
|
|
||||||
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
|
|
||||||
"""
|
|
||||||
if isinstance(model, nn.ModuleList):
|
|
||||||
# interleaved pipeline
|
|
||||||
module_list = []
|
|
||||||
for chunk, m in enumerate(model):
|
|
||||||
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
|
|
||||||
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
|
|
||||||
model = nn.ModuleList(module_list)
|
|
||||||
else:
|
|
||||||
output_to_fp32 = is_no_pp_or_last_stage()
|
|
||||||
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
|
||||||
|
|
||||||
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
|
|
||||||
if use_dynamic_grad_scaler:
|
|
||||||
scaler_class = DynamicGradScaler
|
|
||||||
else:
|
|
||||||
scaler_class = ConstantGradScaler
|
|
||||||
|
|
||||||
sig = inspect.signature(scaler_class.__init__)
|
|
||||||
kwargs = dict()
|
|
||||||
for param in sig.parameters.values():
|
|
||||||
if param.name in amp_config:
|
|
||||||
kwargs[param.name] = amp_config.pop(param.name)
|
|
||||||
grad_scaler = scaler_class(**kwargs)
|
|
||||||
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
|
|
||||||
return model, optimizer
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
|
|
@ -1,28 +0,0 @@
|
|||||||
import click
|
|
||||||
|
|
||||||
from colossalai.context import Config
|
|
||||||
|
|
||||||
from .benchmark import run_benchmark
|
|
||||||
from .utils import *
|
|
||||||
|
|
||||||
__all__ = ['benchmark']
|
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
|
||||||
@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.")
|
|
||||||
@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.")
|
|
||||||
@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.")
|
|
||||||
@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.")
|
|
||||||
@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.")
|
|
||||||
@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.")
|
|
||||||
@click.option("-l", "--layers", type=int, default=2)
|
|
||||||
@click.option("-m",
|
|
||||||
"--model",
|
|
||||||
type=click.Choice(['mlp'], case_sensitive=False),
|
|
||||||
default='mlp',
|
|
||||||
help="Select the model to benchmark, currently only supports MLP")
|
|
||||||
def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int,
|
|
||||||
layers: int, model: str):
|
|
||||||
args_dict = locals()
|
|
||||||
args = Config(args_dict)
|
|
||||||
run_benchmark(args)
|
|
@ -1,105 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import click
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model
|
|
||||||
from colossalai.context import Config
|
|
||||||
from colossalai.context.random import reset_seeds
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
||||||
from colossalai.testing import free_port
|
|
||||||
from colossalai.utils import MultiTimer
|
|
||||||
|
|
||||||
from .models import MLP
|
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(args: Config) -> None:
|
|
||||||
"""
|
|
||||||
Run benchmarking with torch.multiprocessing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# sanity checks
|
|
||||||
if args.gpus is None:
|
|
||||||
click.echo("Error: --num_gpus is not given")
|
|
||||||
exit()
|
|
||||||
if args.gpus <= 1:
|
|
||||||
click.echo("Warning: tensor parallel will be activated with at least 2 devices.")
|
|
||||||
|
|
||||||
click.echo("=== Benchmarking Parameters ===")
|
|
||||||
for k, v in args.items():
|
|
||||||
click.echo(f'{k}: {v}')
|
|
||||||
click.echo('')
|
|
||||||
|
|
||||||
config_list = find_all_configs(args.gpus)
|
|
||||||
|
|
||||||
avail_ports = [free_port() for _ in range(len(config_list))]
|
|
||||||
run_func = partial(run_dist_profiling,
|
|
||||||
world_size=args.gpus,
|
|
||||||
port_list=avail_ports,
|
|
||||||
config_list=config_list,
|
|
||||||
hyperparams=args)
|
|
||||||
mp.spawn(run_func, nprocs=args.gpus)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict],
|
|
||||||
hyperparams: Config) -> None:
|
|
||||||
"""
|
|
||||||
A function executed for profiling, this function should be spawn by torch.multiprocessing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rank (int): rank of the process
|
|
||||||
world_size (int): the number of processes
|
|
||||||
port_list (List[int]): a list of free ports for initializing distributed networks
|
|
||||||
config_list (List[Dict]): a list of configuration
|
|
||||||
hyperparams (Config): the hyperparameters given by the user
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# disable logging for clean output
|
|
||||||
disable_existing_loggers()
|
|
||||||
logger = get_dist_logger()
|
|
||||||
logger.set_level('WARNING')
|
|
||||||
|
|
||||||
for config, port in zip(config_list, port_list):
|
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
timer = MultiTimer()
|
|
||||||
|
|
||||||
# 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size.
|
|
||||||
if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0:
|
|
||||||
click.echo(
|
|
||||||
"1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if hyperparams.model == 'mlp':
|
|
||||||
model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers)
|
|
||||||
else:
|
|
||||||
if gpc.get_global_rank() == 0:
|
|
||||||
click.echo("Error: Invalid argument for --model")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
data_func = partial(get_batch_data,
|
|
||||||
dim=hyperparams.dimension,
|
|
||||||
batch_size=hyperparams.batch_size,
|
|
||||||
seq_length=hyperparams.seq_len,
|
|
||||||
mode=config.parallel.tensor.mode)
|
|
||||||
|
|
||||||
fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model,
|
|
||||||
warmup_steps=hyperparams.warmup_steps,
|
|
||||||
profile_steps=hyperparams.profile_steps,
|
|
||||||
data_func=data_func,
|
|
||||||
timer=timer)
|
|
||||||
|
|
||||||
gpc.destroy()
|
|
||||||
reset_seeds()
|
|
||||||
|
|
||||||
if gpc.get_global_rank() == 0:
|
|
||||||
config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()])
|
|
||||||
click.echo(f"=== {config_str} ===")
|
|
||||||
click.echo(f"Average forward time: {fwd_time}")
|
|
||||||
click.echo(f"Average backward time: {bwd_time}")
|
|
||||||
click.echo(f"Max allocated GPU memory: {max_allocated}")
|
|
||||||
click.echo(f"Max cached GPU memory: {max_cached}\n")
|
|
@ -1,18 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
import colossalai.legacy.nn as col_nn
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim: int, layers: int):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = torch.nn.ModuleList()
|
|
||||||
|
|
||||||
for _ in range(layers):
|
|
||||||
self.layers.append(col_nn.Linear(dim, dim))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x)
|
|
||||||
return x
|
|
@ -1,159 +0,0 @@
|
|||||||
import math
|
|
||||||
import time
|
|
||||||
from typing import Callable, Dict, List, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.context import Config, ParallelMode
|
|
||||||
from colossalai.utils import MultiTimer
|
|
||||||
|
|
||||||
|
|
||||||
def get_time_stamp() -> int:
|
|
||||||
"""
|
|
||||||
Return the time stamp for profiling.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
time_stamp (int): the time given by time.time()
|
|
||||||
"""
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
time_stamp = time.time()
|
|
||||||
return time_stamp
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_states() -> Tuple[float]:
|
|
||||||
"""
|
|
||||||
Return the memory statistics.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
max_allocated (float): the allocated CUDA memory
|
|
||||||
max_cached (float): the cached CUDA memory
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
|
|
||||||
max_cached = torch.cuda.max_memory_reserved() / (1024**3)
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return max_allocated, max_cached
|
|
||||||
|
|
||||||
|
|
||||||
def find_all_configs(device_cnt: int) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Find all possible configurations for tensor parallelism
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device_cnt (int): the number of devices
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
config_list (List[Dict]): a list of configurations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _is_square(num):
|
|
||||||
# 2D parallel should be implemented with at least 2 devices.
|
|
||||||
if num <= 1:
|
|
||||||
return False
|
|
||||||
return math.floor(math.sqrt(num))**2 == num
|
|
||||||
|
|
||||||
def _is_cube(num):
|
|
||||||
# 3D parallel should be implemented with at least 2 devices.
|
|
||||||
if num <= 1:
|
|
||||||
return False
|
|
||||||
return math.floor(num**(1. / 3.))**3 == num
|
|
||||||
|
|
||||||
config_list = []
|
|
||||||
|
|
||||||
# add non-parallel config
|
|
||||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
|
|
||||||
config_list.append(config)
|
|
||||||
|
|
||||||
# add 1D config
|
|
||||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
|
|
||||||
config_list.append(config)
|
|
||||||
|
|
||||||
# add 2D config only if device_cnt is a square
|
|
||||||
if _is_square(device_cnt):
|
|
||||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
|
|
||||||
config_list.append(config)
|
|
||||||
|
|
||||||
# check for 2.5D
|
|
||||||
# iterate over depth
|
|
||||||
for depth in range(1, device_cnt):
|
|
||||||
if device_cnt % depth == 0 and _is_square(device_cnt // depth):
|
|
||||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
|
|
||||||
config_list.append(config)
|
|
||||||
|
|
||||||
# check for 3D if device_cnt is a cube
|
|
||||||
if _is_cube(device_cnt):
|
|
||||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
|
|
||||||
config_list.append(config)
|
|
||||||
|
|
||||||
config_list = [Config(cfg) for cfg in config_list]
|
|
||||||
return config_list
|
|
||||||
|
|
||||||
|
|
||||||
def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
|
|
||||||
timer: MultiTimer) -> Tuple[float]:
|
|
||||||
"""
|
|
||||||
Profile the forward and backward of a model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (torch.nn.Module): a PyTorch model
|
|
||||||
warmup_steps (int): the number of steps for warmup
|
|
||||||
profile_steps (int): the number of steps for profiling
|
|
||||||
data_func (Callable): a function to generate random data
|
|
||||||
timer (colossalai.utils.Multitimer): a timer instance for time recording
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
fwd_time (float): the average forward time taken by forward pass in second
|
|
||||||
bwd_time (float): the average backward time taken by forward pass in second
|
|
||||||
max_allocated (float): the maximum GPU memory allocated in GB
|
|
||||||
max_cached (float): the maximum GPU memory cached in GB
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _run_step(data):
|
|
||||||
timer.start('forward')
|
|
||||||
out = model(data)
|
|
||||||
timer.stop('forward', keep_in_history=True)
|
|
||||||
timer.start('backward')
|
|
||||||
out.mean().backward()
|
|
||||||
timer.stop('backward', keep_in_history=True)
|
|
||||||
|
|
||||||
data_list = [data_func() for _ in range(warmup_steps)]
|
|
||||||
for data in data_list:
|
|
||||||
_run_step(data)
|
|
||||||
timer.reset('forward')
|
|
||||||
timer.reset('backward')
|
|
||||||
|
|
||||||
for _ in range(profile_steps):
|
|
||||||
data = data_func()
|
|
||||||
_run_step(data)
|
|
||||||
|
|
||||||
max_allocated, max_cached = get_memory_states()
|
|
||||||
fwd_time = timer.get_timer('forward').get_history_mean()
|
|
||||||
bwd_time = timer.get_timer('backward').get_history_mean()
|
|
||||||
return fwd_time, bwd_time, max_allocated, max_cached
|
|
||||||
|
|
||||||
|
|
||||||
def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Return a random data of shape (batch_size, seq_length, dim) for profiling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): hidden size
|
|
||||||
batch_size (int): the number of data samples
|
|
||||||
seq_length (int): the number of tokens
|
|
||||||
mode (ParallelMode): Colossal-AI ParallelMode enum
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
data (torch.Tensor): random data
|
|
||||||
"""
|
|
||||||
|
|
||||||
if mode in ['2d', '2.5d']:
|
|
||||||
batch_size = batch_size // 2
|
|
||||||
dim = dim // 2
|
|
||||||
elif mode == '3d':
|
|
||||||
batch_size = batch_size // 4
|
|
||||||
dim = dim // 2
|
|
||||||
|
|
||||||
data = torch.rand(batch_size, seq_length, dim).cuda()
|
|
||||||
return data
|
|
@ -1,6 +1,8 @@
|
|||||||
from .config import Config, ConfigException
|
from .config import Config, ConfigException
|
||||||
from .parallel_context import ParallelContext
|
|
||||||
from .parallel_mode import ParallelMode
|
# from .moe_context import MOE_CONTEXT
|
||||||
from .moe_context import MOE_CONTEXT
|
|
||||||
from .process_group_initializer import *
|
__all__ = [
|
||||||
from .random import *
|
'Config',
|
||||||
|
'ConfigException',
|
||||||
|
]
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from colossalai.context.parallel_context import global_context
|
|
||||||
|
|
||||||
__all__ = ['global_context']
|
|
@ -0,0 +1,9 @@
|
|||||||
|
from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'launch',
|
||||||
|
'launch_from_openmpi',
|
||||||
|
'launch_from_slurm',
|
||||||
|
'launch_from_torch',
|
||||||
|
'initialize',
|
||||||
|
]
|
@ -0,0 +1,54 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from colossalai.context import Config
|
||||||
|
|
||||||
|
from .amp_type import AMP_TYPE
|
||||||
|
from .apex_amp import convert_to_apex_amp
|
||||||
|
from .naive_amp import convert_to_naive_amp
|
||||||
|
from .torch_amp import convert_to_torch_amp
|
||||||
|
|
||||||
|
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
|
||||||
|
"""A helper function to wrap training components with Torch AMP modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param model (:class:`torch.nn.Module`): your model object.
|
||||||
|
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
|
||||||
|
criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
|
||||||
|
mode (:class:`colossalai.legacy.amp.AMP_TYPE`): amp mode.
|
||||||
|
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (model, optimizer, criterion).
|
||||||
|
|
||||||
|
Note:
|
||||||
|
``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
|
||||||
|
for more details about ``amp_config``.
|
||||||
|
For ``apex_amp``, please check
|
||||||
|
`apex_amp config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
|
||||||
|
For ``naive_amp``, please check
|
||||||
|
`naive_amp config <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/amp/naive_amp/_fp16_optimizer.py#L42>`_.
|
||||||
|
For ``torch_amp``, please check
|
||||||
|
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
|
||||||
|
"""
|
||||||
|
assert isinstance(mode, AMP_TYPE), \
|
||||||
|
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
|
||||||
|
|
||||||
|
if amp_config is None:
|
||||||
|
amp_config = Config()
|
||||||
|
|
||||||
|
if mode == AMP_TYPE.TORCH:
|
||||||
|
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
|
||||||
|
elif mode == AMP_TYPE.APEX:
|
||||||
|
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
|
||||||
|
elif mode == AMP_TYPE.NAIVE:
|
||||||
|
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
|
||||||
|
|
||||||
|
return model, optimizer, criterion
|
@ -0,0 +1,60 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from colossalai.amp.naive_amp.grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||||
|
from colossalai.legacy.utils import is_no_pp_or_last_stage
|
||||||
|
|
||||||
|
from ._fp16_optimizer import FP16Optimizer
|
||||||
|
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||||
|
"""A helper function to wrap training components with naive AMP modules. In this mode,
|
||||||
|
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
|
||||||
|
which is equivalent to Apex O3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:class:`torch.nn.Module`): your model object
|
||||||
|
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
|
||||||
|
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: A tuple (model, optimizer)
|
||||||
|
|
||||||
|
The ``amp_config`` should contain parameters below::
|
||||||
|
|
||||||
|
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
|
||||||
|
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
|
||||||
|
Note that clipping is ignored if clip_grad == 0.
|
||||||
|
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
|
||||||
|
"""
|
||||||
|
if isinstance(model, nn.ModuleList):
|
||||||
|
# interleaved pipeline
|
||||||
|
module_list = []
|
||||||
|
for chunk, m in enumerate(model):
|
||||||
|
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
|
||||||
|
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
|
||||||
|
model = nn.ModuleList(module_list)
|
||||||
|
else:
|
||||||
|
output_to_fp32 = is_no_pp_or_last_stage()
|
||||||
|
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
||||||
|
|
||||||
|
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
|
||||||
|
if use_dynamic_grad_scaler:
|
||||||
|
scaler_class = DynamicGradScaler
|
||||||
|
else:
|
||||||
|
scaler_class = ConstantGradScaler
|
||||||
|
|
||||||
|
sig = inspect.signature(scaler_class.__init__)
|
||||||
|
kwargs = dict()
|
||||||
|
for param in sig.parameters.values():
|
||||||
|
if param.name in amp_config:
|
||||||
|
kwargs[param.name] = amp_config.pop(param.name)
|
||||||
|
grad_scaler = scaler_class(**kwargs)
|
||||||
|
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
|
||||||
|
return model, optimizer
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
|
@ -0,0 +1,4 @@
|
|||||||
|
from .parallel_context import ParallelContext
|
||||||
|
from .parallel_mode import ParallelMode
|
||||||
|
from .process_group_initializer import *
|
||||||
|
from .random import *
|
@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from colossalai.legacy.context.parallel_context import global_context
|
||||||
|
|
||||||
|
__all__ = ['global_context']
|
@ -0,0 +1,472 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from colossalai.context import Config, ConfigException
|
||||||
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
|
||||||
|
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
|
||||||
|
from colossalai.legacy.builder.builder import build_gradient_handler
|
||||||
|
from colossalai.legacy.context import ParallelMode
|
||||||
|
from colossalai.legacy.core import global_context as gpc
|
||||||
|
from colossalai.legacy.engine import Engine
|
||||||
|
from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
|
||||||
|
from colossalai.legacy.engine.schedule import (
|
||||||
|
InterleavedPipelineSchedule,
|
||||||
|
NonPipelineSchedule,
|
||||||
|
PipelineSchedule,
|
||||||
|
get_tensor_shape,
|
||||||
|
)
|
||||||
|
from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
|
||||||
|
from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
|
||||||
|
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.utils.moe import sync_moe_model_param
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_parser():
|
||||||
|
"""Reads user command line and uses an argument parser to parse the input arguments.
|
||||||
|
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--config', type=str, help='path to the config file')
|
||||||
|
parser.add_argument('--host', type=str, help='the master address for distributed training')
|
||||||
|
parser.add_argument('--port', type=int, help='the master port for distributed training')
|
||||||
|
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
||||||
|
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
||||||
|
parser.add_argument('--local_rank', type=int, help='local rank on the node')
|
||||||
|
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def launch(config: Union[str, Path, Config, Dict],
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
backend: str = 'nccl',
|
||||||
|
local_rank: int = None,
|
||||||
|
seed: int = 1024,
|
||||||
|
verbose: bool = True):
|
||||||
|
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
|
||||||
|
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||||
|
rank (int): Rank for the default process group
|
||||||
|
world_size (int): World size of the default process group
|
||||||
|
host (str): The master address for distributed training
|
||||||
|
port (str): The master port for distributed training
|
||||||
|
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||||
|
local_rank (int, optional):
|
||||||
|
Rank for the process on the node and is used to set the default CUDA device,
|
||||||
|
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
|
||||||
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||||
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Raise exception when config type is wrong
|
||||||
|
"""
|
||||||
|
gpc.verbose = verbose
|
||||||
|
|
||||||
|
# set config
|
||||||
|
assert isinstance(config, (Config, str, Path, dict)), \
|
||||||
|
f'expected argument config to be Config, str or Path, but got {type(config)}'
|
||||||
|
if not isinstance(config, Config) and isinstance(config, dict):
|
||||||
|
config = Config(config)
|
||||||
|
if isinstance(config, (str, Path)):
|
||||||
|
config = Config.from_file(config)
|
||||||
|
gpc.load_config(config)
|
||||||
|
|
||||||
|
# init default process group
|
||||||
|
gpc.init_global_dist(rank, world_size, backend, host, port)
|
||||||
|
|
||||||
|
# init process groups for different parallel modes from config
|
||||||
|
gpc.init_parallel_groups()
|
||||||
|
|
||||||
|
# set cuda device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
# if local rank is not given, calculate automatically
|
||||||
|
gpc.set_device(local_rank)
|
||||||
|
|
||||||
|
# set the number of processes running on the same node
|
||||||
|
gpc.detect_num_processes_on_current_node()
|
||||||
|
|
||||||
|
gpc.set_seed(seed)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.info(
|
||||||
|
f'Distributed environment is initialized, '
|
||||||
|
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||||
|
f'tensor parallel size: {gpc.tensor_parallel_size}',
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
|
||||||
|
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
backend: str = 'nccl',
|
||||||
|
seed: int = 1024,
|
||||||
|
verbose: bool = True):
|
||||||
|
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
|
||||||
|
set by SLURM
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||||
|
host (str): The master address for distributed training
|
||||||
|
port (str): The master port for distributed training
|
||||||
|
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||||
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||||
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
rank = int(os.environ['SLURM_PROCID'])
|
||||||
|
world_size = int(os.environ['SLURM_NPROCS'])
|
||||||
|
except KeyError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
|
||||||
|
)
|
||||||
|
|
||||||
|
launch(config=config,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
backend=backend,
|
||||||
|
seed=seed,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
backend: str = 'nccl',
|
||||||
|
seed: int = 1024,
|
||||||
|
verbose: bool = True):
|
||||||
|
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
|
||||||
|
set by OpenMPI
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||||
|
host (str): The master address for distributed training
|
||||||
|
port (str): The master port for distributed training
|
||||||
|
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||||
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||||
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||||
|
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||||
|
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||||
|
except KeyError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
|
||||||
|
)
|
||||||
|
|
||||||
|
launch(config=config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
backend=backend,
|
||||||
|
seed=seed,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||||
|
backend: str = 'nccl',
|
||||||
|
seed: int = 1024,
|
||||||
|
verbose: bool = True):
|
||||||
|
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||||
|
from the environment variables set by PyTorch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||||
|
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||||
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||||
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
rank = int(os.environ['RANK'])
|
||||||
|
local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
host = os.environ['MASTER_ADDR']
|
||||||
|
port = int(os.environ['MASTER_PORT'])
|
||||||
|
except KeyError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||||
|
)
|
||||||
|
|
||||||
|
launch(config=config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
backend=backend,
|
||||||
|
seed=seed,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize(model: nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
criterion: Optional[_Loss] = None,
|
||||||
|
train_dataloader: Optional[Iterable] = None,
|
||||||
|
test_dataloader: Optional[Iterable] = None,
|
||||||
|
lr_scheduler: Optional[_LRScheduler] = None,
|
||||||
|
ophooks: Optional[List[BaseOpHook]] = None,
|
||||||
|
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||||
|
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||||
|
loaded into gpc.config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model.
|
||||||
|
optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
|
||||||
|
Your optimizer instance.
|
||||||
|
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
|
||||||
|
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
|
||||||
|
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
||||||
|
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
|
||||||
|
verbose (bool, optional): Whether to print logs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
|
||||||
|
A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
|
||||||
|
where only ``engine`` could not be None.
|
||||||
|
"""
|
||||||
|
# get logger
|
||||||
|
logger = get_dist_logger()
|
||||||
|
gpc.verbose = verbose
|
||||||
|
|
||||||
|
# get config from gpc
|
||||||
|
config = gpc.config
|
||||||
|
|
||||||
|
# print config
|
||||||
|
if verbose:
|
||||||
|
logger.info(
|
||||||
|
f"\n========== Your Config ========\n"
|
||||||
|
f"{pprint.pformat(gpc.config)}\n"
|
||||||
|
f"================================\n",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
# cudnn
|
||||||
|
cudnn_benchmark = config.get('cudnn_benchmark', False)
|
||||||
|
cudnn_deterministic = config.get('cudnn_deterministic', False)
|
||||||
|
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||||
|
torch.backends.cudnn.deterministic = cudnn_deterministic
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||||
|
|
||||||
|
# zero
|
||||||
|
use_zero = hasattr(gpc.config, 'zero')
|
||||||
|
if use_zero:
|
||||||
|
zero_cfg = gpc.config.get('zero', None)
|
||||||
|
if zero_cfg is not None:
|
||||||
|
cfg_ = zero_cfg.copy()
|
||||||
|
else:
|
||||||
|
cfg_ = {}
|
||||||
|
optimizer_config = zero_cfg.get('optimizer_config', None)
|
||||||
|
model_config = zero_cfg.get('model_config', None)
|
||||||
|
model, optimizer = convert_to_zero_v2(model,
|
||||||
|
optimizer,
|
||||||
|
model_config=model_config,
|
||||||
|
optimizer_config=optimizer_config)
|
||||||
|
|
||||||
|
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
|
||||||
|
else:
|
||||||
|
if isinstance(model, nn.Module):
|
||||||
|
# first sync model across dp ranks
|
||||||
|
model.to(get_current_device())
|
||||||
|
elif isinstance(model, Callable):
|
||||||
|
model = model().to(get_current_device())
|
||||||
|
|
||||||
|
# optimizer maybe a optimizer_cls
|
||||||
|
if isinstance(optimizer, Callable):
|
||||||
|
optimizer = optimizer(model.parameters())
|
||||||
|
logger.warning("Initializing an non ZeRO model with optimizer class")
|
||||||
|
|
||||||
|
if not use_zero:
|
||||||
|
if is_using_sequence():
|
||||||
|
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||||
|
elif MOE_CONTEXT.is_initialized:
|
||||||
|
sync_moe_model_param(model)
|
||||||
|
elif is_using_ddp():
|
||||||
|
sync_model_param(model, ParallelMode.DATA)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"The parameters of models is not automatically synchronized.\n"
|
||||||
|
"Please make sure that all parameters are the same in data parallel group.",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
# check amp and zero
|
||||||
|
fp16_cfg = gpc.config.get('fp16', None)
|
||||||
|
|
||||||
|
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
|
||||||
|
raise ConfigException(
|
||||||
|
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||||
|
|
||||||
|
# clip grad norm
|
||||||
|
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||||
|
|
||||||
|
# initialize amp
|
||||||
|
amp_mode = None
|
||||||
|
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||||
|
cfg_ = fp16_cfg.copy()
|
||||||
|
amp_mode = cfg_.pop('mode')
|
||||||
|
if is_using_pp():
|
||||||
|
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
||||||
|
if amp_mode == AMP_TYPE.NAIVE:
|
||||||
|
cfg_['clip_grad_norm'] = clip_grad_norm
|
||||||
|
model, optimizer, criterion = convert_to_amp(model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
mode=amp_mode,
|
||||||
|
amp_config=cfg_)
|
||||||
|
|
||||||
|
# get torch ddp config
|
||||||
|
torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
|
||||||
|
|
||||||
|
# gradient handler
|
||||||
|
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||||
|
if gradient_handler_cfg is None:
|
||||||
|
# if gradient handler is not specified in the configuration file,
|
||||||
|
# check in the following order
|
||||||
|
# 1. if optimizer is ZERO, then use zero grad handler
|
||||||
|
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||||
|
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||||
|
if isinstance(optimizer, ShardedOptimizerV2):
|
||||||
|
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||||
|
if verbose:
|
||||||
|
logger.info(
|
||||||
|
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||||
|
"added even though not specified in the configuration",
|
||||||
|
ranks=[0])
|
||||||
|
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
|
||||||
|
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
|
||||||
|
if verbose:
|
||||||
|
logger.info(
|
||||||
|
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
||||||
|
"added even though not specified in the configuration",
|
||||||
|
ranks=[0])
|
||||||
|
elif is_using_sequence():
|
||||||
|
model = DDP(model,
|
||||||
|
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||||
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
**torch_ddp_cfg)
|
||||||
|
if verbose:
|
||||||
|
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
||||||
|
ranks=[0])
|
||||||
|
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||||
|
model = DDP(model,
|
||||||
|
process_group=gpc.get_group(ParallelMode.DATA),
|
||||||
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
**torch_ddp_cfg)
|
||||||
|
if verbose:
|
||||||
|
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||||
|
elif is_using_ddp():
|
||||||
|
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||||
|
if verbose:
|
||||||
|
logger.info(
|
||||||
|
"Data parallel training is detected when using pipeline parallel, "
|
||||||
|
"DataParallelGradientHandler is automatically "
|
||||||
|
"added even though not specified in the configuration",
|
||||||
|
ranks=[0])
|
||||||
|
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
||||||
|
for param in model.parameters():
|
||||||
|
if getattr(param, 'pipeline_shared_module_pg', None) is not None:
|
||||||
|
if gradient_handler_cfg is None:
|
||||||
|
gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
|
||||||
|
else:
|
||||||
|
gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
|
||||||
|
if verbose:
|
||||||
|
logger.info(
|
||||||
|
"pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
|
||||||
|
"added even though not specified in the configuration",
|
||||||
|
ranks=[0])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if not isinstance(gradient_handler_cfg, list):
|
||||||
|
raise ConfigException(
|
||||||
|
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
|
||||||
|
# to avoid duplicated buffer synchronization
|
||||||
|
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
||||||
|
model.module.sync_buffer = False
|
||||||
|
|
||||||
|
# initialize schedule for engine
|
||||||
|
if is_using_pp():
|
||||||
|
tensor_shape = get_tensor_shape()
|
||||||
|
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
|
||||||
|
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
|
||||||
|
scatter_gather = True
|
||||||
|
else:
|
||||||
|
scatter_gather = False
|
||||||
|
if use_interleaved:
|
||||||
|
if isinstance(model, nn.Sequential):
|
||||||
|
model = nn.ModuleList([model])
|
||||||
|
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||||
|
gpc.config.model.num_chunks,
|
||||||
|
tensor_shape=tensor_shape,
|
||||||
|
scatter_gather_tensors=scatter_gather)
|
||||||
|
else:
|
||||||
|
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||||
|
tensor_shape=tensor_shape,
|
||||||
|
scatter_gather_tensors=scatter_gather)
|
||||||
|
else:
|
||||||
|
schedule = NonPipelineSchedule()
|
||||||
|
|
||||||
|
if gradient_handler_cfg is None:
|
||||||
|
gradient_handlers = None
|
||||||
|
if verbose and not isinstance(model, DDP):
|
||||||
|
logger.warning(
|
||||||
|
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
|
||||||
|
"to all-reduce the gradients after a training step.",
|
||||||
|
ranks=[0])
|
||||||
|
else:
|
||||||
|
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||||
|
|
||||||
|
# check if optimizer is OptimizerWrapper
|
||||||
|
if not isinstance(optimizer, (OptimizerWrapper, ShardedOptimizerV2)):
|
||||||
|
optimizer = OptimizerWrapper(optim=optimizer)
|
||||||
|
|
||||||
|
# gradient accumulation
|
||||||
|
grad_accum_size = gpc.config.get('gradient_accumulation', None)
|
||||||
|
if grad_accum_size is not None:
|
||||||
|
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
dataloader=train_dataloader,
|
||||||
|
accumulate_size=grad_accum_size,
|
||||||
|
gradient_handlers=gradient_handlers,
|
||||||
|
lr_scheduler=lr_scheduler)
|
||||||
|
engine = Engine(model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
gradient_handlers=gradient_handlers,
|
||||||
|
clip_grad_norm=clip_grad_norm,
|
||||||
|
ophook_list=ophooks,
|
||||||
|
schedule=schedule)
|
||||||
|
|
||||||
|
return engine, train_dataloader, test_dataloader, lr_scheduler
|
@ -1,4 +1,3 @@
|
|||||||
from ._ops import *
|
|
||||||
from .layer import *
|
from .layer import *
|
||||||
from .loss import *
|
from .loss import *
|
||||||
from .metric import *
|
from .metric import *
|
||||||
|
@ -1,9 +1 @@
|
|||||||
from .addmm import colo_addmm
|
from ._utils import *
|
||||||
from .batch_norm import colo_batch_norm
|
|
||||||
from .element_wise import *
|
|
||||||
from .embedding import colo_embedding
|
|
||||||
from .embedding_bag import colo_embedding_bag
|
|
||||||
from .layernorm import colo_layernorm
|
|
||||||
from .linear import colo_linear
|
|
||||||
from .loss import colo_cross_entropy
|
|
||||||
from .view import colo_view
|
|
||||||
|
@ -1,90 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input
|
|
||||||
|
|
||||||
|
|
||||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
|
||||||
alpha: Number) -> ColoTensor:
|
|
||||||
# mat1:S[1] x mat2:S[0] = Output:P
|
|
||||||
# beta * input + alpha * All-Reduce(Output) = res
|
|
||||||
|
|
||||||
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group())
|
|
||||||
|
|
||||||
# Output:P
|
|
||||||
partial_output = torch.mm(mat1, mat2)
|
|
||||||
# Reduce(Output)
|
|
||||||
output = reduce_input(partial_output, mat2.get_process_group())
|
|
||||||
# input
|
|
||||||
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
|
||||||
output = beta * input_tensor + alpha * output
|
|
||||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group()))
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
|
||||||
alpha: Number) -> ColoTensor:
|
|
||||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
|
||||||
compute_spec = mat2.compute_spec
|
|
||||||
mat1 = mat1.redistribute(ReplicaSpec())
|
|
||||||
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
|
||||||
|
|
||||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
|
||||||
output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
|
|
||||||
ComputeSpec(ComputePattern.TP1D))
|
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
|
||||||
|
|
||||||
if compute_spec.output_replicate:
|
|
||||||
return output.to_replicate()
|
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
|
||||||
alpha: Number) -> ColoTensor:
|
|
||||||
assert mode in ('row', 'col')
|
|
||||||
funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol}
|
|
||||||
return funcs[mode](input_tensor, mat1, mat2, beta, alpha)
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(torch.addmm)
|
|
||||||
def colo_addmm(input_tensor: GeneralTensor,
|
|
||||||
mat1: ColoTensor,
|
|
||||||
mat2: ColoTensor,
|
|
||||||
beta: Number = 1,
|
|
||||||
alpha: Number = 1,
|
|
||||||
**kargs) -> ColoTensor:
|
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
|
||||||
This method computes a linear.
|
|
||||||
"""
|
|
||||||
# At least one of the tensor should be ColoTensor
|
|
||||||
assert isinstance(mat2, ColoTensor)
|
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group())
|
|
||||||
mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group())
|
|
||||||
|
|
||||||
# Add communication logic before and after linear call.
|
|
||||||
ret_tensor = None
|
|
||||||
if not mat2.has_compute_spec(): # No Model Parallel Applied
|
|
||||||
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
|
|
||||||
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
|
|
||||||
ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor,
|
|
||||||
mat1,
|
|
||||||
mat2,
|
|
||||||
beta=beta,
|
|
||||||
alpha=alpha,
|
|
||||||
**kargs),
|
|
||||||
spec=ColoTensorSpec(mat2.get_process_group()))
|
|
||||||
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
|
||||||
if mat2.is_shard_1drow() and input_tensor.is_replicate():
|
|
||||||
mode = 'row'
|
|
||||||
elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()):
|
|
||||||
mode = 'col'
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
return ret_tensor
|
|
@ -1,33 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(F.batch_norm)
|
|
||||||
def colo_batch_norm(
|
|
||||||
input: GeneralTensor,
|
|
||||||
running_mean: Optional[GeneralTensor],
|
|
||||||
running_var: Optional[GeneralTensor],
|
|
||||||
weight: Optional[GeneralTensor] = None,
|
|
||||||
bias: Optional[GeneralTensor] = None,
|
|
||||||
training: bool = False,
|
|
||||||
momentum: float = 0.1,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
):
|
|
||||||
assert isinstance(weight, ColoTensor)
|
|
||||||
running_mean = running_mean.detach()
|
|
||||||
running_var = running_var.detach()
|
|
||||||
|
|
||||||
input = convert_to_colo_tensor(input, weight.get_process_group())
|
|
||||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
|
||||||
input = input.redistribute(ReplicaSpec())
|
|
||||||
bias = bias.redistribute(ReplicaSpec())
|
|
||||||
|
|
||||||
output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
|
|
||||||
output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group()))
|
|
||||||
return output
|
|
@ -1,250 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def register_elementwise_op(op):
|
|
||||||
|
|
||||||
@colo_op_impl(op)
|
|
||||||
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
|
||||||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
|
||||||
This method computes on either a normal tensor or a sharded tensor.
|
|
||||||
"""
|
|
||||||
if 'inplace' in kwargs:
|
|
||||||
# TODO(jiaruifang) inplace will cause bugs
|
|
||||||
input_tensor = input_tensor.clone()
|
|
||||||
return op(input_tensor, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
output = op(input_tensor, *args, **kwargs)
|
|
||||||
# return output
|
|
||||||
if isinstance(input_tensor, ColoTensor):
|
|
||||||
if isinstance(output, str):
|
|
||||||
return output
|
|
||||||
if not isinstance(output, torch.Tensor):
|
|
||||||
raise NotImplementedError
|
|
||||||
return ColoTensor.from_torch_tensor(output,
|
|
||||||
spec=ColoTensorSpec(input_tensor.get_process_group(),
|
|
||||||
dist_attr=input_tensor.dist_spec))
|
|
||||||
|
|
||||||
|
|
||||||
# @colo_op_impl(torch.relu_)
|
|
||||||
# def elementwise_op(input_tensor):
|
|
||||||
# torch.relu_(input_tensor.data)
|
|
||||||
# return input_tensor
|
|
||||||
|
|
||||||
# @colo_op_impl(Tensor.add_)
|
|
||||||
# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
|
|
||||||
# input_tensor = input_tensor.data.add_(*args, **kwargs)
|
|
||||||
# return input_tensor
|
|
||||||
|
|
||||||
# Tensor op
|
|
||||||
register_elementwise_op(Tensor.abs)
|
|
||||||
register_elementwise_op(Tensor.absolute)
|
|
||||||
register_elementwise_op(Tensor.acos)
|
|
||||||
register_elementwise_op(Tensor.arccos)
|
|
||||||
register_elementwise_op(Tensor.angle)
|
|
||||||
register_elementwise_op(Tensor.asin)
|
|
||||||
register_elementwise_op(Tensor.arcsin)
|
|
||||||
register_elementwise_op(Tensor.atan)
|
|
||||||
register_elementwise_op(Tensor.arctan)
|
|
||||||
register_elementwise_op(Tensor.all)
|
|
||||||
register_elementwise_op(Tensor.any)
|
|
||||||
register_elementwise_op(Tensor.bernoulli)
|
|
||||||
register_elementwise_op(Tensor.bfloat16)
|
|
||||||
register_elementwise_op(Tensor.bitwise_not)
|
|
||||||
register_elementwise_op(Tensor.bool)
|
|
||||||
register_elementwise_op(Tensor.byte)
|
|
||||||
register_elementwise_op(Tensor.ceil)
|
|
||||||
register_elementwise_op(Tensor.char)
|
|
||||||
register_elementwise_op(Tensor.clamp)
|
|
||||||
register_elementwise_op(Tensor.clamp_max)
|
|
||||||
register_elementwise_op(Tensor.clamp_min)
|
|
||||||
register_elementwise_op(Tensor.clip)
|
|
||||||
register_elementwise_op(Tensor.clone)
|
|
||||||
register_elementwise_op(Tensor.contiguous)
|
|
||||||
register_elementwise_op(Tensor.copysign)
|
|
||||||
register_elementwise_op(Tensor.cos)
|
|
||||||
register_elementwise_op(Tensor.cosh)
|
|
||||||
register_elementwise_op(Tensor.acosh)
|
|
||||||
register_elementwise_op(Tensor.arccosh)
|
|
||||||
register_elementwise_op(Tensor.cpu)
|
|
||||||
register_elementwise_op(Tensor.cuda)
|
|
||||||
register_elementwise_op(Tensor.deg2rad)
|
|
||||||
register_elementwise_op(Tensor.detach)
|
|
||||||
register_elementwise_op(Tensor.digamma)
|
|
||||||
register_elementwise_op(Tensor.double)
|
|
||||||
register_elementwise_op(Tensor.erf)
|
|
||||||
register_elementwise_op(Tensor.erfc)
|
|
||||||
register_elementwise_op(Tensor.erfinv)
|
|
||||||
register_elementwise_op(Tensor.exp)
|
|
||||||
register_elementwise_op(Tensor.expm1)
|
|
||||||
register_elementwise_op(Tensor.fix)
|
|
||||||
register_elementwise_op(Tensor.trunc)
|
|
||||||
register_elementwise_op(Tensor.float)
|
|
||||||
register_elementwise_op(Tensor.float_power)
|
|
||||||
register_elementwise_op(Tensor.floor)
|
|
||||||
register_elementwise_op(Tensor.frac)
|
|
||||||
register_elementwise_op(Tensor.half)
|
|
||||||
register_elementwise_op(Tensor.hardshrink)
|
|
||||||
register_elementwise_op(Tensor.heaviside)
|
|
||||||
register_elementwise_op(Tensor.i0)
|
|
||||||
register_elementwise_op(Tensor.int)
|
|
||||||
register_elementwise_op(Tensor.isfinite)
|
|
||||||
register_elementwise_op(Tensor.isinf)
|
|
||||||
register_elementwise_op(Tensor.isposinf)
|
|
||||||
register_elementwise_op(Tensor.isneginf)
|
|
||||||
register_elementwise_op(Tensor.isnan)
|
|
||||||
register_elementwise_op(Tensor.lgamma)
|
|
||||||
register_elementwise_op(Tensor.log)
|
|
||||||
register_elementwise_op(Tensor.log10)
|
|
||||||
register_elementwise_op(Tensor.log1p)
|
|
||||||
register_elementwise_op(Tensor.log2)
|
|
||||||
register_elementwise_op(Tensor.logical_not)
|
|
||||||
register_elementwise_op(Tensor.logit)
|
|
||||||
register_elementwise_op(Tensor.long)
|
|
||||||
register_elementwise_op(Tensor.nan_to_num)
|
|
||||||
register_elementwise_op(Tensor.neg)
|
|
||||||
register_elementwise_op(Tensor.negative)
|
|
||||||
register_elementwise_op(Tensor.positive)
|
|
||||||
register_elementwise_op(Tensor.pow)
|
|
||||||
register_elementwise_op(Tensor.rad2deg)
|
|
||||||
register_elementwise_op(Tensor.reciprocal)
|
|
||||||
register_elementwise_op(Tensor.round)
|
|
||||||
register_elementwise_op(Tensor.rsqrt)
|
|
||||||
register_elementwise_op(Tensor.short)
|
|
||||||
register_elementwise_op(Tensor.sigmoid)
|
|
||||||
register_elementwise_op(Tensor.sign)
|
|
||||||
register_elementwise_op(Tensor.signbit)
|
|
||||||
register_elementwise_op(Tensor.sgn)
|
|
||||||
register_elementwise_op(Tensor.sin)
|
|
||||||
register_elementwise_op(Tensor.sinc)
|
|
||||||
register_elementwise_op(Tensor.sinh)
|
|
||||||
register_elementwise_op(Tensor.asinh)
|
|
||||||
register_elementwise_op(Tensor.arcsinh)
|
|
||||||
register_elementwise_op(Tensor.sqrt)
|
|
||||||
register_elementwise_op(Tensor.square)
|
|
||||||
register_elementwise_op(Tensor.to)
|
|
||||||
register_elementwise_op(Tensor.tan)
|
|
||||||
register_elementwise_op(Tensor.tanh)
|
|
||||||
register_elementwise_op(Tensor.atanh)
|
|
||||||
register_elementwise_op(Tensor.arctanh)
|
|
||||||
register_elementwise_op(Tensor.type)
|
|
||||||
register_elementwise_op(Tensor.type_as)
|
|
||||||
|
|
||||||
# torch OP
|
|
||||||
register_elementwise_op(torch.abs)
|
|
||||||
register_elementwise_op(torch.absolute)
|
|
||||||
register_elementwise_op(torch.acos)
|
|
||||||
register_elementwise_op(torch.arccos)
|
|
||||||
register_elementwise_op(torch.angle)
|
|
||||||
register_elementwise_op(torch.asin)
|
|
||||||
register_elementwise_op(torch.arcsin)
|
|
||||||
register_elementwise_op(torch.atan)
|
|
||||||
register_elementwise_op(torch.arctan)
|
|
||||||
register_elementwise_op(torch.all)
|
|
||||||
register_elementwise_op(torch.any)
|
|
||||||
register_elementwise_op(torch.bernoulli)
|
|
||||||
register_elementwise_op(torch.bitwise_not)
|
|
||||||
register_elementwise_op(torch.ceil)
|
|
||||||
register_elementwise_op(torch.clamp)
|
|
||||||
register_elementwise_op(torch.clamp_max)
|
|
||||||
register_elementwise_op(torch.clamp_min)
|
|
||||||
register_elementwise_op(torch.clip)
|
|
||||||
register_elementwise_op(torch.clone)
|
|
||||||
register_elementwise_op(torch.copysign)
|
|
||||||
register_elementwise_op(torch.cos)
|
|
||||||
register_elementwise_op(torch.cosh)
|
|
||||||
register_elementwise_op(torch.acosh)
|
|
||||||
register_elementwise_op(torch.arccosh)
|
|
||||||
register_elementwise_op(torch.deg2rad)
|
|
||||||
register_elementwise_op(torch.digamma)
|
|
||||||
register_elementwise_op(torch.erf)
|
|
||||||
register_elementwise_op(torch.erfc)
|
|
||||||
register_elementwise_op(torch.erfinv)
|
|
||||||
register_elementwise_op(torch.exp)
|
|
||||||
register_elementwise_op(torch.expm1)
|
|
||||||
register_elementwise_op(torch.fix)
|
|
||||||
register_elementwise_op(torch.trunc)
|
|
||||||
register_elementwise_op(torch.float_power)
|
|
||||||
register_elementwise_op(torch.floor)
|
|
||||||
register_elementwise_op(torch.frac)
|
|
||||||
register_elementwise_op(torch.hardshrink)
|
|
||||||
register_elementwise_op(torch.heaviside)
|
|
||||||
register_elementwise_op(torch.i0)
|
|
||||||
register_elementwise_op(torch.isfinite)
|
|
||||||
register_elementwise_op(torch.isinf)
|
|
||||||
register_elementwise_op(torch.isposinf)
|
|
||||||
register_elementwise_op(torch.isneginf)
|
|
||||||
register_elementwise_op(torch.isnan)
|
|
||||||
register_elementwise_op(torch.lgamma)
|
|
||||||
register_elementwise_op(torch.log)
|
|
||||||
register_elementwise_op(torch.log10)
|
|
||||||
register_elementwise_op(torch.log1p)
|
|
||||||
register_elementwise_op(torch.log2)
|
|
||||||
register_elementwise_op(torch.logical_not)
|
|
||||||
register_elementwise_op(torch.logit)
|
|
||||||
register_elementwise_op(torch.nan_to_num)
|
|
||||||
register_elementwise_op(torch.neg)
|
|
||||||
register_elementwise_op(torch.negative)
|
|
||||||
register_elementwise_op(torch.positive)
|
|
||||||
register_elementwise_op(torch.pow)
|
|
||||||
register_elementwise_op(torch.rad2deg)
|
|
||||||
register_elementwise_op(torch.reciprocal)
|
|
||||||
register_elementwise_op(torch.round)
|
|
||||||
register_elementwise_op(torch.rsqrt)
|
|
||||||
register_elementwise_op(torch.sigmoid)
|
|
||||||
register_elementwise_op(torch.sign)
|
|
||||||
register_elementwise_op(torch.signbit)
|
|
||||||
register_elementwise_op(torch.sgn)
|
|
||||||
register_elementwise_op(torch.sin)
|
|
||||||
register_elementwise_op(torch.sinc)
|
|
||||||
register_elementwise_op(torch.sinh)
|
|
||||||
register_elementwise_op(torch.asinh)
|
|
||||||
register_elementwise_op(torch.arcsinh)
|
|
||||||
register_elementwise_op(torch.sqrt)
|
|
||||||
register_elementwise_op(torch.square)
|
|
||||||
register_elementwise_op(torch.tan)
|
|
||||||
register_elementwise_op(torch.tanh)
|
|
||||||
register_elementwise_op(torch.atanh)
|
|
||||||
register_elementwise_op(torch.arctanh)
|
|
||||||
register_elementwise_op(torch.zeros_like)
|
|
||||||
|
|
||||||
# nn.functional OP
|
|
||||||
register_elementwise_op(F.threshold)
|
|
||||||
register_elementwise_op(F.relu)
|
|
||||||
register_elementwise_op(F.hardtanh)
|
|
||||||
register_elementwise_op(F.hardswish)
|
|
||||||
register_elementwise_op(F.relu6)
|
|
||||||
register_elementwise_op(F.elu)
|
|
||||||
register_elementwise_op(F.selu)
|
|
||||||
register_elementwise_op(F.celu)
|
|
||||||
register_elementwise_op(F.leaky_relu)
|
|
||||||
register_elementwise_op(F.prelu)
|
|
||||||
register_elementwise_op(F.rrelu)
|
|
||||||
register_elementwise_op(F.gelu)
|
|
||||||
register_elementwise_op(F.logsigmoid)
|
|
||||||
register_elementwise_op(F.hardshrink)
|
|
||||||
register_elementwise_op(F.tanhshrink)
|
|
||||||
register_elementwise_op(F.softsign)
|
|
||||||
register_elementwise_op(F.softplus)
|
|
||||||
register_elementwise_op(F.softmin)
|
|
||||||
register_elementwise_op(F.softmax)
|
|
||||||
register_elementwise_op(F.softshrink)
|
|
||||||
register_elementwise_op(F.gumbel_softmax)
|
|
||||||
register_elementwise_op(F.log_softmax)
|
|
||||||
register_elementwise_op(F.tanh)
|
|
||||||
register_elementwise_op(F.sigmoid)
|
|
||||||
register_elementwise_op(F.hardsigmoid)
|
|
||||||
register_elementwise_op(F.silu)
|
|
||||||
register_elementwise_op(F.mish)
|
|
||||||
# TODO(ver217): dropout handles seed
|
|
||||||
register_elementwise_op(F.dropout)
|
|
||||||
register_elementwise_op(F.alpha_dropout)
|
|
||||||
register_elementwise_op(F.feature_alpha_dropout)
|
|
@ -1,142 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
|
|
||||||
|
|
||||||
|
|
||||||
def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|
||||||
weight: ColoTensor,
|
|
||||||
padding_idx: Optional[int] = None,
|
|
||||||
max_norm: Optional[float] = None,
|
|
||||||
norm_type: float = 2.0,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
sparse: bool = False) -> ColoTensor:
|
|
||||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
|
||||||
# Gather splitted lookup table
|
|
||||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
|
||||||
|
|
||||||
output_parallel = F.embedding(input_tensor,
|
|
||||||
weight,
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
sparse=sparse)
|
|
||||||
output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
|
|
||||||
ComputeSpec(ComputePattern.TP1D))
|
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
|
||||||
|
|
||||||
compute_spec = weight.compute_spec
|
|
||||||
|
|
||||||
if compute_spec.output_replicate:
|
|
||||||
return output.to_replicate()
|
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|
||||||
weight: ColoTensor,
|
|
||||||
padding_idx: Optional[int] = None,
|
|
||||||
max_norm: Optional[float] = None,
|
|
||||||
norm_type: float = 2.0,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
sparse: bool = False) -> ColoTensor:
|
|
||||||
# embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim]
|
|
||||||
# get the index of current segment and mask other segments with 0
|
|
||||||
|
|
||||||
# get complete input tensor through all-gather
|
|
||||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
|
||||||
|
|
||||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
|
|
||||||
num_embeddings_per_partition = weight.size_local(0)
|
|
||||||
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
|
||||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
|
||||||
|
|
||||||
# build the mask.
|
|
||||||
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
|
|
||||||
# mask the input.
|
|
||||||
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
|
|
||||||
masked_input = input_tensor - vocab_start_index
|
|
||||||
masked_input[input_mask] = 0
|
|
||||||
|
|
||||||
partial_output = F.embedding(masked_input,
|
|
||||||
weight,
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
sparse=sparse)
|
|
||||||
|
|
||||||
# Mask the output embedding.
|
|
||||||
partial_output[input_mask, :] = 0.
|
|
||||||
# Reduce across all the model parallel GPUs.
|
|
||||||
output = reduce_input(partial_output, weight.get_process_group())
|
|
||||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_embedding_1d(mode: str,
|
|
||||||
input_tensor: ColoTensor,
|
|
||||||
weight: ColoTensor,
|
|
||||||
padding_idx: Optional[int] = None,
|
|
||||||
max_norm: Optional[float] = None,
|
|
||||||
norm_type: float = 2.0,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
sparse: bool = False) -> ColoTensor:
|
|
||||||
assert mode in ('row', 'col')
|
|
||||||
funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
|
|
||||||
return funcs[mode](input_tensor,
|
|
||||||
weight,
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
sparse=sparse)
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(F.embedding)
|
|
||||||
def colo_embedding(input_tensor: GeneralTensor,
|
|
||||||
weight: GeneralTensor,
|
|
||||||
padding_idx: Optional[int] = None,
|
|
||||||
max_norm: Optional[float] = None,
|
|
||||||
norm_type: float = 2.0,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
sparse: bool = False):
|
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
|
|
||||||
This method looks up an embedding table.
|
|
||||||
"""
|
|
||||||
assert isinstance(weight, ColoTensor)
|
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
|
||||||
|
|
||||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
|
||||||
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
|
||||||
return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor,
|
|
||||||
weight,
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
sparse=sparse),
|
|
||||||
spec=ColoTensorSpec(weight.get_process_group()))
|
|
||||||
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
|
||||||
if weight.is_shard_1drow():
|
|
||||||
mode = 'row'
|
|
||||||
elif weight.is_shard_1dcol():
|
|
||||||
mode = 'col'
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
return colo_embedding_1d(mode,
|
|
||||||
input_tensor,
|
|
||||||
weight,
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
sparse=sparse)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
@ -1,127 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
|
||||||
weight: ColoTensor,
|
|
||||||
offsets: Optional[Tensor] = None,
|
|
||||||
max_norm: Optional[float] = None,
|
|
||||||
norm_type: float = 2,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
mode: str = "mean",
|
|
||||||
sparse: bool = False,
|
|
||||||
per_sample_weights: Optional[Tensor] = None,
|
|
||||||
include_last_offset: bool = False,
|
|
||||||
padding_idx: Optional[int] = None) -> ColoTensor:
|
|
||||||
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
|
||||||
# Gather splitted lookup table
|
|
||||||
pg = weight.get_process_group()
|
|
||||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
|
||||||
|
|
||||||
output_parallel = F.embedding_bag(input_tensor,
|
|
||||||
weight,
|
|
||||||
offsets=offsets,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
mode=mode,
|
|
||||||
sparse=sparse,
|
|
||||||
per_sample_weights=per_sample_weights,
|
|
||||||
include_last_offset=include_last_offset,
|
|
||||||
padding_idx=padding_idx)
|
|
||||||
output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
|
||||||
|
|
||||||
if weight.compute_spec.output_replicate:
|
|
||||||
return output.to_replicate()
|
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_embedding_bag_1d(tp_mode: str,
|
|
||||||
input_tensor: ColoTensor,
|
|
||||||
weight: ColoTensor,
|
|
||||||
offsets: Optional[Tensor] = None,
|
|
||||||
max_norm: Optional[float] = None,
|
|
||||||
norm_type: float = 2,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
mode: str = "mean",
|
|
||||||
sparse: bool = False,
|
|
||||||
per_sample_weights: Optional[Tensor] = None,
|
|
||||||
include_last_offset: bool = False,
|
|
||||||
padding_idx: Optional[int] = None) -> ColoTensor:
|
|
||||||
assert tp_mode in ('col',)
|
|
||||||
funcs = {'col': colo_embedding_bag_1Dcol}
|
|
||||||
return funcs[tp_mode](input_tensor,
|
|
||||||
weight,
|
|
||||||
offsets=offsets,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
mode=mode,
|
|
||||||
sparse=sparse,
|
|
||||||
per_sample_weights=per_sample_weights,
|
|
||||||
include_last_offset=include_last_offset,
|
|
||||||
padding_idx=padding_idx)
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(F.embedding_bag)
|
|
||||||
def colo_embedding_bag(input_tensor: GeneralTensor,
|
|
||||||
weight: GeneralTensor,
|
|
||||||
offsets: Optional[Tensor] = None,
|
|
||||||
max_norm: Optional[float] = None,
|
|
||||||
norm_type: float = 2,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
mode: str = "mean",
|
|
||||||
sparse: bool = False,
|
|
||||||
per_sample_weights: Optional[Tensor] = None,
|
|
||||||
include_last_offset: bool = False,
|
|
||||||
padding_idx: Optional[int] = None):
|
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
|
|
||||||
This method looks up an embedding table.
|
|
||||||
"""
|
|
||||||
assert isinstance(weight, ColoTensor)
|
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
|
||||||
|
|
||||||
# Handle different parallel actions.
|
|
||||||
|
|
||||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
|
||||||
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
|
||||||
return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
|
|
||||||
weight,
|
|
||||||
offsets=offsets,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
mode=mode,
|
|
||||||
sparse=sparse,
|
|
||||||
per_sample_weights=per_sample_weights,
|
|
||||||
include_last_offset=include_last_offset,
|
|
||||||
padding_idx=padding_idx),
|
|
||||||
spec=ColoTensorSpec(weight.get_process_group()))
|
|
||||||
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
|
||||||
if weight.is_shard_1dcol():
|
|
||||||
tp_mode = 'col'
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
return colo_embedding_bag_1d(tp_mode,
|
|
||||||
input_tensor,
|
|
||||||
weight,
|
|
||||||
offsets=offsets,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
mode=mode,
|
|
||||||
sparse=sparse,
|
|
||||||
per_sample_weights=per_sample_weights,
|
|
||||||
include_last_offset=include_last_offset,
|
|
||||||
padding_idx=padding_idx)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
@ -1,28 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(F.layer_norm)
|
|
||||||
def colo_layernorm(
|
|
||||||
input_tensor: GeneralTensor,
|
|
||||||
normalized_shape: List[int],
|
|
||||||
weight: Optional[GeneralTensor] = None,
|
|
||||||
bias: Optional[GeneralTensor] = None,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
):
|
|
||||||
assert isinstance(weight, ColoTensor)
|
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
|
||||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
|
||||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
|
||||||
|
|
||||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
|
||||||
output = ColoTensor.from_torch_tensor(tensor=output,
|
|
||||||
spec=ColoTensorSpec(pg=input_tensor.get_process_group(),
|
|
||||||
dist_attr=input_tensor.dist_spec))
|
|
||||||
return output
|
|
@ -1,171 +0,0 @@
|
|||||||
from copy import deepcopy
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
||||||
|
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input
|
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
|
||||||
# Input:S[1] x Weight:S[0] = Output:P
|
|
||||||
# All-Reduce(Output) + bias = res
|
|
||||||
# Input:S[1]
|
|
||||||
pg = weight.get_process_group()
|
|
||||||
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg)
|
|
||||||
|
|
||||||
# Output:P
|
|
||||||
partial_output = F.linear(input_tensor, weight)
|
|
||||||
# Reduce(Output)
|
|
||||||
|
|
||||||
output = reduce_input(partial_output, pg)
|
|
||||||
# Bias
|
|
||||||
if bias is not None:
|
|
||||||
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
|
||||||
output = output + bias
|
|
||||||
|
|
||||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
|
||||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
|
||||||
# All-Gather(Output)
|
|
||||||
# Input:B
|
|
||||||
compute_spec = weight.compute_spec
|
|
||||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
|
||||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
|
||||||
|
|
||||||
output_parallel = F.linear(input_parallel, weight, bias)
|
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel,
|
|
||||||
spec=ColoTensorSpec(weight.get_process_group(),
|
|
||||||
ShardSpec([-1], [weight.get_tp_world_size()]),
|
|
||||||
ComputeSpec(ComputePattern.TP1D)))
|
|
||||||
if compute_spec.output_replicate:
|
|
||||||
return output.to_replicate()
|
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
|
||||||
assert mode in ('row', 'col')
|
|
||||||
funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol}
|
|
||||||
return funcs[mode](input_tensor, weight, bias)
|
|
||||||
|
|
||||||
|
|
||||||
# @register_colo_graph(input_pos=[1], param_pos=[2, 3])
|
|
||||||
def colo_linear_imp(input_tensor: GeneralTensor,
|
|
||||||
weight: GeneralTensor,
|
|
||||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
|
||||||
This method computes a linear.
|
|
||||||
"""
|
|
||||||
assert isinstance(weight, ColoTensor)
|
|
||||||
pg = weight.get_process_group()
|
|
||||||
assert pg
|
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, pg)
|
|
||||||
bias = convert_to_colo_tensor(bias, pg)
|
|
||||||
# input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
|
||||||
|
|
||||||
# Add communication logic before and after linear call.
|
|
||||||
ret_tensor = None
|
|
||||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
|
||||||
assert weight.is_replicate(), 'Invalid weight spec for native Linear op'
|
|
||||||
assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op'
|
|
||||||
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg))
|
|
||||||
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
|
||||||
if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()):
|
|
||||||
mode = 'row'
|
|
||||||
elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()):
|
|
||||||
mode = 'col'
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}")
|
|
||||||
ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
return ret_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def _new_colo_linear_imp(input_tensor: GeneralTensor,
|
|
||||||
weight: GeneralTensor,
|
|
||||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
|
||||||
"""
|
|
||||||
A tentative function to compute the distributed linear layer with the latest sharding spec.
|
|
||||||
This function is subject to future change as the current sharding API is not stable.
|
|
||||||
"""
|
|
||||||
# get mesh info
|
|
||||||
input_sharding_seq = input_tensor.sharding_spec.sharding_sequence
|
|
||||||
weight_sharding_seq = weight.sharding_spec.sharding_sequence
|
|
||||||
if bias is not None:
|
|
||||||
bias_sharding_seq = bias.sharding_spec.sharding_sequence
|
|
||||||
device_mesh = weight.sharding_spec.device_mesh
|
|
||||||
pg_axis0 = weight.pg_axis0
|
|
||||||
pg_axis1 = weight.pg_axis1
|
|
||||||
|
|
||||||
# the last dim of input should have the same spec as the first dim of weight
|
|
||||||
# the weight is transposed, so we look at the second dimension
|
|
||||||
assert input_sharding_seq[-1] == weight_sharding_seq[1]
|
|
||||||
|
|
||||||
if bias is not None:
|
|
||||||
assert bias_sharding_seq[0] == weight_sharding_seq[0]
|
|
||||||
|
|
||||||
# compute the output sharding sequence
|
|
||||||
# as weight is transposed, so we look at the first dimension
|
|
||||||
output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1]
|
|
||||||
output_shard_seq = deepcopy(output_shard_seq)
|
|
||||||
|
|
||||||
# TODO: add reduce grad logic
|
|
||||||
|
|
||||||
# handle column and row parallel linear
|
|
||||||
# by reusing the implementation above
|
|
||||||
out = F.linear(input_tensor, weight)
|
|
||||||
|
|
||||||
# run all reduce if necessary
|
|
||||||
last_dim_spec = input_sharding_seq[-1]
|
|
||||||
if last_dim_spec.is_replica:
|
|
||||||
pass
|
|
||||||
elif last_dim_spec.shard_list is not None:
|
|
||||||
for dim in last_dim_spec.shard_list:
|
|
||||||
if dim == 0:
|
|
||||||
reduce_input(out, pg_axis0)
|
|
||||||
elif dim == 1:
|
|
||||||
reduce_input(out, pg_axis1)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected")
|
|
||||||
# add bias
|
|
||||||
if bias is not None:
|
|
||||||
out += bias
|
|
||||||
|
|
||||||
# convert shard seq to partition dict
|
|
||||||
output_partition_dict = {}
|
|
||||||
for index, dim_spec in enumerate(output_shard_seq):
|
|
||||||
if not dim_spec.is_replica:
|
|
||||||
if index not in output_partition_dict:
|
|
||||||
output_partition_dict[index] = []
|
|
||||||
output_partition_dict[index].extend(dim_spec.shard_list)
|
|
||||||
|
|
||||||
entire_shape = out.shape
|
|
||||||
output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict)
|
|
||||||
ret_tensor = ColoTensor.from_torch_tensor(out)
|
|
||||||
setattr(ret_tensor, 'sharding_spec', output_sharding_spec)
|
|
||||||
return ret_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def _has_sharding_spec(tensor):
|
|
||||||
"""
|
|
||||||
A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is
|
|
||||||
set as the attribute `sharding_spec` on a tensor.
|
|
||||||
"""
|
|
||||||
return hasattr(tensor, 'sharding_spec')
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(F.linear)
|
|
||||||
def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
|
||||||
if _has_sharding_spec(weight):
|
|
||||||
return _new_colo_linear_imp(input, weight, bias)
|
|
||||||
else:
|
|
||||||
return colo_linear_imp(input, weight, bias)
|
|
@ -1,51 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from colossalai.legacy.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(F.cross_entropy)
|
|
||||||
def colo_cross_entropy(input_tensor: GeneralTensor,
|
|
||||||
target: GeneralTensor,
|
|
||||||
weight: Optional[GeneralTensor] = None,
|
|
||||||
size_average: Optional[bool] = None,
|
|
||||||
ignore_index: int = -100,
|
|
||||||
reduce: Optional[bool] = None,
|
|
||||||
reduction: str = "mean",
|
|
||||||
label_smoothing: float = 0.0):
|
|
||||||
assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor)
|
|
||||||
pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor)
|
|
||||||
weight = convert_to_colo_tensor(weight, pg)
|
|
||||||
target = convert_to_colo_tensor(target, pg)
|
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, pg)
|
|
||||||
|
|
||||||
if input_tensor.is_replicate(): # Input is gathered
|
|
||||||
assert target.is_replicate() and (weight is None or weight.is_replicate()), \
|
|
||||||
"Target tensor and weight tensor both should be complete"
|
|
||||||
output = F.cross_entropy(input_tensor,
|
|
||||||
target,
|
|
||||||
weight=weight,
|
|
||||||
size_average=size_average,
|
|
||||||
ignore_index=ignore_index,
|
|
||||||
reduce=reduce,
|
|
||||||
reduction=reduction,
|
|
||||||
label_smoothing=label_smoothing)
|
|
||||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
|
||||||
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
|
|
||||||
if input_tensor.is_shard_1dcol():
|
|
||||||
assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in"
|
|
||||||
assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function"
|
|
||||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor,
|
|
||||||
target,
|
|
||||||
process_group=input_tensor.process_group.tp_process_group())
|
|
||||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
@ -1,96 +0,0 @@
|
|||||||
import operator
|
|
||||||
from functools import reduce
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
|
|
||||||
def _all_int(my_iter):
|
|
||||||
return all(isinstance(i, int) for i in my_iter)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_valid_shape(shape):
|
|
||||||
if isinstance(shape, list):
|
|
||||||
if _all_int(shape):
|
|
||||||
return tuple(shape)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("expects type(int) but finds an other type")
|
|
||||||
elif isinstance(shape, tuple):
|
|
||||||
if _all_int(shape):
|
|
||||||
return shape
|
|
||||||
else:
|
|
||||||
return _get_valid_shape(shape[0])
|
|
||||||
else:
|
|
||||||
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
|
|
||||||
|
|
||||||
|
|
||||||
def _shape_infer(org_sp, tgt_sp):
|
|
||||||
cnt = 0
|
|
||||||
pos = 0
|
|
||||||
for idx, dim in enumerate(tgt_sp):
|
|
||||||
if dim < -1:
|
|
||||||
raise RuntimeError("invalid shape dimension {}".format(dim))
|
|
||||||
elif dim == -1:
|
|
||||||
cnt += 1
|
|
||||||
pos = idx
|
|
||||||
|
|
||||||
if cnt > 1:
|
|
||||||
raise RuntimeError("only one dimension can be inferred")
|
|
||||||
|
|
||||||
org_prod = reduce(operator.mul, org_sp, 1)
|
|
||||||
tgt_prod = reduce(operator.mul, tgt_sp, 1)
|
|
||||||
|
|
||||||
if cnt == 0:
|
|
||||||
if org_prod != tgt_prod:
|
|
||||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
|
||||||
else:
|
|
||||||
return tgt_sp
|
|
||||||
elif org_prod % tgt_prod != 0:
|
|
||||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
|
||||||
|
|
||||||
infer_dim = -(org_prod // tgt_prod)
|
|
||||||
return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(torch.Tensor.view)
|
|
||||||
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
|
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
|
|
||||||
Changes the shape of the current tensor.
|
|
||||||
"""
|
|
||||||
assert isinstance(self, ColoTensor)
|
|
||||||
# apply original `view` function for replicated colo tensors
|
|
||||||
if self.is_replicate():
|
|
||||||
return self.view(*shape)
|
|
||||||
|
|
||||||
cur_sp = self.size()
|
|
||||||
org_sp = self.size_global()
|
|
||||||
# parse the passed arguments
|
|
||||||
tgt_sp = _get_valid_shape(shape)
|
|
||||||
# get the correct shape from inference
|
|
||||||
inf_sp = _shape_infer(org_sp, tgt_sp)
|
|
||||||
|
|
||||||
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
|
|
||||||
new_shape = (cur_sp[0],) + tgt_sp[1:]
|
|
||||||
res = self.view(*new_shape)
|
|
||||||
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
|
|
||||||
new_shape = tgt_sp[:-1] + (cur_sp[-1],)
|
|
||||||
res = self.view(*new_shape)
|
|
||||||
else:
|
|
||||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
|
||||||
return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
|
|
||||||
spec=ColoTensorSpec(self.get_process_group()))
|
|
||||||
|
|
||||||
return ColoTensor.from_torch_tensor(tensor=res,
|
|
||||||
spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(torch.Tensor.size)
|
|
||||||
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
|
|
||||||
size = self.size_global()
|
|
||||||
if dim is None:
|
|
||||||
return size
|
|
||||||
else:
|
|
||||||
return size[dim]
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue