mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] refactor the pipeline module (#1087)
* [pipeline] refactor the pipeline module * polish codepull/1098/head
parent
bad5d4c0a1
commit
2b2dc1c86b
|
@ -1,12 +1,5 @@
|
||||||
from .builder import (build_schedule, build_lr_scheduler, build_model,
|
from .builder import build_from_config, build_from_registry, build_gradient_handler
|
||||||
build_optimizer, build_layer, build_loss, build_hooks,
|
|
||||||
build_dataset, build_transform, build_data_sampler,
|
|
||||||
build_gradient_handler, build_ophooks)
|
|
||||||
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
|
'build_gradient_handler', 'build_from_config', 'build_from_registry'
|
||||||
'build_layer', 'build_loss', 'build_hooks', 'build_dataset',
|
|
||||||
'build_transform', 'build_data_sampler', 'build_gradient_handler',
|
|
||||||
'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks'
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from collections.abc import Iterable
|
|
||||||
|
|
||||||
from colossalai.registry import *
|
from colossalai.registry import *
|
||||||
|
|
||||||
|
@ -64,84 +63,6 @@ def build_from_registry(config, registry: Registry):
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_layer(config):
|
|
||||||
"""Returns a layer object of :class:`nn.Module` constructed from `config`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``LAYERS``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`torch.nn.Module`
|
|
||||||
"""
|
|
||||||
return build_from_registry(config, LAYERS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_loss(config):
|
|
||||||
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
|
|
||||||
from `config`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``LOSSES``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`torch.nn.modules.loss._Loss`
|
|
||||||
"""
|
|
||||||
return build_from_registry(config, LOSSES)
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(config):
|
|
||||||
"""Returns a model object of :class:`nn.Module` constructed from `config`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``MODELS``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`torch.nn.Module`
|
|
||||||
"""
|
|
||||||
return build_from_registry(config, MODELS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(config):
|
|
||||||
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
|
|
||||||
from `config`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``DATASETS``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`torch.utils.data.Dataset`
|
|
||||||
"""
|
|
||||||
return build_from_registry(config, DATASETS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_optimizer(config, model):
|
|
||||||
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
|
|
||||||
'model' and 'params'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``OPTIMIZERS``.
|
|
||||||
model (:class:`nn.Module`): A model containing parameters for the optimizer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`torch.optim.Optimizer`
|
|
||||||
"""
|
|
||||||
config_ = config.copy()
|
|
||||||
config_['params'] = model.parameters()
|
|
||||||
return build_from_registry(config_, OPTIMIZERS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_gradient_handler(config, model, optimizer):
|
def build_gradient_handler(config, model, optimizer):
|
||||||
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
|
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
|
||||||
`model` and `optimizer`.
|
`model` and `optimizer`.
|
||||||
|
@ -160,100 +81,3 @@ def build_gradient_handler(config, model, optimizer):
|
||||||
config_['model'] = model
|
config_['model'] = model
|
||||||
config_['optimizer'] = optimizer
|
config_['optimizer'] = optimizer
|
||||||
return build_from_registry(config_, GRADIENT_HANDLER)
|
return build_from_registry(config_, GRADIENT_HANDLER)
|
||||||
|
|
||||||
|
|
||||||
def build_hooks(config, trainer):
|
|
||||||
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``HOOKS``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`colossalai.trainer.hooks.BaseHook`
|
|
||||||
"""
|
|
||||||
config_ = config.copy()
|
|
||||||
config_['trainer'] = trainer
|
|
||||||
return build_from_registry(config_, HOOKS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_ophooks(config):
|
|
||||||
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``OPHOOKS``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`colossalai.trainer.hooks.BaseOpHook`
|
|
||||||
"""
|
|
||||||
config_ = config.copy()
|
|
||||||
return build_from_registry(config_, OPHOOKS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_transform(config):
|
|
||||||
"""Returns a transformation object of :class:`torchvision.transforms` constructed
|
|
||||||
from `config`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``TRANSFORMS``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`torchvision.transforms`
|
|
||||||
"""
|
|
||||||
return build_from_registry(config, TRANSFORMS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_data_sampler(config, dataset):
|
|
||||||
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
|
|
||||||
constructed from `config`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``DATA_SAMPLERS``.
|
|
||||||
dataset (:class:`torch.utils.data.Dataset`): An object of
|
|
||||||
:class:`torch.utils.data.Dataset` containing information
|
|
||||||
used in the construction of the return object
|
|
||||||
Returns:
|
|
||||||
An object of :class:`colossalai.utils.data_sampler.BaseSampler`
|
|
||||||
"""
|
|
||||||
config_ = config.copy()
|
|
||||||
config_['dataset'] = dataset
|
|
||||||
return build_from_registry(config_, DATA_SAMPLERS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_lr_scheduler(config, optimizer):
|
|
||||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
|
||||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``lr_schedule``.
|
|
||||||
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing
|
|
||||||
parameters for the learning rate scheduler.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`torch.optim.lr_scheduler`
|
|
||||||
"""
|
|
||||||
config_ = config.copy()
|
|
||||||
config_['optimizer'] = optimizer
|
|
||||||
return build_from_registry(config_, LR_SCHEDULERS)
|
|
||||||
|
|
||||||
def build_schedule(config):
|
|
||||||
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict or :class:`colossalai.context.Config`): A python dict or
|
|
||||||
a :class:`colossalai.context.Config` object containing information
|
|
||||||
used in the construction of the ``Schedule``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object of :class:`colossalai.engine.schedule.BaseSchedule`
|
|
||||||
"""
|
|
||||||
return build_from_registry(config, SCHEDULE)
|
|
||||||
|
|
|
@ -1,266 +0,0 @@
|
||||||
import copy
|
|
||||||
import heapq
|
|
||||||
|
|
||||||
from colossalai.builder import build_model, build_layer
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def _binary_partition(weights, st, ed):
|
|
||||||
"""Returns the binary partition position of `weights`, given the start
|
|
||||||
position `st` and the end position `ed`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weights (list): A python list to be binary partitioned
|
|
||||||
st (int): the start position of the binary partition
|
|
||||||
ed (int): the end position of the binary partition
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: the binary partition position of `weights`
|
|
||||||
"""
|
|
||||||
w_sum = weights[ed - 1]
|
|
||||||
prefix = 0
|
|
||||||
if st > 0:
|
|
||||||
w_sum -= weights[st - 1]
|
|
||||||
prefix = weights[st - 1]
|
|
||||||
minimum = float("inf")
|
|
||||||
for idx in range(st + 1, ed):
|
|
||||||
front = weights[idx - 1] - prefix
|
|
||||||
diff = abs(w_sum - 2 * front)
|
|
||||||
if diff < minimum:
|
|
||||||
pos = idx
|
|
||||||
minimum = diff
|
|
||||||
|
|
||||||
return st, pos, ed
|
|
||||||
|
|
||||||
|
|
||||||
def _heap_addition(weights, intervals, add_cnt):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _heap_push(heap, st, ed):
|
|
||||||
value = weights[ed - 1]
|
|
||||||
if st > 0:
|
|
||||||
value -= weights[st - 1]
|
|
||||||
heapq.heappush(heap, (-value, st, ed))
|
|
||||||
|
|
||||||
ret_intervals = []
|
|
||||||
heap = []
|
|
||||||
|
|
||||||
for st, ed in intervals:
|
|
||||||
_heap_push(heap, st, ed)
|
|
||||||
|
|
||||||
while add_cnt > 0:
|
|
||||||
_, st, ed = heapq.heappop(heap)
|
|
||||||
if ed - st == 1:
|
|
||||||
ret_intervals.append((st, ed))
|
|
||||||
else:
|
|
||||||
l, m, r = _binary_partition(weights, st, ed)
|
|
||||||
_heap_push(heap, l, m)
|
|
||||||
_heap_push(heap, m, r)
|
|
||||||
add_cnt -= 1
|
|
||||||
|
|
||||||
while heap:
|
|
||||||
_, st, ed = heapq.heappop(heap)
|
|
||||||
ret_intervals.append((st, ed))
|
|
||||||
|
|
||||||
ret_intervals.sort()
|
|
||||||
return ret_intervals
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_partitions(weights, value):
|
|
||||||
prev = 0
|
|
||||||
prefix = 0
|
|
||||||
num_block = 0
|
|
||||||
intervals = []
|
|
||||||
|
|
||||||
for idx, w in enumerate(weights):
|
|
||||||
if weights[idx] - prefix > value:
|
|
||||||
intervals.append((prev, idx))
|
|
||||||
prev = idx
|
|
||||||
prefix = weights[idx - 1]
|
|
||||||
num_block += 1
|
|
||||||
|
|
||||||
intervals.append((prev, len(weights)))
|
|
||||||
return num_block + 1, intervals
|
|
||||||
|
|
||||||
|
|
||||||
def _binary_search(weights, num):
|
|
||||||
length = len(weights)
|
|
||||||
prefix = [1 if w == 0 else w for w in weights]
|
|
||||||
for i in range(1, length):
|
|
||||||
prefix[i] += prefix[i - 1]
|
|
||||||
|
|
||||||
lower_bound = max(weights)
|
|
||||||
upper_bound = prefix[length - 1]
|
|
||||||
|
|
||||||
while upper_bound > lower_bound:
|
|
||||||
mid = (upper_bound + lower_bound) // 2
|
|
||||||
number, _ = _calc_partitions(prefix, mid)
|
|
||||||
if number <= num:
|
|
||||||
upper_bound = mid
|
|
||||||
else:
|
|
||||||
lower_bound = mid + 1
|
|
||||||
|
|
||||||
num_block, intervals = _calc_partitions(prefix, upper_bound)
|
|
||||||
if num_block < num:
|
|
||||||
intervals = _heap_addition(prefix, intervals, num - num_block)
|
|
||||||
|
|
||||||
return intervals
|
|
||||||
|
|
||||||
|
|
||||||
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
|
||||||
assert num_items % num_chunks == 0, \
|
|
||||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
|
||||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
|
||||||
partition_items = num_items // num_chunks
|
|
||||||
for idx in range(num_chunks):
|
|
||||||
base_idx = idx * partition_items
|
|
||||||
chunk_size = partition_items // pipeline_parallel_size
|
|
||||||
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
|
||||||
if chunk_size == 0:
|
|
||||||
logger.warning("Some nodes in Pipeline have no requests")
|
|
||||||
|
|
||||||
for p in range(pipeline_parallel_size):
|
|
||||||
st = base_idx
|
|
||||||
base_idx += chunk_size + (p >= left)
|
|
||||||
parts[p].append((st, base_idx))
|
|
||||||
|
|
||||||
return parts
|
|
||||||
|
|
||||||
|
|
||||||
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
|
||||||
num_total = pipeline_parallel_size * num_chunks
|
|
||||||
num_items = len(weights)
|
|
||||||
if num_items <= num_total:
|
|
||||||
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
|
||||||
|
|
||||||
intervals = _binary_search(weights, num_total)
|
|
||||||
|
|
||||||
current = 0
|
|
||||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
|
||||||
for inter in intervals:
|
|
||||||
parts[current].append(inter)
|
|
||||||
current = (current + 1) % pipeline_parallel_size
|
|
||||||
|
|
||||||
return parts
|
|
||||||
|
|
||||||
|
|
||||||
def count_layer_params(layers):
|
|
||||||
"""Count the number of parameters in each layer
|
|
||||||
"""
|
|
||||||
param_counts = [0] * len(layers)
|
|
||||||
for idx, cfg in enumerate(layers):
|
|
||||||
layer = build_layer(cfg)
|
|
||||||
params = filter(lambda p: p.requires_grad, layer.parameters())
|
|
||||||
param_counts[idx] = sum(p.numel() for p in params)
|
|
||||||
|
|
||||||
return param_counts
|
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline_model_from_cfg(config,
|
|
||||||
num_chunks: int = 1,
|
|
||||||
partition_method: str = 'parameter',
|
|
||||||
verbose: bool = False):
|
|
||||||
"""An initializer to split the model into different stages for pipeline parallelism.
|
|
||||||
|
|
||||||
An example for the model config is shown below. The class VisionTransformerFromConfig should
|
|
||||||
inherit colossalai.nn.model.ModelFromConfig to allow this initializer to build model from a sequence
|
|
||||||
of layer configurations.
|
|
||||||
|
|
||||||
::
|
|
||||||
|
|
||||||
model_config = dict(
|
|
||||||
type='VisionTransformerFromConfig',
|
|
||||||
embedding_cfg=dict(...),
|
|
||||||
...
|
|
||||||
)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): Configuration of the model.
|
|
||||||
num_chunks (int, optional): The number of chunks you want to have on the current stage.
|
|
||||||
This value should be 1 in most cases unless you are using virtual pipeline parallelism.
|
|
||||||
partition_method (str, optional): This parameter determines how you want to split your model
|
|
||||||
layers into stages, you can set it as 'layer' or 'parameter'.
|
|
||||||
verbose (bool, optional): Whether to print the logs.
|
|
||||||
"""
|
|
||||||
ori_model = build_model(config)
|
|
||||||
layers = ori_model.layers_cfg
|
|
||||||
layer_length = len(layers)
|
|
||||||
logger = get_dist_logger()
|
|
||||||
if verbose:
|
|
||||||
logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
|
||||||
|
|
||||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
||||||
|
|
||||||
method = partition_method.lower()
|
|
||||||
# Make a partition
|
|
||||||
if method == 'layer':
|
|
||||||
num_layers = len(layers)
|
|
||||||
parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
|
|
||||||
elif method == 'parameter':
|
|
||||||
param_counts = count_layer_params(layers)
|
|
||||||
# print_rank_0(param_counts)
|
|
||||||
parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
|
|
||||||
else:
|
|
||||||
raise ValueError("Method should be a pre-set string in [layer, parameter]")
|
|
||||||
|
|
||||||
# Display the partition
|
|
||||||
if verbose:
|
|
||||||
log_str = 'Layer allocation after partitioning: \n'
|
|
||||||
for stage in range(pipeline_parallel_size):
|
|
||||||
|
|
||||||
num_layers = 0
|
|
||||||
for st, ed in parts[stage]:
|
|
||||||
num_layers += ed - st
|
|
||||||
|
|
||||||
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
|
|
||||||
for st, ed in parts[stage]:
|
|
||||||
for idx, layer in enumerate(layers[st:ed]):
|
|
||||||
log_str += f'\t{idx + st:2d}: {layer}\n'
|
|
||||||
logger.info(log_str, ranks=[0])
|
|
||||||
|
|
||||||
# Save the partition
|
|
||||||
interval = parts[pipeline_rank]
|
|
||||||
|
|
||||||
models = []
|
|
||||||
for st, ed in interval:
|
|
||||||
model = copy.deepcopy(ori_model)
|
|
||||||
model.build_from_cfg(st, ed)
|
|
||||||
models.append(model)
|
|
||||||
|
|
||||||
return nn.ModuleList(models) if len(models) > 1 else models[0]
|
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
|
|
||||||
"""An intializer to split the model into different stages for pipeline parallelism.
|
|
||||||
Note that `layer` must be `torch.nn.Sequential`.
|
|
||||||
Args:
|
|
||||||
layers (`torch.nn.Sequential`): Layers of model
|
|
||||||
num_chunks: The number of chunks you want to have on the current stage. This value should be 1
|
|
||||||
in most cases unless you are using virtual pipeline parallelism.
|
|
||||||
verbose (bool, optional): Whether to print the logs.
|
|
||||||
"""
|
|
||||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
||||||
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
|
||||||
module_list = []
|
|
||||||
for start, end in partitions[pipeline_rank]:
|
|
||||||
module_list.append(
|
|
||||||
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end],
|
|
||||||
*[nn.Identity() for _ in range(len(layers) - end)]))
|
|
||||||
if verbose:
|
|
||||||
logger = get_dist_logger()
|
|
||||||
logger.info(f'Total {len(layers)} layers', ranks=[0])
|
|
||||||
for rank, part in enumerate(partitions):
|
|
||||||
log_str = f'===== stage={rank} =====\n'
|
|
||||||
for chunk, (start, end) in enumerate(part):
|
|
||||||
log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
|
|
||||||
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
|
|
||||||
logger.info(log_str, ranks=[0])
|
|
||||||
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
|
|
|
@ -2,6 +2,5 @@ from .layer import *
|
||||||
from .loss import *
|
from .loss import *
|
||||||
from .lr_scheduler import *
|
from .lr_scheduler import *
|
||||||
from .metric import *
|
from .metric import *
|
||||||
from .model import *
|
|
||||||
from .optimizer import *
|
from .optimizer import *
|
||||||
from ._ops import *
|
from ._ops import *
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from .lambda_wrapper import LambdaWrapper
|
|
||||||
from .pipeline_wrapper import PipelineSharedModuleWrapper
|
from .pipeline_wrapper import PipelineSharedModuleWrapper
|
||||||
|
|
||||||
__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper']
|
__all__ = ['PipelineSharedModuleWrapper']
|
||||||
|
|
|
@ -1,36 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from colossalai.builder import build_layer
|
|
||||||
from colossalai.registry import LAYERS
|
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
|
||||||
class LambdaWrapper(nn.Module):
|
|
||||||
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func (``Callable``): User customed function.
|
|
||||||
layers_cfg (dict, optional): Config of layers, defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, func, layers_cfg: dict = None):
|
|
||||||
super().__init__()
|
|
||||||
self.func = func
|
|
||||||
self.layers = self._build_layers(layers_cfg)
|
|
||||||
|
|
||||||
def _build_layers(self, layers_cfg: dict):
|
|
||||||
if layers_cfg is None:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
layers = []
|
|
||||||
|
|
||||||
for cfg in layers_cfg:
|
|
||||||
layer = build_layer(cfg)
|
|
||||||
layers.append(layer)
|
|
||||||
return layers
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return self.func(self, *args, **kwargs)
|
|
|
@ -1,3 +0,0 @@
|
||||||
from .model_from_config import ModelFromConfig
|
|
||||||
|
|
||||||
__all__ = ['ModelFromConfig']
|
|
|
@ -1,37 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from colossalai.builder import build_layer
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFromConfig(nn.Module, ABC):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(ModelFromConfig, self).__init__()
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
self.layers_cfg = []
|
|
||||||
|
|
||||||
def build_from_cfg(self, start=None, end=None):
|
|
||||||
assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \
|
|
||||||
'spelling and if you have initialized this variable'
|
|
||||||
if start is None:
|
|
||||||
start = 0
|
|
||||||
if end is None:
|
|
||||||
end = len(self.layers_cfg)
|
|
||||||
for cfg in self.layers_cfg[start: end]:
|
|
||||||
layer = build_layer(cfg)
|
|
||||||
self.layers.append(layer)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def init_weights(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
|
||||||
keep_vars=False):
|
|
||||||
"""Use this function to override the state dict for
|
|
||||||
saving checkpoints."""
|
|
||||||
return self.state_dict(destination, prefix, keep_vars)
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .pipelinable import PipelinableContext, PipelinableModel
|
||||||
|
from .layer_sepc import LayerSpec
|
||||||
|
|
||||||
|
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
|
|
@ -0,0 +1,55 @@
|
||||||
|
import torch
|
||||||
|
from colossalai.utils.model.utils import call_to_str
|
||||||
|
|
||||||
|
class LayerSpec:
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, typename, *module_args, **module_kwargs):
|
||||||
|
self.typename = typename
|
||||||
|
self.module_args = module_args
|
||||||
|
self.module_kwargs = module_kwargs
|
||||||
|
self.children = None
|
||||||
|
self._param_count = 0
|
||||||
|
|
||||||
|
if not issubclass(typename, torch.nn.Module):
|
||||||
|
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_count(self):
|
||||||
|
return self._param_count
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
"""Build the stored specification."""
|
||||||
|
|
||||||
|
recovered_args = []
|
||||||
|
for obj in self.module_args:
|
||||||
|
if isinstance(obj, LayerSpec):
|
||||||
|
obj = obj.build()
|
||||||
|
recovered_args.append(obj)
|
||||||
|
recovered_args = tuple(recovered_args)
|
||||||
|
|
||||||
|
recovered_kwargs = {}
|
||||||
|
for k, v in self.module_kwargs.items():
|
||||||
|
if isinstance(v, LayerSpec):
|
||||||
|
v = v.build()
|
||||||
|
recovered_kwargs[k] = v
|
||||||
|
|
||||||
|
return self.typename(*recovered_args, **recovered_kwargs)
|
||||||
|
|
||||||
|
def set_children(self, children):
|
||||||
|
self.children = children
|
||||||
|
|
||||||
|
def count_params(self):
|
||||||
|
self._param_count = 0
|
||||||
|
layer = self.build()
|
||||||
|
for param in layer.parameters():
|
||||||
|
self._param_count += param.numel()
|
||||||
|
return self._param_count
|
||||||
|
|
||||||
|
def reset_param_count(self):
|
||||||
|
self._param_count = 0
|
|
@ -1,26 +1,34 @@
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str
|
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||||
from colossalai.builder.pipeline import partition_uniform, partition_balanced
|
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
|
||||||
from colossalai.nn.layer.utils import CheckpointModule
|
from colossalai.nn.layer.utils import CheckpointModule
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoParameter
|
||||||
|
from .layer_sepc import LayerSpec
|
||||||
|
|
||||||
|
|
||||||
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
"""
|
||||||
|
A context manager to split the model into pipeline stages.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, policy: str="balanced"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._layer_spec_dict = {}
|
self._layer_spec_dict = {}
|
||||||
self._root_children = None
|
self._root_children = None
|
||||||
self._model = None
|
self._model = None
|
||||||
self._layer_spec_list = []
|
self._layer_spec_list = []
|
||||||
self._func_dict = {}
|
self._func_dict = {}
|
||||||
self._policy = "balanced"
|
self._policy = policy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def policy(self):
|
def policy(self):
|
||||||
return self._policy
|
return self._policy
|
||||||
|
|
||||||
|
@policy.setter
|
||||||
|
def policy(self, policy: str):
|
||||||
|
self._policy = policy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers_count(self):
|
def layers_count(self):
|
||||||
return len(self._layer_spec_list)
|
return len(self._layer_spec_list)
|
||||||
|
@ -30,10 +38,9 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
return len(self._func_dict)
|
return len(self._func_dict)
|
||||||
|
|
||||||
def _pre_context_exec(self):
|
def _pre_context_exec(self):
|
||||||
"""
|
"""
|
||||||
The Callback function when entering the context
|
The Callback function when entering the context
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# reserve rng states
|
# reserve rng states
|
||||||
self.cpu_rng_state = torch.get_rng_state()
|
self.cpu_rng_state = torch.get_rng_state()
|
||||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
|
@ -52,35 +59,50 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
The function to call at the end of the constructor of each module.
|
The function to call at the end of the constructor of each module.
|
||||||
NOTE() The module may be passed to this function multiple times.
|
NOTE() The module may be passed to this function multiple times.
|
||||||
"""
|
"""
|
||||||
module_id = id(module)
|
# iterate over the positional arguments
|
||||||
|
# to check if an argument is a torch Module
|
||||||
|
# if found any torch Module, replace it with its layer spec
|
||||||
|
# for storage purpose
|
||||||
modified_args = []
|
modified_args = []
|
||||||
for obj in args:
|
for arg in args:
|
||||||
if issubclass(obj.__class__, torch.nn.modules.module.Module):
|
if isinstance(arg, torch.nn.Module):
|
||||||
obj = self._layer_spec_dict[id(obj)]
|
arg = self._layer_spec_dict[id(arg)]
|
||||||
modified_args.append(obj)
|
modified_args.append(arg)
|
||||||
|
|
||||||
|
# to the same for the keyword arguments
|
||||||
modified_kwargs = {}
|
modified_kwargs = {}
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if issubclass(v.__class__, torch.nn.modules.module.Module):
|
if isinstance(v, torch.nn.Module):
|
||||||
v = self._layer_spec_dict[id(v)]
|
v = self._layer_spec_dict[id(v)]
|
||||||
# (lyl)TODO: analyse ColoTensor as well
|
# (lyl)TODO: analyse ColoTensor as well
|
||||||
modified_kwargs[k] = v
|
modified_kwargs[k] = v
|
||||||
|
|
||||||
modified_args = tuple(modified_args)
|
# keep track of the module children
|
||||||
|
# as torch.nn.Module.__init__ is called from inner module to outer module,
|
||||||
|
# the final value of self._model will be the outermost model
|
||||||
|
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
|
||||||
|
# will be the ``ResNet`` object.
|
||||||
self._root_children = list(module.children())
|
self._root_children = list(module.children())
|
||||||
self._model = module
|
self._model = module
|
||||||
|
|
||||||
|
# store the children to keep the module hierarchy
|
||||||
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
|
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
|
||||||
layer_spec.set_children(module.children())
|
layer_spec.set_children(module.children())
|
||||||
|
|
||||||
|
# store the layer spec in this context
|
||||||
|
module_id = id(module)
|
||||||
self._layer_spec_dict[module_id] = layer_spec
|
self._layer_spec_dict[module_id] = layer_spec
|
||||||
|
|
||||||
|
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
|
||||||
name_list = []
|
name_list = []
|
||||||
for name, param in module.named_parameters():
|
for name, param in module.named_parameters():
|
||||||
if isinstance(param, ColoTensor):
|
if isinstance(param, ColoParameter):
|
||||||
continue
|
continue
|
||||||
name_list.append((name, param))
|
name_list.append((name, param))
|
||||||
|
|
||||||
for name, param in name_list:
|
for name, param in name_list:
|
||||||
delattr(module, name)
|
delattr(module, name)
|
||||||
setattr(module, name, ColoTensor.from_torch_tensor(param))
|
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
|
||||||
|
|
||||||
def to_layer_list(self, exec_seq=None):
|
def to_layer_list(self, exec_seq=None):
|
||||||
"""
|
"""
|
||||||
|
@ -100,7 +122,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
if id(module) == id(child_in_container):
|
if id(module) == id(child_in_container):
|
||||||
children_name.append(name)
|
children_name.append(name)
|
||||||
break
|
break
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self._layer_spec_list.append(layer_spec)
|
self._layer_spec_list.append(layer_spec)
|
||||||
for name, module in self._model.named_modules():
|
for name, module in self._model.named_modules():
|
||||||
|
@ -110,10 +131,16 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
front_funcs_list = []
|
front_funcs_list = []
|
||||||
|
named_modules = dict(self._model.named_modules())
|
||||||
for index, element in enumerate(exec_seq):
|
for index, element in enumerate(exec_seq):
|
||||||
if isinstance(element, str):
|
if isinstance(element, str):
|
||||||
module = dict(self._model.named_modules())[element]
|
assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
|
||||||
|
|
||||||
|
# get the layer spec based on the module ID
|
||||||
|
module = named_modules[element]
|
||||||
layer_spec = self._layer_spec_dict[id(module)]
|
layer_spec = self._layer_spec_dict[id(module)]
|
||||||
|
|
||||||
|
# check whether there are functions which should be executed before this module
|
||||||
if len(front_funcs_list) != 0:
|
if len(front_funcs_list) != 0:
|
||||||
func_key = (layer_spec, "front")
|
func_key = (layer_spec, "front")
|
||||||
if func_key not in self._func_dict:
|
if func_key not in self._func_dict:
|
||||||
|
@ -121,6 +148,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
for f in front_funcs_list:
|
for f in front_funcs_list:
|
||||||
self._func_dict[func_key].append(f)
|
self._func_dict[func_key].append(f)
|
||||||
front_funcs_list = []
|
front_funcs_list = []
|
||||||
|
|
||||||
func_key = (layer_spec, "behind")
|
func_key = (layer_spec, "behind")
|
||||||
self._layer_spec_list.append(layer_spec)
|
self._layer_spec_list.append(layer_spec)
|
||||||
elif isinstance(element, tuple) and element[1] == "front":
|
elif isinstance(element, tuple) and element[1] == "front":
|
||||||
|
@ -172,70 +200,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
return pipeline_model
|
return pipeline_model
|
||||||
|
|
||||||
def load_policy(self, policy):
|
|
||||||
self._policy = policy
|
|
||||||
|
|
||||||
|
|
||||||
def _build_kwargs_for_module(function, kw_dict):
|
|
||||||
"""
|
|
||||||
Generally, the first argument of module.forward is an input tensor come from the previous layer.
|
|
||||||
Therefore, we just filter the kwargs from second element of the dictionary.
|
|
||||||
"""
|
|
||||||
sig = inspect.signature(function)
|
|
||||||
if len(sig.parameters) <= 1:
|
|
||||||
return None
|
|
||||||
args_name_list = list(sig.parameters.keys())
|
|
||||||
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
|
|
||||||
return kw_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _build_kwargs_for_function(function, kw_dict):
|
|
||||||
sig = inspect.signature(function)
|
|
||||||
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
|
|
||||||
if len(kw_dict) == 0:
|
|
||||||
return None
|
|
||||||
return kw_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
|
||||||
"""
|
|
||||||
We suppose the callable object passed to to_layer_list method in two purpose:
|
|
||||||
a. use the callable object to modify input tensor, such as \
|
|
||||||
lambda x: torch.flatten(x, 1)
|
|
||||||
b. use the callable object to modify kwargs value, such as \
|
|
||||||
def foo(attention_mask=None):
|
|
||||||
if attention_mask is not None:
|
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
|
||||||
return attention_mask
|
|
||||||
"""
|
|
||||||
|
|
||||||
if kw_dict is not None:
|
|
||||||
rst = func(**kw_dict)
|
|
||||||
if isinstance(rst, tuple):
|
|
||||||
for i, k in enumerate(kw_dict.keys()):
|
|
||||||
kwargs[k] = rst[i]
|
|
||||||
else:
|
|
||||||
for k in kw_dict.keys():
|
|
||||||
kwargs[k] = rst
|
|
||||||
return input_tensor
|
|
||||||
return func(input_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def _exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
|
||||||
|
|
||||||
assert func_key in func_dict, f"{func_key} is not in the function_dict."
|
|
||||||
funcs_to_exec = func_dict[func_key]
|
|
||||||
if isinstance(funcs_to_exec, list):
|
|
||||||
for f in funcs_to_exec:
|
|
||||||
f_kwargs = _build_kwargs_for_function(f, kwargs)
|
|
||||||
input_tensor = _exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
|
|
||||||
else:
|
|
||||||
f_kwargs = _build_kwargs_for_function(funcs_to_exec, kwargs)
|
|
||||||
input_tensor = _exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
|
|
||||||
|
|
||||||
return input_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class PipelinableModel(torch.nn.Module):
|
class PipelinableModel(torch.nn.Module):
|
||||||
|
|
||||||
|
@ -250,16 +214,16 @@ class PipelinableModel(torch.nn.Module):
|
||||||
for module in self._module_list:
|
for module in self._module_list:
|
||||||
|
|
||||||
if id(module) in self._front_func_dict:
|
if id(module) in self._front_func_dict:
|
||||||
input_tensor = _exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
|
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
|
||||||
|
|
||||||
if isinstance(module, CheckpointModule):
|
if isinstance(module, CheckpointModule):
|
||||||
forward_func = module._forward
|
forward_func = module._forward
|
||||||
else:
|
else:
|
||||||
forward_func = module.forward
|
forward_func = module.forward
|
||||||
if input_tensor is None:
|
if input_tensor is None:
|
||||||
module_kwargs = _build_kwargs_for_function(forward_func, kwargs)
|
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
|
||||||
else:
|
else:
|
||||||
module_kwargs = _build_kwargs_for_module(forward_func, kwargs)
|
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
|
||||||
if module_kwargs is not None and input_tensor is not None:
|
if module_kwargs is not None and input_tensor is not None:
|
||||||
if isinstance(module, CheckpointModule):
|
if isinstance(module, CheckpointModule):
|
||||||
convert_kwargs_to_args = []
|
convert_kwargs_to_args = []
|
||||||
|
@ -288,57 +252,9 @@ class PipelinableModel(torch.nn.Module):
|
||||||
input_tensor = module(input_tensor)
|
input_tensor = module(input_tensor)
|
||||||
|
|
||||||
if id(module) in self._behind_func_dict:
|
if id(module) in self._behind_func_dict:
|
||||||
input_tensor = _exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||||
|
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
class LayerSpec:
|
|
||||||
|
|
||||||
def __init__(self, typename, *module_args, **module_kwargs):
|
|
||||||
self.typename = typename
|
|
||||||
self.module_args = module_args
|
|
||||||
self.module_kwargs = module_kwargs
|
|
||||||
self.children = None
|
|
||||||
self._param_count = 0
|
|
||||||
|
|
||||||
if not issubclass(typename, torch.nn.Module):
|
|
||||||
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def param_count(self):
|
|
||||||
return self._param_count
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
"""Build the stored specification."""
|
|
||||||
|
|
||||||
recovered_args = []
|
|
||||||
for obj in self.module_args:
|
|
||||||
if isinstance(obj, LayerSpec):
|
|
||||||
obj = obj.build()
|
|
||||||
recovered_args.append(obj)
|
|
||||||
recovered_args = tuple(recovered_args)
|
|
||||||
|
|
||||||
recovered_kwargs = {}
|
|
||||||
for k, v in self.module_kwargs.items():
|
|
||||||
if isinstance(v, LayerSpec):
|
|
||||||
v = v.build()
|
|
||||||
recovered_kwargs[k] = v
|
|
||||||
|
|
||||||
return self.typename(*recovered_args, **recovered_kwargs)
|
|
||||||
|
|
||||||
def set_children(self, children):
|
|
||||||
self.children = children
|
|
||||||
|
|
||||||
def count_params(self):
|
|
||||||
self._param_count = 0
|
|
||||||
layer = self.build()
|
|
||||||
for param in layer.parameters():
|
|
||||||
self._param_count += param.numel()
|
|
||||||
return self._param_count
|
|
||||||
|
|
||||||
def reset_param_count(self):
|
|
||||||
self._param_count = 0
|
|
|
@ -0,0 +1,207 @@
|
||||||
|
import heapq
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
def _binary_partition(weights: List, start: int, end: int):
|
||||||
|
"""Returns the binary partition position of `weights`, given the start
|
||||||
|
position `st` and the end position `ed`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (list): A python list to be binary partitioned
|
||||||
|
start (int): the start position of the binary partition
|
||||||
|
end (int): the end position of the binary partition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the binary partition position of `weights`
|
||||||
|
"""
|
||||||
|
w_sum = weights[end - 1]
|
||||||
|
prefix = 0
|
||||||
|
if start > 0:
|
||||||
|
w_sum -= weights[start - 1]
|
||||||
|
prefix = weights[start - 1]
|
||||||
|
minimum = float("inf")
|
||||||
|
for idx in range(start + 1, end):
|
||||||
|
front = weights[idx - 1] - prefix
|
||||||
|
diff = abs(w_sum - 2 * front)
|
||||||
|
if diff < minimum:
|
||||||
|
pos = idx
|
||||||
|
minimum = diff
|
||||||
|
|
||||||
|
return start, pos, end
|
||||||
|
|
||||||
|
|
||||||
|
def _heap_addition(weights: List, intervals: int, add_cnt: int):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _heap_push(heap, st, ed):
|
||||||
|
value = weights[ed - 1]
|
||||||
|
if st > 0:
|
||||||
|
value -= weights[st - 1]
|
||||||
|
heapq.heappush(heap, (-value, st, ed))
|
||||||
|
|
||||||
|
ret_intervals = []
|
||||||
|
heap = []
|
||||||
|
|
||||||
|
for st, ed in intervals:
|
||||||
|
_heap_push(heap, st, ed)
|
||||||
|
|
||||||
|
while add_cnt > 0:
|
||||||
|
_, st, ed = heapq.heappop(heap)
|
||||||
|
if ed - st == 1:
|
||||||
|
ret_intervals.append((st, ed))
|
||||||
|
else:
|
||||||
|
l, m, r = _binary_partition(weights, st, ed)
|
||||||
|
_heap_push(heap, l, m)
|
||||||
|
_heap_push(heap, m, r)
|
||||||
|
add_cnt -= 1
|
||||||
|
|
||||||
|
while heap:
|
||||||
|
_, st, ed = heapq.heappop(heap)
|
||||||
|
ret_intervals.append((st, ed))
|
||||||
|
|
||||||
|
ret_intervals.sort()
|
||||||
|
return ret_intervals
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_partitions(weights, value):
|
||||||
|
prev = 0
|
||||||
|
prefix = 0
|
||||||
|
num_block = 0
|
||||||
|
intervals = []
|
||||||
|
|
||||||
|
for idx, w in enumerate(weights):
|
||||||
|
if weights[idx] - prefix > value:
|
||||||
|
intervals.append((prev, idx))
|
||||||
|
prev = idx
|
||||||
|
prefix = weights[idx - 1]
|
||||||
|
num_block += 1
|
||||||
|
|
||||||
|
intervals.append((prev, len(weights)))
|
||||||
|
return num_block + 1, intervals
|
||||||
|
|
||||||
|
|
||||||
|
def _binary_search(weights, num):
|
||||||
|
length = len(weights)
|
||||||
|
prefix = [1 if w == 0 else w for w in weights]
|
||||||
|
for i in range(1, length):
|
||||||
|
prefix[i] += prefix[i - 1]
|
||||||
|
|
||||||
|
lower_bound = max(weights)
|
||||||
|
upper_bound = prefix[length - 1]
|
||||||
|
|
||||||
|
while upper_bound > lower_bound:
|
||||||
|
mid = (upper_bound + lower_bound) // 2
|
||||||
|
number, _ = _calc_partitions(prefix, mid)
|
||||||
|
if number <= num:
|
||||||
|
upper_bound = mid
|
||||||
|
else:
|
||||||
|
lower_bound = mid + 1
|
||||||
|
|
||||||
|
num_block, intervals = _calc_partitions(prefix, upper_bound)
|
||||||
|
if num_block < num:
|
||||||
|
intervals = _heap_addition(prefix, intervals, num - num_block)
|
||||||
|
|
||||||
|
return intervals
|
||||||
|
|
||||||
|
|
||||||
|
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||||
|
assert num_items % num_chunks == 0, \
|
||||||
|
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||||
|
partition_items = num_items // num_chunks
|
||||||
|
for idx in range(num_chunks):
|
||||||
|
base_idx = idx * partition_items
|
||||||
|
chunk_size = partition_items // pipeline_parallel_size
|
||||||
|
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
||||||
|
if chunk_size == 0:
|
||||||
|
logger.warning("Some nodes in Pipeline have no requests")
|
||||||
|
|
||||||
|
for p in range(pipeline_parallel_size):
|
||||||
|
st = base_idx
|
||||||
|
base_idx += chunk_size + (p >= left)
|
||||||
|
parts[p].append((st, base_idx))
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||||||
|
num_total = pipeline_parallel_size * num_chunks
|
||||||
|
num_items = len(weights)
|
||||||
|
if num_items <= num_total:
|
||||||
|
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
||||||
|
|
||||||
|
intervals = _binary_search(weights, num_total)
|
||||||
|
|
||||||
|
current = 0
|
||||||
|
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||||
|
for inter in intervals:
|
||||||
|
parts[current].append(inter)
|
||||||
|
current = (current + 1) % pipeline_parallel_size
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def build_kwargs_for_module(function, kw_dict):
|
||||||
|
"""
|
||||||
|
Generally, the first argument of module.forward is an input tensor come from the previous layer.
|
||||||
|
Therefore, we just filter the kwargs from second element of the dictionary.
|
||||||
|
"""
|
||||||
|
sig = inspect.signature(function)
|
||||||
|
if len(sig.parameters) <= 1:
|
||||||
|
return None
|
||||||
|
args_name_list = list(sig.parameters.keys())
|
||||||
|
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
|
||||||
|
return kw_dict
|
||||||
|
|
||||||
|
|
||||||
|
def build_kwargs_for_function(function, kw_dict):
|
||||||
|
sig = inspect.signature(function)
|
||||||
|
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
|
||||||
|
if len(kw_dict) == 0:
|
||||||
|
return None
|
||||||
|
return kw_dict
|
||||||
|
|
||||||
|
|
||||||
|
def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
||||||
|
"""
|
||||||
|
We suppose the callable object passed to to_layer_list method in two purpose:
|
||||||
|
a. use the callable object to modify input tensor, such as \
|
||||||
|
lambda x: torch.flatten(x, 1)
|
||||||
|
b. use the callable object to modify kwargs value, such as \
|
||||||
|
def foo(attention_mask=None):
|
||||||
|
if attention_mask is not None:
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
return attention_mask
|
||||||
|
"""
|
||||||
|
|
||||||
|
if kw_dict is not None:
|
||||||
|
rst = func(**kw_dict)
|
||||||
|
if isinstance(rst, tuple):
|
||||||
|
for i, k in enumerate(kw_dict.keys()):
|
||||||
|
kwargs[k] = rst[i]
|
||||||
|
else:
|
||||||
|
for k in kw_dict.keys():
|
||||||
|
kwargs[k] = rst
|
||||||
|
return input_tensor
|
||||||
|
return func(input_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
||||||
|
|
||||||
|
assert func_key in func_dict, f"{func_key} is not in the function_dict."
|
||||||
|
funcs_to_exec = func_dict[func_key]
|
||||||
|
if isinstance(funcs_to_exec, list):
|
||||||
|
for f in funcs_to_exec:
|
||||||
|
f_kwargs = build_kwargs_for_function(f, kwargs)
|
||||||
|
input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
|
||||||
|
else:
|
||||||
|
f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs)
|
||||||
|
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
|
||||||
|
|
||||||
|
return input_tensor
|
|
@ -6,7 +6,6 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from colossalai.context.config import Config
|
from colossalai.context.config import Config
|
||||||
from colossalai.builder import build_ophooks
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.cpu
|
@pytest.mark.cpu
|
||||||
|
|
|
@ -17,11 +17,14 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import CrossEntropyLoss
|
from colossalai.nn import CrossEntropyLoss
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.utils import is_using_pp, get_dataloader
|
from colossalai.utils import is_using_pp, get_dataloader
|
||||||
from colossalai.utils.model.pipelinable import PipelinableContext
|
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from torchvision.datasets import CIFAR10
|
||||||
from titans.dataloader.cifar10 import build_cifar
|
from torchvision.transforms import transforms
|
||||||
from titans.model.vit import vit_tiny_patch4_32
|
try:
|
||||||
|
from titans.model.vit import vit_tiny_patch4_32
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
NUM_EPOCHS = 60
|
NUM_EPOCHS = 60
|
||||||
|
@ -49,7 +52,14 @@ def run_trainer(rank, world_size, port):
|
||||||
|
|
||||||
# craete dataloaders
|
# craete dataloaders
|
||||||
root = Path(os.environ['DATA'])
|
root = Path(os.environ['DATA'])
|
||||||
train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, pad_if_needed=True, crop=32, resize=32)
|
transform_train = transforms.Compose([
|
||||||
|
transforms.RandomCrop(224, padding=4, pad_if_needed=True),
|
||||||
|
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
])
|
||||||
|
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
|
||||||
|
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
|
||||||
|
|
||||||
# create loss function
|
# create loss function
|
||||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.utils.model.pipelinable import PipelinableContext
|
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||||
|
|
||||||
from colossalai.testing import rerun_on_exception
|
from colossalai.testing import rerun_on_exception
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def run_pipelinable(rank):
|
||||||
model = MLP()
|
model = MLP()
|
||||||
|
|
||||||
assert pipelinable.policy == "balanced"
|
assert pipelinable.policy == "balanced"
|
||||||
pipelinable.load_policy("uniform")
|
pipelinable.policy = "uniform"
|
||||||
assert pipelinable.policy == "uniform"
|
assert pipelinable.policy == "uniform"
|
||||||
pipelinable.to_layer_list()
|
pipelinable.to_layer_list()
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
from .layers import *
|
|
||||||
from .resnet import VanillaResNet
|
|
|
@ -1,3 +0,0 @@
|
||||||
from .basic_block import ResNetBasicBlock
|
|
||||||
from .bottleneck import ResNetBottleneck
|
|
||||||
from .reslayer import ResLayer
|
|
|
@ -1,64 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from typing import Optional, Callable
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from colossalai.registry import LAYERS
|
|
||||||
from .conv import conv3x3
|
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
|
||||||
class ResNetBasicBlock(nn.Module):
|
|
||||||
"""Basic ResNet block
|
|
||||||
"""
|
|
||||||
expansion: int = 1
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inplanes: int,
|
|
||||||
planes: int,
|
|
||||||
stride: int = 1,
|
|
||||||
downsample: Optional[nn.Module] = None,
|
|
||||||
groups: int = 1,
|
|
||||||
base_width: int = 64,
|
|
||||||
dilation: int = 1,
|
|
||||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
if groups != 1 or base_width != 64:
|
|
||||||
raise ValueError(
|
|
||||||
'BasicBlock only supports groups=1 and base_width=64')
|
|
||||||
if dilation > 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Dilation > 1 not supported in BasicBlock")
|
|
||||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
|
||||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
||||||
self.bn1 = norm_layer(planes)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.conv2 = conv3x3(planes, planes)
|
|
||||||
self.bn2 = norm_layer(planes)
|
|
||||||
self.downsample = downsample
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
identity = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
identity = self.downsample(x)
|
|
||||||
|
|
||||||
out += identity
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
|
@ -1,69 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from typing import Optional, Callable
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from colossalai.registry import LAYERS
|
|
||||||
from .conv import conv3x3, conv1x1
|
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
|
||||||
class ResNetBottleneck(nn.Module):
|
|
||||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
|
||||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
|
||||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
|
||||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
|
||||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
|
||||||
|
|
||||||
expansion: int = 4
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inplanes: int,
|
|
||||||
planes: int,
|
|
||||||
stride: int = 1,
|
|
||||||
downsample: Optional[nn.Module] = None,
|
|
||||||
groups: int = 1,
|
|
||||||
base_width: int = 64,
|
|
||||||
dilation: int = 1,
|
|
||||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
width = int(planes * (base_width / 64.)) * groups
|
|
||||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
|
||||||
self.conv1 = conv1x1(inplanes, width)
|
|
||||||
self.bn1 = norm_layer(width)
|
|
||||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
|
||||||
self.bn2 = norm_layer(width)
|
|
||||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
|
||||||
self.bn3 = norm_layer(planes * self.expansion)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.downsample = downsample
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
identity = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
|
||||||
out = self.bn3(out)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
identity = self.downsample(x)
|
|
||||||
|
|
||||||
out += identity
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
|
@ -1,15 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
|
||||||
"""3x3 convolution with padding"""
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
|
||||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
|
||||||
|
|
||||||
|
|
||||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
|
||||||
"""1x1 convolution"""
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
|
@ -1,63 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from colossalai.registry import LAYERS
|
|
||||||
from .conv import conv1x1
|
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
|
||||||
class ResLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
block_type: str,
|
|
||||||
norm_layer_type: str,
|
|
||||||
inplanes: int,
|
|
||||||
planes: int,
|
|
||||||
blocks: int,
|
|
||||||
groups: int,
|
|
||||||
base_width: int,
|
|
||||||
stride: int = 1,
|
|
||||||
dilation: int = 1,
|
|
||||||
dilate: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.block = LAYERS.get_module(block_type)
|
|
||||||
self.norm_layer = LAYERS.get_module(norm_layer_type)
|
|
||||||
self.inplanes = inplanes
|
|
||||||
self.planes = planes
|
|
||||||
self.blocks = blocks
|
|
||||||
self.groups = groups
|
|
||||||
self.dilation = dilation
|
|
||||||
self.base_width = base_width
|
|
||||||
self.dilate = dilate
|
|
||||||
self.stride = stride
|
|
||||||
self.layer = self._make_layer()
|
|
||||||
|
|
||||||
def _make_layer(self):
|
|
||||||
norm_layer = self.norm_layer
|
|
||||||
downsample = None
|
|
||||||
previous_dilation = self.dilation
|
|
||||||
if self.dilate:
|
|
||||||
self.dilation *= self.stride
|
|
||||||
self.stride = 1
|
|
||||||
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
|
|
||||||
downsample = nn.Sequential(
|
|
||||||
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
|
|
||||||
norm_layer(self.planes * self.block.expansion),
|
|
||||||
)
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
|
|
||||||
self.base_width, previous_dilation, norm_layer))
|
|
||||||
self.inplanes = self.planes * self.block.expansion
|
|
||||||
for _ in range(1, self.blocks):
|
|
||||||
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
|
|
||||||
base_width=self.base_width, dilation=self.dilation,
|
|
||||||
norm_layer=norm_layer))
|
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layer(x)
|
|
|
@ -1,163 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from colossalai.registry import LAYERS
|
|
||||||
from colossalai.registry import MODELS
|
|
||||||
from colossalai.nn.model import ModelFromConfig
|
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module
|
|
||||||
class VanillaResNet(ModelFromConfig):
|
|
||||||
"""ResNet from
|
|
||||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_cls: int,
|
|
||||||
block_type: str,
|
|
||||||
layers: List[int],
|
|
||||||
norm_layer_type: str = 'BatchNorm2d',
|
|
||||||
in_channels: int = 3,
|
|
||||||
groups: int = 1,
|
|
||||||
width_per_group: int = 64,
|
|
||||||
zero_init_residual: bool = False,
|
|
||||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
|
||||||
dilations=(1, 1, 1, 1)
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.inplanes = 64
|
|
||||||
self.zero_init_residual = zero_init_residual
|
|
||||||
self.blocks = layers
|
|
||||||
self.block_expansion = LAYERS.get_module(block_type).expansion
|
|
||||||
self.dilations = dilations
|
|
||||||
self.reslayer_common_cfg = dict(
|
|
||||||
type='ResLayer',
|
|
||||||
block_type=block_type,
|
|
||||||
norm_layer_type=norm_layer_type,
|
|
||||||
groups=groups,
|
|
||||||
base_width=width_per_group
|
|
||||||
)
|
|
||||||
|
|
||||||
if replace_stride_with_dilation is None:
|
|
||||||
# each element in the tuple indicates if we should replace
|
|
||||||
# the 2x2 stride with a dilated convolution instead
|
|
||||||
replace_stride_with_dilation = [False, False, False]
|
|
||||||
|
|
||||||
if len(replace_stride_with_dilation) != 3:
|
|
||||||
raise ValueError("replace_stride_with_dilation should be None "
|
|
||||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
|
||||||
|
|
||||||
self.layers_cfg = [
|
|
||||||
# conv1
|
|
||||||
dict(type='Conv2d',
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=self.inplanes,
|
|
||||||
kernel_size=7,
|
|
||||||
stride=2,
|
|
||||||
padding=3,
|
|
||||||
bias=False),
|
|
||||||
# bn1
|
|
||||||
dict(
|
|
||||||
type=norm_layer_type,
|
|
||||||
num_features=self.inplanes
|
|
||||||
),
|
|
||||||
# relu
|
|
||||||
dict(
|
|
||||||
type='ReLU',
|
|
||||||
inplace=True
|
|
||||||
),
|
|
||||||
# maxpool
|
|
||||||
dict(
|
|
||||||
type='MaxPool2d',
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1
|
|
||||||
),
|
|
||||||
# layer 1
|
|
||||||
dict(
|
|
||||||
inplanes=self.inplanes,
|
|
||||||
planes=64,
|
|
||||||
blocks=self.blocks[0],
|
|
||||||
dilation=self.dilations[0],
|
|
||||||
**self.reslayer_common_cfg
|
|
||||||
),
|
|
||||||
# layer 2
|
|
||||||
dict(
|
|
||||||
inplanes=64 * self.block_expansion,
|
|
||||||
planes=128,
|
|
||||||
blocks=self.blocks[1],
|
|
||||||
stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[0],
|
|
||||||
dilation=self.dilations[1],
|
|
||||||
**self.reslayer_common_cfg
|
|
||||||
),
|
|
||||||
# layer 3
|
|
||||||
dict(
|
|
||||||
inplanes=128 * self.block_expansion,
|
|
||||||
planes=256,
|
|
||||||
blocks=layers[2],
|
|
||||||
stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[1],
|
|
||||||
dilation=self.dilations[2],
|
|
||||||
**self.reslayer_common_cfg
|
|
||||||
),
|
|
||||||
# layer 4
|
|
||||||
dict(
|
|
||||||
inplanes=256 * self.block_expansion,
|
|
||||||
planes=512,
|
|
||||||
blocks=layers[3], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[2],
|
|
||||||
dilation=self.dilations[3],
|
|
||||||
**self.reslayer_common_cfg
|
|
||||||
),
|
|
||||||
# avg pool
|
|
||||||
dict(
|
|
||||||
type='AdaptiveAvgPool2d',
|
|
||||||
output_size=(1, 1)
|
|
||||||
),
|
|
||||||
# flatten
|
|
||||||
dict(
|
|
||||||
type='LambdaWrapper',
|
|
||||||
func=lambda mod, x: torch.flatten(x, 1)
|
|
||||||
),
|
|
||||||
# linear
|
|
||||||
dict(
|
|
||||||
type='Linear',
|
|
||||||
in_features=512 * self.block_expansion,
|
|
||||||
out_features=num_cls
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
nn.init.kaiming_normal_(
|
|
||||||
m.weight, mode='fan_out', nonlinearity='relu')
|
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
||||||
nn.init.constant_(m.weight, 1)
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
# Zero-initialize the last BN in each residual branch,
|
|
||||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
|
||||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
|
||||||
if self.zero_init_residual:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
|
|
||||||
# type: ignore[arg-type]
|
|
||||||
nn.init.constant_(m.bn3.weight, 0)
|
|
||||||
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
|
|
||||||
# type: ignore[arg-type]
|
|
||||||
nn.init.constant_(m.bn2.weight, 0)
|
|
|
@ -1,21 +0,0 @@
|
||||||
import os
|
|
||||||
import model
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
|
||||||
IMG_SIZE = 224
|
|
||||||
DIM = 768
|
|
||||||
NUM_CLASSES = 10
|
|
||||||
NUM_ATTN_HEADS = 12
|
|
||||||
NUM_MICRO_BATCHES = 2
|
|
||||||
|
|
||||||
# resnet 18
|
|
||||||
model = dict(type='VanillaResNet',
|
|
||||||
block_type='ResNetBasicBlock',
|
|
||||||
layers=[2, 2, 2, 2],
|
|
||||||
num_cls=10)
|
|
||||||
|
|
||||||
parallel = dict(
|
|
||||||
pipeline=dict(size=4),
|
|
||||||
tensor=dict(size=1, mode=None)
|
|
||||||
)
|
|
|
@ -1,43 +0,0 @@
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
|
|
||||||
from colossalai.core import global_context
|
|
||||||
from colossalai.initialize import launch
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from functools import partial
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.testing import rerun_on_exception
|
|
||||||
|
|
||||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
|
||||||
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
|
|
||||||
|
|
||||||
|
|
||||||
def run_partition(rank, world_size, port):
|
|
||||||
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
logger = get_dist_logger()
|
|
||||||
logger.info('finished initialization')
|
|
||||||
|
|
||||||
# build model
|
|
||||||
model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True)
|
|
||||||
assert isinstance(model, torch.nn.Module)
|
|
||||||
logger.info('model is created')
|
|
||||||
|
|
||||||
global_context.destroy()
|
|
||||||
logger.info('training finished')
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
|
||||||
def test_partition():
|
|
||||||
world_size = 4
|
|
||||||
run_func = partial(run_partition, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_partition()
|
|
|
@ -8,27 +8,45 @@ from pathlib import Path
|
||||||
import colossalai
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.builder import build_pipeline_model_from_cfg
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine.schedule import PipelineSchedule
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.utils import free_port, get_dataloader, print_rank_0
|
from colossalai.utils import free_port, get_dataloader, print_rank_0
|
||||||
from colossalai.testing import rerun_on_exception
|
from colossalai.testing import rerun_on_exception
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
BATCH_SIZE = 4
|
|
||||||
|
|
||||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
BATCH_SIZE = 8
|
||||||
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
|
||||||
|
|
||||||
|
CONFIG=dict(
|
||||||
|
NUM_MICRO_BATCHES=2,
|
||||||
|
parallel = dict(
|
||||||
|
pipeline=dict(size=2),
|
||||||
|
tensor=dict(size=1, mode=None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def run_schedule(rank, world_size, port):
|
def run_schedule(rank, world_size, port):
|
||||||
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model = build_pipeline_model_from_cfg(gpc.config.model, 1)
|
model = resnet18(num_classes=10)
|
||||||
|
|
||||||
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||||||
|
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
|
||||||
|
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.flatten(x, 1)
|
||||||
|
|
||||||
|
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
|
||||||
|
|
||||||
print_rank_0('model is created')
|
print_rank_0('model is created')
|
||||||
|
|
||||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||||
|
@ -69,7 +87,7 @@ def run_schedule(rank, world_size, port):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||||
def test_pipeline_schedule():
|
def test_pipeline_schedule():
|
||||||
world_size = 4
|
world_size = 2
|
||||||
run_func = partial(run_schedule, world_size=world_size, port=free_port())
|
run_func = partial(run_schedule, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(model):
|
def build_pipeline(model):
|
||||||
from colossalai.builder.pipeline import partition_uniform
|
from colossalai.pipeline.utils import partition_uniform
|
||||||
|
|
||||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(model):
|
def build_pipeline(model):
|
||||||
from colossalai.builder.pipeline import partition_uniform
|
from colossalai.pipeline.utils import partition_uniform
|
||||||
|
|
||||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(model):
|
def build_pipeline(model):
|
||||||
from colossalai.builder.pipeline import partition_uniform
|
from colossalai.pipeline.utils import partition_uniform
|
||||||
|
|
||||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||||
|
|
||||||
|
|
||||||
def build_pipeline(model):
|
def build_pipeline(model):
|
||||||
from colossalai.builder.pipeline import partition_uniform
|
from colossalai.pipeline.utils import partition_uniform
|
||||||
|
|
||||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
|
Loading…
Reference in New Issue