[pipeline] refactor the pipeline module (#1087)

* [pipeline] refactor the pipeline module

* polish code
pull/1098/head
Frank Lee 2022-06-10 11:27:38 +08:00 committed by GitHub
parent bad5d4c0a1
commit 2b2dc1c86b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 366 additions and 1127 deletions

View File

@ -1,12 +1,5 @@
from .builder import (build_schedule, build_lr_scheduler, build_model,
build_optimizer, build_layer, build_loss, build_hooks,
build_dataset, build_transform, build_data_sampler,
build_gradient_handler, build_ophooks)
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
from .builder import build_from_config, build_from_registry, build_gradient_handler
__all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset',
'build_transform', 'build_data_sampler', 'build_gradient_handler',
'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks'
'build_gradient_handler', 'build_from_config', 'build_from_registry'
]

View File

@ -2,7 +2,6 @@
# -*- encoding: utf-8 -*-
import inspect
from collections.abc import Iterable
from colossalai.registry import *
@ -64,84 +63,6 @@ def build_from_registry(config, registry: Registry):
return obj
def build_layer(config):
"""Returns a layer object of :class:`nn.Module` constructed from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``LAYERS``.
Returns:
An object of :class:`torch.nn.Module`
"""
return build_from_registry(config, LAYERS)
def build_loss(config):
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``LOSSES``.
Returns:
An object of :class:`torch.nn.modules.loss._Loss`
"""
return build_from_registry(config, LOSSES)
def build_model(config):
"""Returns a model object of :class:`nn.Module` constructed from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``MODELS``.
Returns:
An object of :class:`torch.nn.Module`
"""
return build_from_registry(config, MODELS)
def build_dataset(config):
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``DATASETS``.
Returns:
An object of :class:`torch.utils.data.Dataset`
"""
return build_from_registry(config, DATASETS)
def build_optimizer(config, model):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``OPTIMIZERS``.
model (:class:`nn.Module`): A model containing parameters for the optimizer
Returns:
An object of :class:`torch.optim.Optimizer`
"""
config_ = config.copy()
config_['params'] = model.parameters()
return build_from_registry(config_, OPTIMIZERS)
def build_gradient_handler(config, model, optimizer):
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
`model` and `optimizer`.
@ -160,100 +81,3 @@ def build_gradient_handler(config, model, optimizer):
config_['model'] = model
config_['optimizer'] = optimizer
return build_from_registry(config_, GRADIENT_HANDLER)
def build_hooks(config, trainer):
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``HOOKS``.
Returns:
An object of :class:`colossalai.trainer.hooks.BaseHook`
"""
config_ = config.copy()
config_['trainer'] = trainer
return build_from_registry(config_, HOOKS)
def build_ophooks(config):
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``OPHOOKS``.
Returns:
An object of :class:`colossalai.trainer.hooks.BaseOpHook`
"""
config_ = config.copy()
return build_from_registry(config_, OPHOOKS)
def build_transform(config):
"""Returns a transformation object of :class:`torchvision.transforms` constructed
from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``TRANSFORMS``.
Returns:
An object of :class:`torchvision.transforms`
"""
return build_from_registry(config, TRANSFORMS)
def build_data_sampler(config, dataset):
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
constructed from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``DATA_SAMPLERS``.
dataset (:class:`torch.utils.data.Dataset`): An object of
:class:`torch.utils.data.Dataset` containing information
used in the construction of the return object
Returns:
An object of :class:`colossalai.utils.data_sampler.BaseSampler`
"""
config_ = config.copy()
config_['dataset'] = dataset
return build_from_registry(config_, DATA_SAMPLERS)
def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``lr_schedule``.
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing
parameters for the learning rate scheduler.
Returns:
An object of :class:`torch.optim.lr_scheduler`
"""
config_ = config.copy()
config_['optimizer'] = optimizer
return build_from_registry(config_, LR_SCHEDULERS)
def build_schedule(config):
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``Schedule``.
Returns:
An object of :class:`colossalai.engine.schedule.BaseSchedule`
"""
return build_from_registry(config, SCHEDULE)

View File

@ -1,266 +0,0 @@
import copy
import heapq
from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
import torch.nn as nn
def _binary_partition(weights, st, ed):
"""Returns the binary partition position of `weights`, given the start
position `st` and the end position `ed`.
Args:
weights (list): A python list to be binary partitioned
st (int): the start position of the binary partition
ed (int): the end position of the binary partition
Returns:
int: the binary partition position of `weights`
"""
w_sum = weights[ed - 1]
prefix = 0
if st > 0:
w_sum -= weights[st - 1]
prefix = weights[st - 1]
minimum = float("inf")
for idx in range(st + 1, ed):
front = weights[idx - 1] - prefix
diff = abs(w_sum - 2 * front)
if diff < minimum:
pos = idx
minimum = diff
return st, pos, ed
def _heap_addition(weights, intervals, add_cnt):
"""
"""
def _heap_push(heap, st, ed):
value = weights[ed - 1]
if st > 0:
value -= weights[st - 1]
heapq.heappush(heap, (-value, st, ed))
ret_intervals = []
heap = []
for st, ed in intervals:
_heap_push(heap, st, ed)
while add_cnt > 0:
_, st, ed = heapq.heappop(heap)
if ed - st == 1:
ret_intervals.append((st, ed))
else:
l, m, r = _binary_partition(weights, st, ed)
_heap_push(heap, l, m)
_heap_push(heap, m, r)
add_cnt -= 1
while heap:
_, st, ed = heapq.heappop(heap)
ret_intervals.append((st, ed))
ret_intervals.sort()
return ret_intervals
def _calc_partitions(weights, value):
prev = 0
prefix = 0
num_block = 0
intervals = []
for idx, w in enumerate(weights):
if weights[idx] - prefix > value:
intervals.append((prev, idx))
prev = idx
prefix = weights[idx - 1]
num_block += 1
intervals.append((prev, len(weights)))
return num_block + 1, intervals
def _binary_search(weights, num):
length = len(weights)
prefix = [1 if w == 0 else w for w in weights]
for i in range(1, length):
prefix[i] += prefix[i - 1]
lower_bound = max(weights)
upper_bound = prefix[length - 1]
while upper_bound > lower_bound:
mid = (upper_bound + lower_bound) // 2
number, _ = _calc_partitions(prefix, mid)
if number <= num:
upper_bound = mid
else:
lower_bound = mid + 1
num_block, intervals = _calc_partitions(prefix, upper_bound)
if num_block < num:
intervals = _heap_addition(prefix, intervals, num - num_block)
return intervals
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
logger = get_dist_logger()
parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks
for idx in range(num_chunks):
base_idx = idx * partition_items
chunk_size = partition_items // pipeline_parallel_size
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
if chunk_size == 0:
logger.warning("Some nodes in Pipeline have no requests")
for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
return parts
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
num_total = pipeline_parallel_size * num_chunks
num_items = len(weights)
if num_items <= num_total:
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
intervals = _binary_search(weights, num_total)
current = 0
parts = [[] for _ in range(pipeline_parallel_size)]
for inter in intervals:
parts[current].append(inter)
current = (current + 1) % pipeline_parallel_size
return parts
def count_layer_params(layers):
"""Count the number of parameters in each layer
"""
param_counts = [0] * len(layers)
for idx, cfg in enumerate(layers):
layer = build_layer(cfg)
params = filter(lambda p: p.requires_grad, layer.parameters())
param_counts[idx] = sum(p.numel() for p in params)
return param_counts
def build_pipeline_model_from_cfg(config,
num_chunks: int = 1,
partition_method: str = 'parameter',
verbose: bool = False):
"""An initializer to split the model into different stages for pipeline parallelism.
An example for the model config is shown below. The class VisionTransformerFromConfig should
inherit colossalai.nn.model.ModelFromConfig to allow this initializer to build model from a sequence
of layer configurations.
::
model_config = dict(
type='VisionTransformerFromConfig',
embedding_cfg=dict(...),
...
)
Args:
config (dict): Configuration of the model.
num_chunks (int, optional): The number of chunks you want to have on the current stage.
This value should be 1 in most cases unless you are using virtual pipeline parallelism.
partition_method (str, optional): This parameter determines how you want to split your model
layers into stages, you can set it as 'layer' or 'parameter'.
verbose (bool, optional): Whether to print the logs.
"""
ori_model = build_model(config)
layers = ori_model.layers_cfg
layer_length = len(layers)
logger = get_dist_logger()
if verbose:
logger.info(f"The total length of layers is {layer_length}", ranks=[0])
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
method = partition_method.lower()
# Make a partition
if method == 'layer':
num_layers = len(layers)
parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
elif method == 'parameter':
param_counts = count_layer_params(layers)
# print_rank_0(param_counts)
parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
else:
raise ValueError("Method should be a pre-set string in [layer, parameter]")
# Display the partition
if verbose:
log_str = 'Layer allocation after partitioning: \n'
for stage in range(pipeline_parallel_size):
num_layers = 0
for st, ed in parts[stage]:
num_layers += ed - st
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in parts[stage]:
for idx, layer in enumerate(layers[st:ed]):
log_str += f'\t{idx + st:2d}: {layer}\n'
logger.info(log_str, ranks=[0])
# Save the partition
interval = parts[pipeline_rank]
models = []
for st, ed in interval:
model = copy.deepcopy(ori_model)
model.build_from_cfg(st, ed)
models.append(model)
return nn.ModuleList(models) if len(models) > 1 else models[0]
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
"""An intializer to split the model into different stages for pipeline parallelism.
Note that `layer` must be `torch.nn.Sequential`.
Args:
layers (`torch.nn.Sequential`): Layers of model
num_chunks: The number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virtual pipeline parallelism.
verbose (bool, optional): Whether to print the logs.
"""
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = []
for start, end in partitions[pipeline_rank]:
module_list.append(
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)]))
if verbose:
logger = get_dist_logger()
logger.info(f'Total {len(layers)} layers', ranks=[0])
for rank, part in enumerate(partitions):
log_str = f'===== stage={rank} =====\n'
for chunk, (start, end) in enumerate(part):
log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]

View File

@ -2,6 +2,5 @@ from .layer import *
from .loss import *
from .lr_scheduler import *
from .metric import *
from .model import *
from .optimizer import *
from ._ops import *

View File

@ -1,4 +1,3 @@
from .lambda_wrapper import LambdaWrapper
from .pipeline_wrapper import PipelineSharedModuleWrapper
__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper']
__all__ = ['PipelineSharedModuleWrapper']

View File

@ -1,36 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class LambdaWrapper(nn.Module):
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them.
Args:
func (``Callable``): User customed function.
layers_cfg (dict, optional): Config of layers, defaults to None.
"""
def __init__(self, func, layers_cfg: dict = None):
super().__init__()
self.func = func
self.layers = self._build_layers(layers_cfg)
def _build_layers(self, layers_cfg: dict):
if layers_cfg is None:
return None
else:
layers = []
for cfg in layers_cfg:
layer = build_layer(cfg)
layers.append(layer)
return layers
def forward(self, *args, **kwargs):
return self.func(self, *args, **kwargs)

View File

@ -1,3 +0,0 @@
from .model_from_config import ModelFromConfig
__all__ = ['ModelFromConfig']

View File

@ -1,37 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.builder import build_layer
class ModelFromConfig(nn.Module, ABC):
def __init__(self):
super(ModelFromConfig, self).__init__()
self.layers = nn.ModuleList()
self.layers_cfg = []
def build_from_cfg(self, start=None, end=None):
assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \
'spelling and if you have initialized this variable'
if start is None:
start = 0
if end is None:
end = len(self.layers_cfg)
for cfg in self.layers_cfg[start: end]:
layer = build_layer(cfg)
self.layers.append(layer)
@abstractmethod
def init_weights(self):
pass
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)

View File

@ -0,0 +1,4 @@
from .pipelinable import PipelinableContext, PipelinableModel
from .layer_sepc import LayerSpec
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']

View File

@ -0,0 +1,55 @@
import torch
from colossalai.utils.model.utils import call_to_str
class LayerSpec:
"""
"""
def __init__(self, typename, *module_args, **module_kwargs):
self.typename = typename
self.module_args = module_args
self.module_kwargs = module_kwargs
self.children = None
self._param_count = 0
if not issubclass(typename, torch.nn.Module):
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
def __repr__(self):
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
@property
def param_count(self):
return self._param_count
def build(self):
"""Build the stored specification."""
recovered_args = []
for obj in self.module_args:
if isinstance(obj, LayerSpec):
obj = obj.build()
recovered_args.append(obj)
recovered_args = tuple(recovered_args)
recovered_kwargs = {}
for k, v in self.module_kwargs.items():
if isinstance(v, LayerSpec):
v = v.build()
recovered_kwargs[k] = v
return self.typename(*recovered_args, **recovered_kwargs)
def set_children(self, children):
self.children = children
def count_params(self):
self._param_count = 0
layer = self.build()
for param in layer.parameters():
self._param_count += param.numel()
return self._param_count
def reset_param_count(self):
self._param_count = 0

View File

@ -1,26 +1,34 @@
import torch
import inspect
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.builder.pipeline import partition_uniform, partition_balanced
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoTensor
from colossalai.tensor import ColoParameter
from .layer_sepc import LayerSpec
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
"""
A context manager to split the model into pipeline stages.
"""
def __init__(self):
def __init__(self, policy: str="balanced"):
super().__init__()
self._layer_spec_dict = {}
self._root_children = None
self._model = None
self._layer_spec_list = []
self._func_dict = {}
self._policy = "balanced"
self._policy = policy
@property
def policy(self):
return self._policy
@policy.setter
def policy(self, policy: str):
self._policy = policy
@property
def layers_count(self):
return len(self._layer_spec_list)
@ -33,7 +41,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
"""
The Callback function when entering the context
"""
# reserve rng states
self.cpu_rng_state = torch.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state()
@ -52,35 +59,50 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
module_id = id(module)
# iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = []
for obj in args:
if issubclass(obj.__class__, torch.nn.modules.module.Module):
obj = self._layer_spec_dict[id(obj)]
modified_args.append(obj)
for arg in args:
if isinstance(arg, torch.nn.Module):
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)
# to the same for the keyword arguments
modified_kwargs = {}
for k, v in kwargs.items():
if issubclass(v.__class__, torch.nn.modules.module.Module):
if isinstance(v, torch.nn.Module):
v = self._layer_spec_dict[id(v)]
# (lyl)TODO: analyse ColoTensor as well
modified_kwargs[k] = v
modified_args = tuple(modified_args)
# keep track of the module children
# as torch.nn.Module.__init__ is called from inner module to outer module,
# the final value of self._model will be the outermost model
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
# will be the ``ResNet`` object.
self._root_children = list(module.children())
self._model = module
# store the children to keep the module hierarchy
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
layer_spec.set_children(module.children())
# store the layer spec in this context
module_id = id(module)
self._layer_spec_dict[module_id] = layer_spec
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
name_list = []
for name, param in module.named_parameters():
if isinstance(param, ColoTensor):
if isinstance(param, ColoParameter):
continue
name_list.append((name, param))
for name, param in name_list:
delattr(module, name)
setattr(module, name, ColoTensor.from_torch_tensor(param))
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
def to_layer_list(self, exec_seq=None):
"""
@ -100,7 +122,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
if id(module) == id(child_in_container):
children_name.append(name)
break
else:
self._layer_spec_list.append(layer_spec)
for name, module in self._model.named_modules():
@ -110,10 +131,16 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
else:
front_funcs_list = []
named_modules = dict(self._model.named_modules())
for index, element in enumerate(exec_seq):
if isinstance(element, str):
module = dict(self._model.named_modules())[element]
assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
# get the layer spec based on the module ID
module = named_modules[element]
layer_spec = self._layer_spec_dict[id(module)]
# check whether there are functions which should be executed before this module
if len(front_funcs_list) != 0:
func_key = (layer_spec, "front")
if func_key not in self._func_dict:
@ -121,6 +148,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
for f in front_funcs_list:
self._func_dict[func_key].append(f)
front_funcs_list = []
func_key = (layer_spec, "behind")
self._layer_spec_list.append(layer_spec)
elif isinstance(element, tuple) and element[1] == "front":
@ -172,70 +200,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
return pipeline_model
def load_policy(self, policy):
self._policy = policy
def _build_kwargs_for_module(function, kw_dict):
"""
Generally, the first argument of module.forward is an input tensor come from the previous layer.
Therefore, we just filter the kwargs from second element of the dictionary.
"""
sig = inspect.signature(function)
if len(sig.parameters) <= 1:
return None
args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
return kw_dict
def _build_kwargs_for_function(function, kw_dict):
sig = inspect.signature(function)
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
if len(kw_dict) == 0:
return None
return kw_dict
def _exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
"""
We suppose the callable object passed to to_layer_list method in two purpose:
a. use the callable object to modify input tensor, such as \
lambda x: torch.flatten(x, 1)
b. use the callable object to modify kwargs value, such as \
def foo(attention_mask=None):
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
return attention_mask
"""
if kw_dict is not None:
rst = func(**kw_dict)
if isinstance(rst, tuple):
for i, k in enumerate(kw_dict.keys()):
kwargs[k] = rst[i]
else:
for k in kw_dict.keys():
kwargs[k] = rst
return input_tensor
return func(input_tensor)
def _exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
assert func_key in func_dict, f"{func_key} is not in the function_dict."
funcs_to_exec = func_dict[func_key]
if isinstance(funcs_to_exec, list):
for f in funcs_to_exec:
f_kwargs = _build_kwargs_for_function(f, kwargs)
input_tensor = _exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
else:
f_kwargs = _build_kwargs_for_function(funcs_to_exec, kwargs)
input_tensor = _exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
return input_tensor
class PipelinableModel(torch.nn.Module):
@ -250,16 +214,16 @@ class PipelinableModel(torch.nn.Module):
for module in self._module_list:
if id(module) in self._front_func_dict:
input_tensor = _exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
if isinstance(module, CheckpointModule):
forward_func = module._forward
else:
forward_func = module.forward
if input_tensor is None:
module_kwargs = _build_kwargs_for_function(forward_func, kwargs)
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
else:
module_kwargs = _build_kwargs_for_module(forward_func, kwargs)
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
if module_kwargs is not None and input_tensor is not None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
@ -288,57 +252,9 @@ class PipelinableModel(torch.nn.Module):
input_tensor = module(input_tensor)
if id(module) in self._behind_func_dict:
input_tensor = _exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
return input_tensor
class LayerSpec:
def __init__(self, typename, *module_args, **module_kwargs):
self.typename = typename
self.module_args = module_args
self.module_kwargs = module_kwargs
self.children = None
self._param_count = 0
if not issubclass(typename, torch.nn.Module):
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
def __repr__(self):
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
@property
def param_count(self):
return self._param_count
def build(self):
"""Build the stored specification."""
recovered_args = []
for obj in self.module_args:
if isinstance(obj, LayerSpec):
obj = obj.build()
recovered_args.append(obj)
recovered_args = tuple(recovered_args)
recovered_kwargs = {}
for k, v in self.module_kwargs.items():
if isinstance(v, LayerSpec):
v = v.build()
recovered_kwargs[k] = v
return self.typename(*recovered_args, **recovered_kwargs)
def set_children(self, children):
self.children = children
def count_params(self):
self._param_count = 0
layer = self.build()
for param in layer.parameters():
self._param_count += param.numel()
return self._param_count
def reset_param_count(self):
self._param_count = 0

View File

@ -0,0 +1,207 @@
import heapq
import inspect
from colossalai.logging import get_dist_logger
from typing import List
def _binary_partition(weights: List, start: int, end: int):
"""Returns the binary partition position of `weights`, given the start
position `st` and the end position `ed`.
Args:
weights (list): A python list to be binary partitioned
start (int): the start position of the binary partition
end (int): the end position of the binary partition
Returns:
int: the binary partition position of `weights`
"""
w_sum = weights[end - 1]
prefix = 0
if start > 0:
w_sum -= weights[start - 1]
prefix = weights[start - 1]
minimum = float("inf")
for idx in range(start + 1, end):
front = weights[idx - 1] - prefix
diff = abs(w_sum - 2 * front)
if diff < minimum:
pos = idx
minimum = diff
return start, pos, end
def _heap_addition(weights: List, intervals: int, add_cnt: int):
"""
"""
def _heap_push(heap, st, ed):
value = weights[ed - 1]
if st > 0:
value -= weights[st - 1]
heapq.heappush(heap, (-value, st, ed))
ret_intervals = []
heap = []
for st, ed in intervals:
_heap_push(heap, st, ed)
while add_cnt > 0:
_, st, ed = heapq.heappop(heap)
if ed - st == 1:
ret_intervals.append((st, ed))
else:
l, m, r = _binary_partition(weights, st, ed)
_heap_push(heap, l, m)
_heap_push(heap, m, r)
add_cnt -= 1
while heap:
_, st, ed = heapq.heappop(heap)
ret_intervals.append((st, ed))
ret_intervals.sort()
return ret_intervals
def _calc_partitions(weights, value):
prev = 0
prefix = 0
num_block = 0
intervals = []
for idx, w in enumerate(weights):
if weights[idx] - prefix > value:
intervals.append((prev, idx))
prev = idx
prefix = weights[idx - 1]
num_block += 1
intervals.append((prev, len(weights)))
return num_block + 1, intervals
def _binary_search(weights, num):
length = len(weights)
prefix = [1 if w == 0 else w for w in weights]
for i in range(1, length):
prefix[i] += prefix[i - 1]
lower_bound = max(weights)
upper_bound = prefix[length - 1]
while upper_bound > lower_bound:
mid = (upper_bound + lower_bound) // 2
number, _ = _calc_partitions(prefix, mid)
if number <= num:
upper_bound = mid
else:
lower_bound = mid + 1
num_block, intervals = _calc_partitions(prefix, upper_bound)
if num_block < num:
intervals = _heap_addition(prefix, intervals, num - num_block)
return intervals
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
logger = get_dist_logger()
parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks
for idx in range(num_chunks):
base_idx = idx * partition_items
chunk_size = partition_items // pipeline_parallel_size
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
if chunk_size == 0:
logger.warning("Some nodes in Pipeline have no requests")
for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
return parts
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
num_total = pipeline_parallel_size * num_chunks
num_items = len(weights)
if num_items <= num_total:
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
intervals = _binary_search(weights, num_total)
current = 0
parts = [[] for _ in range(pipeline_parallel_size)]
for inter in intervals:
parts[current].append(inter)
current = (current + 1) % pipeline_parallel_size
return parts
def build_kwargs_for_module(function, kw_dict):
"""
Generally, the first argument of module.forward is an input tensor come from the previous layer.
Therefore, we just filter the kwargs from second element of the dictionary.
"""
sig = inspect.signature(function)
if len(sig.parameters) <= 1:
return None
args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
return kw_dict
def build_kwargs_for_function(function, kw_dict):
sig = inspect.signature(function)
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
if len(kw_dict) == 0:
return None
return kw_dict
def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
"""
We suppose the callable object passed to to_layer_list method in two purpose:
a. use the callable object to modify input tensor, such as \
lambda x: torch.flatten(x, 1)
b. use the callable object to modify kwargs value, such as \
def foo(attention_mask=None):
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
return attention_mask
"""
if kw_dict is not None:
rst = func(**kw_dict)
if isinstance(rst, tuple):
for i, k in enumerate(kw_dict.keys()):
kwargs[k] = rst[i]
else:
for k in kw_dict.keys():
kwargs[k] = rst
return input_tensor
return func(input_tensor)
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
assert func_key in func_dict, f"{func_key} is not in the function_dict."
funcs_to_exec = func_dict[func_key]
if isinstance(funcs_to_exec, list):
for f in funcs_to_exec:
f_kwargs = build_kwargs_for_function(f, kwargs)
input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
else:
f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs)
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
return input_tensor

View File

@ -6,7 +6,6 @@ from pathlib import Path
import pytest
from colossalai.context.config import Config
from colossalai.builder import build_ophooks
@pytest.mark.cpu

View File

@ -17,11 +17,14 @@ from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.pipeline.pipelinable import PipelinableContext
from tqdm import tqdm
from titans.dataloader.cifar10 import build_cifar
from titans.model.vit import vit_tiny_patch4_32
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
try:
from titans.model.vit import vit_tiny_patch4_32
except:
pass
BATCH_SIZE = 4
NUM_EPOCHS = 60
@ -49,7 +52,14 @@ def run_trainer(rank, world_size, port):
# craete dataloaders
root = Path(os.environ['DATA'])
train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, pad_if_needed=True, crop=32, resize=32)
transform_train = transforms.Compose([
transforms.RandomCrop(224, padding=4, pad_if_needed=True),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)

View File

@ -1,7 +1,7 @@
import torch
import torch.multiprocessing as mp
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.testing import rerun_on_exception
@ -33,7 +33,7 @@ def run_pipelinable(rank):
model = MLP()
assert pipelinable.policy == "balanced"
pipelinable.load_policy("uniform")
pipelinable.policy = "uniform"
assert pipelinable.policy == "uniform"
pipelinable.to_layer_list()

View File

@ -1,2 +0,0 @@
from .layers import *
from .resnet import VanillaResNet

View File

@ -1,3 +0,0 @@
from .basic_block import ResNetBasicBlock
from .bottleneck import ResNetBottleneck
from .reslayer import ResLayer

View File

@ -1,64 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3
@LAYERS.register_module
class ResNetBasicBlock(nn.Module):
"""Basic ResNet block
"""
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

View File

@ -1,69 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3, conv1x1
@LAYERS.register_module
class ResNetBottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out

View File

@ -1,15 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

View File

@ -1,63 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.registry import LAYERS
from .conv import conv1x1
@LAYERS.register_module
class ResLayer(nn.Module):
def __init__(self,
block_type: str,
norm_layer_type: str,
inplanes: int,
planes: int,
blocks: int,
groups: int,
base_width: int,
stride: int = 1,
dilation: int = 1,
dilate: bool = False,
):
super().__init__()
self.block = LAYERS.get_module(block_type)
self.norm_layer = LAYERS.get_module(norm_layer_type)
self.inplanes = inplanes
self.planes = planes
self.blocks = blocks
self.groups = groups
self.dilation = dilation
self.base_width = base_width
self.dilate = dilate
self.stride = stride
self.layer = self._make_layer()
def _make_layer(self):
norm_layer = self.norm_layer
downsample = None
previous_dilation = self.dilation
if self.dilate:
self.dilation *= self.stride
self.stride = 1
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
norm_layer(self.planes * self.block.expansion),
)
layers = []
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = self.planes * self.block.expansion
for _ in range(1, self.blocks):
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)

View File

@ -1,163 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from colossalai.registry import MODELS
from colossalai.nn.model import ModelFromConfig
@MODELS.register_module
class VanillaResNet(ModelFromConfig):
"""ResNet from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""
def __init__(
self,
num_cls: int,
block_type: str,
layers: List[int],
norm_layer_type: str = 'BatchNorm2d',
in_channels: int = 3,
groups: int = 1,
width_per_group: int = 64,
zero_init_residual: bool = False,
replace_stride_with_dilation: Optional[List[bool]] = None,
dilations=(1, 1, 1, 1)
) -> None:
super().__init__()
self.inplanes = 64
self.zero_init_residual = zero_init_residual
self.blocks = layers
self.block_expansion = LAYERS.get_module(block_type).expansion
self.dilations = dilations
self.reslayer_common_cfg = dict(
type='ResLayer',
block_type=block_type,
norm_layer_type=norm_layer_type,
groups=groups,
base_width=width_per_group
)
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.layers_cfg = [
# conv1
dict(type='Conv2d',
in_channels=in_channels,
out_channels=self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False),
# bn1
dict(
type=norm_layer_type,
num_features=self.inplanes
),
# relu
dict(
type='ReLU',
inplace=True
),
# maxpool
dict(
type='MaxPool2d',
kernel_size=3,
stride=2,
padding=1
),
# layer 1
dict(
inplanes=self.inplanes,
planes=64,
blocks=self.blocks[0],
dilation=self.dilations[0],
**self.reslayer_common_cfg
),
# layer 2
dict(
inplanes=64 * self.block_expansion,
planes=128,
blocks=self.blocks[1],
stride=2,
dilate=replace_stride_with_dilation[0],
dilation=self.dilations[1],
**self.reslayer_common_cfg
),
# layer 3
dict(
inplanes=128 * self.block_expansion,
planes=256,
blocks=layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
dilation=self.dilations[2],
**self.reslayer_common_cfg
),
# layer 4
dict(
inplanes=256 * self.block_expansion,
planes=512,
blocks=layers[3], stride=2,
dilate=replace_stride_with_dilation[2],
dilation=self.dilations[3],
**self.reslayer_common_cfg
),
# avg pool
dict(
type='AdaptiveAvgPool2d',
output_size=(1, 1)
),
# flatten
dict(
type='LambdaWrapper',
func=lambda mod, x: torch.flatten(x, 1)
),
# linear
dict(
type='Linear',
in_features=512 * self.block_expansion,
out_features=num_cls
)
]
def forward(self, x: Tensor):
for layer in self.layers:
x = layer(x)
return x
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
# type: ignore[arg-type]
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
# type: ignore[arg-type]
nn.init.constant_(m.bn2.weight, 0)

View File

@ -1,21 +0,0 @@
import os
import model
from pathlib import Path
BATCH_SIZE = 128
IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
NUM_MICRO_BATCHES = 2
# resnet 18
model = dict(type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[2, 2, 2, 2],
num_cls=10)
parallel = dict(
pipeline=dict(size=4),
tensor=dict(size=1, mode=None)
)

View File

@ -1,43 +0,0 @@
import os.path as osp
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
from colossalai.core import global_context
from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from functools import partial
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
def run_partition(rank, world_size, port):
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
logger.info('finished initialization')
# build model
model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True)
assert isinstance(model, torch.nn.Module)
logger.info('model is created')
global_context.destroy()
logger.info('training finished')
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_partition():
world_size = 4
run_func = partial(run_partition, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_partition()

View File

@ -8,27 +8,45 @@ from pathlib import Path
import colossalai
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule
from colossalai.context import ParallelMode
from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0
from colossalai.testing import rerun_on_exception
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 4
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
BATCH_SIZE = 8
CONFIG=dict(
NUM_MICRO_BATCHES=2,
parallel = dict(
pipeline=dict(size=2),
tensor=dict(size=1, mode=None)
)
)
def run_schedule(rank, world_size, port):
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model
model = build_pipeline_model_from_cfg(gpc.config.model, 1)
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
print_rank_0('model is created')
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
@ -69,7 +87,7 @@ def run_schedule(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_pipeline_schedule():
world_size = 4
world_size = 2
run_func = partial(run_schedule, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
from colossalai.builder.pipeline import partition_uniform
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)

View File

@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
from colossalai.builder.pipeline import partition_uniform
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)

View File

@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
from colossalai.builder.pipeline import partition_uniform
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)

View File

@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
from colossalai.builder.pipeline import partition_uniform
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)