mirror of https://github.com/hpcaitech/ColossalAI
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples cipull/4750/head
parent
32e7f99416
commit
b5f9e37c70
@ -1,54 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import Config
|
||||
|
||||
from .amp_type import AMP_TYPE
|
||||
from .apex_amp import convert_to_apex_amp
|
||||
from .naive_amp import convert_to_naive_amp
|
||||
from .torch_amp import convert_to_torch_amp
|
||||
|
||||
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
|
||||
|
||||
|
||||
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
|
||||
"""A helper function to wrap training components with Torch AMP modules.
|
||||
|
||||
Args:
|
||||
param model (:class:`torch.nn.Module`): your model object.
|
||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
|
||||
criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
|
||||
mode (:class:`colossalai.amp.AMP_TYPE`): amp mode.
|
||||
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
|
||||
|
||||
Returns:
|
||||
A tuple (model, optimizer, criterion).
|
||||
|
||||
Note:
|
||||
``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
|
||||
for more details about ``amp_config``.
|
||||
For ``apex_amp``, please check
|
||||
`apex_amp config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
|
||||
For ``naive_amp``, please check
|
||||
`naive_amp config <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/amp/naive_amp/_fp16_optimizer.py#L42>`_.
|
||||
For ``torch_amp``, please check
|
||||
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
|
||||
"""
|
||||
assert isinstance(mode, AMP_TYPE), \
|
||||
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
|
||||
|
||||
if amp_config is None:
|
||||
amp_config = Config()
|
||||
|
||||
if mode == AMP_TYPE.TORCH:
|
||||
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
|
||||
elif mode == AMP_TYPE.APEX:
|
||||
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
|
||||
elif mode == AMP_TYPE.NAIVE:
|
||||
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
|
||||
|
||||
return model, optimizer, criterion
|
@ -1,60 +0,0 @@
|
||||
import inspect
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
|
||||
|
||||
|
||||
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||
"""A helper function to wrap training components with naive AMP modules. In this mode,
|
||||
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
|
||||
which is equivalent to Apex O3.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): your model object
|
||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
|
||||
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
|
||||
|
||||
Returns:
|
||||
Tuple: A tuple (model, optimizer)
|
||||
|
||||
The ``amp_config`` should contain parameters below::
|
||||
|
||||
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
|
||||
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
|
||||
Note that clipping is ignored if clip_grad == 0.
|
||||
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
|
||||
"""
|
||||
if isinstance(model, nn.ModuleList):
|
||||
# interleaved pipeline
|
||||
module_list = []
|
||||
for chunk, m in enumerate(model):
|
||||
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
|
||||
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
|
||||
model = nn.ModuleList(module_list)
|
||||
else:
|
||||
output_to_fp32 = is_no_pp_or_last_stage()
|
||||
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
||||
|
||||
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
|
||||
if use_dynamic_grad_scaler:
|
||||
scaler_class = DynamicGradScaler
|
||||
else:
|
||||
scaler_class = ConstantGradScaler
|
||||
|
||||
sig = inspect.signature(scaler_class.__init__)
|
||||
kwargs = dict()
|
||||
for param in sig.parameters.values():
|
||||
if param.name in amp_config:
|
||||
kwargs[param.name] = amp_config.pop(param.name)
|
||||
grad_scaler = scaler_class(**kwargs)
|
||||
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
|
@ -1,28 +0,0 @@
|
||||
import click
|
||||
|
||||
from colossalai.context import Config
|
||||
|
||||
from .benchmark import run_benchmark
|
||||
from .utils import *
|
||||
|
||||
__all__ = ['benchmark']
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.")
|
||||
@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.")
|
||||
@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.")
|
||||
@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.")
|
||||
@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.")
|
||||
@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.")
|
||||
@click.option("-l", "--layers", type=int, default=2)
|
||||
@click.option("-m",
|
||||
"--model",
|
||||
type=click.Choice(['mlp'], case_sensitive=False),
|
||||
default='mlp',
|
||||
help="Select the model to benchmark, currently only supports MLP")
|
||||
def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int,
|
||||
layers: int, model: str):
|
||||
args_dict = locals()
|
||||
args = Config(args_dict)
|
||||
run_benchmark(args)
|
@ -1,105 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
import click
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model
|
||||
from colossalai.context import Config
|
||||
from colossalai.context.random import reset_seeds
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.testing import free_port
|
||||
from colossalai.utils import MultiTimer
|
||||
|
||||
from .models import MLP
|
||||
|
||||
|
||||
def run_benchmark(args: Config) -> None:
|
||||
"""
|
||||
Run benchmarking with torch.multiprocessing.
|
||||
"""
|
||||
|
||||
# sanity checks
|
||||
if args.gpus is None:
|
||||
click.echo("Error: --num_gpus is not given")
|
||||
exit()
|
||||
if args.gpus <= 1:
|
||||
click.echo("Warning: tensor parallel will be activated with at least 2 devices.")
|
||||
|
||||
click.echo("=== Benchmarking Parameters ===")
|
||||
for k, v in args.items():
|
||||
click.echo(f'{k}: {v}')
|
||||
click.echo('')
|
||||
|
||||
config_list = find_all_configs(args.gpus)
|
||||
|
||||
avail_ports = [free_port() for _ in range(len(config_list))]
|
||||
run_func = partial(run_dist_profiling,
|
||||
world_size=args.gpus,
|
||||
port_list=avail_ports,
|
||||
config_list=config_list,
|
||||
hyperparams=args)
|
||||
mp.spawn(run_func, nprocs=args.gpus)
|
||||
|
||||
|
||||
def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict],
|
||||
hyperparams: Config) -> None:
|
||||
"""
|
||||
A function executed for profiling, this function should be spawn by torch.multiprocessing.
|
||||
|
||||
Args:
|
||||
rank (int): rank of the process
|
||||
world_size (int): the number of processes
|
||||
port_list (List[int]): a list of free ports for initializing distributed networks
|
||||
config_list (List[Dict]): a list of configuration
|
||||
hyperparams (Config): the hyperparameters given by the user
|
||||
|
||||
"""
|
||||
|
||||
# disable logging for clean output
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
logger.set_level('WARNING')
|
||||
|
||||
for config, port in zip(config_list, port_list):
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
timer = MultiTimer()
|
||||
|
||||
# 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size.
|
||||
if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0:
|
||||
click.echo(
|
||||
"1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size."
|
||||
)
|
||||
continue
|
||||
|
||||
if hyperparams.model == 'mlp':
|
||||
model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers)
|
||||
else:
|
||||
if gpc.get_global_rank() == 0:
|
||||
click.echo("Error: Invalid argument for --model")
|
||||
exit()
|
||||
|
||||
data_func = partial(get_batch_data,
|
||||
dim=hyperparams.dimension,
|
||||
batch_size=hyperparams.batch_size,
|
||||
seq_length=hyperparams.seq_len,
|
||||
mode=config.parallel.tensor.mode)
|
||||
|
||||
fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model,
|
||||
warmup_steps=hyperparams.warmup_steps,
|
||||
profile_steps=hyperparams.profile_steps,
|
||||
data_func=data_func,
|
||||
timer=timer)
|
||||
|
||||
gpc.destroy()
|
||||
reset_seeds()
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()])
|
||||
click.echo(f"=== {config_str} ===")
|
||||
click.echo(f"Average forward time: {fwd_time}")
|
||||
click.echo(f"Average backward time: {bwd_time}")
|
||||
click.echo(f"Max allocated GPU memory: {max_allocated}")
|
||||
click.echo(f"Max cached GPU memory: {max_cached}\n")
|
@ -1,18 +0,0 @@
|
||||
import torch
|
||||
|
||||
import colossalai.legacy.nn as col_nn
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int, layers: int):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
|
||||
for _ in range(layers):
|
||||
self.layers.append(col_nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
@ -1,159 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context import Config, ParallelMode
|
||||
from colossalai.utils import MultiTimer
|
||||
|
||||
|
||||
def get_time_stamp() -> int:
|
||||
"""
|
||||
Return the time stamp for profiling.
|
||||
|
||||
Returns:
|
||||
time_stamp (int): the time given by time.time()
|
||||
"""
|
||||
|
||||
torch.cuda.synchronize()
|
||||
time_stamp = time.time()
|
||||
return time_stamp
|
||||
|
||||
|
||||
def get_memory_states() -> Tuple[float]:
|
||||
"""
|
||||
Return the memory statistics.
|
||||
|
||||
Returns:
|
||||
max_allocated (float): the allocated CUDA memory
|
||||
max_cached (float): the cached CUDA memory
|
||||
"""
|
||||
|
||||
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
max_cached = torch.cuda.max_memory_reserved() / (1024**3)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
return max_allocated, max_cached
|
||||
|
||||
|
||||
def find_all_configs(device_cnt: int) -> List[Dict]:
|
||||
"""
|
||||
Find all possible configurations for tensor parallelism
|
||||
|
||||
Args:
|
||||
device_cnt (int): the number of devices
|
||||
|
||||
Returns:
|
||||
config_list (List[Dict]): a list of configurations
|
||||
"""
|
||||
|
||||
def _is_square(num):
|
||||
# 2D parallel should be implemented with at least 2 devices.
|
||||
if num <= 1:
|
||||
return False
|
||||
return math.floor(math.sqrt(num))**2 == num
|
||||
|
||||
def _is_cube(num):
|
||||
# 3D parallel should be implemented with at least 2 devices.
|
||||
if num <= 1:
|
||||
return False
|
||||
return math.floor(num**(1. / 3.))**3 == num
|
||||
|
||||
config_list = []
|
||||
|
||||
# add non-parallel config
|
||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
|
||||
config_list.append(config)
|
||||
|
||||
# add 1D config
|
||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
|
||||
config_list.append(config)
|
||||
|
||||
# add 2D config only if device_cnt is a square
|
||||
if _is_square(device_cnt):
|
||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
|
||||
config_list.append(config)
|
||||
|
||||
# check for 2.5D
|
||||
# iterate over depth
|
||||
for depth in range(1, device_cnt):
|
||||
if device_cnt % depth == 0 and _is_square(device_cnt // depth):
|
||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
|
||||
config_list.append(config)
|
||||
|
||||
# check for 3D if device_cnt is a cube
|
||||
if _is_cube(device_cnt):
|
||||
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
|
||||
config_list.append(config)
|
||||
|
||||
config_list = [Config(cfg) for cfg in config_list]
|
||||
return config_list
|
||||
|
||||
|
||||
def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
|
||||
timer: MultiTimer) -> Tuple[float]:
|
||||
"""
|
||||
Profile the forward and backward of a model
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a PyTorch model
|
||||
warmup_steps (int): the number of steps for warmup
|
||||
profile_steps (int): the number of steps for profiling
|
||||
data_func (Callable): a function to generate random data
|
||||
timer (colossalai.utils.Multitimer): a timer instance for time recording
|
||||
|
||||
Returns:
|
||||
fwd_time (float): the average forward time taken by forward pass in second
|
||||
bwd_time (float): the average backward time taken by forward pass in second
|
||||
max_allocated (float): the maximum GPU memory allocated in GB
|
||||
max_cached (float): the maximum GPU memory cached in GB
|
||||
"""
|
||||
|
||||
def _run_step(data):
|
||||
timer.start('forward')
|
||||
out = model(data)
|
||||
timer.stop('forward', keep_in_history=True)
|
||||
timer.start('backward')
|
||||
out.mean().backward()
|
||||
timer.stop('backward', keep_in_history=True)
|
||||
|
||||
data_list = [data_func() for _ in range(warmup_steps)]
|
||||
for data in data_list:
|
||||
_run_step(data)
|
||||
timer.reset('forward')
|
||||
timer.reset('backward')
|
||||
|
||||
for _ in range(profile_steps):
|
||||
data = data_func()
|
||||
_run_step(data)
|
||||
|
||||
max_allocated, max_cached = get_memory_states()
|
||||
fwd_time = timer.get_timer('forward').get_history_mean()
|
||||
bwd_time = timer.get_timer('backward').get_history_mean()
|
||||
return fwd_time, bwd_time, max_allocated, max_cached
|
||||
|
||||
|
||||
def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
|
||||
"""
|
||||
Return a random data of shape (batch_size, seq_length, dim) for profiling.
|
||||
|
||||
Args:
|
||||
dim (int): hidden size
|
||||
batch_size (int): the number of data samples
|
||||
seq_length (int): the number of tokens
|
||||
mode (ParallelMode): Colossal-AI ParallelMode enum
|
||||
|
||||
Returns:
|
||||
data (torch.Tensor): random data
|
||||
"""
|
||||
|
||||
if mode in ['2d', '2.5d']:
|
||||
batch_size = batch_size // 2
|
||||
dim = dim // 2
|
||||
elif mode == '3d':
|
||||
batch_size = batch_size // 4
|
||||
dim = dim // 2
|
||||
|
||||
data = torch.rand(batch_size, seq_length, dim).cuda()
|
||||
return data
|
@ -1,6 +1,8 @@
|
||||
from .config import Config, ConfigException
|
||||
from .parallel_context import ParallelContext
|
||||
from .parallel_mode import ParallelMode
|
||||
from .moe_context import MOE_CONTEXT
|
||||
from .process_group_initializer import *
|
||||
from .random import *
|
||||
|
||||
# from .moe_context import MOE_CONTEXT
|
||||
|
||||
__all__ = [
|
||||
'Config',
|
||||
'ConfigException',
|
||||
]
|
||||
|
@ -1,6 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.context.parallel_context import global_context
|
||||
|
||||
__all__ = ['global_context']
|
@ -0,0 +1,9 @@
|
||||
from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
|
||||
|
||||
__all__ = [
|
||||
'launch',
|
||||
'launch_from_openmpi',
|
||||
'launch_from_slurm',
|
||||
'launch_from_torch',
|
||||
'initialize',
|
||||
]
|
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import Config
|
||||
|
||||
from .amp_type import AMP_TYPE
|
||||
from .apex_amp import convert_to_apex_amp
|
||||
from .naive_amp import convert_to_naive_amp
|
||||
from .torch_amp import convert_to_torch_amp
|
||||
|
||||
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
|
||||
|
||||
|
||||
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
|
||||
"""A helper function to wrap training components with Torch AMP modules.
|
||||
|
||||
Args:
|
||||
param model (:class:`torch.nn.Module`): your model object.
|
||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
|
||||
criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
|
||||
mode (:class:`colossalai.legacy.amp.AMP_TYPE`): amp mode.
|
||||
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
|
||||
|
||||
Returns:
|
||||
A tuple (model, optimizer, criterion).
|
||||
|
||||
Note:
|
||||
``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
|
||||
for more details about ``amp_config``.
|
||||
For ``apex_amp``, please check
|
||||
`apex_amp config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
|
||||
For ``naive_amp``, please check
|
||||
`naive_amp config <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/amp/naive_amp/_fp16_optimizer.py#L42>`_.
|
||||
For ``torch_amp``, please check
|
||||
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
|
||||
"""
|
||||
assert isinstance(mode, AMP_TYPE), \
|
||||
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
|
||||
|
||||
if amp_config is None:
|
||||
amp_config = Config()
|
||||
|
||||
if mode == AMP_TYPE.TORCH:
|
||||
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
|
||||
elif mode == AMP_TYPE.APEX:
|
||||
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
|
||||
elif mode == AMP_TYPE.NAIVE:
|
||||
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
|
||||
|
||||
return model, optimizer, criterion
|
@ -0,0 +1,60 @@
|
||||
import inspect
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
from colossalai.legacy.utils import is_no_pp_or_last_stage
|
||||
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
|
||||
|
||||
|
||||
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||
"""A helper function to wrap training components with naive AMP modules. In this mode,
|
||||
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
|
||||
which is equivalent to Apex O3.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): your model object
|
||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
|
||||
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
|
||||
|
||||
Returns:
|
||||
Tuple: A tuple (model, optimizer)
|
||||
|
||||
The ``amp_config`` should contain parameters below::
|
||||
|
||||
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
|
||||
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
|
||||
Note that clipping is ignored if clip_grad == 0.
|
||||
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
|
||||
"""
|
||||
if isinstance(model, nn.ModuleList):
|
||||
# interleaved pipeline
|
||||
module_list = []
|
||||
for chunk, m in enumerate(model):
|
||||
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
|
||||
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
|
||||
model = nn.ModuleList(module_list)
|
||||
else:
|
||||
output_to_fp32 = is_no_pp_or_last_stage()
|
||||
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
||||
|
||||
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
|
||||
if use_dynamic_grad_scaler:
|
||||
scaler_class = DynamicGradScaler
|
||||
else:
|
||||
scaler_class = ConstantGradScaler
|
||||
|
||||
sig = inspect.signature(scaler_class.__init__)
|
||||
kwargs = dict()
|
||||
for param in sig.parameters.values():
|
||||
if param.name in amp_config:
|
||||
kwargs[param.name] = amp_config.pop(param.name)
|
||||
grad_scaler = scaler_class(**kwargs)
|
||||
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
|
@ -0,0 +1,4 @@
|
||||
from .parallel_context import ParallelContext
|
||||
from .parallel_mode import ParallelMode
|
||||
from .process_group_initializer import *
|
||||
from .random import *
|
@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.legacy.context.parallel_context import global_context
|
||||
|
||||
__all__ = ['global_context']
|
@ -0,0 +1,472 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.context import Config, ConfigException
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
|
||||
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.legacy.builder.builder import build_gradient_handler
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.engine import Engine
|
||||
from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
|
||||
from colossalai.legacy.engine.schedule import (
|
||||
InterleavedPipelineSchedule,
|
||||
NonPipelineSchedule,
|
||||
PipelineSchedule,
|
||||
get_tensor_shape,
|
||||
)
|
||||
from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
|
||||
from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
|
||||
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
"""Reads user command line and uses an argument parser to parse the input arguments.
|
||||
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
||||
|
||||
Returns:
|
||||
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, help='path to the config file')
|
||||
parser.add_argument('--host', type=str, help='the master address for distributed training')
|
||||
parser.add_argument('--port', type=int, help='the master port for distributed training')
|
||||
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
||||
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
||||
parser.add_argument('--local_rank', type=int, help='local rank on the node')
|
||||
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
|
||||
return parser
|
||||
|
||||
|
||||
def launch(config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
local_rank: int = None,
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
|
||||
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
|
||||
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
rank (int): Rank for the default process group
|
||||
world_size (int): World size of the default process group
|
||||
host (str): The master address for distributed training
|
||||
port (str): The master port for distributed training
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
local_rank (int, optional):
|
||||
Rank for the process on the node and is used to set the default CUDA device,
|
||||
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
|
||||
Raises:
|
||||
Exception: Raise exception when config type is wrong
|
||||
"""
|
||||
gpc.verbose = verbose
|
||||
|
||||
# set config
|
||||
assert isinstance(config, (Config, str, Path, dict)), \
|
||||
f'expected argument config to be Config, str or Path, but got {type(config)}'
|
||||
if not isinstance(config, Config) and isinstance(config, dict):
|
||||
config = Config(config)
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.from_file(config)
|
||||
gpc.load_config(config)
|
||||
|
||||
# init default process group
|
||||
gpc.init_global_dist(rank, world_size, backend, host, port)
|
||||
|
||||
# init process groups for different parallel modes from config
|
||||
gpc.init_parallel_groups()
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available():
|
||||
# if local rank is not given, calculate automatically
|
||||
gpc.set_device(local_rank)
|
||||
|
||||
# set the number of processes running on the same node
|
||||
gpc.detect_num_processes_on_current_node()
|
||||
|
||||
gpc.set_seed(seed)
|
||||
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(
|
||||
f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}',
|
||||
ranks=[0])
|
||||
|
||||
|
||||
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
|
||||
set by SLURM
|
||||
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
host (str): The master address for distributed training
|
||||
port (str): The master port for distributed training
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
try:
|
||||
rank = int(os.environ['SLURM_PROCID'])
|
||||
world_size = int(os.environ['SLURM_NPROCS'])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
|
||||
)
|
||||
|
||||
launch(config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
|
||||
set by OpenMPI
|
||||
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
host (str): The master address for distributed training
|
||||
port (str): The master port for distributed training
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
try:
|
||||
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
|
||||
)
|
||||
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||
)
|
||||
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
ophooks: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model.
|
||||
optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
|
||||
Your optimizer instance.
|
||||
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
|
||||
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
|
||||
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
||||
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
|
||||
verbose (bool, optional): Whether to print logs.
|
||||
|
||||
Returns:
|
||||
Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
|
||||
A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
|
||||
where only ``engine`` could not be None.
|
||||
"""
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
gpc.verbose = verbose
|
||||
|
||||
# get config from gpc
|
||||
config = gpc.config
|
||||
|
||||
# print config
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================\n",
|
||||
ranks=[0])
|
||||
|
||||
# cudnn
|
||||
cudnn_benchmark = config.get('cudnn_benchmark', False)
|
||||
cudnn_deterministic = config.get('cudnn_deterministic', False)
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = cudnn_deterministic
|
||||
if verbose:
|
||||
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
|
||||
# zero
|
||||
use_zero = hasattr(gpc.config, 'zero')
|
||||
if use_zero:
|
||||
zero_cfg = gpc.config.get('zero', None)
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
else:
|
||||
cfg_ = {}
|
||||
optimizer_config = zero_cfg.get('optimizer_config', None)
|
||||
model_config = zero_cfg.get('model_config', None)
|
||||
model, optimizer = convert_to_zero_v2(model,
|
||||
optimizer,
|
||||
model_config=model_config,
|
||||
optimizer_config=optimizer_config)
|
||||
|
||||
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
|
||||
else:
|
||||
if isinstance(model, nn.Module):
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
elif isinstance(model, Callable):
|
||||
model = model().to(get_current_device())
|
||||
|
||||
# optimizer maybe a optimizer_cls
|
||||
if isinstance(optimizer, Callable):
|
||||
optimizer = optimizer(model.parameters())
|
||||
logger.warning("Initializing an non ZeRO model with optimizer class")
|
||||
|
||||
if not use_zero:
|
||||
if is_using_sequence():
|
||||
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||
elif MOE_CONTEXT.is_initialized:
|
||||
sync_moe_model_param(model)
|
||||
elif is_using_ddp():
|
||||
sync_model_param(model, ParallelMode.DATA)
|
||||
else:
|
||||
logger.warning(
|
||||
"The parameters of models is not automatically synchronized.\n"
|
||||
"Please make sure that all parameters are the same in data parallel group.",
|
||||
ranks=[0])
|
||||
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
|
||||
raise ConfigException(
|
||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||
|
||||
# clip grad norm
|
||||
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||
|
||||
# initialize amp
|
||||
amp_mode = None
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||
cfg_ = fp16_cfg.copy()
|
||||
amp_mode = cfg_.pop('mode')
|
||||
if is_using_pp():
|
||||
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
||||
if amp_mode == AMP_TYPE.NAIVE:
|
||||
cfg_['clip_grad_norm'] = clip_grad_norm
|
||||
model, optimizer, criterion = convert_to_amp(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
mode=amp_mode,
|
||||
amp_config=cfg_)
|
||||
|
||||
# get torch ddp config
|
||||
torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
|
||||
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
if gradient_handler_cfg is None:
|
||||
# if gradient handler is not specified in the configuration file,
|
||||
# check in the following order
|
||||
# 1. if optimizer is ZERO, then use zero grad handler
|
||||
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||
if isinstance(optimizer, ShardedOptimizerV2):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
|
||||
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_sequence():
|
||||
model = DDP(model,
|
||||
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
**torch_ddp_cfg)
|
||||
if verbose:
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model,
|
||||
process_group=gpc.get_group(ParallelMode.DATA),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
**torch_ddp_cfg)
|
||||
if verbose:
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
elif is_using_ddp():
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected when using pipeline parallel, "
|
||||
"DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
||||
for param in model.parameters():
|
||||
if getattr(param, 'pipeline_shared_module_pg', None) is not None:
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
|
||||
else:
|
||||
gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
|
||||
if verbose:
|
||||
logger.info(
|
||||
"pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
break
|
||||
else:
|
||||
if not isinstance(gradient_handler_cfg, list):
|
||||
raise ConfigException(
|
||||
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
|
||||
)
|
||||
|
||||
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
|
||||
# to avoid duplicated buffer synchronization
|
||||
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
||||
model.module.sync_buffer = False
|
||||
|
||||
# initialize schedule for engine
|
||||
if is_using_pp():
|
||||
tensor_shape = get_tensor_shape()
|
||||
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
|
||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
|
||||
scatter_gather = True
|
||||
else:
|
||||
scatter_gather = False
|
||||
if use_interleaved:
|
||||
if isinstance(model, nn.Sequential):
|
||||
model = nn.ModuleList([model])
|
||||
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.model.num_chunks,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather)
|
||||
else:
|
||||
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather)
|
||||
else:
|
||||
schedule = NonPipelineSchedule()
|
||||
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handlers = None
|
||||
if verbose and not isinstance(model, DDP):
|
||||
logger.warning(
|
||||
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
|
||||
"to all-reduce the gradients after a training step.",
|
||||
ranks=[0])
|
||||
else:
|
||||
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||
|
||||
# check if optimizer is OptimizerWrapper
|
||||
if not isinstance(optimizer, (OptimizerWrapper, ShardedOptimizerV2)):
|
||||
optimizer = OptimizerWrapper(optim=optimizer)
|
||||
|
||||
# gradient accumulation
|
||||
grad_accum_size = gpc.config.get('gradient_accumulation', None)
|
||||
if grad_accum_size is not None:
|
||||
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
accumulate_size=grad_accum_size,
|
||||
gradient_handlers=gradient_handlers,
|
||||
lr_scheduler=lr_scheduler)
|
||||
engine = Engine(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
ophook_list=ophooks,
|
||||
schedule=schedule)
|
||||
|
||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
@ -1,4 +1,3 @@
|
||||
from ._ops import *
|
||||
from .layer import *
|
||||
from .loss import *
|
||||
from .metric import *
|
||||
|
@ -1,9 +1 @@
|
||||
from .addmm import colo_addmm
|
||||
from .batch_norm import colo_batch_norm
|
||||
from .element_wise import *
|
||||
from .embedding import colo_embedding
|
||||
from .embedding_bag import colo_embedding_bag
|
||||
from .layernorm import colo_layernorm
|
||||
from .linear import colo_linear
|
||||
from .loss import colo_cross_entropy
|
||||
from .view import colo_view
|
||||
from ._utils import *
|
||||
|
@ -1,90 +0,0 @@
|
||||
import torch
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input
|
||||
|
||||
|
||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
# mat1:S[1] x mat2:S[0] = Output:P
|
||||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group())
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, mat2.get_process_group())
|
||||
# input
|
||||
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||
output = beta * input_tensor + alpha * output
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group()))
|
||||
return output
|
||||
|
||||
|
||||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
compute_spec = mat2.compute_spec
|
||||
mat1 = mat1.redistribute(ReplicaSpec())
|
||||
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
||||
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
assert mode in ('row', 'col')
|
||||
funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol}
|
||||
return funcs[mode](input_tensor, mat1, mat2, beta, alpha)
|
||||
|
||||
|
||||
@colo_op_impl(torch.addmm)
|
||||
def colo_addmm(input_tensor: GeneralTensor,
|
||||
mat1: ColoTensor,
|
||||
mat2: ColoTensor,
|
||||
beta: Number = 1,
|
||||
alpha: Number = 1,
|
||||
**kargs) -> ColoTensor:
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||
This method computes a linear.
|
||||
"""
|
||||
# At least one of the tensor should be ColoTensor
|
||||
assert isinstance(mat2, ColoTensor)
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group())
|
||||
mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group())
|
||||
|
||||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not mat2.has_compute_spec(): # No Model Parallel Applied
|
||||
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
|
||||
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
|
||||
ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor,
|
||||
mat1,
|
||||
mat2,
|
||||
beta=beta,
|
||||
alpha=alpha,
|
||||
**kargs),
|
||||
spec=ColoTensorSpec(mat2.get_process_group()))
|
||||
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if mat2.is_shard_1drow() and input_tensor.is_replicate():
|
||||
mode = 'row'
|
||||
elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()):
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return ret_tensor
|
@ -1,33 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
@colo_op_impl(F.batch_norm)
|
||||
def colo_batch_norm(
|
||||
input: GeneralTensor,
|
||||
running_mean: Optional[GeneralTensor],
|
||||
running_var: Optional[GeneralTensor],
|
||||
weight: Optional[GeneralTensor] = None,
|
||||
bias: Optional[GeneralTensor] = None,
|
||||
training: bool = False,
|
||||
momentum: float = 0.1,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
assert isinstance(weight, ColoTensor)
|
||||
running_mean = running_mean.detach()
|
||||
running_var = running_var.detach()
|
||||
|
||||
input = convert_to_colo_tensor(input, weight.get_process_group())
|
||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
||||
input = input.redistribute(ReplicaSpec())
|
||||
bias = bias.redistribute(ReplicaSpec())
|
||||
|
||||
output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
|
||||
output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group()))
|
||||
return output
|
@ -1,250 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
def register_elementwise_op(op):
|
||||
|
||||
@colo_op_impl(op)
|
||||
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
if 'inplace' in kwargs:
|
||||
# TODO(jiaruifang) inplace will cause bugs
|
||||
input_tensor = input_tensor.clone()
|
||||
return op(input_tensor, *args, **kwargs)
|
||||
else:
|
||||
output = op(input_tensor, *args, **kwargs)
|
||||
# return output
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise NotImplementedError
|
||||
return ColoTensor.from_torch_tensor(output,
|
||||
spec=ColoTensorSpec(input_tensor.get_process_group(),
|
||||
dist_attr=input_tensor.dist_spec))
|
||||
|
||||
|
||||
# @colo_op_impl(torch.relu_)
|
||||
# def elementwise_op(input_tensor):
|
||||
# torch.relu_(input_tensor.data)
|
||||
# return input_tensor
|
||||
|
||||
# @colo_op_impl(Tensor.add_)
|
||||
# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
|
||||
# input_tensor = input_tensor.data.add_(*args, **kwargs)
|
||||
# return input_tensor
|
||||
|
||||
# Tensor op
|
||||
register_elementwise_op(Tensor.abs)
|
||||
register_elementwise_op(Tensor.absolute)
|
||||
register_elementwise_op(Tensor.acos)
|
||||
register_elementwise_op(Tensor.arccos)
|
||||
register_elementwise_op(Tensor.angle)
|
||||
register_elementwise_op(Tensor.asin)
|
||||
register_elementwise_op(Tensor.arcsin)
|
||||
register_elementwise_op(Tensor.atan)
|
||||
register_elementwise_op(Tensor.arctan)
|
||||
register_elementwise_op(Tensor.all)
|
||||
register_elementwise_op(Tensor.any)
|
||||
register_elementwise_op(Tensor.bernoulli)
|
||||
register_elementwise_op(Tensor.bfloat16)
|
||||
register_elementwise_op(Tensor.bitwise_not)
|
||||
register_elementwise_op(Tensor.bool)
|
||||
register_elementwise_op(Tensor.byte)
|
||||
register_elementwise_op(Tensor.ceil)
|
||||
register_elementwise_op(Tensor.char)
|
||||
register_elementwise_op(Tensor.clamp)
|
||||
register_elementwise_op(Tensor.clamp_max)
|
||||
register_elementwise_op(Tensor.clamp_min)
|
||||
register_elementwise_op(Tensor.clip)
|
||||
register_elementwise_op(Tensor.clone)
|
||||
register_elementwise_op(Tensor.contiguous)
|
||||
register_elementwise_op(Tensor.copysign)
|
||||
register_elementwise_op(Tensor.cos)
|
||||
register_elementwise_op(Tensor.cosh)
|
||||
register_elementwise_op(Tensor.acosh)
|
||||
register_elementwise_op(Tensor.arccosh)
|
||||
register_elementwise_op(Tensor.cpu)
|
||||
register_elementwise_op(Tensor.cuda)
|
||||
register_elementwise_op(Tensor.deg2rad)
|
||||
register_elementwise_op(Tensor.detach)
|
||||
register_elementwise_op(Tensor.digamma)
|
||||
register_elementwise_op(Tensor.double)
|
||||
register_elementwise_op(Tensor.erf)
|
||||
register_elementwise_op(Tensor.erfc)
|
||||
register_elementwise_op(Tensor.erfinv)
|
||||
register_elementwise_op(Tensor.exp)
|
||||
register_elementwise_op(Tensor.expm1)
|
||||
register_elementwise_op(Tensor.fix)
|
||||
register_elementwise_op(Tensor.trunc)
|
||||
register_elementwise_op(Tensor.float)
|
||||
register_elementwise_op(Tensor.float_power)
|
||||
register_elementwise_op(Tensor.floor)
|
||||
register_elementwise_op(Tensor.frac)
|
||||
register_elementwise_op(Tensor.half)
|
||||
register_elementwise_op(Tensor.hardshrink)
|
||||
register_elementwise_op(Tensor.heaviside)
|
||||
register_elementwise_op(Tensor.i0)
|
||||
register_elementwise_op(Tensor.int)
|
||||
register_elementwise_op(Tensor.isfinite)
|
||||
register_elementwise_op(Tensor.isinf)
|
||||
register_elementwise_op(Tensor.isposinf)
|
||||
register_elementwise_op(Tensor.isneginf)
|
||||
register_elementwise_op(Tensor.isnan)
|
||||
register_elementwise_op(Tensor.lgamma)
|
||||
register_elementwise_op(Tensor.log)
|
||||
register_elementwise_op(Tensor.log10)
|
||||
register_elementwise_op(Tensor.log1p)
|
||||
register_elementwise_op(Tensor.log2)
|
||||
register_elementwise_op(Tensor.logical_not)
|
||||
register_elementwise_op(Tensor.logit)
|
||||
register_elementwise_op(Tensor.long)
|
||||
register_elementwise_op(Tensor.nan_to_num)
|
||||
register_elementwise_op(Tensor.neg)
|
||||
register_elementwise_op(Tensor.negative)
|
||||
register_elementwise_op(Tensor.positive)
|
||||
register_elementwise_op(Tensor.pow)
|
||||
register_elementwise_op(Tensor.rad2deg)
|
||||
register_elementwise_op(Tensor.reciprocal)
|
||||
register_elementwise_op(Tensor.round)
|
||||
register_elementwise_op(Tensor.rsqrt)
|
||||
register_elementwise_op(Tensor.short)
|
||||
register_elementwise_op(Tensor.sigmoid)
|
||||
register_elementwise_op(Tensor.sign)
|
||||
register_elementwise_op(Tensor.signbit)
|
||||
register_elementwise_op(Tensor.sgn)
|
||||
register_elementwise_op(Tensor.sin)
|
||||
register_elementwise_op(Tensor.sinc)
|
||||
register_elementwise_op(Tensor.sinh)
|
||||
register_elementwise_op(Tensor.asinh)
|
||||
register_elementwise_op(Tensor.arcsinh)
|
||||
register_elementwise_op(Tensor.sqrt)
|
||||
register_elementwise_op(Tensor.square)
|
||||
register_elementwise_op(Tensor.to)
|
||||
register_elementwise_op(Tensor.tan)
|
||||
register_elementwise_op(Tensor.tanh)
|
||||
register_elementwise_op(Tensor.atanh)
|
||||
register_elementwise_op(Tensor.arctanh)
|
||||
register_elementwise_op(Tensor.type)
|
||||
register_elementwise_op(Tensor.type_as)
|
||||
|
||||
# torch OP
|
||||
register_elementwise_op(torch.abs)
|
||||
register_elementwise_op(torch.absolute)
|
||||
register_elementwise_op(torch.acos)
|
||||
register_elementwise_op(torch.arccos)
|
||||
register_elementwise_op(torch.angle)
|
||||
register_elementwise_op(torch.asin)
|
||||
register_elementwise_op(torch.arcsin)
|
||||
register_elementwise_op(torch.atan)
|
||||
register_elementwise_op(torch.arctan)
|
||||
register_elementwise_op(torch.all)
|
||||
register_elementwise_op(torch.any)
|
||||
register_elementwise_op(torch.bernoulli)
|
||||
register_elementwise_op(torch.bitwise_not)
|
||||
register_elementwise_op(torch.ceil)
|
||||
register_elementwise_op(torch.clamp)
|
||||
register_elementwise_op(torch.clamp_max)
|
||||
register_elementwise_op(torch.clamp_min)
|
||||
register_elementwise_op(torch.clip)
|
||||
register_elementwise_op(torch.clone)
|
||||
register_elementwise_op(torch.copysign)
|
||||
register_elementwise_op(torch.cos)
|
||||
register_elementwise_op(torch.cosh)
|
||||
register_elementwise_op(torch.acosh)
|
||||
register_elementwise_op(torch.arccosh)
|
||||
register_elementwise_op(torch.deg2rad)
|
||||
register_elementwise_op(torch.digamma)
|
||||
register_elementwise_op(torch.erf)
|
||||
register_elementwise_op(torch.erfc)
|
||||
register_elementwise_op(torch.erfinv)
|
||||
register_elementwise_op(torch.exp)
|
||||
register_elementwise_op(torch.expm1)
|
||||
register_elementwise_op(torch.fix)
|
||||
register_elementwise_op(torch.trunc)
|
||||
register_elementwise_op(torch.float_power)
|
||||
register_elementwise_op(torch.floor)
|
||||
register_elementwise_op(torch.frac)
|
||||
register_elementwise_op(torch.hardshrink)
|
||||
register_elementwise_op(torch.heaviside)
|
||||
register_elementwise_op(torch.i0)
|
||||
register_elementwise_op(torch.isfinite)
|
||||
register_elementwise_op(torch.isinf)
|
||||
register_elementwise_op(torch.isposinf)
|
||||
register_elementwise_op(torch.isneginf)
|
||||
register_elementwise_op(torch.isnan)
|
||||
register_elementwise_op(torch.lgamma)
|
||||
register_elementwise_op(torch.log)
|
||||
register_elementwise_op(torch.log10)
|
||||
register_elementwise_op(torch.log1p)
|
||||
register_elementwise_op(torch.log2)
|
||||
register_elementwise_op(torch.logical_not)
|
||||
register_elementwise_op(torch.logit)
|
||||
register_elementwise_op(torch.nan_to_num)
|
||||
register_elementwise_op(torch.neg)
|
||||
register_elementwise_op(torch.negative)
|
||||
register_elementwise_op(torch.positive)
|
||||
register_elementwise_op(torch.pow)
|
||||
register_elementwise_op(torch.rad2deg)
|
||||
register_elementwise_op(torch.reciprocal)
|
||||
register_elementwise_op(torch.round)
|
||||
register_elementwise_op(torch.rsqrt)
|
||||
register_elementwise_op(torch.sigmoid)
|
||||
register_elementwise_op(torch.sign)
|
||||
register_elementwise_op(torch.signbit)
|
||||
register_elementwise_op(torch.sgn)
|
||||
register_elementwise_op(torch.sin)
|
||||
register_elementwise_op(torch.sinc)
|
||||
register_elementwise_op(torch.sinh)
|
||||
register_elementwise_op(torch.asinh)
|
||||
register_elementwise_op(torch.arcsinh)
|
||||
register_elementwise_op(torch.sqrt)
|
||||
register_elementwise_op(torch.square)
|
||||
register_elementwise_op(torch.tan)
|
||||
register_elementwise_op(torch.tanh)
|
||||
register_elementwise_op(torch.atanh)
|
||||
register_elementwise_op(torch.arctanh)
|
||||
register_elementwise_op(torch.zeros_like)
|
||||
|
||||
# nn.functional OP
|
||||
register_elementwise_op(F.threshold)
|
||||
register_elementwise_op(F.relu)
|
||||
register_elementwise_op(F.hardtanh)
|
||||
register_elementwise_op(F.hardswish)
|
||||
register_elementwise_op(F.relu6)
|
||||
register_elementwise_op(F.elu)
|
||||
register_elementwise_op(F.selu)
|
||||
register_elementwise_op(F.celu)
|
||||
register_elementwise_op(F.leaky_relu)
|
||||
register_elementwise_op(F.prelu)
|
||||
register_elementwise_op(F.rrelu)
|
||||
register_elementwise_op(F.gelu)
|
||||
register_elementwise_op(F.logsigmoid)
|
||||
register_elementwise_op(F.hardshrink)
|
||||
register_elementwise_op(F.tanhshrink)
|
||||
register_elementwise_op(F.softsign)
|
||||
register_elementwise_op(F.softplus)
|
||||
register_elementwise_op(F.softmin)
|
||||
register_elementwise_op(F.softmax)
|
||||
register_elementwise_op(F.softshrink)
|
||||
register_elementwise_op(F.gumbel_softmax)
|
||||
register_elementwise_op(F.log_softmax)
|
||||
register_elementwise_op(F.tanh)
|
||||
register_elementwise_op(F.sigmoid)
|
||||
register_elementwise_op(F.hardsigmoid)
|
||||
register_elementwise_op(F.silu)
|
||||
register_elementwise_op(F.mish)
|
||||
# TODO(ver217): dropout handles seed
|
||||
register_elementwise_op(F.dropout)
|
||||
register_elementwise_op(F.alpha_dropout)
|
||||
register_elementwise_op(F.feature_alpha_dropout)
|
@ -1,142 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
|
||||
|
||||
|
||||
def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
weight: ColoTensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
compute_spec = weight.compute_spec
|
||||
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||
weight: ColoTensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim]
|
||||
# get the index of current segment and mask other segments with 0
|
||||
|
||||
# get complete input tensor through all-gather
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
|
||||
num_embeddings_per_partition = weight.size_local(0)
|
||||
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||
|
||||
# build the mask.
|
||||
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
|
||||
# mask the input.
|
||||
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
|
||||
masked_input = input_tensor - vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
partial_output = F.embedding(masked_input,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
|
||||
# Mask the output embedding.
|
||||
partial_output[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(partial_output, weight.get_process_group())
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
|
||||
return output
|
||||
|
||||
|
||||
def colo_embedding_1d(mode: str,
|
||||
input_tensor: ColoTensor,
|
||||
weight: ColoTensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
assert mode in ('row', 'col')
|
||||
funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
|
||||
return funcs[mode](input_tensor,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
|
||||
|
||||
@colo_op_impl(F.embedding)
|
||||
def colo_embedding(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False):
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
|
||||
This method looks up an embedding table.
|
||||
"""
|
||||
assert isinstance(weight, ColoTensor)
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
||||
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
||||
return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse),
|
||||
spec=ColoTensorSpec(weight.get_process_group()))
|
||||
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.is_shard_1drow():
|
||||
mode = 'row'
|
||||
elif weight.is_shard_1dcol():
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return colo_embedding_1d(mode,
|
||||
input_tensor,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
else:
|
||||
raise NotImplementedError
|
@ -1,127 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
||||
weight: ColoTensor,
|
||||
offsets: Optional[Tensor] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2,
|
||||
scale_grad_by_freq: bool = False,
|
||||
mode: str = "mean",
|
||||
sparse: bool = False,
|
||||
per_sample_weights: Optional[Tensor] = None,
|
||||
include_last_offset: bool = False,
|
||||
padding_idx: Optional[int] = None) -> ColoTensor:
|
||||
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
pg = weight.get_process_group()
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
output_parallel = F.embedding_bag(input_tensor,
|
||||
weight,
|
||||
offsets=offsets,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
mode=mode,
|
||||
sparse=sparse,
|
||||
per_sample_weights=per_sample_weights,
|
||||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx)
|
||||
output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
if weight.compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def colo_embedding_bag_1d(tp_mode: str,
|
||||
input_tensor: ColoTensor,
|
||||
weight: ColoTensor,
|
||||
offsets: Optional[Tensor] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2,
|
||||
scale_grad_by_freq: bool = False,
|
||||
mode: str = "mean",
|
||||
sparse: bool = False,
|
||||
per_sample_weights: Optional[Tensor] = None,
|
||||
include_last_offset: bool = False,
|
||||
padding_idx: Optional[int] = None) -> ColoTensor:
|
||||
assert tp_mode in ('col',)
|
||||
funcs = {'col': colo_embedding_bag_1Dcol}
|
||||
return funcs[tp_mode](input_tensor,
|
||||
weight,
|
||||
offsets=offsets,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
mode=mode,
|
||||
sparse=sparse,
|
||||
per_sample_weights=per_sample_weights,
|
||||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx)
|
||||
|
||||
|
||||
@colo_op_impl(F.embedding_bag)
|
||||
def colo_embedding_bag(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
offsets: Optional[Tensor] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2,
|
||||
scale_grad_by_freq: bool = False,
|
||||
mode: str = "mean",
|
||||
sparse: bool = False,
|
||||
per_sample_weights: Optional[Tensor] = None,
|
||||
include_last_offset: bool = False,
|
||||
padding_idx: Optional[int] = None):
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
|
||||
This method looks up an embedding table.
|
||||
"""
|
||||
assert isinstance(weight, ColoTensor)
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
||||
|
||||
# Handle different parallel actions.
|
||||
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
||||
return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
|
||||
weight,
|
||||
offsets=offsets,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
mode=mode,
|
||||
sparse=sparse,
|
||||
per_sample_weights=per_sample_weights,
|
||||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx),
|
||||
spec=ColoTensorSpec(weight.get_process_group()))
|
||||
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.is_shard_1dcol():
|
||||
tp_mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return colo_embedding_bag_1d(tp_mode,
|
||||
input_tensor,
|
||||
weight,
|
||||
offsets=offsets,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
mode=mode,
|
||||
sparse=sparse,
|
||||
per_sample_weights=per_sample_weights,
|
||||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx)
|
||||
else:
|
||||
raise NotImplementedError
|
@ -1,28 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
@colo_op_impl(F.layer_norm)
|
||||
def colo_layernorm(
|
||||
input_tensor: GeneralTensor,
|
||||
normalized_shape: List[int],
|
||||
weight: Optional[GeneralTensor] = None,
|
||||
bias: Optional[GeneralTensor] = None,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
assert isinstance(weight, ColoTensor)
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
output = ColoTensor.from_torch_tensor(tensor=output,
|
||||
spec=ColoTensorSpec(pg=input_tensor.get_process_group(),
|
||||
dist_attr=input_tensor.dist_spec))
|
||||
return output
|
@ -1,171 +0,0 @@
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input
|
||||
|
||||
|
||||
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
pg = weight.get_process_group()
|
||||
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg)
|
||||
|
||||
# Output:P
|
||||
partial_output = F.linear(input_tensor, weight)
|
||||
# Reduce(Output)
|
||||
|
||||
output = reduce_input(partial_output, pg)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||
output = output + bias
|
||||
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
|
||||
return output
|
||||
|
||||
|
||||
def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
compute_spec = weight.compute_spec
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
output = ColoTensor.from_torch_tensor(output_parallel,
|
||||
spec=ColoTensorSpec(weight.get_process_group(),
|
||||
ShardSpec([-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D)))
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
assert mode in ('row', 'col')
|
||||
funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol}
|
||||
return funcs[mode](input_tensor, weight, bias)
|
||||
|
||||
|
||||
# @register_colo_graph(input_pos=[1], param_pos=[2, 3])
|
||||
def colo_linear_imp(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||
This method computes a linear.
|
||||
"""
|
||||
assert isinstance(weight, ColoTensor)
|
||||
pg = weight.get_process_group()
|
||||
assert pg
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, pg)
|
||||
bias = convert_to_colo_tensor(bias, pg)
|
||||
# input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||
|
||||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.is_replicate(), 'Invalid weight spec for native Linear op'
|
||||
assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op'
|
||||
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg))
|
||||
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()):
|
||||
mode = 'row'
|
||||
elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()):
|
||||
mode = 'col'
|
||||
else:
|
||||
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}")
|
||||
ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return ret_tensor
|
||||
|
||||
|
||||
def _new_colo_linear_imp(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
"""
|
||||
A tentative function to compute the distributed linear layer with the latest sharding spec.
|
||||
This function is subject to future change as the current sharding API is not stable.
|
||||
"""
|
||||
# get mesh info
|
||||
input_sharding_seq = input_tensor.sharding_spec.sharding_sequence
|
||||
weight_sharding_seq = weight.sharding_spec.sharding_sequence
|
||||
if bias is not None:
|
||||
bias_sharding_seq = bias.sharding_spec.sharding_sequence
|
||||
device_mesh = weight.sharding_spec.device_mesh
|
||||
pg_axis0 = weight.pg_axis0
|
||||
pg_axis1 = weight.pg_axis1
|
||||
|
||||
# the last dim of input should have the same spec as the first dim of weight
|
||||
# the weight is transposed, so we look at the second dimension
|
||||
assert input_sharding_seq[-1] == weight_sharding_seq[1]
|
||||
|
||||
if bias is not None:
|
||||
assert bias_sharding_seq[0] == weight_sharding_seq[0]
|
||||
|
||||
# compute the output sharding sequence
|
||||
# as weight is transposed, so we look at the first dimension
|
||||
output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1]
|
||||
output_shard_seq = deepcopy(output_shard_seq)
|
||||
|
||||
# TODO: add reduce grad logic
|
||||
|
||||
# handle column and row parallel linear
|
||||
# by reusing the implementation above
|
||||
out = F.linear(input_tensor, weight)
|
||||
|
||||
# run all reduce if necessary
|
||||
last_dim_spec = input_sharding_seq[-1]
|
||||
if last_dim_spec.is_replica:
|
||||
pass
|
||||
elif last_dim_spec.shard_list is not None:
|
||||
for dim in last_dim_spec.shard_list:
|
||||
if dim == 0:
|
||||
reduce_input(out, pg_axis0)
|
||||
elif dim == 1:
|
||||
reduce_input(out, pg_axis1)
|
||||
else:
|
||||
raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected")
|
||||
# add bias
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
# convert shard seq to partition dict
|
||||
output_partition_dict = {}
|
||||
for index, dim_spec in enumerate(output_shard_seq):
|
||||
if not dim_spec.is_replica:
|
||||
if index not in output_partition_dict:
|
||||
output_partition_dict[index] = []
|
||||
output_partition_dict[index].extend(dim_spec.shard_list)
|
||||
|
||||
entire_shape = out.shape
|
||||
output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict)
|
||||
ret_tensor = ColoTensor.from_torch_tensor(out)
|
||||
setattr(ret_tensor, 'sharding_spec', output_sharding_spec)
|
||||
return ret_tensor
|
||||
|
||||
|
||||
def _has_sharding_spec(tensor):
|
||||
"""
|
||||
A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is
|
||||
set as the attribute `sharding_spec` on a tensor.
|
||||
"""
|
||||
return hasattr(tensor, 'sharding_spec')
|
||||
|
||||
|
||||
@colo_op_impl(F.linear)
|
||||
def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
if _has_sharding_spec(weight):
|
||||
return _new_colo_linear_imp(input, weight, bias)
|
||||
else:
|
||||
return colo_linear_imp(input, weight, bias)
|
@ -1,51 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.legacy.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
@colo_op_impl(F.cross_entropy)
|
||||
def colo_cross_entropy(input_tensor: GeneralTensor,
|
||||
target: GeneralTensor,
|
||||
weight: Optional[GeneralTensor] = None,
|
||||
size_average: Optional[bool] = None,
|
||||
ignore_index: int = -100,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean",
|
||||
label_smoothing: float = 0.0):
|
||||
assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor)
|
||||
pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor)
|
||||
weight = convert_to_colo_tensor(weight, pg)
|
||||
target = convert_to_colo_tensor(target, pg)
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, pg)
|
||||
|
||||
if input_tensor.is_replicate(): # Input is gathered
|
||||
assert target.is_replicate() and (weight is None or weight.is_replicate()), \
|
||||
"Target tensor and weight tensor both should be complete"
|
||||
output = F.cross_entropy(input_tensor,
|
||||
target,
|
||||
weight=weight,
|
||||
size_average=size_average,
|
||||
ignore_index=ignore_index,
|
||||
reduce=reduce,
|
||||
reduction=reduction,
|
||||
label_smoothing=label_smoothing)
|
||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
||||
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
|
||||
if input_tensor.is_shard_1dcol():
|
||||
assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in"
|
||||
assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function"
|
||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor,
|
||||
target,
|
||||
process_group=input_tensor.process_group.tp_process_group())
|
||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
@ -1,96 +0,0 @@
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
|
||||
def _all_int(my_iter):
|
||||
return all(isinstance(i, int) for i in my_iter)
|
||||
|
||||
|
||||
def _get_valid_shape(shape):
|
||||
if isinstance(shape, list):
|
||||
if _all_int(shape):
|
||||
return tuple(shape)
|
||||
else:
|
||||
raise RuntimeError("expects type(int) but finds an other type")
|
||||
elif isinstance(shape, tuple):
|
||||
if _all_int(shape):
|
||||
return shape
|
||||
else:
|
||||
return _get_valid_shape(shape[0])
|
||||
else:
|
||||
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
|
||||
|
||||
|
||||
def _shape_infer(org_sp, tgt_sp):
|
||||
cnt = 0
|
||||
pos = 0
|
||||
for idx, dim in enumerate(tgt_sp):
|
||||
if dim < -1:
|
||||
raise RuntimeError("invalid shape dimension {}".format(dim))
|
||||
elif dim == -1:
|
||||
cnt += 1
|
||||
pos = idx
|
||||
|
||||
if cnt > 1:
|
||||
raise RuntimeError("only one dimension can be inferred")
|
||||
|
||||
org_prod = reduce(operator.mul, org_sp, 1)
|
||||
tgt_prod = reduce(operator.mul, tgt_sp, 1)
|
||||
|
||||
if cnt == 0:
|
||||
if org_prod != tgt_prod:
|
||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
||||
else:
|
||||
return tgt_sp
|
||||
elif org_prod % tgt_prod != 0:
|
||||
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
|
||||
|
||||
infer_dim = -(org_prod // tgt_prod)
|
||||
return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
|
||||
|
||||
|
||||
@colo_op_impl(torch.Tensor.view)
|
||||
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
|
||||
Changes the shape of the current tensor.
|
||||
"""
|
||||
assert isinstance(self, ColoTensor)
|
||||
# apply original `view` function for replicated colo tensors
|
||||
if self.is_replicate():
|
||||
return self.view(*shape)
|
||||
|
||||
cur_sp = self.size()
|
||||
org_sp = self.size_global()
|
||||
# parse the passed arguments
|
||||
tgt_sp = _get_valid_shape(shape)
|
||||
# get the correct shape from inference
|
||||
inf_sp = _shape_infer(org_sp, tgt_sp)
|
||||
|
||||
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
|
||||
new_shape = (cur_sp[0],) + tgt_sp[1:]
|
||||
res = self.view(*new_shape)
|
||||
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
|
||||
new_shape = tgt_sp[:-1] + (cur_sp[-1],)
|
||||
res = self.view(*new_shape)
|
||||
else:
|
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||
return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
|
||||
spec=ColoTensorSpec(self.get_process_group()))
|
||||
|
||||
return ColoTensor.from_torch_tensor(tensor=res,
|
||||
spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
|
||||
|
||||
|
||||
@colo_op_impl(torch.Tensor.size)
|
||||
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
|
||||
size = self.size_global()
|
||||
if dim is None:
|
||||
return size
|
||||
else:
|
||||
return size[dim]
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue